package checker import ( "context" "encoding/json" "fmt" "sort" "strings" "sync" "github.com/miekg/dns" sdk "git.happydns.org/checker-sdk-go/checker" ) // Gathers raw per-NS DNS answers. No severity or pass/fail is decided here; // rules turn the resulting ObservationData into CheckStates. func (p *authoritativeConsistencyProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) { svc, err := loadService(opts) if err != nil { return nil, err } zone, err := loadZone(opts, svc) if err != nil { return nil, err } checkEDNS := sdk.GetBoolOption(opts, "checkEDNS", true) useParentNS := sdk.GetBoolOption(opts, "useParentNS", true) data := &ObservationData{ Zone: dns.Fqdn(zone), HasSOA: svc.SOA != nil, DeclaredNS: normalizeNSList(svc.NameServers), Results: map[string]*NSResult{}, } if svc.SOA != nil { data.DeclaredSerial = svc.SOA.Serial } if useParentNS { parentNS, perr := parentReferral(ctx, data.Zone) if perr != nil { data.ParentQueryError = perr.Error() } else { data.ParentNS = parentNS } } data.Probed = unionStrings(data.DeclaredNS, data.ParentNS) if len(data.Probed) == 0 { return data, nil } // Cap fan-out: an unbounded Origin NS list would otherwise spawn one // goroutine and a fresh batch of UDP/TCP sockets per name. const maxConcurrentProbes = 16 sem := make(chan struct{}, maxConcurrentProbes) var wg sync.WaitGroup var mu sync.Mutex for _, nsName := range data.Probed { nsName := nsName wg.Add(1) sem <- struct{}{} go func() { defer wg.Done() defer func() { <-sem }() res := probeNS(ctx, data.Zone, nsName, checkEDNS) mu.Lock() data.Results[nsName] = res mu.Unlock() }() } wg.Wait() return data, nil } // First authoritative answer wins as the canonical view of this NS; // subsequent addresses only contribute reachability/error state. Avoids // dual-homed servers appearing twice in the drift matrix while still // surfacing IPv4/IPv6-specific failures. func probeNS(ctx context.Context, zone, nsName string, checkEDNS bool) *NSResult { res := &NSResult{Name: nsName} addrs, err := resolveHost(ctx, nsName) if err != nil { res.ResolveError = err.Error() return res } if len(addrs) == 0 { res.ResolveError = "no A/AAAA records" return res } res.Addresses = addrs for _, addr := range addrs { srv := hostPort(addr, "53") soa, aa, rtt, qerr := querySOA(ctx, "", srv, zone) if qerr != nil { res.appendError("UDP %s: %v", addr, qerr) continue } res.UDPReachable = true if res.LatencyMs == 0 { res.LatencyMs = rtt.Milliseconds() } if aa { res.Authoritative = true } if soa != nil && res.SOA == nil { res.SOA = soa res.Serial = soa.Serial } if _, _, _, terr := querySOA(ctx, "tcp", srv, zone); terr != nil { res.appendError("TCP %s: %v", addr, terr) } else { res.TCPReachable = true } if checkEDNS { if eerr := probeEDNS0(ctx, srv, zone); eerr != nil { res.appendError("EDNS0 %s: %v", addr, eerr) } else { res.EDNSSupported = true } } if nss, nerr := queryNSAt(ctx, srv, zone); nerr == nil && len(res.NSRRset) == 0 { sort.Strings(nss) res.NSRRset = nss } } return res } func loadService(opts sdk.CheckerOptions) (*originService, error) { svc, ok := sdk.GetOption[serviceMessage](opts, "service") if !ok { return nil, fmt.Errorf("missing 'service' option") } switch svc.Type { case "", "abstract.Origin", "abstract.NSOnlyOrigin": default: return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type) } var d originService if err := json.Unmarshal(svc.Service, &d); err != nil { return nil, fmt.Errorf("decoding origin service: %w", err) } return &d, nil } // Falls back to the service's SOA owner name when domain_name is unset. func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) { if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" { return dns.Fqdn(v), nil } if svc.SOA != nil && svc.SOA.Header().Name != "" { return dns.Fqdn(svc.SOA.Header().Name), nil } return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)") } func normalizeNSList(ns []*dns.NS) []string { out := make([]string, 0, len(ns)) for _, n := range ns { if n == nil { continue } out = append(out, strings.ToLower(dns.Fqdn(n.Ns))) } sort.Strings(out) return out } func unionStrings(a, b []string) []string { seen := map[string]bool{} var out []string for _, s := range a { if !seen[s] { seen[s] = true out = append(out, s) } } for _, s := range b { if !seen[s] { seen[s] = true out = append(out, s) } } sort.Strings(out) return out } func diffStringSets(want, got []string) (missing, extra []string) { w := map[string]bool{} for _, v := range want { w[strings.ToLower(strings.TrimSuffix(v, "."))] = true } g := map[string]bool{} for _, v := range got { g[strings.ToLower(strings.TrimSuffix(v, "."))] = true } for k := range w { if !g[k] { missing = append(missing, k) } } for k := range g { if !w[k] { extra = append(extra, k) } } sort.Strings(missing) sort.Strings(extra) return } // RFC 1982 serial-number arithmetic (handles wraparound). func serialLess(a, b uint32) bool { diff := b - a return diff != 0 && diff < (1<<31) }