422 lines
10 KiB
Go
422 lines
10 KiB
Go
package checker
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
|
||
sdk "git.happydns.org/checker-sdk-go/checker"
|
||
)
|
||
|
||
// Collect gathers raw DNS answers from each selected public resolver plus the
|
||
// zone's own authoritative ground-truth. It performs no judgement: rules
|
||
// derive consensus, drift, splits, latency, and DNSSEC verdicts from the
|
||
// observation.
|
||
func (p *resolverPropagationProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
|
||
svc, err := loadService(opts)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
zone, err := loadZone(opts, svc)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
includeFiltered := sdk.GetBoolOption(opts, "includeFiltered", false)
|
||
region := getStringOpt(opts, "region", "all")
|
||
transportsOpt := getStringOpt(opts, "transports", "udp")
|
||
recordTypesOpt := getStringOpt(opts, "recordTypes", "SOA,NS,A,AAAA,MX,TXT,CAA")
|
||
subdomainsOpt := getStringOpt(opts, "subdomains", "")
|
||
runTimeoutS := sdk.GetIntOption(opts, "runTimeoutSeconds", 30)
|
||
allowlistOpt := getStringOpt(opts, "resolverAllowlist", "")
|
||
|
||
// Parse options.
|
||
transports := parseCSV(transportsOpt)
|
||
if len(transports) == 0 {
|
||
transports = []string{string(TransportUDP)}
|
||
}
|
||
qtypes := parseQTypes(recordTypesOpt)
|
||
if len(qtypes) == 0 {
|
||
return nil, fmt.Errorf("no valid record types in %q", recordTypesOpt)
|
||
}
|
||
extraNames := parseCSV(subdomainsOpt)
|
||
allowlist := parseCSV(allowlistOpt)
|
||
|
||
// Build the list of owner names to probe.
|
||
names := []string{dns.Fqdn(zone)}
|
||
seenName := map[string]bool{names[0]: true}
|
||
for _, sd := range extraNames {
|
||
full := joinSubdomain(sd, zone)
|
||
if !seenName[full] {
|
||
seenName[full] = true
|
||
names = append(names, full)
|
||
}
|
||
}
|
||
|
||
resolvers := selectedResolvers(includeFiltered, region, allowlist)
|
||
|
||
data := &ResolverPropagationData{
|
||
Zone: dns.Fqdn(zone),
|
||
Names: names,
|
||
Types: qtypeNames(qtypes),
|
||
Resolvers: map[string]*ResolverView{},
|
||
RRsets: map[string]*RRsetView{},
|
||
}
|
||
if svc.SOA != nil {
|
||
data.DeclaredSerial = svc.SOA.Serial
|
||
}
|
||
|
||
// If the selection matches no resolvers, simply return the (empty)
|
||
// payload. Rules classify "no resolvers matched" as their own concern.
|
||
if len(resolvers) == 0 {
|
||
data.Stats = computeBasicStats(data)
|
||
return data, nil
|
||
}
|
||
|
||
runCtx, cancel := context.WithTimeout(ctx, time.Duration(runTimeoutS)*time.Second)
|
||
defer cancel()
|
||
|
||
started := time.Now()
|
||
|
||
// Ground truth from the zone's own authoritative servers.
|
||
expected := collectExpected(runCtx, zone, svc, names, qtypes)
|
||
|
||
for _, n := range names {
|
||
for _, qt := range qtypes {
|
||
key := rrsetKey(n, dns.TypeToString[qt])
|
||
v := &RRsetView{
|
||
Name: strings.ToLower(dns.Fqdn(n)),
|
||
Type: dns.TypeToString[qt],
|
||
}
|
||
if e, ok := expected[key]; ok {
|
||
v.Expected = e.sig
|
||
v.ExpectedRecords = e.records
|
||
}
|
||
data.RRsets[key] = v
|
||
}
|
||
}
|
||
|
||
// Fan out probes across resolvers × transports × RRsets.
|
||
type probeJob struct {
|
||
r Resolver
|
||
tr Transport
|
||
}
|
||
var jobs []probeJob
|
||
for _, r := range resolvers {
|
||
for _, tname := range transports {
|
||
tr := Transport(strings.ToLower(strings.TrimSpace(tname)))
|
||
switch tr {
|
||
case TransportUDP, TransportTCP:
|
||
jobs = append(jobs, probeJob{r: r, tr: tr})
|
||
case TransportDoT:
|
||
if r.DoTHost != "" {
|
||
jobs = append(jobs, probeJob{r: r, tr: tr})
|
||
}
|
||
case TransportDoH:
|
||
if r.DoHURL != "" {
|
||
jobs = append(jobs, probeJob{r: r, tr: tr})
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
const maxConcurrent = 32
|
||
sem := make(chan struct{}, maxConcurrent)
|
||
|
||
var wg sync.WaitGroup
|
||
var mu sync.Mutex
|
||
for _, job := range jobs {
|
||
job := job
|
||
wg.Add(1)
|
||
sem <- struct{}{}
|
||
go func() {
|
||
defer wg.Done()
|
||
defer func() { <-sem }()
|
||
|
||
rid := job.r.ID
|
||
if job.tr != TransportUDP {
|
||
rid = fmt.Sprintf("%s|%s", job.r.ID, job.tr)
|
||
}
|
||
|
||
view := &ResolverView{
|
||
ID: rid,
|
||
Name: job.r.Name,
|
||
IP: job.r.IP,
|
||
Region: job.r.Region,
|
||
Filtered: job.r.Filtered,
|
||
Transport: job.tr,
|
||
Probes: map[string]*RRProbe{},
|
||
}
|
||
|
||
for _, n := range names {
|
||
for _, qt := range qtypes {
|
||
probe := runProbe(runCtx, job.r, job.tr, n, qt)
|
||
key := rrsetKey(n, dns.TypeToString[qt])
|
||
view.Probes[key] = probe
|
||
if probe.Error == "" {
|
||
view.Reachable = true
|
||
}
|
||
}
|
||
}
|
||
|
||
mu.Lock()
|
||
data.Resolvers[rid] = view
|
||
mu.Unlock()
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
|
||
data.RunDurationMs = time.Since(started).Milliseconds()
|
||
data.Stats = computeBasicStats(data)
|
||
|
||
return data, nil
|
||
}
|
||
|
||
func runProbe(ctx context.Context, r Resolver, tr Transport, name string, qtype uint16) *RRProbe {
|
||
p := &RRProbe{Transport: tr}
|
||
|
||
res, err := queryResolver(ctx, r, tr, name, qtype)
|
||
if err != nil {
|
||
p.Error = err.Error()
|
||
return p
|
||
}
|
||
p.Rcode = rcodeToString(res.Rcode)
|
||
p.AD = res.AD
|
||
p.LatencyMs = res.Latency.Milliseconds()
|
||
|
||
if res.Rcode == dns.RcodeSuccess {
|
||
sig, recs, ttl := signatureFromRRs(res.Answer, name, qtype)
|
||
p.Signature = sig
|
||
p.Records = recs
|
||
p.MinTTL = ttl
|
||
}
|
||
return p
|
||
}
|
||
|
||
type expectedEntry struct {
|
||
sig string
|
||
records []string
|
||
}
|
||
|
||
func collectExpected(ctx context.Context, zone string, svc *originService, names []string, qtypes []uint16) map[string]*expectedEntry {
|
||
out := map[string]*expectedEntry{}
|
||
|
||
var nsHosts []string
|
||
for _, n := range svc.NameServers {
|
||
if n == nil {
|
||
continue
|
||
}
|
||
nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(n.Ns)))
|
||
}
|
||
if len(nsHosts) == 0 {
|
||
var resolver net.Resolver
|
||
nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, "."))
|
||
if err != nil {
|
||
log.Printf("collectExpected: NS lookup failed for %q: %v", zone, err)
|
||
return out
|
||
}
|
||
for _, ns := range nss {
|
||
nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(ns.Host)))
|
||
}
|
||
}
|
||
|
||
var resolver net.Resolver
|
||
var authAddrs []string
|
||
for _, ns := range nsHosts {
|
||
addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns, "."))
|
||
if err != nil {
|
||
continue
|
||
}
|
||
for _, a := range addrs {
|
||
authAddrs = append(authAddrs, net.JoinHostPort(a, "53"))
|
||
}
|
||
}
|
||
if len(authAddrs) == 0 {
|
||
return out
|
||
}
|
||
|
||
for _, n := range names {
|
||
for _, qt := range qtypes {
|
||
key := rrsetKey(n, dns.TypeToString[qt])
|
||
if e := queryAuthoritative(ctx, authAddrs, n, qt); e != nil {
|
||
out[key] = e
|
||
}
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func queryAuthoritative(ctx context.Context, servers []string, name string, qtype uint16) *expectedEntry {
|
||
q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET}
|
||
m := new(dns.Msg)
|
||
m.Id = dns.Id()
|
||
m.Question = []dns.Question{q}
|
||
m.RecursionDesired = false
|
||
m.SetEdns0(ednsUDPSize, false)
|
||
|
||
client := dns.Client{Timeout: dnsTimeout}
|
||
for _, srv := range servers {
|
||
r, _, err := client.ExchangeContext(ctx, m, srv)
|
||
if err != nil || r == nil {
|
||
continue
|
||
}
|
||
if !r.Authoritative {
|
||
continue
|
||
}
|
||
if r.Rcode != dns.RcodeSuccess {
|
||
return &expectedEntry{}
|
||
}
|
||
sig, recs, _ := signatureFromRRs(r.Answer, name, qtype)
|
||
return &expectedEntry{sig: sig, records: recs}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// computeBasicStats returns the raw rollup that Collect can produce without
|
||
// judgement: simple counts. "Agreement" (UnfilteredAgreeing) is a derived
|
||
// metric computed by deriveView once consensus has been established.
|
||
func computeBasicStats(data *ResolverPropagationData) Stats {
|
||
s := Stats{TotalResolvers: len(data.Resolvers)}
|
||
regions := map[string]bool{}
|
||
for _, rv := range data.Resolvers {
|
||
if rv.Reachable {
|
||
s.ReachableResolvers++
|
||
}
|
||
if rv.Filtered {
|
||
s.FilteredProbed++
|
||
} else {
|
||
s.UnfilteredProbed++
|
||
}
|
||
regions[rv.Region] = true
|
||
}
|
||
s.CountriesCovered = len(regions)
|
||
return s
|
||
}
|
||
|
||
func loadService(opts sdk.CheckerOptions) (*originService, error) {
|
||
svc, ok := sdk.GetOption[serviceMessage](opts, "service")
|
||
if !ok {
|
||
// Standalone / interactive use: no service was attached. Fall back
|
||
// to an empty payload; collectExpected will look up NS via the
|
||
// system resolver.
|
||
return &originService{}, nil
|
||
}
|
||
switch svc.Type {
|
||
case "", "abstract.Origin", "abstract.NSOnlyOrigin":
|
||
default:
|
||
return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type)
|
||
}
|
||
var d originService
|
||
if err := json.Unmarshal(svc.Service, &d); err != nil {
|
||
return nil, fmt.Errorf("decoding origin service: %w", err)
|
||
}
|
||
return &d, nil
|
||
}
|
||
|
||
func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) {
|
||
if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" {
|
||
return dns.Fqdn(v), nil
|
||
}
|
||
if svc.SOA != nil && svc.SOA.Header().Name != "" {
|
||
return dns.Fqdn(svc.SOA.Header().Name), nil
|
||
}
|
||
return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)")
|
||
}
|
||
|
||
func getStringOpt(opts sdk.CheckerOptions, key, dflt string) string {
|
||
if v, ok := sdk.GetOption[string](opts, key); ok && v != "" {
|
||
return v
|
||
}
|
||
return dflt
|
||
}
|
||
|
||
func parseCSV(s string) []string {
|
||
if s == "" {
|
||
return nil
|
||
}
|
||
parts := strings.Split(s, ",")
|
||
out := make([]string, 0, len(parts))
|
||
for _, p := range parts {
|
||
p = strings.TrimSpace(p)
|
||
if p != "" {
|
||
out = append(out, p)
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func parseQTypes(s string) []uint16 {
|
||
seen := map[uint16]bool{}
|
||
var out []uint16
|
||
for _, t := range parseCSV(s) {
|
||
if q, ok := dns.StringToType[strings.ToUpper(t)]; ok && !seen[q] {
|
||
seen[q] = true
|
||
out = append(out, q)
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func qtypeNames(qtypes []uint16) []string {
|
||
out := make([]string, len(qtypes))
|
||
for i, q := range qtypes {
|
||
out[i] = dns.TypeToString[q]
|
||
}
|
||
return out
|
||
}
|
||
|
||
func joinSubdomain(sd, zone string) string {
|
||
sd = strings.TrimSpace(sd)
|
||
zone = dns.Fqdn(zone)
|
||
if sd == "" || sd == "@" {
|
||
return zone
|
||
}
|
||
if strings.HasSuffix(sd, ".") {
|
||
return strings.ToLower(sd)
|
||
}
|
||
return strings.ToLower(sd + "." + zone)
|
||
}
|
||
|
||
func extractSerial(records []string) uint32 {
|
||
if len(records) == 0 {
|
||
return 0
|
||
}
|
||
fields := strings.Fields(records[0])
|
||
if len(fields) < 7 {
|
||
return 0
|
||
}
|
||
s, err := strconv.ParseUint(fields[2], 10, 32)
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
return uint32(s)
|
||
}
|
||
|
||
// Hardcoded allowlist; only these resolvers' AD bit is trustworthy.
|
||
func isValidatingResolver(id string) bool {
|
||
switch strings.SplitN(id, "|", 2)[0] {
|
||
case "cloudflare", "cloudflare-malware", "cloudflare-family",
|
||
"google", "quad9", "quad9-unfiltered",
|
||
"adguard", "adguard-unfiltered", "adguard-family",
|
||
"cleanbrowsing-family", "cleanbrowsing-adult":
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
// firstN returns a short "x, y, z (+N more)" display list.
|
||
func firstN(items []string, n int) string {
|
||
if len(items) <= n {
|
||
return strings.Join(items, ", ")
|
||
}
|
||
return strings.Join(items[:n], ", ") + fmt.Sprintf(" (+%d more)", len(items)-n)
|
||
}
|