package checker import ( "context" "fmt" "net" "strings" "time" "github.com/miekg/dns" ) const dnsTimeout = 5 * time.Second // dnsExchange sends a single query to an authoritative server (no RD). func dnsExchange(ctx context.Context, proto, server string, q dns.Question, edns bool) (*dns.Msg, error) { client := dns.Client{Net: proto, Timeout: dnsTimeout} m := new(dns.Msg) m.Id = dns.Id() m.Question = []dns.Question{q} m.RecursionDesired = false if edns { m.SetEdns0(4096, true) } if deadline, ok := ctx.Deadline(); ok { if d := time.Until(deadline); d > 0 && d < client.Timeout { client.Timeout = d } } r, _, err := client.Exchange(m, server) if err != nil { return nil, err } if r == nil { return nil, fmt.Errorf("nil response from %s", server) } return r, nil } // recursiveExchange sends a query via a recursive resolver (RD=1). Used for // fallbacks: resolving NS addresses, following chains across foreign zones. func recursiveExchange(ctx context.Context, server string, q dns.Question) (*dns.Msg, error) { client := dns.Client{Timeout: dnsTimeout} m := new(dns.Msg) m.Id = dns.Id() m.Question = []dns.Question{q} m.RecursionDesired = true m.SetEdns0(4096, true) if deadline, ok := ctx.Deadline(); ok { if d := time.Until(deadline); d > 0 && d < client.Timeout { client.Timeout = d } } r, _, err := client.Exchange(m, server) if err != nil { return nil, err } if r == nil { return nil, fmt.Errorf("nil response from %s", server) } return r, nil } // systemResolver returns the first configured resolver of the local system, // falling back to a public one if none is configured. func systemResolver() string { cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf") if err != nil || len(cfg.Servers) == 0 { return net.JoinHostPort("1.1.1.1", "53") } return net.JoinHostPort(cfg.Servers[0], cfg.Port) } // hostPort returns "host:port", correctly bracketing IPv6 literals and // stripping the trailing dot from FQDNs. func hostPort(host, port string) string { if ip := net.ParseIP(host); ip != nil && ip.To4() == nil { return "[" + host + "]:" + port } return strings.TrimSuffix(host, ".") + ":" + port } // findApex walks up the labels of fqdn until it finds a zone cut (SOA), using // the system resolver. Returns the apex FQDN and the list of "host:53" // authoritative servers for that zone. func findApex(ctx context.Context, fqdn string) (apex string, servers []string, err error) { resolver := systemResolver() labels := dns.SplitDomainName(fqdn) for i := 0; i < len(labels); i++ { candidate := dns.Fqdn(strings.Join(labels[i:], ".")) q := dns.Question{Name: candidate, Qtype: dns.TypeSOA, Qclass: dns.ClassINET} r, rerr := recursiveExchange(ctx, resolver, q) if rerr != nil { continue } if r.Rcode != dns.RcodeSuccess { continue } hasSOA := false for _, rr := range r.Answer { if _, ok := rr.(*dns.SOA); ok { hasSOA = true break } } if !hasSOA { continue } apex = candidate servers, err = resolveZoneNSAddrs(ctx, apex) if err != nil { return "", nil, err } if len(servers) == 0 { return "", nil, fmt.Errorf("apex %s has no resolvable NS", apex) } return apex, servers, nil } return "", nil, fmt.Errorf("could not locate apex of %s", fqdn) } // resolveZoneNSAddrs returns "host:53" entries for every NS of the zone. func resolveZoneNSAddrs(ctx context.Context, zone string) ([]string, error) { var resolver net.Resolver nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, ".")) if err != nil { return nil, err } var out []string for _, ns := range nss { addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns.Host, ".")) if err != nil || len(addrs) == 0 { continue } for _, a := range addrs { out = append(out, hostPort(a, "53")) } } return out, nil } // pickServer returns the first usable server from list (helper for deterministic picking). func pickServer(list []string) string { if len(list) == 0 { return "" } return list[0] } // queryAtAuth sends a query to the first reachable server of list. func queryAtAuth(ctx context.Context, servers []string, q dns.Question) (*dns.Msg, string, error) { var lastErr error for _, s := range servers { r, err := dnsExchange(ctx, "", s, q, true) if err != nil { lastErr = err continue } return r, s, nil } if lastErr == nil { lastErr = fmt.Errorf("no servers provided") } return nil, "", lastErr } // queryAtAuthTCP sends a query over TCP to the first reachable server. func queryAtAuthTCP(ctx context.Context, servers []string, q dns.Question) (*dns.Msg, string, error) { var lastErr error for _, s := range servers { r, err := dnsExchange(ctx, "tcp", s, q, true) if err != nil { lastErr = err continue } return r, s, nil } if lastErr == nil { lastErr = fmt.Errorf("no servers provided") } return nil, "", lastErr } // rcodeText returns the textual name of an rcode or a fallback string. func rcodeText(r int) string { if s, ok := dns.RcodeToString[r]; ok { return s } return fmt.Sprintf("RCODE(%d)", r) } // isSubdomain reports whether child is equal to or sits under parent. func isSubdomain(child, parent string) bool { child = strings.ToLower(dns.Fqdn(child)) parent = strings.ToLower(dns.Fqdn(parent)) return child == parent || strings.HasSuffix(child, "."+parent) } // parentOf returns the parent zone of name (one label up), or "." for TLDs. func parentOf(name string) string { labels := dns.SplitDomainName(name) if len(labels) <= 1 { return "." } return dns.Fqdn(strings.Join(labels[1:], ".")) } // lowerFQDN returns the canonical lowercase FQDN form of name. func lowerFQDN(name string) string { return strings.ToLower(dns.Fqdn(name)) }