checker-authoritative-consi.../checker/collect.go

229 lines
5.3 KiB
Go

package checker
import (
"context"
"encoding/json"
"fmt"
"sort"
"strings"
"sync"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
// Gathers raw per-NS DNS answers. No severity or pass/fail is decided here;
// rules turn the resulting ObservationData into CheckStates.
func (p *authoritativeConsistencyProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
svc, err := loadService(opts)
if err != nil {
return nil, err
}
zone, err := loadZone(opts, svc)
if err != nil {
return nil, err
}
checkEDNS := sdk.GetBoolOption(opts, "checkEDNS", true)
useParentNS := sdk.GetBoolOption(opts, "useParentNS", true)
data := &ObservationData{
Zone: dns.Fqdn(zone),
HasSOA: svc.SOA != nil,
DeclaredNS: normalizeNSList(svc.NameServers),
Results: map[string]*NSResult{},
}
if svc.SOA != nil {
data.DeclaredSerial = svc.SOA.Serial
}
if useParentNS {
parentNS, perr := parentReferral(ctx, data.Zone)
if perr != nil {
data.ParentQueryError = perr.Error()
} else {
data.ParentNS = parentNS
}
}
data.Probed = unionStrings(data.DeclaredNS, data.ParentNS)
if len(data.Probed) == 0 {
return data, nil
}
// Cap fan-out: an unbounded Origin NS list would otherwise spawn one
// goroutine and a fresh batch of UDP/TCP sockets per name.
const maxConcurrentProbes = 16
sem := make(chan struct{}, maxConcurrentProbes)
var wg sync.WaitGroup
var mu sync.Mutex
for _, nsName := range data.Probed {
nsName := nsName
wg.Add(1)
sem <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-sem }()
res := probeNS(ctx, data.Zone, nsName, checkEDNS)
mu.Lock()
data.Results[nsName] = res
mu.Unlock()
}()
}
wg.Wait()
return data, nil
}
// First authoritative answer wins as the canonical view of this NS;
// subsequent addresses only contribute reachability/error state. Avoids
// dual-homed servers appearing twice in the drift matrix while still
// surfacing IPv4/IPv6-specific failures.
func probeNS(ctx context.Context, zone, nsName string, checkEDNS bool) *NSResult {
res := &NSResult{Name: nsName}
addrs, err := resolveHost(ctx, nsName)
if err != nil {
res.ResolveError = err.Error()
return res
}
if len(addrs) == 0 {
res.ResolveError = "no A/AAAA records"
return res
}
res.Addresses = addrs
for _, addr := range addrs {
srv := hostPort(addr, "53")
soa, aa, rtt, qerr := querySOA(ctx, "", srv, zone)
if qerr != nil {
res.appendError("UDP %s: %v", addr, qerr)
continue
}
res.UDPReachable = true
if res.LatencyMs == 0 {
res.LatencyMs = rtt.Milliseconds()
}
if aa {
res.Authoritative = true
}
if soa != nil && res.SOA == nil {
res.SOA = soa
res.Serial = soa.Serial
}
if _, _, _, terr := querySOA(ctx, "tcp", srv, zone); terr != nil {
res.appendError("TCP %s: %v", addr, terr)
} else {
res.TCPReachable = true
}
if checkEDNS {
if eerr := probeEDNS0(ctx, srv, zone); eerr != nil {
res.appendError("EDNS0 %s: %v", addr, eerr)
} else {
res.EDNSSupported = true
}
}
if nss, nerr := queryNSAt(ctx, srv, zone); nerr == nil && len(res.NSRRset) == 0 {
sort.Strings(nss)
res.NSRRset = nss
}
}
return res
}
func loadService(opts sdk.CheckerOptions) (*originService, error) {
svc, ok := sdk.GetOption[serviceMessage](opts, "service")
if !ok {
return nil, fmt.Errorf("missing 'service' option")
}
switch svc.Type {
case "", "abstract.Origin", "abstract.NSOnlyOrigin":
default:
return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type)
}
var d originService
if err := json.Unmarshal(svc.Service, &d); err != nil {
return nil, fmt.Errorf("decoding origin service: %w", err)
}
return &d, nil
}
// Falls back to the service's SOA owner name when domain_name is unset.
func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) {
if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" {
return dns.Fqdn(v), nil
}
if svc.SOA != nil && svc.SOA.Header().Name != "" {
return dns.Fqdn(svc.SOA.Header().Name), nil
}
return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)")
}
func normalizeNSList(ns []*dns.NS) []string {
out := make([]string, 0, len(ns))
for _, n := range ns {
if n == nil {
continue
}
out = append(out, strings.ToLower(dns.Fqdn(n.Ns)))
}
sort.Strings(out)
return out
}
func unionStrings(a, b []string) []string {
seen := map[string]bool{}
var out []string
for _, s := range a {
if !seen[s] {
seen[s] = true
out = append(out, s)
}
}
for _, s := range b {
if !seen[s] {
seen[s] = true
out = append(out, s)
}
}
sort.Strings(out)
return out
}
func diffStringSets(want, got []string) (missing, extra []string) {
w := map[string]bool{}
for _, v := range want {
w[strings.ToLower(strings.TrimSuffix(v, "."))] = true
}
g := map[string]bool{}
for _, v := range got {
g[strings.ToLower(strings.TrimSuffix(v, "."))] = true
}
for k := range w {
if !g[k] {
missing = append(missing, k)
}
}
for k := range g {
if !w[k] {
extra = append(extra, k)
}
}
sort.Strings(missing)
sort.Strings(extra)
return
}
// RFC 1982 serial-number arithmetic (handles wraparound).
func serialLess(a, b uint32) bool {
diff := b - a
return diff != 0 && diff < (1<<31)
}