157 lines
4.3 KiB
Go
157 lines
4.3 KiB
Go
package checker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const dnsTimeout = 5 * time.Second
|
|
|
|
// dnsExchange sends a single query to an authoritative server (no RD).
|
|
func dnsExchange(ctx context.Context, proto, server string, q dns.Question, rd, edns bool) (*dns.Msg, error) {
|
|
client := dns.Client{Net: proto, Timeout: dnsTimeout}
|
|
|
|
m := new(dns.Msg)
|
|
m.Id = dns.Id()
|
|
m.Question = []dns.Question{q}
|
|
m.RecursionDesired = rd
|
|
if edns {
|
|
m.SetEdns0(4096, true)
|
|
}
|
|
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
if d := time.Until(deadline); d > 0 && d < client.Timeout {
|
|
client.Timeout = d
|
|
}
|
|
}
|
|
|
|
r, _, err := client.Exchange(m, server)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if r == nil {
|
|
return nil, fmt.Errorf("nil response from %s", server)
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
// recursiveExchange sends a query via a recursive resolver (RD=1). Used for
|
|
// fallbacks: resolving NS addresses, following chains across foreign zones.
|
|
func recursiveExchange(ctx context.Context, server string, q dns.Question) (*dns.Msg, error) {
|
|
return dnsExchange(ctx, "", server, q, true, true)
|
|
}
|
|
|
|
// systemResolver returns the first configured resolver of the local system,
|
|
// falling back to a public one if none is configured.
|
|
func systemResolver() string {
|
|
cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf")
|
|
if err != nil || len(cfg.Servers) == 0 {
|
|
return net.JoinHostPort("1.1.1.1", "53")
|
|
}
|
|
return net.JoinHostPort(cfg.Servers[0], cfg.Port)
|
|
}
|
|
|
|
// hostPort returns "host:port", stripping the trailing dot from FQDNs.
|
|
func hostPort(host, port string) string {
|
|
return net.JoinHostPort(strings.TrimSuffix(host, "."), port)
|
|
}
|
|
|
|
// findApex walks up the labels of fqdn until it finds a zone cut (SOA), using
|
|
// the system resolver. Returns the apex FQDN and the list of "host:53"
|
|
// authoritative servers for that zone.
|
|
func findApex(ctx context.Context, fqdn, resolver string) (apex string, servers []string, err error) {
|
|
labels := dns.SplitDomainName(fqdn)
|
|
for i := range labels {
|
|
candidate := dns.Fqdn(strings.Join(labels[i:], "."))
|
|
q := dns.Question{Name: candidate, Qtype: dns.TypeSOA, Qclass: dns.ClassINET}
|
|
r, rerr := recursiveExchange(ctx, resolver, q)
|
|
if rerr != nil {
|
|
continue
|
|
}
|
|
if r.Rcode != dns.RcodeSuccess {
|
|
continue
|
|
}
|
|
hasSOA := false
|
|
for _, rr := range r.Answer {
|
|
if _, ok := rr.(*dns.SOA); ok {
|
|
hasSOA = true
|
|
break
|
|
}
|
|
}
|
|
if !hasSOA {
|
|
continue
|
|
}
|
|
apex = candidate
|
|
servers, err = resolveZoneNSAddrs(ctx, apex)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
if len(servers) == 0 {
|
|
return "", nil, fmt.Errorf("apex %s has no resolvable NS", apex)
|
|
}
|
|
return apex, servers, nil
|
|
}
|
|
return "", nil, fmt.Errorf("could not locate apex of %s", fqdn)
|
|
}
|
|
|
|
// resolveZoneNSAddrs returns "host:53" entries for every NS of the zone.
|
|
func resolveZoneNSAddrs(ctx context.Context, zone string) ([]string, error) {
|
|
var resolver net.Resolver
|
|
nss, err := resolver.LookupNS(ctx, strings.TrimSuffix(zone, "."))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var out []string
|
|
for _, ns := range nss {
|
|
addrs, err := resolver.LookupHost(ctx, strings.TrimSuffix(ns.Host, "."))
|
|
if err != nil || len(addrs) == 0 {
|
|
continue
|
|
}
|
|
for _, a := range addrs {
|
|
out = append(out, hostPort(a, "53"))
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// queryAtAuth sends a query to the first reachable server of list.
|
|
func queryAtAuth(ctx context.Context, proto string, servers []string, q dns.Question) (*dns.Msg, string, error) {
|
|
var lastErr error
|
|
for _, s := range servers {
|
|
r, err := dnsExchange(ctx, proto, s, q, false, true)
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
return r, s, nil
|
|
}
|
|
if lastErr == nil {
|
|
lastErr = fmt.Errorf("no servers provided")
|
|
}
|
|
return nil, "", lastErr
|
|
}
|
|
|
|
// rcodeText returns the textual name of an rcode or a fallback string.
|
|
func rcodeText(r int) string {
|
|
if s, ok := dns.RcodeToString[r]; ok {
|
|
return s
|
|
}
|
|
return fmt.Sprintf("RCODE(%d)", r)
|
|
}
|
|
|
|
// isSubdomain reports whether child is equal to or sits under parent.
|
|
func isSubdomain(child, parent string) bool {
|
|
child = strings.ToLower(dns.Fqdn(child))
|
|
parent = strings.ToLower(dns.Fqdn(parent))
|
|
return child == parent || strings.HasSuffix(child, "."+parent)
|
|
}
|
|
|
|
// lowerFQDN returns the canonical lowercase FQDN form of name.
|
|
func lowerFQDN(name string) string {
|
|
return strings.ToLower(dns.Fqdn(name))
|
|
}
|