diff --git a/checker/dns.go b/checker/dns.go index 7270da1..876fde1 100644 --- a/checker/dns.go +++ b/checker/dns.go @@ -128,24 +128,46 @@ func resolveZoneNSAddrs(ctx context.Context, zone string) ([]string, error) { return out, nil } -// queryAtAuth tries each server in order and returns the first usable answer. -// dnssec=true sets the DO bit; only the DNSSEC probes need it. +// queryAtAuth tries each server in order and returns the first definitive +// answer. Transport errors and transient failures (SERVFAIL/REFUSED) make it +// fail over to the next server so a single flaky auth server cannot decide the +// verdict; a definitive response (NOERROR/NXDOMAIN/...) is returned at once. +// If every server fails it returns the last transient response when there was +// one (so callers can still inspect the rcode), otherwise the last transport +// error. dnssec=true sets the DO bit; only the DNSSEC probes need it. func queryAtAuth(ctx context.Context, proto string, servers []string, q dns.Question, dnssec bool) (*dns.Msg, string, error) { var lastErr error + var transientMsg *dns.Msg + var transientServer string for _, s := range servers { r, err := dnsExchange(ctx, proto, s, q, false, dnssec) if err != nil { lastErr = err continue } + if isTransientRcode(r.Rcode) { + transientMsg, transientServer = r, s + continue + } return r, s, nil } + if transientMsg != nil { + return transientMsg, transientServer, nil + } if lastErr == nil { lastErr = fmt.Errorf("no servers provided") } return nil, "", lastErr } +// isTransientRcode reports whether an rcode is worth retrying against another +// auth server rather than treating as the zone's final answer. SERVFAIL and +// REFUSED are typically per-server faults (backend down, server not yet loaded +// the zone), unlike NXDOMAIN which is an authoritative negative answer. +func isTransientRcode(rcode int) bool { + return rcode == dns.RcodeServerFailure || rcode == dns.RcodeRefused +} + func rcodeText(r int) string { if s, ok := dns.RcodeToString[r]; ok { return s diff --git a/checker/dns_test.go b/checker/dns_test.go new file mode 100644 index 0000000..35fce22 --- /dev/null +++ b/checker/dns_test.go @@ -0,0 +1,91 @@ +package checker + +import ( + "context" + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestIsTransientRcode(t *testing.T) { + transient := []int{dns.RcodeServerFailure, dns.RcodeRefused} + for _, rc := range transient { + if !isTransientRcode(rc) { + t.Errorf("rcode %s should be transient", rcodeText(rc)) + } + } + final := []int{dns.RcodeSuccess, dns.RcodeNameError, dns.RcodeNotImplemented} + for _, rc := range final { + if isTransientRcode(rc) { + t.Errorf("rcode %s should not be transient", rcodeText(rc)) + } + } +} + +// startTestServer spins up a UDP DNS server that answers every query with the +// given handler, returning its address and a shutdown func. +func startTestServer(t *testing.T, handler dns.HandlerFunc) (string, func()) { + t.Helper() + mux := dns.NewServeMux() + mux.HandleFunc(".", handler) + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := &dns.Server{PacketConn: pc, Handler: mux} + go srv.ActivateAndServe() + return pc.LocalAddr().String(), func() { srv.Shutdown() } +} + +func answerWith(rcode int, withAnswer bool) dns.HandlerFunc { + return func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Rcode = rcode + if withAnswer && len(r.Question) > 0 { + rr, _ := dns.NewRR(r.Question[0].Name + " 300 IN CNAME target.example.") + if rr != nil { + m.Answer = append(m.Answer, rr) + } + } + w.WriteMsg(m) + } +} + +func TestQueryAtAuthFailsOverTransientRcode(t *testing.T) { + q := dns.Question{Name: "example.com.", Qtype: dns.TypeCNAME, Qclass: dns.ClassINET} + + t.Run("prefers definitive answer over SERVFAIL", func(t *testing.T) { + bad, stopBad := startTestServer(t, answerWith(dns.RcodeServerFailure, false)) + defer stopBad() + good, stopGood := startTestServer(t, answerWith(dns.RcodeSuccess, true)) + defer stopGood() + + r, server, err := queryAtAuth(context.Background(), "", []string{bad, good}, q, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r.Rcode != dns.RcodeSuccess { + t.Fatalf("got rcode %s, want NOERROR", rcodeText(r.Rcode)) + } + if server != good { + t.Fatalf("answered by %s, want the healthy server %s", server, good) + } + }) + + t.Run("returns transient response when every server fails", func(t *testing.T) { + s1, stop1 := startTestServer(t, answerWith(dns.RcodeServerFailure, false)) + defer stop1() + s2, stop2 := startTestServer(t, answerWith(dns.RcodeRefused, false)) + defer stop2() + + r, _, err := queryAtAuth(context.Background(), "", []string{s1, s2}, q, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !isTransientRcode(r.Rcode) { + t.Fatalf("got rcode %s, want a transient rcode preserved", rcodeText(r.Rcode)) + } + }) +}