checker: harden HTTP client, cap response size, drop dead legacy rule

This commit is contained in:
nemunaire 2026-04-26 17:12:01 +07:00
commit 2710dfb459
9 changed files with 407 additions and 118 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
checker-zonemaster
checker-zonemaster.so

View file

@ -55,6 +55,20 @@ the running checker-zonemaster server (e.g.,
`http://checker-zonemaster:8080`). happyDomain will delegate observation
collection to this endpoint.
### Deployment
The `/collect` endpoint has no built-in authentication and will issue
JSON-RPC calls to whatever Zonemaster API URL is configured via the
`zonemasterAPIURL` admin option (defaulting to the official public API
at `https://zonemaster.net/api`). Operators should point this option
only at trusted Zonemaster instances; pointing it at an untrusted host
turns the checker into an SSRF vector, since responses are parsed and
surfaced back to the caller. The checker itself is meant to run on a
trusted network, reachable only by the happyDomain instance that drives
it. Restrict access via a reverse proxy with authentication, a network
ACL, or by binding the listener to a private interface; do not expose
it directly to the public internet.
## Options
| Scope | Id | Description |

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
@ -13,6 +14,37 @@ import (
sdk "git.happydns.org/checker-sdk-go/checker"
)
// maxResponseBytes caps the body size we'll read from the Zonemaster API.
// Real result payloads are tens to a few hundred KB; 8 MiB is generous head-
// room and still bounded so a misbehaving or hostile endpoint can't exhaust
// memory.
const maxResponseBytes = 8 << 20
// maxCollectDuration caps the total time spent collecting (start + poll +
// fetch). The caller's context still wins if it has a tighter deadline.
const maxCollectDuration = 15 * time.Minute
// pollInterval is how often we ask the Zonemaster API for test progress.
const pollInterval = 2 * time.Second
// zmHTTPClient is the HTTP client used for all Zonemaster API calls. It has
// per-phase timeouts so a stalling endpoint can never hang us indefinitely
// even if the caller passes a context without a deadline.
var zmHTTPClient = &http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 90 * time.Second,
},
}
func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
domainName, ok := opts["domainName"].(string)
if !ok || domainName == "" {
@ -36,6 +68,11 @@ func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOption
profile = prof
}
// Cap the total collection time even when the caller's context has no
// deadline. The caller's deadline still wins if it's tighter.
ctx, cancel := context.WithTimeout(ctx, maxCollectDuration)
defer cancel()
// Step 1: start the test.
startResult, err := zmCallJSONRPC(ctx, apiURL, "start_domain_test", zmStartTestParams{
Domain: domainName,
@ -56,9 +93,10 @@ func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOption
}
// Step 2: poll for completion.
ticker := time.NewTicker(2 * time.Second)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
poll:
for {
select {
case <-ctx.Done():
@ -75,12 +113,11 @@ func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOption
}
if progress >= 100 {
goto testComplete
break poll
}
}
}
testComplete:
// Step 3: fetch results.
rawResults, err := zmCallJSONRPC(ctx, apiURL, "get_test_results", zmGetResultsParams{
ID: testID,
@ -117,19 +154,37 @@ func zmCallJSONRPC(ctx context.Context, apiURL, method string, params any) (json
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := zmHTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call API: %w", err)
}
defer resp.Body.Close()
// Cap the body we'll ever read so a misbehaving endpoint can't exhaust
// memory. +1 lets us detect that the cap was hit.
limited := io.LimitReader(resp.Body, maxResponseBytes+1)
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
b, readErr := io.ReadAll(limited)
if readErr != nil {
return nil, fmt.Errorf("API returned status %d (failed to read body: %v)", resp.StatusCode, readErr)
}
if len(b) > maxResponseBytes {
b = b[:maxResponseBytes]
}
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(b))
}
body, readErr := io.ReadAll(limited)
if readErr != nil {
return nil, fmt.Errorf("failed to read response: %w", readErr)
}
if len(body) > maxResponseBytes {
return nil, fmt.Errorf("API response exceeds %d bytes", maxResponseBytes)
}
var rpcResp zmJSONRPCResponse
if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil {
if err := json.Unmarshal(body, &rpcResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

View file

@ -17,7 +17,7 @@ import (
var Version = "built-in"
// Definition returns the CheckerDefinition for the zonemaster checker.
func Definition() *sdk.CheckerDefinition {
func (p *zonemasterProvider) Definition() *sdk.CheckerDefinition {
return &sdk.CheckerDefinition{
ID: "zonemaster",
Name: "Zonemaster",
@ -75,13 +75,11 @@ func Definition() *sdk.CheckerDefinition {
},
},
},
Rules: []sdk.CheckRule{
Rule(),
},
Rules: Rules(),
Interval: &sdk.CheckIntervalSpec{
Min: 1 * time.Hour,
Max: 7 * 24 * time.Hour,
Default: 24 * time.Hour,
Min: 12 * time.Hour,
Max: 30 * 24 * time.Hour,
Default: 7 * 24 * time.Hour,
},
}
}

View file

@ -14,8 +14,3 @@ type zonemasterProvider struct{}
func (p *zonemasterProvider) Key() sdk.ObservationKey {
return ObservationKeyZonemaster
}
// Definition implements sdk.CheckerDefinitionProvider.
func (p *zonemasterProvider) Definition() *sdk.CheckerDefinition {
return Definition()
}

View file

@ -15,13 +15,14 @@ import (
// zmLevelDisplayOrder defines the severity order used for sorting and display.
var zmLevelDisplayOrder = []string{"CRITICAL", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}
var zmLevelRank = func() map[string]int {
m := make(map[string]int, len(zmLevelDisplayOrder))
for i, l := range zmLevelDisplayOrder {
m[l] = len(zmLevelDisplayOrder) - i
}
return m
}()
var zmLevelRank = map[string]int{
"CRITICAL": 6,
"ERROR": 5,
"WARNING": 4,
"NOTICE": 3,
"INFO": 2,
"DEBUG": 1,
}
type zmLevelCount struct {
Level string
@ -50,7 +51,7 @@ var zonemasterHTMLTemplate = template.Must(
template.New("zonemaster").
Funcs(template.FuncMap{
"badgeClass": func(level string) string {
switch strings.ToUpper(level) {
switch normLevel(level) {
case "CRITICAL":
return "badge-critical"
case "ERROR":
@ -222,7 +223,7 @@ func (p *zonemasterProvider) GetHTMLReport(ctx sdk.ReportContext) (string, error
rs := moduleMap[name]
counts := map[string]int{}
for _, r := range rs {
lvl := strings.ToUpper(r.Level)
lvl := normLevel(r.Level)
counts[lvl]++
totalCounts[lvl]++
}

View file

@ -30,32 +30,6 @@ func Rules() []sdk.CheckRule {
}
}
// Rule returns the legacy single-rule view of the Zonemaster checker.
//
// Deprecated: use Rules() for per-category CheckRules. This wrapper is kept
// so existing callers that only expect a single rule keep compiling.
func Rule() sdk.CheckRule { return &legacyRule{} }
type legacyRule struct{}
func (r *legacyRule) Name() string { return "zonemaster" }
func (r *legacyRule) Description() string {
return "Runs Zonemaster DNS validation tests against the zone (aggregate view)."
}
func (r *legacyRule) ValidateOptions(opts sdk.CheckerOptions) error {
return validateZonemasterOptions(opts)
}
func (r *legacyRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter, _ sdk.CheckerOptions) []sdk.CheckState {
data, errSt := loadZonemasterData(ctx, obs)
if errSt != nil {
return []sdk.CheckState{*errSt}
}
return []sdk.CheckState{summarizeAll(data)}
}
// ── shared helpers ────────────────────────────────────────────────────────────
// validateZonemasterOptions validates the options accepted by the Zonemaster
@ -96,18 +70,23 @@ func loadZonemasterData(ctx context.Context, obs sdk.ObservationGetter) (*Zonema
return &data, nil
}
// normLevel returns the canonical (upper-case) form of a Zonemaster severity
// string. Use this anywhere a severity needs to be compared, looked up or
// keyed so the canonical list stays in one place.
func normLevel(level string) string {
return strings.ToUpper(level)
}
// levelToStatus maps a Zonemaster-returned severity to happyDomain's status.
// Zonemaster's own judgement is treated as raw input; this is happyDomain's
// own mapping onto the SDK status enum.
func levelToStatus(level string) sdk.Status {
switch strings.ToUpper(level) {
switch normLevel(level) {
case "CRITICAL", "ERROR":
return sdk.StatusCrit
case "WARNING":
return sdk.StatusWarn
case "NOTICE", "INFO":
return sdk.StatusInfo
case "DEBUG":
case "NOTICE", "INFO", "DEBUG":
return sdk.StatusInfo
default:
return sdk.StatusUnknown
@ -154,6 +133,10 @@ type categoryRule struct {
func (r *categoryRule) Name() string { return r.name }
func (r *categoryRule) Description() string { return r.description }
func (r *categoryRule) ValidateOptions(opts sdk.CheckerOptions) error {
return validateZonemasterOptions(opts)
}
func (r *categoryRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter, _ sdk.CheckerOptions) []sdk.CheckState {
data, errSt := loadZonemasterData(ctx, obs)
if errSt != nil {
@ -176,7 +159,7 @@ func (r *categoryRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter,
)
for _, res := range matched {
lvl := strings.ToUpper(res.Level)
lvl := normLevel(res.Level)
st := levelToStatus(lvl)
worst = worstStatus(worst, st)
@ -252,61 +235,3 @@ func filterByModules(results []ZonemasterTestResult, modules []string) []Zonemas
}
return out
}
// summarizeAll produces the legacy monolithic summary state. Preserved so
// Rule() keeps behaving as before for callers that still use it.
func summarizeAll(data *ZonemasterData) sdk.CheckState {
var errorCount, warningCount int
var criticalMsgs []string
for _, res := range data.Results {
switch strings.ToUpper(res.Level) {
case "CRITICAL", "ERROR":
errorCount++
if len(criticalMsgs) < 5 {
criticalMsgs = append(criticalMsgs, res.Message)
}
case "WARNING":
warningCount++
}
}
meta := map[string]any{
"errorCount": errorCount,
"warningCount": warningCount,
"totalChecks": len(data.Results),
"hashId": data.HashID,
"createdAt": data.CreatedAt,
}
if errorCount > 0 {
statusLine := fmt.Sprintf("%d error(s), %d warning(s) found", errorCount, warningCount)
if len(criticalMsgs) > 0 {
n := 2
if len(criticalMsgs) < n {
n = len(criticalMsgs)
}
statusLine += ": " + strings.Join(criticalMsgs[:n], "; ")
}
return sdk.CheckState{
Status: sdk.StatusCrit,
Message: statusLine,
Code: "zonemaster_errors",
Meta: meta,
}
}
if warningCount > 0 {
return sdk.CheckState{
Status: sdk.StatusWarn,
Message: fmt.Sprintf("%d warning(s) found", warningCount),
Code: "zonemaster_warnings",
Meta: meta,
}
}
return sdk.CheckState{
Status: sdk.StatusOK,
Message: fmt.Sprintf("All checks passed (%d checks)", len(data.Results)),
Code: "zonemaster_ok",
Meta: meta,
}
}

298
checker/rule_test.go Normal file
View file

@ -0,0 +1,298 @@
package checker
import (
"context"
"encoding/json"
"errors"
"strings"
"testing"
sdk "git.happydns.org/checker-sdk-go/checker"
)
// fakeObs is a minimal ObservationGetter for tests. If err is non-nil, Get
// returns it; otherwise, it JSON-roundtrips data into dest.
type fakeObs struct {
data any
err error
}
func (f *fakeObs) Get(_ context.Context, _ sdk.ObservationKey, dest any) error {
if f.err != nil {
return f.err
}
b, err := json.Marshal(f.data)
if err != nil {
return err
}
return json.Unmarshal(b, dest)
}
func (f *fakeObs) GetRelated(_ context.Context, _ sdk.ObservationKey) ([]sdk.RelatedObservation, error) {
return nil, nil
}
func TestLevelToStatus(t *testing.T) {
cases := []struct {
level string
want sdk.Status
}{
{"CRITICAL", sdk.StatusCrit},
{"ERROR", sdk.StatusCrit},
{"critical", sdk.StatusCrit}, // case-insensitive
{"WARNING", sdk.StatusWarn},
{"NOTICE", sdk.StatusInfo},
{"INFO", sdk.StatusInfo},
{"DEBUG", sdk.StatusInfo},
{"", sdk.StatusUnknown},
{"BANANA", sdk.StatusUnknown},
}
for _, tc := range cases {
t.Run(tc.level, func(t *testing.T) {
if got := levelToStatus(tc.level); got != tc.want {
t.Errorf("levelToStatus(%q) = %v, want %v", tc.level, got, tc.want)
}
})
}
}
func TestWorstStatus(t *testing.T) {
// Severity ordering used by worstStatus:
// Error > Crit > Warn > Info > OK > Unknown
cases := []struct {
a, b, want sdk.Status
}{
{sdk.StatusOK, sdk.StatusOK, sdk.StatusOK},
{sdk.StatusOK, sdk.StatusInfo, sdk.StatusInfo},
{sdk.StatusInfo, sdk.StatusWarn, sdk.StatusWarn},
{sdk.StatusWarn, sdk.StatusCrit, sdk.StatusCrit},
{sdk.StatusCrit, sdk.StatusError, sdk.StatusError},
{sdk.StatusError, sdk.StatusCrit, sdk.StatusError},
{sdk.StatusUnknown, sdk.StatusOK, sdk.StatusOK},
{sdk.StatusUnknown, sdk.StatusUnknown, sdk.StatusUnknown},
}
for _, tc := range cases {
if got := worstStatus(tc.a, tc.b); got != tc.want {
t.Errorf("worstStatus(%v, %v) = %v, want %v", tc.a, tc.b, got, tc.want)
}
}
}
func TestFilterByModules(t *testing.T) {
results := []ZonemasterTestResult{
{Module: "DNSSEC", Message: "a"},
{Module: "Delegation", Message: "b"},
{Module: "dnssec", Message: "c"},
{Module: "Syntax", Message: "d"},
}
t.Run("matches case-insensitively", func(t *testing.T) {
got := filterByModules(results, []string{"dnssec"})
if len(got) != 2 {
t.Fatalf("got %d results, want 2: %+v", len(got), got)
}
if got[0].Message != "a" || got[1].Message != "c" {
t.Errorf("unexpected results: %+v", got)
}
})
t.Run("multiple modules", func(t *testing.T) {
got := filterByModules(results, []string{"delegation", "syntax"})
if len(got) != 2 {
t.Errorf("got %d, want 2", len(got))
}
})
t.Run("empty modules returns nil", func(t *testing.T) {
if got := filterByModules(results, nil); got != nil {
t.Errorf("got %+v, want nil", got)
}
})
t.Run("no match returns empty", func(t *testing.T) {
if got := filterByModules(results, []string{"nope"}); len(got) != 0 {
t.Errorf("got %+v, want empty", got)
}
})
}
func TestValidateZonemasterOptions(t *testing.T) {
cases := []struct {
name string
opts sdk.CheckerOptions
wantErr string // substring; empty means no error expected
}{
{"empty opts", sdk.CheckerOptions{}, ""},
{"empty url", sdk.CheckerOptions{"zonemasterAPIURL": ""}, ""},
{"valid http", sdk.CheckerOptions{"zonemasterAPIURL": "http://localhost:5000/api"}, ""},
{"valid https", sdk.CheckerOptions{"zonemasterAPIURL": "https://zonemaster.net/api"}, ""},
{"non-string", sdk.CheckerOptions{"zonemasterAPIURL": 42}, "must be a string"},
{"bad scheme", sdk.CheckerOptions{"zonemasterAPIURL": "ftp://x/api"}, "http or https"},
{"no host", sdk.CheckerOptions{"zonemasterAPIURL": "http:///api"}, "must include a host"},
{"unparseable", sdk.CheckerOptions{"zonemasterAPIURL": "http://[::1"}, "zonemasterAPIURL"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := validateZonemasterOptions(tc.opts)
switch {
case tc.wantErr == "" && err != nil:
t.Errorf("unexpected error: %v", err)
case tc.wantErr != "" && err == nil:
t.Errorf("expected error containing %q, got nil", tc.wantErr)
case tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr):
t.Errorf("error %q does not contain %q", err.Error(), tc.wantErr)
}
})
}
}
func TestNormLevel(t *testing.T) {
cases := map[string]string{
"": "",
"info": "INFO",
"WaRnInG": "WARNING",
"CRITICAL": "CRITICAL",
}
for in, want := range cases {
if got := normLevel(in); got != want {
t.Errorf("normLevel(%q) = %q, want %q", in, got, want)
}
}
}
func TestCategoryRuleEvaluate_NoData(t *testing.T) {
r := &categoryRule{name: "zonemaster.dnssec", modules: []string{"dnssec"}}
obs := &fakeObs{data: ZonemasterData{Results: nil}}
states := r.Evaluate(context.Background(), obs, nil)
if len(states) != 1 {
t.Fatalf("got %d states, want 1", len(states))
}
if states[0].Status != sdk.StatusUnknown {
t.Errorf("status = %v, want StatusUnknown", states[0].Status)
}
if states[0].Code != "zonemaster.dnssec.not_tested" {
t.Errorf("code = %q", states[0].Code)
}
}
func TestCategoryRuleEvaluate_ObservationError(t *testing.T) {
r := &categoryRule{name: "zonemaster.dnssec", modules: []string{"dnssec"}}
obs := &fakeObs{err: errors.New("boom")}
states := r.Evaluate(context.Background(), obs, nil)
if len(states) != 1 {
t.Fatalf("got %d states, want 1", len(states))
}
if states[0].Status != sdk.StatusError {
t.Errorf("status = %v, want StatusError", states[0].Status)
}
if states[0].Code != "zonemaster.observation_error" {
t.Errorf("code = %q", states[0].Code)
}
}
func TestCategoryRuleEvaluate_AllOK(t *testing.T) {
r := &categoryRule{name: "zonemaster.dnssec", modules: []string{"dnssec"}}
obs := &fakeObs{data: ZonemasterData{Results: []ZonemasterTestResult{
{Module: "dnssec", Level: "INFO", Message: "ok1"},
{Module: "dnssec", Level: "NOTICE", Message: "ok2"},
{Module: "delegation", Level: "ERROR", Message: "ignored, wrong module"},
}}}
states := r.Evaluate(context.Background(), obs, nil)
if len(states) != 1 {
t.Fatalf("got %d states, want 1 (summary only): %+v", len(states), states)
}
if states[0].Status != sdk.StatusOK {
t.Errorf("status = %v, want StatusOK", states[0].Status)
}
if got, _ := states[0].Meta["total"].(int); got != 2 {
t.Errorf("total = %d, want 2", got)
}
}
func TestCategoryRuleEvaluate_MixedSeverities(t *testing.T) {
r := &categoryRule{name: "zonemaster.dnssec", modules: []string{"dnssec"}}
obs := &fakeObs{data: ZonemasterData{Results: []ZonemasterTestResult{
{Module: "DNSSEC", Level: "INFO", Message: "i"},
{Module: "dnssec", Level: "WARNING", Message: "w", Testcase: "tc-w"},
{Module: "dnssec", Level: "ERROR", Message: "e", Testcase: "tc-e"},
{Module: "dnssec", Level: "CRITICAL", Message: "c", Testcase: "tc-c"},
}}}
states := r.Evaluate(context.Background(), obs, nil)
// Expect 1 summary + 3 issue states (warning + error + critical).
if len(states) != 4 {
t.Fatalf("got %d states, want 4: %+v", len(states), states)
}
summary := states[0]
if summary.Status != sdk.StatusCrit {
t.Errorf("summary status = %v, want StatusCrit", summary.Status)
}
if got, _ := summary.Meta["critical"].(int); got != 1 {
t.Errorf("critical = %d, want 1", got)
}
if got, _ := summary.Meta["error"].(int); got != 1 {
t.Errorf("error = %d, want 1", got)
}
if got, _ := summary.Meta["warning"].(int); got != 1 {
t.Errorf("warning = %d, want 1", got)
}
if got, _ := summary.Meta["info"].(int); got != 1 {
t.Errorf("info = %d, want 1", got)
}
// Issue states: codes should be dotted, lowercased levels.
wantCodes := map[string]bool{
"zonemaster.dnssec.warning": false,
"zonemaster.dnssec.error": false,
"zonemaster.dnssec.critical": false,
}
for _, s := range states[1:] {
if _, ok := wantCodes[s.Code]; !ok {
t.Errorf("unexpected issue code: %q", s.Code)
continue
}
wantCodes[s.Code] = true
if s.Subject == "" {
t.Errorf("issue state %q missing Subject", s.Code)
}
}
for code, seen := range wantCodes {
if !seen {
t.Errorf("missing issue state for %q", code)
}
}
}
func TestRulesContainsAllCategories(t *testing.T) {
got := Rules()
wantNames := []string{
"zonemaster.dnssec",
"zonemaster.delegation",
"zonemaster.consistency",
"zonemaster.connectivity",
"zonemaster.nameserver",
"zonemaster.syntax",
"zonemaster.zone",
"zonemaster.address",
"zonemaster.basic",
}
if len(got) != len(wantNames) {
t.Fatalf("Rules() returned %d rules, want %d", len(got), len(wantNames))
}
seen := map[string]bool{}
for _, r := range got {
seen[r.Name()] = true
if r.Description() == "" {
t.Errorf("rule %q has empty description", r.Name())
}
}
for _, n := range wantNames {
if !seen[n] {
t.Errorf("Rules() missing %q", n)
}
}
}

View file

@ -5,8 +5,8 @@
package main
import (
zonemaster "git.happydns.org/checker-zonemaster/checker"
sdk "git.happydns.org/checker-sdk-go/checker"
zonemaster "git.happydns.org/checker-zonemaster/checker"
)
// Version is the plugin's version. It defaults to "custom-build" and is
@ -20,5 +20,6 @@ var Version = "custom-build"
// that the host will register in its global registries.
func NewCheckerPlugin() (*sdk.CheckerDefinition, sdk.ObservationProvider, error) {
zonemaster.Version = Version
return zonemaster.Definition(), zonemaster.Provider(), nil
prvd := zonemaster.Provider()
return prvd.(sdk.CheckerDefinitionProvider).Definition(), prvd, nil
}