package checker import ( "context" "encoding/json" "fmt" "log" "net" "strconv" "strings" "sync" "time" "github.com/miekg/dns" sdk "git.happydns.org/checker-sdk-go/checker" ) // Collect gathers raw DNS answers from each selected public resolver plus the // zone's own authoritative ground-truth. It performs no judgement: rules // derive consensus, drift, splits, latency, and DNSSEC verdicts from the // observation. func (p *resolverPropagationProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) { svc, err := loadService(opts) if err != nil { return nil, err } zone, err := loadZone(opts, svc) if err != nil { return nil, err } includeFiltered := sdk.GetBoolOption(opts, "includeFiltered", false) region := getStringOpt(opts, "region", "all") transportsOpt := getStringOpt(opts, "transports", "udp") recordTypesOpt := getStringOpt(opts, "recordTypes", "SOA,NS,A,AAAA,MX,TXT,CAA") subdomainsOpt := getStringOpt(opts, "subdomains", "") runTimeoutS := sdk.GetIntOption(opts, "runTimeoutSeconds", 30) allowlistOpt := getStringOpt(opts, "resolverAllowlist", "") // Parse options. transports := parseCSV(transportsOpt) if len(transports) == 0 { transports = []string{string(TransportUDP)} } qtypes := parseQTypes(recordTypesOpt) if len(qtypes) == 0 { return nil, fmt.Errorf("no valid record types in %q", recordTypesOpt) } extraNames := parseCSV(subdomainsOpt) allowlist := parseCSV(allowlistOpt) // Build the list of owner names to probe. names := []string{dns.Fqdn(zone)} seenName := map[string]bool{names[0]: true} for _, sd := range extraNames { full := joinSubdomain(sd, zone) if !seenName[full] { seenName[full] = true names = append(names, full) } } resolvers := selectedResolvers(includeFiltered, region, allowlist) data := &ResolverPropagationData{ Zone: dns.Fqdn(zone), Names: names, Types: qtypeNames(qtypes), Resolvers: map[string]*ResolverView{}, RRsets: map[string]*RRsetView{}, } if svc.SOA != nil { data.DeclaredSerial = svc.SOA.Serial } // If the selection matches no resolvers, simply return the (empty) // payload. Rules classify "no resolvers matched" as their own concern. if len(resolvers) == 0 { data.Stats = computeBasicStats(data) return data, nil } runCtx, cancel := context.WithTimeout(ctx, time.Duration(runTimeoutS)*time.Second) defer cancel() started := time.Now() // Ground truth from the zone's own authoritative servers. expected := collectExpected(runCtx, zone, svc, names, qtypes) for _, n := range names { for _, qt := range qtypes { key := rrsetKey(n, dns.TypeToString[qt]) v := &RRsetView{ Name: strings.ToLower(dns.Fqdn(n)), Type: dns.TypeToString[qt], } if e, ok := expected[key]; ok { v.Expected = e.sig v.ExpectedRecords = e.records } data.RRsets[key] = v } } // Fan out probes across resolvers × transports × RRsets. type probeJob struct { r Resolver tr Transport } var jobs []probeJob for _, r := range resolvers { for _, tname := range transports { tr := Transport(strings.ToLower(strings.TrimSpace(tname))) switch tr { case TransportUDP, TransportTCP: jobs = append(jobs, probeJob{r: r, tr: tr}) case TransportDoT: if r.DoTHost != "" { jobs = append(jobs, probeJob{r: r, tr: tr}) } case TransportDoH: if r.DoHURL != "" { jobs = append(jobs, probeJob{r: r, tr: tr}) } } } } const maxConcurrent = 32 sem := make(chan struct{}, maxConcurrent) var wg sync.WaitGroup var mu sync.Mutex for _, job := range jobs { job := job wg.Add(1) sem <- struct{}{} go func() { defer wg.Done() defer func() { <-sem }() rid := job.r.ID if job.tr != TransportUDP { rid = fmt.Sprintf("%s|%s", job.r.ID, job.tr) } view := &ResolverView{ ID: rid, Name: job.r.Name, IP: job.r.IP, Region: job.r.Region, Filtered: job.r.Filtered, Transport: job.tr, Probes: map[string]*RRProbe{}, } for _, n := range names { for _, qt := range qtypes { probe := runProbe(runCtx, job.r, job.tr, n, qt) key := rrsetKey(n, dns.TypeToString[qt]) view.Probes[key] = probe if probe.Error == "" { view.Reachable = true } } } mu.Lock() data.Resolvers[rid] = view mu.Unlock() }() } wg.Wait() data.RunDurationMs = time.Since(started).Milliseconds() data.Stats = computeBasicStats(data) return data, nil } func runProbe(ctx context.Context, r Resolver, tr Transport, name string, qtype uint16) *RRProbe { p := &RRProbe{Transport: tr} res, err := queryResolver(ctx, r, tr, name, qtype) if err != nil { p.Error = err.Error() return p } p.Rcode = rcodeToString(res.Rcode) p.AD = res.AD p.LatencyMs = res.Latency.Milliseconds() if res.Rcode == dns.RcodeSuccess { sig, recs, ttl := signatureFromRRs(res.Answer, name, qtype) p.Signature = sig p.Records = recs p.MinTTL = ttl } return p } type expectedEntry struct { sig string records []string } func collectExpected(ctx context.Context, zone string, svc *originService, names []string, qtypes []uint16) map[string]*expectedEntry { out := map[string]*expectedEntry{} var nsHosts []string for _, n := range svc.NameServers { if n == nil { continue } nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(n.Ns))) } if len(nsHosts) == 0 { var resolver net.Resolver nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, ".")) if err != nil { log.Printf("collectExpected: NS lookup failed for %q: %v", zone, err) return out } for _, ns := range nss { nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(ns.Host))) } } var resolver net.Resolver var authAddrs []string for _, ns := range nsHosts { addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns, ".")) if err != nil { continue } for _, a := range addrs { authAddrs = append(authAddrs, net.JoinHostPort(a, "53")) } } if len(authAddrs) == 0 { return out } for _, n := range names { for _, qt := range qtypes { key := rrsetKey(n, dns.TypeToString[qt]) if e := queryAuthoritative(ctx, authAddrs, n, qt); e != nil { out[key] = e } } } return out } func queryAuthoritative(ctx context.Context, servers []string, name string, qtype uint16) *expectedEntry { q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET} m := new(dns.Msg) m.Id = dns.Id() m.Question = []dns.Question{q} m.RecursionDesired = false m.SetEdns0(ednsUDPSize, false) client := dns.Client{Timeout: dnsTimeout} for _, srv := range servers { r, _, err := client.ExchangeContext(ctx, m, srv) if err != nil || r == nil { continue } if !r.Authoritative { continue } if r.Rcode != dns.RcodeSuccess { return &expectedEntry{} } sig, recs, _ := signatureFromRRs(r.Answer, name, qtype) return &expectedEntry{sig: sig, records: recs} } return nil } // computeBasicStats returns the raw rollup that Collect can produce without // judgement: simple counts. "Agreement" (UnfilteredAgreeing) is a derived // metric computed by deriveView once consensus has been established. func computeBasicStats(data *ResolverPropagationData) Stats { s := Stats{TotalResolvers: len(data.Resolvers)} regions := map[string]bool{} for _, rv := range data.Resolvers { if rv.Reachable { s.ReachableResolvers++ } if rv.Filtered { s.FilteredProbed++ } else { s.UnfilteredProbed++ } regions[rv.Region] = true } s.CountriesCovered = len(regions) return s } func loadService(opts sdk.CheckerOptions) (*originService, error) { svc, ok := sdk.GetOption[serviceMessage](opts, "service") if !ok { // Standalone / interactive use: no service was attached. Fall back // to an empty payload; collectExpected will look up NS via the // system resolver. return &originService{}, nil } switch svc.Type { case "", "abstract.Origin", "abstract.NSOnlyOrigin": default: return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type) } var d originService if err := json.Unmarshal(svc.Service, &d); err != nil { return nil, fmt.Errorf("decoding origin service: %w", err) } return &d, nil } func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) { if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" { return dns.Fqdn(v), nil } if svc.SOA != nil && svc.SOA.Header().Name != "" { return dns.Fqdn(svc.SOA.Header().Name), nil } return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)") } func getStringOpt(opts sdk.CheckerOptions, key, dflt string) string { if v, ok := sdk.GetOption[string](opts, key); ok && v != "" { return v } return dflt } func parseCSV(s string) []string { if s == "" { return nil } parts := strings.Split(s, ",") out := make([]string, 0, len(parts)) for _, p := range parts { p = strings.TrimSpace(p) if p != "" { out = append(out, p) } } return out } func parseQTypes(s string) []uint16 { seen := map[uint16]bool{} var out []uint16 for _, t := range parseCSV(s) { if q, ok := dns.StringToType[strings.ToUpper(t)]; ok && !seen[q] { seen[q] = true out = append(out, q) } } return out } func qtypeNames(qtypes []uint16) []string { out := make([]string, len(qtypes)) for i, q := range qtypes { out[i] = dns.TypeToString[q] } return out } func joinSubdomain(sd, zone string) string { sd = strings.TrimSpace(sd) zone = dns.Fqdn(zone) if sd == "" || sd == "@" { return zone } if strings.HasSuffix(sd, ".") { return strings.ToLower(sd) } return strings.ToLower(sd + "." + zone) } func extractSerial(records []string) uint32 { if len(records) == 0 { return 0 } fields := strings.Fields(records[0]) if len(fields) < 7 { return 0 } s, err := strconv.ParseUint(fields[2], 10, 32) if err != nil { return 0 } return uint32(s) } // Hardcoded allowlist; only these resolvers' AD bit is trustworthy. func isValidatingResolver(id string) bool { switch strings.SplitN(id, "|", 2)[0] { case "cloudflare", "cloudflare-malware", "cloudflare-family", "google", "quad9", "quad9-unfiltered", "adguard", "adguard-unfiltered", "adguard-family", "cleanbrowsing-family", "cleanbrowsing-adult": return true } return false } // firstN returns a short "x, y, z (+N more)" display list. func firstN(items []string, n int) string { if len(items) <= n { return strings.Join(items, ", ") } return strings.Join(items[:n], ", ") + fmt.Sprintf(" (+%d more)", len(items)-n) }