checker-resolver-propagation/checker/collect.go

422 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package checker
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
// Collect gathers raw DNS answers from each selected public resolver plus the
// zone's own authoritative ground-truth. It performs no judgement: rules
// derive consensus, drift, splits, latency, and DNSSEC verdicts from the
// observation.
func (p *resolverPropagationProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
svc, err := loadService(opts)
if err != nil {
return nil, err
}
zone, err := loadZone(opts, svc)
if err != nil {
return nil, err
}
includeFiltered := sdk.GetBoolOption(opts, "includeFiltered", false)
region := getStringOpt(opts, "region", "all")
transportsOpt := getStringOpt(opts, "transports", "udp")
recordTypesOpt := getStringOpt(opts, "recordTypes", "SOA,NS,A,AAAA,MX,TXT,CAA")
subdomainsOpt := getStringOpt(opts, "subdomains", "")
runTimeoutS := sdk.GetIntOption(opts, "runTimeoutSeconds", 30)
allowlistOpt := getStringOpt(opts, "resolverAllowlist", "")
// Parse options.
transports := parseCSV(transportsOpt)
if len(transports) == 0 {
transports = []string{string(TransportUDP)}
}
qtypes := parseQTypes(recordTypesOpt)
if len(qtypes) == 0 {
return nil, fmt.Errorf("no valid record types in %q", recordTypesOpt)
}
extraNames := parseCSV(subdomainsOpt)
allowlist := parseCSV(allowlistOpt)
// Build the list of owner names to probe.
names := []string{dns.Fqdn(zone)}
seenName := map[string]bool{names[0]: true}
for _, sd := range extraNames {
full := joinSubdomain(sd, zone)
if !seenName[full] {
seenName[full] = true
names = append(names, full)
}
}
resolvers := selectedResolvers(includeFiltered, region, allowlist)
data := &ResolverPropagationData{
Zone: dns.Fqdn(zone),
Names: names,
Types: qtypeNames(qtypes),
Resolvers: map[string]*ResolverView{},
RRsets: map[string]*RRsetView{},
}
if svc.SOA != nil {
data.DeclaredSerial = svc.SOA.Serial
}
// If the selection matches no resolvers, simply return the (empty)
// payload. Rules classify "no resolvers matched" as their own concern.
if len(resolvers) == 0 {
data.Stats = computeBasicStats(data)
return data, nil
}
runCtx, cancel := context.WithTimeout(ctx, time.Duration(runTimeoutS)*time.Second)
defer cancel()
started := time.Now()
// Ground truth from the zone's own authoritative servers.
expected := collectExpected(runCtx, zone, svc, names, qtypes)
for _, n := range names {
for _, qt := range qtypes {
key := rrsetKey(n, dns.TypeToString[qt])
v := &RRsetView{
Name: strings.ToLower(dns.Fqdn(n)),
Type: dns.TypeToString[qt],
}
if e, ok := expected[key]; ok {
v.Expected = e.sig
v.ExpectedRecords = e.records
}
data.RRsets[key] = v
}
}
// Fan out probes across resolvers × transports × RRsets.
type probeJob struct {
r Resolver
tr Transport
}
var jobs []probeJob
for _, r := range resolvers {
for _, tname := range transports {
tr := Transport(strings.ToLower(strings.TrimSpace(tname)))
switch tr {
case TransportUDP, TransportTCP:
jobs = append(jobs, probeJob{r: r, tr: tr})
case TransportDoT:
if r.DoTHost != "" {
jobs = append(jobs, probeJob{r: r, tr: tr})
}
case TransportDoH:
if r.DoHURL != "" {
jobs = append(jobs, probeJob{r: r, tr: tr})
}
}
}
}
const maxConcurrent = 32
sem := make(chan struct{}, maxConcurrent)
var wg sync.WaitGroup
var mu sync.Mutex
for _, job := range jobs {
job := job
wg.Add(1)
sem <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-sem }()
rid := job.r.ID
if job.tr != TransportUDP {
rid = fmt.Sprintf("%s|%s", job.r.ID, job.tr)
}
view := &ResolverView{
ID: rid,
Name: job.r.Name,
IP: job.r.IP,
Region: job.r.Region,
Filtered: job.r.Filtered,
Transport: job.tr,
Probes: map[string]*RRProbe{},
}
for _, n := range names {
for _, qt := range qtypes {
probe := runProbe(runCtx, job.r, job.tr, n, qt)
key := rrsetKey(n, dns.TypeToString[qt])
view.Probes[key] = probe
if probe.Error == "" {
view.Reachable = true
}
}
}
mu.Lock()
data.Resolvers[rid] = view
mu.Unlock()
}()
}
wg.Wait()
data.RunDurationMs = time.Since(started).Milliseconds()
data.Stats = computeBasicStats(data)
return data, nil
}
func runProbe(ctx context.Context, r Resolver, tr Transport, name string, qtype uint16) *RRProbe {
p := &RRProbe{Transport: tr}
res, err := queryResolver(ctx, r, tr, name, qtype)
if err != nil {
p.Error = err.Error()
return p
}
p.Rcode = rcodeToString(res.Rcode)
p.AD = res.AD
p.LatencyMs = res.Latency.Milliseconds()
if res.Rcode == dns.RcodeSuccess {
sig, recs, ttl := signatureFromRRs(res.Answer, name, qtype)
p.Signature = sig
p.Records = recs
p.MinTTL = ttl
}
return p
}
type expectedEntry struct {
sig string
records []string
}
func collectExpected(ctx context.Context, zone string, svc *originService, names []string, qtypes []uint16) map[string]*expectedEntry {
out := map[string]*expectedEntry{}
var nsHosts []string
for _, n := range svc.NameServers {
if n == nil {
continue
}
nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(n.Ns)))
}
if len(nsHosts) == 0 {
var resolver net.Resolver
nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, "."))
if err != nil {
log.Printf("collectExpected: NS lookup failed for %q: %v", zone, err)
return out
}
for _, ns := range nss {
nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(ns.Host)))
}
}
var resolver net.Resolver
var authAddrs []string
for _, ns := range nsHosts {
addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns, "."))
if err != nil {
continue
}
for _, a := range addrs {
authAddrs = append(authAddrs, net.JoinHostPort(a, "53"))
}
}
if len(authAddrs) == 0 {
return out
}
for _, n := range names {
for _, qt := range qtypes {
key := rrsetKey(n, dns.TypeToString[qt])
if e := queryAuthoritative(ctx, authAddrs, n, qt); e != nil {
out[key] = e
}
}
}
return out
}
func queryAuthoritative(ctx context.Context, servers []string, name string, qtype uint16) *expectedEntry {
q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET}
m := new(dns.Msg)
m.Id = dns.Id()
m.Question = []dns.Question{q}
m.RecursionDesired = false
m.SetEdns0(ednsUDPSize, false)
client := dns.Client{Timeout: dnsTimeout}
for _, srv := range servers {
r, _, err := client.ExchangeContext(ctx, m, srv)
if err != nil || r == nil {
continue
}
if !r.Authoritative {
continue
}
if r.Rcode != dns.RcodeSuccess {
return &expectedEntry{}
}
sig, recs, _ := signatureFromRRs(r.Answer, name, qtype)
return &expectedEntry{sig: sig, records: recs}
}
return nil
}
// computeBasicStats returns the raw rollup that Collect can produce without
// judgement: simple counts. "Agreement" (UnfilteredAgreeing) is a derived
// metric computed by deriveView once consensus has been established.
func computeBasicStats(data *ResolverPropagationData) Stats {
s := Stats{TotalResolvers: len(data.Resolvers)}
regions := map[string]bool{}
for _, rv := range data.Resolvers {
if rv.Reachable {
s.ReachableResolvers++
}
if rv.Filtered {
s.FilteredProbed++
} else {
s.UnfilteredProbed++
}
regions[rv.Region] = true
}
s.CountriesCovered = len(regions)
return s
}
func loadService(opts sdk.CheckerOptions) (*originService, error) {
svc, ok := sdk.GetOption[serviceMessage](opts, "service")
if !ok {
// Standalone / interactive use: no service was attached. Fall back
// to an empty payload; collectExpected will look up NS via the
// system resolver.
return &originService{}, nil
}
switch svc.Type {
case "", "abstract.Origin", "abstract.NSOnlyOrigin":
default:
return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type)
}
var d originService
if err := json.Unmarshal(svc.Service, &d); err != nil {
return nil, fmt.Errorf("decoding origin service: %w", err)
}
return &d, nil
}
func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) {
if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" {
return dns.Fqdn(v), nil
}
if svc.SOA != nil && svc.SOA.Header().Name != "" {
return dns.Fqdn(svc.SOA.Header().Name), nil
}
return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)")
}
func getStringOpt(opts sdk.CheckerOptions, key, dflt string) string {
if v, ok := sdk.GetOption[string](opts, key); ok && v != "" {
return v
}
return dflt
}
func parseCSV(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func parseQTypes(s string) []uint16 {
seen := map[uint16]bool{}
var out []uint16
for _, t := range parseCSV(s) {
if q, ok := dns.StringToType[strings.ToUpper(t)]; ok && !seen[q] {
seen[q] = true
out = append(out, q)
}
}
return out
}
func qtypeNames(qtypes []uint16) []string {
out := make([]string, len(qtypes))
for i, q := range qtypes {
out[i] = dns.TypeToString[q]
}
return out
}
func joinSubdomain(sd, zone string) string {
sd = strings.TrimSpace(sd)
zone = dns.Fqdn(zone)
if sd == "" || sd == "@" {
return zone
}
if strings.HasSuffix(sd, ".") {
return strings.ToLower(sd)
}
return strings.ToLower(sd + "." + zone)
}
func extractSerial(records []string) uint32 {
if len(records) == 0 {
return 0
}
fields := strings.Fields(records[0])
if len(fields) < 7 {
return 0
}
s, err := strconv.ParseUint(fields[2], 10, 32)
if err != nil {
return 0
}
return uint32(s)
}
// Hardcoded allowlist; only these resolvers' AD bit is trustworthy.
func isValidatingResolver(id string) bool {
switch strings.SplitN(id, "|", 2)[0] {
case "cloudflare", "cloudflare-malware", "cloudflare-family",
"google", "quad9", "quad9-unfiltered",
"adguard", "adguard-unfiltered", "adguard-family",
"cleanbrowsing-family", "cleanbrowsing-adult":
return true
}
return false
}
// firstN returns a short "x, y, z (+N more)" display list.
func firstN(items []string, n int) string {
if len(items) <= n {
return strings.Join(items, ", ")
}
return strings.Join(items[:n], ", ") + fmt.Sprintf(" (+%d more)", len(items)-n)
}