package checker import ( "context" "net" "testing" "github.com/miekg/dns" ) func TestIsTransientRcode(t *testing.T) { transient := []int{dns.RcodeServerFailure, dns.RcodeRefused} for _, rc := range transient { if !isTransientRcode(rc) { t.Errorf("rcode %s should be transient", rcodeText(rc)) } } final := []int{dns.RcodeSuccess, dns.RcodeNameError, dns.RcodeNotImplemented} for _, rc := range final { if isTransientRcode(rc) { t.Errorf("rcode %s should not be transient", rcodeText(rc)) } } } // startTestServer spins up a UDP DNS server that answers every query with the // given handler, returning its address and a shutdown func. func startTestServer(t *testing.T, handler dns.HandlerFunc) (string, func()) { t.Helper() mux := dns.NewServeMux() mux.HandleFunc(".", handler) pc, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } srv := &dns.Server{PacketConn: pc, Handler: mux} go srv.ActivateAndServe() return pc.LocalAddr().String(), func() { srv.Shutdown() } } func answerWith(rcode int, withAnswer bool) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Rcode = rcode if withAnswer && len(r.Question) > 0 { rr, _ := dns.NewRR(r.Question[0].Name + " 300 IN CNAME target.example.") if rr != nil { m.Answer = append(m.Answer, rr) } } w.WriteMsg(m) } } func TestQueryAtAuthFailsOverTransientRcode(t *testing.T) { q := dns.Question{Name: "example.com.", Qtype: dns.TypeCNAME, Qclass: dns.ClassINET} t.Run("prefers definitive answer over SERVFAIL", func(t *testing.T) { bad, stopBad := startTestServer(t, answerWith(dns.RcodeServerFailure, false)) defer stopBad() good, stopGood := startTestServer(t, answerWith(dns.RcodeSuccess, true)) defer stopGood() r, server, err := queryAtAuth(context.Background(), "", []string{bad, good}, q, false) if err != nil { t.Fatalf("unexpected error: %v", err) } if r.Rcode != dns.RcodeSuccess { t.Fatalf("got rcode %s, want NOERROR", rcodeText(r.Rcode)) } if server != good { t.Fatalf("answered by %s, want the healthy server %s", server, good) } }) t.Run("returns transient response when every server fails", func(t *testing.T) { s1, stop1 := startTestServer(t, answerWith(dns.RcodeServerFailure, false)) defer stop1() s2, stop2 := startTestServer(t, answerWith(dns.RcodeRefused, false)) defer stop2() r, _, err := queryAtAuth(context.Background(), "", []string{s1, s2}, q, false) if err != nil { t.Fatalf("unexpected error: %v", err) } if !isTransientRcode(r.Rcode) { t.Fatalf("got rcode %s, want a transient rcode preserved", rcodeText(r.Rcode)) } }) }