package checker import ( "context" "net" "strings" "sync" "testing" "time" "github.com/miekg/dns" ) func mustRR(t *testing.T, s string) dns.RR { t.Helper() rr, err := dns.NewRR(s) if err != nil { t.Fatalf("dns.NewRR(%q): %v", s, err) } return rr } func TestCanonicalRR(t *testing.T) { if got := canonicalRR(nil); got != "" { t.Errorf("nil RR: want empty, got %q", got) } cases := []struct { rr string want string }{ {"example.com. 300 IN A 192.0.2.1", "192.0.2.1"}, {"Example.Com. 300 IN NS Ns1.Example.Com.", "ns1.example.com."}, {"example.com. 60 IN MX 10 mail.example.com.", "10 mail.example.com."}, {"example.com. 30 IN TXT \"v=spf1 -all\"", "\"v=spf1 -all\""}, } for _, c := range cases { if got := canonicalRR(mustRR(t, c.rr)); got != c.want { t.Errorf("canonicalRR(%q) = %q, want %q", c.rr, got, c.want) } } } func TestSignatureFromRRs(t *testing.T) { rrs := []dns.RR{ mustRR(t, "example.com. 300 IN A 192.0.2.2"), mustRR(t, "example.com. 60 IN A 192.0.2.1"), mustRR(t, "example.com. 300 IN AAAA 2001:db8::1"), // wrong type mustRR(t, "other.example.com. 300 IN A 198.51.100.1"), // wrong owner } sig, recs, ttl := signatureFromRRs(rrs, "example.com", dns.TypeA) if sig != "192.0.2.1|192.0.2.2" { t.Errorf("sig = %q", sig) } if len(recs) != 2 || recs[0] != "192.0.2.1" || recs[1] != "192.0.2.2" { t.Errorf("records = %v", recs) } if ttl != 60 { t.Errorf("minTTL = %d, want 60", ttl) } // Owner case-insensitivity. sig2, _, _ := signatureFromRRs(rrs, "EXAMPLE.com.", dns.TypeA) if sig2 != sig { t.Errorf("owner case sensitivity: %q vs %q", sig2, sig) } // Empty input. if s, r, ttl := signatureFromRRs(nil, "x", dns.TypeA); s != "" || r != nil || ttl != 0 { t.Errorf("empty input: %q %v %d", s, r, ttl) } } func TestSignatureDeterministic(t *testing.T) { a := []dns.RR{ mustRR(t, "x. 30 IN A 1.1.1.1"), mustRR(t, "x. 30 IN A 2.2.2.2"), } b := []dns.RR{ mustRR(t, "x. 30 IN A 2.2.2.2"), mustRR(t, "x. 30 IN A 1.1.1.1"), } sa, _, _ := signatureFromRRs(a, "x", dns.TypeA) sb, _, _ := signatureFromRRs(b, "x", dns.TypeA) if sa != sb { t.Errorf("ordering changed sig: %q vs %q", sa, sb) } } func TestRcodeToString(t *testing.T) { cases := []struct { in int want string }{ {dns.RcodeSuccess, "NOERROR"}, {dns.RcodeNameError, "NXDOMAIN"}, {dns.RcodeServerFailure, "SERVFAIL"}, {42, "RCODE42"}, } for _, c := range cases { if got := rcodeToString(c.in); got != c.want { t.Errorf("rcodeToString(%d) = %q, want %q", c.in, got, c.want) } } } // startUDPServer brings up a tiny miekg/dns UDP server bound to a free port, // returning its address and a stop func. The handler is called for every // query and decides what to write back. func startUDPServer(t *testing.T, handler dns.HandlerFunc) (string, func()) { t.Helper() pc, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } srv := &dns.Server{PacketConn: pc, Handler: handler} done := make(chan struct{}) go func() { _ = srv.ActivateAndServe() close(done) }() // give the server a moment time.Sleep(20 * time.Millisecond) return pc.LocalAddr().String(), func() { _ = srv.Shutdown() <-done } } func TestExchangeUDPOrTCP_Success(t *testing.T) { addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) { resp := new(dns.Msg) resp.SetReply(m) resp.Authoritative = true resp.Answer = []dns.RR{mustRR(t, "example.com. 60 IN A 192.0.2.10")} _ = w.WriteMsg(resp) }) defer stop() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) res, err := exchangeUDPOrTCP(ctx, m, addr, "udp") if err != nil { t.Fatalf("exchange: %v", err) } if res.Rcode != dns.RcodeSuccess { t.Errorf("rcode = %d", res.Rcode) } if len(res.Answer) != 1 { t.Fatalf("answers: %v", res.Answer) } } func TestQueryResolver_UnknownTransport(t *testing.T) { _, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, Transport("xyz"), "x.", dns.TypeA) if err == nil || !strings.Contains(err.Error(), "unknown transport") { t.Errorf("want unknown transport error, got %v", err) } } func TestQueryResolver_MissingDoTEndpoint(t *testing.T) { _, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoT, "x.", dns.TypeA) if err == nil || !strings.Contains(err.Error(), "no DoT endpoint") { t.Errorf("want missing DoT err, got %v", err) } } func TestQueryResolver_MissingDoHEndpoint(t *testing.T) { _, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoH, "x.", dns.TypeA) if err == nil || !strings.Contains(err.Error(), "no DoH endpoint") { t.Errorf("want missing DoH err, got %v", err) } } func TestRunProbe_TransportError(t *testing.T) { // Missing DoT host on the resolver: queryResolver returns an error, // runProbe converts it into RRProbe.Error. p := runProbe(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoT, "ex.", dns.TypeA) if p.Error == "" { t.Errorf("expected error for missing DoT host") } if p.Transport != TransportDoT { t.Errorf("transport = %v", p.Transport) } } func TestQueryAuthoritative(t *testing.T) { addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) { resp := new(dns.Msg) resp.SetReply(m) resp.Authoritative = true resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 5.6.7.8")} _ = w.WriteMsg(resp) }) defer stop() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA) if e == nil { t.Fatal("nil entry") } if e.sig != "5.6.7.8" { t.Errorf("sig = %q", e.sig) } } func TestQueryAuthoritative_NotAuthoritative(t *testing.T) { addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) { resp := new(dns.Msg) resp.SetReply(m) resp.Authoritative = false resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 5.6.7.8")} _ = w.WriteMsg(resp) }) defer stop() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA); e != nil { t.Errorf("non-authoritative answer should be ignored, got %+v", e) } } func TestQueryAuthoritative_NXDOMAIN(t *testing.T) { addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) { resp := new(dns.Msg) resp.SetReply(m) resp.Authoritative = true resp.Rcode = dns.RcodeNameError _ = w.WriteMsg(resp) }) defer stop() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA) if e == nil { t.Fatal("want non-nil entry for NXDOMAIN") } if e.sig != "" { t.Errorf("NXDOMAIN should give empty sig: %q", e.sig) } } func TestExchangeUDP_TruncationFallsBackToTCP(t *testing.T) { // UDP returns truncated; we also start a TCP listener that returns the full // answer. miekg/dns ServeMux supports both via a single Server, but we // keep it explicit here. pcUDP, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("udp listen: %v", err) } defer pcUDP.Close() addr := pcUDP.LocalAddr().String() host, port, err := net.SplitHostPort(addr) if err != nil { t.Fatalf("split: %v", err) } // TCP needs to share the same port; bind a TCP listener on it. tcpL, err := net.Listen("tcp", net.JoinHostPort(host, port)) if err != nil { t.Fatalf("tcp listen: %v", err) } defer tcpL.Close() udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { resp := new(dns.Msg) resp.SetReply(m) resp.Truncated = true _ = w.WriteMsg(resp) }) tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { resp := new(dns.Msg) resp.SetReply(m) resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 1.2.3.4")} _ = w.WriteMsg(resp) }) udpSrv := &dns.Server{PacketConn: pcUDP, Handler: udpHandler} tcpSrv := &dns.Server{Listener: tcpL, Handler: tcpHandler} var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done(); _ = udpSrv.ActivateAndServe() }() go func() { defer wg.Done(); _ = tcpSrv.ActivateAndServe() }() defer func() { _ = udpSrv.Shutdown() _ = tcpSrv.Shutdown() wg.Wait() }() time.Sleep(30 * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() m := new(dns.Msg) m.SetQuestion("ex.", dns.TypeA) res, err := exchangeUDPOrTCP(ctx, m, addr, "udp") if err != nil { t.Fatalf("exchange: %v", err) } if len(res.Answer) != 1 { t.Fatalf("expected TCP fallback to populate answer, got %v", res.Answer) } }