checker-dnssec/checker/dns.go

153 lines
4.3 KiB
Go

package checker
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
const dnsTimeout = 5 * time.Second
// dnsExchange sends a single query against a host:port server.
// rd controls the RD bit (set false when querying an authoritative server),
// dnssec controls the DO bit so the server returns RRSIG / NSEC[3] records.
func dnsExchange(ctx context.Context, proto, server string, q dns.Question, rd, dnssec bool) (*dns.Msg, error) {
client := dns.Client{Net: proto, Timeout: dnsTimeout}
m := new(dns.Msg)
m.Id = dns.Id()
m.Question = []dns.Question{q}
m.RecursionDesired = rd
m.SetEdns0(4096, dnssec)
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 && d < client.Timeout {
client.Timeout = d
}
}
r, _, err := client.Exchange(m, server)
if err != nil {
return nil, err
}
if r == nil {
return nil, fmt.Errorf("nil response from %s", server)
}
return r, nil
}
func recursiveExchange(ctx context.Context, server string, q dns.Question, dnssec bool) (*dns.Msg, error) {
return dnsExchange(ctx, "", server, q, true, dnssec)
}
func systemResolver() string {
cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil || len(cfg.Servers) == 0 {
return net.JoinHostPort("1.1.1.1", "53")
}
return net.JoinHostPort(cfg.Servers[0], cfg.Port)
}
func hostPort(host, port string) string {
return net.JoinHostPort(strings.TrimSuffix(host, "."), port)
}
func lowerFQDN(name string) string {
return strings.ToLower(dns.Fqdn(name))
}
// resolveAuthNS returns "host:port" addresses for every authoritative NS of
// zone, asking the bootstrap resolver. The list is deduplicated and sorted
// only by NS host order so the per-server section of the report is stable.
// Per-host lookup failures are returned as nsErrors so the caller can surface
// them without aborting the whole collection.
func resolveAuthNS(ctx context.Context, zone, resolver string) (hosts []string, addrs []string, nsErrors []string, err error) {
q := dns.Question{Name: dns.Fqdn(zone), Qtype: dns.TypeNS, Qclass: dns.ClassINET}
r, err := recursiveExchange(ctx, resolver, q, false)
if err != nil {
return nil, nil, nil, fmt.Errorf("NS lookup for %s: %w", zone, err)
}
if r.Rcode != dns.RcodeSuccess {
return nil, nil, nil, fmt.Errorf("NS lookup for %s: rcode %s", zone, dns.RcodeToString[r.Rcode])
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
hosts = append(hosts, strings.ToLower(strings.TrimSuffix(ns.Ns, ".")))
}
}
if len(hosts) == 0 {
return nil, nil, nil, fmt.Errorf("no NS records for %s", zone)
}
results := make([][]string, len(hosts))
errs := make([]string, len(hosts))
var wg sync.WaitGroup
wg.Add(len(hosts))
for i, host := range hosts {
go func() {
defer wg.Done()
a, err := net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
errs[i] = fmt.Sprintf("address lookup for %s: %v", host, err)
return
}
out := make([]string, 0, len(a))
for _, ip := range a {
out = append(out, hostPort(ip, "53"))
}
results[i] = out
}()
}
wg.Wait()
seen := map[string]struct{}{}
for _, batch := range results {
for _, a := range batch {
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
addrs = append(addrs, a)
}
}
for _, e := range errs {
if e != "" {
nsErrors = append(nsErrors, e)
}
}
return hosts, addrs, nsErrors, nil
}
// hasParentDS asks the bootstrap resolver whether the parent zone publishes
// a DS for zone. Failures are reported as "false, nil" because absence-of-
// evidence is the practical fallback when the network is glitchy.
func hasParentDS(ctx context.Context, zone, resolver string) bool {
q := dns.Question{Name: dns.Fqdn(zone), Qtype: dns.TypeDS, Qclass: dns.ClassINET}
r, err := recursiveExchange(ctx, resolver, q, true)
if err != nil || r == nil || r.Rcode != dns.RcodeSuccess {
return false
}
for _, rr := range r.Answer {
if _, ok := rr.(*dns.DS); ok {
return true
}
}
return false
}
// randomLabel returns a 32-hex-char label used as the leftmost component of
// the NXDOMAIN probe name. 32 hex chars = 128 bits of entropy: collision
// with an existing wildcard or zone name is statistically impossible.
func randomLabel() string {
var b [16]byte
_, _ = rand.Read(b[:])
return hex.EncodeToString(b[:])
}