package checker import ( "bytes" "context" "crypto/tls" "encoding/base64" "fmt" "io" "net" "net/http" "net/url" "sort" "strings" "time" "github.com/miekg/dns" ) // Slower than this, a public resolver is unreachable or too flaky to be useful. const dnsTimeout = 5 * time.Second // 4096 is the de-facto ceiling for unfragmented EDNS0 responses on the public Internet. const ednsUDPSize = 4096 // Bound DoH reads so a hostile server can't stream junk indefinitely. const maxDoHResponseBytes = 64 * 1024 // Shared so concurrent probes reuse connections and TLS state. var dohClient = &http.Client{ Timeout: dnsTimeout + 2*time.Second, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, }, TLSHandshakeTimeout: dnsTimeout, ResponseHeaderTimeout: dnsTimeout, ExpectContinueTimeout: 1 * time.Second, DisableKeepAlives: false, MaxIdleConnsPerHost: 4, }, } // Flatter than *dns.Msg so the collector stays protocol-agnostic. type queryResult struct { Rcode int Answer []dns.RR AD bool Latency time.Duration } // Forces RD=1 (recurse), CD=0 (let resolver validate DNSSEC), AD=1 (signal validation back). func queryResolver(ctx context.Context, r Resolver, tr Transport, name string, qtype uint16) (*queryResult, error) { q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET} m := new(dns.Msg) m.Id = dns.Id() m.Question = []dns.Question{q} m.RecursionDesired = true m.CheckingDisabled = false m.AuthenticatedData = true m.SetEdns0(ednsUDPSize, true) switch tr { case TransportUDP: return exchangeUDPOrTCP(ctx, m, r.IP+":53", "udp") case TransportTCP: return exchangeUDPOrTCP(ctx, m, r.IP+":53", "tcp") case TransportDoT: if r.DoTHost == "" { return nil, fmt.Errorf("no DoT endpoint for %s", r.ID) } return exchangeDoT(ctx, m, r.IP, r.DoTHost) case TransportDoH: if r.DoHURL == "" { return nil, fmt.Errorf("no DoH endpoint for %s", r.ID) } return exchangeDoH(ctx, m, r.DoHURL) default: return nil, fmt.Errorf("unknown transport %q", tr) } } func exchangeUDPOrTCP(ctx context.Context, m *dns.Msg, server, proto string) (*queryResult, error) { client := dns.Client{Net: proto, Timeout: dnsTimeout} if deadline, ok := ctx.Deadline(); ok { if d := time.Until(deadline); d > 0 && d < client.Timeout { client.Timeout = d } } r, rtt, err := client.ExchangeContext(ctx, m, server) if err != nil { return nil, err } if r == nil { return nil, fmt.Errorf("nil response from %s", server) } // Truncated UDP answers force a retry over TCP per RFC 5966. if proto == "udp" && r.Truncated { tcpClient := dns.Client{Net: "tcp", Timeout: dnsTimeout} if r2, rtt2, err2 := tcpClient.ExchangeContext(ctx, m, server); err2 == nil && r2 != nil { return &queryResult{ Rcode: r2.Rcode, Answer: r2.Answer, AD: r2.AuthenticatedData, Latency: rtt2, }, nil } } return &queryResult{ Rcode: r.Rcode, Answer: r.Answer, AD: r.AuthenticatedData, Latency: rtt, }, nil } // sni validates the certificate; the IP is what we actually dial. func exchangeDoT(ctx context.Context, m *dns.Msg, ip, sni string) (*queryResult, error) { client := dns.Client{ Net: "tcp-tls", Timeout: dnsTimeout, TLSConfig: &tls.Config{ ServerName: sni, MinVersion: tls.VersionTLS12, }, } if deadline, ok := ctx.Deadline(); ok { if d := time.Until(deadline); d > 0 && d < client.Timeout { client.Timeout = d } } r, rtt, err := client.ExchangeContext(ctx, m, net.JoinHostPort(ip, "853")) if err != nil { return nil, err } if r == nil { return nil, fmt.Errorf("nil response from %s", ip) } return &queryResult{ Rcode: r.Rcode, Answer: r.Answer, AD: r.AuthenticatedData, Latency: rtt, }, nil } // GET (per RFC 8484) so HTTP caches can merge equivalent queries. func exchangeDoH(ctx context.Context, m *dns.Msg, endpoint string) (*queryResult, error) { // Id=0 lets HTTP caches merge equivalent queries. m.Id = 0 packed, err := m.Pack() if err != nil { return nil, fmt.Errorf("packing message: %w", err) } u, err := url.Parse(endpoint) if err != nil { return nil, fmt.Errorf("invalid DoH endpoint %q: %w", endpoint, err) } q := u.Query() q.Set("dns", base64.RawURLEncoding.EncodeToString(packed)) u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { return nil, err } req.Header.Set("Accept", "application/dns-message") req.Header.Set("User-Agent", "happyDomain-checker-resolver-propagation/"+Version) start := time.Now() resp, err := dohClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("DoH HTTP %d", resp.StatusCode) } ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/dns-message") { return nil, fmt.Errorf("DoH unexpected content-type %q", ct) } var buf bytes.Buffer if _, err := io.Copy(&buf, io.LimitReader(resp.Body, maxDoHResponseBytes)); err != nil { return nil, err } latency := time.Since(start) r := new(dns.Msg) if err := r.Unpack(buf.Bytes()); err != nil { return nil, fmt.Errorf("unpacking DoH response: %w", err) } return &queryResult{ Rcode: r.Rcode, Answer: r.Answer, AD: r.AuthenticatedData, Latency: latency, }, nil } // Strips the "owner TTL class type" header from miekg's zone-file form to leave RDATA. func canonicalRR(rr dns.RR) string { if rr == nil { return "" } fields := strings.Fields(rr.String()) if len(fields) <= 4 { return "" } rdata := strings.Join(fields[4:], " ") // Lowercase so case-only drift in hostnames doesn't read as disagreement. return strings.ToLower(strings.TrimSpace(rdata)) } // Deterministic signature for cross-resolver comparison; sort-then-join keeps RRset order irrelevant. func signatureFromRRs(rrs []dns.RR, owner string, qtype uint16) (sig string, records []string, minTTL uint32) { ownerL := strings.ToLower(dns.Fqdn(owner)) for _, rr := range rrs { h := rr.Header() if h == nil { continue } if !strings.EqualFold(dns.Fqdn(h.Name), ownerL) { continue } if h.Rrtype != qtype { continue } if c := canonicalRR(rr); c != "" { records = append(records, c) if minTTL == 0 || h.Ttl < minTTL { minTTL = h.Ttl } } } sort.Strings(records) sig = strings.Join(records, "|") return sig, records, minTTL } func rcodeToString(c int) string { if s, ok := dns.RcodeToString[c]; ok { return s } return fmt.Sprintf("RCODE%d", c) }