checker-alias/checker/dns.go

221 lines
5.8 KiB
Go

// This file is part of the happyDomain (R) project.
// Copyright (c) 2026 happyDomain
// Authors: Pierre-Olivier Mercier, et al.
package checker
import (
"context"
"fmt"
"net"
"strings"
"time"
"github.com/miekg/dns"
)
const dnsTimeout = 5 * time.Second
// dnsExchange sends a single query to an authoritative server (no RD).
func dnsExchange(ctx context.Context, proto, server string, q dns.Question, edns bool) (*dns.Msg, error) {
client := dns.Client{Net: proto, Timeout: dnsTimeout}
m := new(dns.Msg)
m.Id = dns.Id()
m.Question = []dns.Question{q}
m.RecursionDesired = false
if edns {
m.SetEdns0(4096, true)
}
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 && d < client.Timeout {
client.Timeout = d
}
}
r, _, err := client.Exchange(m, server)
if err != nil {
return nil, err
}
if r == nil {
return nil, fmt.Errorf("nil response from %s", server)
}
return r, nil
}
// recursiveExchange sends a query via a recursive resolver (RD=1). Used for
// fallbacks: resolving NS addresses, following chains across foreign zones.
func recursiveExchange(ctx context.Context, server string, q dns.Question) (*dns.Msg, error) {
client := dns.Client{Timeout: dnsTimeout}
m := new(dns.Msg)
m.Id = dns.Id()
m.Question = []dns.Question{q}
m.RecursionDesired = true
m.SetEdns0(4096, true)
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 && d < client.Timeout {
client.Timeout = d
}
}
r, _, err := client.Exchange(m, server)
if err != nil {
return nil, err
}
if r == nil {
return nil, fmt.Errorf("nil response from %s", server)
}
return r, nil
}
// systemResolver returns the first configured resolver of the local system,
// falling back to a public one if none is configured.
func systemResolver() string {
cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil || len(cfg.Servers) == 0 {
return net.JoinHostPort("1.1.1.1", "53")
}
return net.JoinHostPort(cfg.Servers[0], cfg.Port)
}
// hostPort returns "host:port", correctly bracketing IPv6 literals and
// stripping the trailing dot from FQDNs.
func hostPort(host, port string) string {
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
return "[" + host + "]:" + port
}
return strings.TrimSuffix(host, ".") + ":" + port
}
// findApex walks up the labels of fqdn until it finds a zone cut (SOA), using
// the system resolver. Returns the apex FQDN and the list of "host:53"
// authoritative servers for that zone.
func findApex(ctx context.Context, fqdn string) (apex string, servers []string, err error) {
resolver := systemResolver()
labels := dns.SplitDomainName(fqdn)
for i := 0; i < len(labels); i++ {
candidate := dns.Fqdn(strings.Join(labels[i:], "."))
q := dns.Question{Name: candidate, Qtype: dns.TypeSOA, Qclass: dns.ClassINET}
r, rerr := recursiveExchange(ctx, resolver, q)
if rerr != nil {
continue
}
if r.Rcode != dns.RcodeSuccess {
continue
}
hasSOA := false
for _, rr := range r.Answer {
if _, ok := rr.(*dns.SOA); ok {
hasSOA = true
break
}
}
if !hasSOA {
continue
}
apex = candidate
servers, err = resolveZoneNSAddrs(ctx, apex)
if err != nil {
return "", nil, err
}
if len(servers) == 0 {
return "", nil, fmt.Errorf("apex %s has no resolvable NS", apex)
}
return apex, servers, nil
}
return "", nil, fmt.Errorf("could not locate apex of %s", fqdn)
}
// resolveZoneNSAddrs returns "host:53" entries for every NS of the zone.
func resolveZoneNSAddrs(ctx context.Context, zone string) ([]string, error) {
var resolver net.Resolver
nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, "."))
if err != nil {
return nil, err
}
var out []string
for _, ns := range nss {
addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns.Host, "."))
if err != nil || len(addrs) == 0 {
continue
}
for _, a := range addrs {
out = append(out, hostPort(a, "53"))
}
}
return out, nil
}
// pickServer returns the first usable server from list (helper for deterministic picking).
func pickServer(list []string) string {
if len(list) == 0 {
return ""
}
return list[0]
}
// queryAtAuth sends a query to the first reachable server of list.
func queryAtAuth(ctx context.Context, servers []string, q dns.Question) (*dns.Msg, string, error) {
var lastErr error
for _, s := range servers {
r, err := dnsExchange(ctx, "", s, q, true)
if err != nil {
lastErr = err
continue
}
return r, s, nil
}
if lastErr == nil {
lastErr = fmt.Errorf("no servers provided")
}
return nil, "", lastErr
}
// queryAtAuthTCP sends a query over TCP to the first reachable server.
func queryAtAuthTCP(ctx context.Context, servers []string, q dns.Question) (*dns.Msg, string, error) {
var lastErr error
for _, s := range servers {
r, err := dnsExchange(ctx, "tcp", s, q, true)
if err != nil {
lastErr = err
continue
}
return r, s, nil
}
if lastErr == nil {
lastErr = fmt.Errorf("no servers provided")
}
return nil, "", lastErr
}
// rcodeText returns the textual name of an rcode or a fallback string.
func rcodeText(r int) string {
if s, ok := dns.RcodeToString[r]; ok {
return s
}
return fmt.Sprintf("RCODE(%d)", r)
}
// isSubdomain reports whether child is equal to or sits under parent.
func isSubdomain(child, parent string) bool {
child = strings.ToLower(dns.Fqdn(child))
parent = strings.ToLower(dns.Fqdn(parent))
return child == parent || strings.HasSuffix(child, "."+parent)
}
// parentOf returns the parent zone of name (one label up), or "." for TLDs.
func parentOf(name string) string {
labels := dns.SplitDomainName(name)
if len(labels) <= 1 {
return "."
}
return dns.Fqdn(strings.Join(labels[1:], "."))
}
// lowerFQDN returns the canonical lowercase FQDN form of name.
func lowerFQDN(name string) string {
return strings.ToLower(dns.Fqdn(name))
}