checker-resolver-propagation/checker/dns.go

245 lines
6.4 KiB
Go

package checker
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sort"
"strings"
"time"
"github.com/miekg/dns"
)
// Slower than this, a public resolver is unreachable or too flaky to be useful.
const dnsTimeout = 5 * time.Second
// 4096 is the de-facto ceiling for unfragmented EDNS0 responses on the public Internet.
const ednsUDPSize = 4096
// Bound DoH reads so a hostile server can't stream junk indefinitely.
const maxDoHResponseBytes = 64 * 1024
// Shared so concurrent probes reuse connections and TLS state.
var dohClient = &http.Client{
Timeout: dnsTimeout + 2*time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
TLSHandshakeTimeout: dnsTimeout,
ResponseHeaderTimeout: dnsTimeout,
ExpectContinueTimeout: 1 * time.Second,
DisableKeepAlives: false,
MaxIdleConnsPerHost: 4,
},
}
// Flatter than *dns.Msg so the collector stays protocol-agnostic.
type queryResult struct {
Rcode int
Answer []dns.RR
AD bool
Latency time.Duration
}
// Forces RD=1 (recurse), CD=0 (let resolver validate DNSSEC), AD=1 (signal validation back).
func queryResolver(ctx context.Context, r Resolver, tr Transport, name string, qtype uint16) (*queryResult, error) {
q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET}
m := new(dns.Msg)
m.Id = dns.Id()
m.Question = []dns.Question{q}
m.RecursionDesired = true
m.CheckingDisabled = false
m.AuthenticatedData = true
m.SetEdns0(ednsUDPSize, true)
switch tr {
case TransportUDP:
return exchangeUDPOrTCP(ctx, m, r.IP+":53", "udp")
case TransportTCP:
return exchangeUDPOrTCP(ctx, m, r.IP+":53", "tcp")
case TransportDoT:
if r.DoTHost == "" {
return nil, fmt.Errorf("no DoT endpoint for %s", r.ID)
}
return exchangeDoT(ctx, m, r.IP, r.DoTHost)
case TransportDoH:
if r.DoHURL == "" {
return nil, fmt.Errorf("no DoH endpoint for %s", r.ID)
}
return exchangeDoH(ctx, m, r.DoHURL)
default:
return nil, fmt.Errorf("unknown transport %q", tr)
}
}
func exchangeUDPOrTCP(ctx context.Context, m *dns.Msg, server, proto string) (*queryResult, error) {
client := dns.Client{Net: proto, Timeout: dnsTimeout}
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 && d < client.Timeout {
client.Timeout = d
}
}
r, rtt, err := client.ExchangeContext(ctx, m, server)
if err != nil {
return nil, err
}
if r == nil {
return nil, fmt.Errorf("nil response from %s", server)
}
// Truncated UDP answers force a retry over TCP per RFC 5966.
if proto == "udp" && r.Truncated {
tcpClient := dns.Client{Net: "tcp", Timeout: dnsTimeout}
if r2, rtt2, err2 := tcpClient.ExchangeContext(ctx, m, server); err2 == nil && r2 != nil {
return &queryResult{
Rcode: r2.Rcode, Answer: r2.Answer,
AD: r2.AuthenticatedData, Latency: rtt2,
}, nil
}
}
return &queryResult{
Rcode: r.Rcode, Answer: r.Answer,
AD: r.AuthenticatedData, Latency: rtt,
}, nil
}
// sni validates the certificate; the IP is what we actually dial.
func exchangeDoT(ctx context.Context, m *dns.Msg, ip, sni string) (*queryResult, error) {
client := dns.Client{
Net: "tcp-tls",
Timeout: dnsTimeout,
TLSConfig: &tls.Config{
ServerName: sni,
MinVersion: tls.VersionTLS12,
},
}
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 && d < client.Timeout {
client.Timeout = d
}
}
r, rtt, err := client.ExchangeContext(ctx, m, net.JoinHostPort(ip, "853"))
if err != nil {
return nil, err
}
if r == nil {
return nil, fmt.Errorf("nil response from %s", ip)
}
return &queryResult{
Rcode: r.Rcode, Answer: r.Answer,
AD: r.AuthenticatedData, Latency: rtt,
}, nil
}
// GET (per RFC 8484) so HTTP caches can merge equivalent queries.
func exchangeDoH(ctx context.Context, m *dns.Msg, endpoint string) (*queryResult, error) {
// Id=0 lets HTTP caches merge equivalent queries.
m.Id = 0
packed, err := m.Pack()
if err != nil {
return nil, fmt.Errorf("packing message: %w", err)
}
u, err := url.Parse(endpoint)
if err != nil {
return nil, fmt.Errorf("invalid DoH endpoint %q: %w", endpoint, err)
}
q := u.Query()
q.Set("dns", base64.RawURLEncoding.EncodeToString(packed))
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/dns-message")
req.Header.Set("User-Agent", "happyDomain-checker-resolver-propagation/"+Version)
start := time.Now()
resp, err := dohClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH HTTP %d", resp.StatusCode)
}
ct := resp.Header.Get("Content-Type")
if !strings.HasPrefix(ct, "application/dns-message") {
return nil, fmt.Errorf("DoH unexpected content-type %q", ct)
}
var buf bytes.Buffer
if _, err := io.Copy(&buf, io.LimitReader(resp.Body, maxDoHResponseBytes)); err != nil {
return nil, err
}
latency := time.Since(start)
r := new(dns.Msg)
if err := r.Unpack(buf.Bytes()); err != nil {
return nil, fmt.Errorf("unpacking DoH response: %w", err)
}
return &queryResult{
Rcode: r.Rcode, Answer: r.Answer,
AD: r.AuthenticatedData, Latency: latency,
}, nil
}
// Strips the "owner TTL class type" header from miekg's zone-file form to leave RDATA.
func canonicalRR(rr dns.RR) string {
if rr == nil {
return ""
}
fields := strings.Fields(rr.String())
if len(fields) <= 4 {
return ""
}
rdata := strings.Join(fields[4:], " ")
// Lowercase so case-only drift in hostnames doesn't read as disagreement.
return strings.ToLower(strings.TrimSpace(rdata))
}
// Deterministic signature for cross-resolver comparison; sort-then-join keeps RRset order irrelevant.
func signatureFromRRs(rrs []dns.RR, owner string, qtype uint16) (sig string, records []string, minTTL uint32) {
ownerL := strings.ToLower(dns.Fqdn(owner))
for _, rr := range rrs {
h := rr.Header()
if h == nil {
continue
}
if !strings.EqualFold(dns.Fqdn(h.Name), ownerL) {
continue
}
if h.Rrtype != qtype {
continue
}
if c := canonicalRR(rr); c != "" {
records = append(records, c)
if minTTL == 0 || h.Ttl < minTTL {
minTTL = h.Ttl
}
}
}
sort.Strings(records)
sig = strings.Join(records, "|")
return sig, records, minTTL
}
func rcodeToString(c int) string {
if s, ok := dns.RcodeToString[c]; ok {
return s
}
return fmt.Sprintf("RCODE%d", c)
}