Initial commit

This commit is contained in:
nemunaire 2026-04-26 11:49:13 +07:00
commit 2d98ed1b5d
33 changed files with 4644 additions and 0 deletions

422
checker/collect.go Normal file
View file

@ -0,0 +1,422 @@
package checker
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
// Collect gathers raw DNS answers from each selected public resolver plus the
// zone's own authoritative ground-truth. It performs no judgement: rules
// derive consensus, drift, splits, latency, and DNSSEC verdicts from the
// observation.
func (p *resolverPropagationProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
svc, err := loadService(opts)
if err != nil {
return nil, err
}
zone, err := loadZone(opts, svc)
if err != nil {
return nil, err
}
includeFiltered := sdk.GetBoolOption(opts, "includeFiltered", false)
region := getStringOpt(opts, "region", "all")
transportsOpt := getStringOpt(opts, "transports", "udp")
recordTypesOpt := getStringOpt(opts, "recordTypes", "SOA,NS,A,AAAA,MX,TXT,CAA")
subdomainsOpt := getStringOpt(opts, "subdomains", "")
runTimeoutS := sdk.GetIntOption(opts, "runTimeoutSeconds", 30)
allowlistOpt := getStringOpt(opts, "resolverAllowlist", "")
// Parse options.
transports := parseCSV(transportsOpt)
if len(transports) == 0 {
transports = []string{string(TransportUDP)}
}
qtypes := parseQTypes(recordTypesOpt)
if len(qtypes) == 0 {
return nil, fmt.Errorf("no valid record types in %q", recordTypesOpt)
}
extraNames := parseCSV(subdomainsOpt)
allowlist := parseCSV(allowlistOpt)
// Build the list of owner names to probe.
names := []string{dns.Fqdn(zone)}
seenName := map[string]bool{names[0]: true}
for _, sd := range extraNames {
full := joinSubdomain(sd, zone)
if !seenName[full] {
seenName[full] = true
names = append(names, full)
}
}
resolvers := selectedResolvers(includeFiltered, region, allowlist)
data := &ResolverPropagationData{
Zone: dns.Fqdn(zone),
Names: names,
Types: qtypeNames(qtypes),
Resolvers: map[string]*ResolverView{},
RRsets: map[string]*RRsetView{},
}
if svc.SOA != nil {
data.DeclaredSerial = svc.SOA.Serial
}
// If the selection matches no resolvers, simply return the (empty)
// payload. Rules classify "no resolvers matched" as their own concern.
if len(resolvers) == 0 {
data.Stats = computeBasicStats(data)
return data, nil
}
runCtx, cancel := context.WithTimeout(ctx, time.Duration(runTimeoutS)*time.Second)
defer cancel()
started := time.Now()
// Ground truth from the zone's own authoritative servers.
expected := collectExpected(runCtx, zone, svc, names, qtypes)
for _, n := range names {
for _, qt := range qtypes {
key := rrsetKey(n, dns.TypeToString[qt])
v := &RRsetView{
Name: strings.ToLower(dns.Fqdn(n)),
Type: dns.TypeToString[qt],
}
if e, ok := expected[key]; ok {
v.Expected = e.sig
v.ExpectedRecords = e.records
}
data.RRsets[key] = v
}
}
// Fan out probes across resolvers × transports × RRsets.
type probeJob struct {
r Resolver
tr Transport
}
var jobs []probeJob
for _, r := range resolvers {
for _, tname := range transports {
tr := Transport(strings.ToLower(strings.TrimSpace(tname)))
switch tr {
case TransportUDP, TransportTCP:
jobs = append(jobs, probeJob{r: r, tr: tr})
case TransportDoT:
if r.DoTHost != "" {
jobs = append(jobs, probeJob{r: r, tr: tr})
}
case TransportDoH:
if r.DoHURL != "" {
jobs = append(jobs, probeJob{r: r, tr: tr})
}
}
}
}
const maxConcurrent = 32
sem := make(chan struct{}, maxConcurrent)
var wg sync.WaitGroup
var mu sync.Mutex
for _, job := range jobs {
job := job
wg.Add(1)
sem <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-sem }()
rid := job.r.ID
if job.tr != TransportUDP {
rid = fmt.Sprintf("%s|%s", job.r.ID, job.tr)
}
view := &ResolverView{
ID: rid,
Name: job.r.Name,
IP: job.r.IP,
Region: job.r.Region,
Filtered: job.r.Filtered,
Transport: job.tr,
Probes: map[string]*RRProbe{},
}
for _, n := range names {
for _, qt := range qtypes {
probe := runProbe(runCtx, job.r, job.tr, n, qt)
key := rrsetKey(n, dns.TypeToString[qt])
view.Probes[key] = probe
if probe.Error == "" {
view.Reachable = true
}
}
}
mu.Lock()
data.Resolvers[rid] = view
mu.Unlock()
}()
}
wg.Wait()
data.RunDurationMs = time.Since(started).Milliseconds()
data.Stats = computeBasicStats(data)
return data, nil
}
func runProbe(ctx context.Context, r Resolver, tr Transport, name string, qtype uint16) *RRProbe {
p := &RRProbe{Transport: tr}
res, err := queryResolver(ctx, r, tr, name, qtype)
if err != nil {
p.Error = err.Error()
return p
}
p.Rcode = rcodeToString(res.Rcode)
p.AD = res.AD
p.LatencyMs = res.Latency.Milliseconds()
if res.Rcode == dns.RcodeSuccess {
sig, recs, ttl := signatureFromRRs(res.Answer, name, qtype)
p.Signature = sig
p.Records = recs
p.MinTTL = ttl
}
return p
}
type expectedEntry struct {
sig string
records []string
}
func collectExpected(ctx context.Context, zone string, svc *originService, names []string, qtypes []uint16) map[string]*expectedEntry {
out := map[string]*expectedEntry{}
var nsHosts []string
for _, n := range svc.NameServers {
if n == nil {
continue
}
nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(n.Ns)))
}
if len(nsHosts) == 0 {
var resolver net.Resolver
nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, "."))
if err != nil {
log.Printf("collectExpected: NS lookup failed for %q: %v", zone, err)
return out
}
for _, ns := range nss {
nsHosts = append(nsHosts, strings.ToLower(dns.Fqdn(ns.Host)))
}
}
var resolver net.Resolver
var authAddrs []string
for _, ns := range nsHosts {
addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns, "."))
if err != nil {
continue
}
for _, a := range addrs {
authAddrs = append(authAddrs, net.JoinHostPort(a, "53"))
}
}
if len(authAddrs) == 0 {
return out
}
for _, n := range names {
for _, qt := range qtypes {
key := rrsetKey(n, dns.TypeToString[qt])
if e := queryAuthoritative(ctx, authAddrs, n, qt); e != nil {
out[key] = e
}
}
}
return out
}
func queryAuthoritative(ctx context.Context, servers []string, name string, qtype uint16) *expectedEntry {
q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET}
m := new(dns.Msg)
m.Id = dns.Id()
m.Question = []dns.Question{q}
m.RecursionDesired = false
m.SetEdns0(ednsUDPSize, false)
client := dns.Client{Timeout: dnsTimeout}
for _, srv := range servers {
r, _, err := client.ExchangeContext(ctx, m, srv)
if err != nil || r == nil {
continue
}
if !r.Authoritative {
continue
}
if r.Rcode != dns.RcodeSuccess {
return &expectedEntry{}
}
sig, recs, _ := signatureFromRRs(r.Answer, name, qtype)
return &expectedEntry{sig: sig, records: recs}
}
return nil
}
// computeBasicStats returns the raw rollup that Collect can produce without
// judgement: simple counts. "Agreement" (UnfilteredAgreeing) is a derived
// metric computed by deriveView once consensus has been established.
func computeBasicStats(data *ResolverPropagationData) Stats {
s := Stats{TotalResolvers: len(data.Resolvers)}
regions := map[string]bool{}
for _, rv := range data.Resolvers {
if rv.Reachable {
s.ReachableResolvers++
}
if rv.Filtered {
s.FilteredProbed++
} else {
s.UnfilteredProbed++
}
regions[rv.Region] = true
}
s.CountriesCovered = len(regions)
return s
}
func loadService(opts sdk.CheckerOptions) (*originService, error) {
svc, ok := sdk.GetOption[serviceMessage](opts, "service")
if !ok {
// Standalone / interactive use: no service was attached. Fall back
// to an empty payload; collectExpected will look up NS via the
// system resolver.
return &originService{}, nil
}
switch svc.Type {
case "", "abstract.Origin", "abstract.NSOnlyOrigin":
default:
return nil, fmt.Errorf("service is %s, expected abstract.Origin or abstract.NSOnlyOrigin", svc.Type)
}
var d originService
if err := json.Unmarshal(svc.Service, &d); err != nil {
return nil, fmt.Errorf("decoding origin service: %w", err)
}
return &d, nil
}
func loadZone(opts sdk.CheckerOptions, svc *originService) (string, error) {
if v, ok := sdk.GetOption[string](opts, "domain_name"); ok && v != "" {
return dns.Fqdn(v), nil
}
if svc.SOA != nil && svc.SOA.Header().Name != "" {
return dns.Fqdn(svc.SOA.Header().Name), nil
}
return "", fmt.Errorf("no zone name provided (missing 'domain_name' option and SOA header)")
}
func getStringOpt(opts sdk.CheckerOptions, key, dflt string) string {
if v, ok := sdk.GetOption[string](opts, key); ok && v != "" {
return v
}
return dflt
}
func parseCSV(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func parseQTypes(s string) []uint16 {
seen := map[uint16]bool{}
var out []uint16
for _, t := range parseCSV(s) {
if q, ok := dns.StringToType[strings.ToUpper(t)]; ok && !seen[q] {
seen[q] = true
out = append(out, q)
}
}
return out
}
func qtypeNames(qtypes []uint16) []string {
out := make([]string, len(qtypes))
for i, q := range qtypes {
out[i] = dns.TypeToString[q]
}
return out
}
func joinSubdomain(sd, zone string) string {
sd = strings.TrimSpace(sd)
zone = dns.Fqdn(zone)
if sd == "" || sd == "@" {
return zone
}
if strings.HasSuffix(sd, ".") {
return strings.ToLower(sd)
}
return strings.ToLower(sd + "." + zone)
}
func extractSerial(records []string) uint32 {
if len(records) == 0 {
return 0
}
fields := strings.Fields(records[0])
if len(fields) < 7 {
return 0
}
s, err := strconv.ParseUint(fields[2], 10, 32)
if err != nil {
return 0
}
return uint32(s)
}
// Hardcoded allowlist; only these resolvers' AD bit is trustworthy.
func isValidatingResolver(id string) bool {
switch strings.SplitN(id, "|", 2)[0] {
case "cloudflare", "cloudflare-malware", "cloudflare-family",
"google", "quad9", "quad9-unfiltered",
"adguard", "adguard-unfiltered", "adguard-family",
"cleanbrowsing-family", "cleanbrowsing-adult":
return true
}
return false
}
// firstN returns a short "x, y, z (+N more)" display list.
func firstN(items []string, n int) string {
if len(items) <= n {
return strings.Join(items, ", ")
}
return strings.Join(items[:n], ", ") + fmt.Sprintf(" (+%d more)", len(items)-n)
}