package checker import ( "context" "encoding/json" "fmt" "sort" "strings" "sync" "github.com/miekg/dns" sdk "git.happydns.org/checker-sdk-go/checker" ) // Collect gathers raw per-authoritative-NS DNS answers for the zone. It does // NOT judge: no severity, no pass/fail, no pre-derived findings. Rules in // rules.go translate the resulting ObservationData into CheckStates. func (p *authoritativeConsistencyProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) { svc, err := loadService(opts) if err != nil { return nil, err } zone, err := loadZone(opts, svc) if err != nil { return nil, err } checkEDNS := sdk.GetBoolOption(opts, "checkEDNS", true) useParentNS := sdk.GetBoolOption(opts, "useParentNS", true) data := &ObservationData{ Zone: dns.Fqdn(zone), HasSOA: svc.SOA != nil, DeclaredNS: normalizeNSList(svc.NameServers), Results: map[string]*NSResult{}, } if svc.SOA != nil { data.DeclaredSerial = svc.SOA.Serial } // Parent referral probe (raw). if useParentNS { parentNS, perr := parentReferral(ctx, data.Zone) if perr != nil { data.ParentQueryError = perr.Error() } else { data.ParentNS = parentNS } } // Union of every NS name we intend to probe. data.Probed = unionStrings(data.DeclaredNS, data.ParentNS) if len(data.Probed) == 0 { // Nothing to probe. Rules will turn this into a finding. return data, nil } // Per-NS probes (concurrent, bounded). The cap protects the checker // from a malicious or misconfigured Origin declaring an unbounded NS // list, which would otherwise spawn one goroutine and a fresh batch of // UDP/TCP sockets per name. const maxConcurrentProbes = 16 sem := make(chan struct{}, maxConcurrentProbes) var wg sync.WaitGroup var mu sync.Mutex for _, nsName := range data.Probed { nsName := nsName wg.Add(1) sem <- struct{}{} go func() { defer wg.Done() defer func() { <-sem }() res := probeNS(ctx, data.Zone, nsName, checkEDNS) mu.Lock() data.Results[nsName] = res mu.Unlock() }() } wg.Wait() return data, nil } // probeNS performs every probe against a single NS hostname. It resolves the // name, then iterates over its addresses. For consistency, the "canonical" // view returned by the NS is the first address that provided an // authoritative answer; subsequent addresses only update reachability and // error state. This avoids dual-homed servers appearing twice in the drift // matrix while still catching IPv4/IPv6-specific failures. func probeNS(ctx context.Context, zone, nsName string, checkEDNS bool) *NSResult { res := &NSResult{Name: nsName} addrs, err := resolveHost(ctx, nsName) if err != nil { res.ResolveError = err.Error() return res } if len(addrs) == 0 { res.ResolveError = "no A/AAAA records" return res } res.Addresses = addrs for _, addr := range addrs { srv := hostPort(addr, "53") soa, aa, rtt, qerr := querySOA(ctx, "", srv, zone) if qerr != nil { res.appendError("UDP %s: %v", addr, qerr) continue } res.UDPReachable = true if res.LatencyMs == 0 { res.LatencyMs = rtt.Milliseconds() } if aa { res.Authoritative = true } // First authoritative answer wins; that is the canonical view of // this NS. Subsequent addresses only contribute reachability/errors. if soa != nil && res.SOA == nil { res.SOA = soa res.Serial = soa.Serial } // TCP probe against the same address. if _, _, _, terr := querySOA(ctx, "tcp", srv, zone); terr != nil { res.appendError("TCP %s: %v", addr, terr) } else { res.TCPReachable = true } // EDNS0 probe against the same address. if checkEDNS { if eerr := probeEDNS0(ctx, srv, zone); eerr != nil { res.appendError("EDNS0 %s: %v", addr, eerr) } else { res.EDNSSupported = true } } // NS RRset as seen by this server. if nss, nerr := queryNSAt(ctx, srv, zone); nerr == nil && len(res.NSRRset) == 0 { sort.Strings(nss) res.NSRRset = nss } } return res } // loadService extracts the abstract.Origin / abstract.NSOnlyOrigin payload // from the auto-filled "service" option. func loadService(opts sdk.CheckerOptions) (*originService, error) { svc, ok := sdk.GetOption[serviceMessage](opts, "service") if !ok { return nil, fmt.Errorf("missing 'service' option") } switch svc.Type { case "", "abstract.Origin", "abstract.NSOnlyOrigin": default: return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type) } var d originService if err := json.Unmarshal(svc.Service, &d); err != nil { return nil, fmt.Errorf("decoding origin service: %w", err) } return &d, nil } // loadZone picks the zone name from the "domain_name" option or falls back // to the service's SOA owner name. func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) { if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" { return dns.Fqdn(v), nil } if svc.SOA != nil && svc.SOA.Header().Name != "" { return dns.Fqdn(svc.SOA.Header().Name), nil } return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)") } // normalizeNSList lowercases and FQDN-normalizes a list of NS records. func normalizeNSList(ns []*dns.NS) []string { out := make([]string, 0, len(ns)) for _, n := range ns { if n == nil { continue } out = append(out, strings.ToLower(dns.Fqdn(n.Ns))) } sort.Strings(out) return out } // unionStrings returns the sorted union of two string slices, de-duplicated. func unionStrings(a, b []string) []string { seen := map[string]bool{} var out []string for _, s := range a { if !seen[s] { seen[s] = true out = append(out, s) } } for _, s := range b { if !seen[s] { seen[s] = true out = append(out, s) } } sort.Strings(out) return out } // diffStringSets returns the elements of "want" missing from "got" and the // elements of "got" not present in "want". func diffStringSets(want, got []string) (missing, extra []string) { w := map[string]bool{} for _, v := range want { w[strings.ToLower(strings.TrimSuffix(v, "."))] = true } g := map[string]bool{} for _, v := range got { g[strings.ToLower(strings.TrimSuffix(v, "."))] = true } for k := range w { if !g[k] { missing = append(missing, k) } } for k := range g { if !w[k] { extra = append(extra, k) } } sort.Strings(missing) sort.Strings(extra) return } // serialLess reports whether a is earlier than b under RFC 1982 serial // number arithmetic (handles wraparound). func serialLess(a, b uint32) bool { diff := b - a return diff != 0 && diff < (1<<31) }