checker-resolver-propagation/checker/dns_test.go

305 lines
8.5 KiB
Go

package checker
import (
"context"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
)
func mustRR(t *testing.T, s string) dns.RR {
t.Helper()
rr, err := dns.NewRR(s)
if err != nil {
t.Fatalf("dns.NewRR(%q): %v", s, err)
}
return rr
}
func TestCanonicalRR(t *testing.T) {
if got := canonicalRR(nil); got != "" {
t.Errorf("nil RR: want empty, got %q", got)
}
cases := []struct {
rr string
want string
}{
{"example.com. 300 IN A 192.0.2.1", "192.0.2.1"},
{"Example.Com. 300 IN NS Ns1.Example.Com.", "ns1.example.com."},
{"example.com. 60 IN MX 10 mail.example.com.", "10 mail.example.com."},
{"example.com. 30 IN TXT \"v=spf1 -all\"", "\"v=spf1 -all\""},
}
for _, c := range cases {
if got := canonicalRR(mustRR(t, c.rr)); got != c.want {
t.Errorf("canonicalRR(%q) = %q, want %q", c.rr, got, c.want)
}
}
}
func TestSignatureFromRRs(t *testing.T) {
rrs := []dns.RR{
mustRR(t, "example.com. 300 IN A 192.0.2.2"),
mustRR(t, "example.com. 60 IN A 192.0.2.1"),
mustRR(t, "example.com. 300 IN AAAA 2001:db8::1"), // wrong type
mustRR(t, "other.example.com. 300 IN A 198.51.100.1"), // wrong owner
}
sig, recs, ttl := signatureFromRRs(rrs, "example.com", dns.TypeA)
if sig != "192.0.2.1|192.0.2.2" {
t.Errorf("sig = %q", sig)
}
if len(recs) != 2 || recs[0] != "192.0.2.1" || recs[1] != "192.0.2.2" {
t.Errorf("records = %v", recs)
}
if ttl != 60 {
t.Errorf("minTTL = %d, want 60", ttl)
}
// Owner case-insensitivity.
sig2, _, _ := signatureFromRRs(rrs, "EXAMPLE.com.", dns.TypeA)
if sig2 != sig {
t.Errorf("owner case sensitivity: %q vs %q", sig2, sig)
}
// Empty input.
if s, r, ttl := signatureFromRRs(nil, "x", dns.TypeA); s != "" || r != nil || ttl != 0 {
t.Errorf("empty input: %q %v %d", s, r, ttl)
}
}
func TestSignatureDeterministic(t *testing.T) {
a := []dns.RR{
mustRR(t, "x. 30 IN A 1.1.1.1"),
mustRR(t, "x. 30 IN A 2.2.2.2"),
}
b := []dns.RR{
mustRR(t, "x. 30 IN A 2.2.2.2"),
mustRR(t, "x. 30 IN A 1.1.1.1"),
}
sa, _, _ := signatureFromRRs(a, "x", dns.TypeA)
sb, _, _ := signatureFromRRs(b, "x", dns.TypeA)
if sa != sb {
t.Errorf("ordering changed sig: %q vs %q", sa, sb)
}
}
func TestRcodeToString(t *testing.T) {
cases := []struct {
in int
want string
}{
{dns.RcodeSuccess, "NOERROR"},
{dns.RcodeNameError, "NXDOMAIN"},
{dns.RcodeServerFailure, "SERVFAIL"},
{42, "RCODE42"},
}
for _, c := range cases {
if got := rcodeToString(c.in); got != c.want {
t.Errorf("rcodeToString(%d) = %q, want %q", c.in, got, c.want)
}
}
}
// startUDPServer brings up a tiny miekg/dns UDP server bound to a free port,
// returning its address and a stop func. The handler is called for every
// query and decides what to write back.
func startUDPServer(t *testing.T, handler dns.HandlerFunc) (string, func()) {
t.Helper()
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
srv := &dns.Server{PacketConn: pc, Handler: handler}
done := make(chan struct{})
go func() {
_ = srv.ActivateAndServe()
close(done)
}()
// give the server a moment
time.Sleep(20 * time.Millisecond)
return pc.LocalAddr().String(), func() {
_ = srv.Shutdown()
<-done
}
}
func TestExchangeUDPOrTCP_Success(t *testing.T) {
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(m)
resp.Authoritative = true
resp.Answer = []dns.RR{mustRR(t, "example.com. 60 IN A 192.0.2.10")}
_ = w.WriteMsg(resp)
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
res, err := exchangeUDPOrTCP(ctx, m, addr, "udp")
if err != nil {
t.Fatalf("exchange: %v", err)
}
if res.Rcode != dns.RcodeSuccess {
t.Errorf("rcode = %d", res.Rcode)
}
if len(res.Answer) != 1 {
t.Fatalf("answers: %v", res.Answer)
}
}
func TestQueryResolver_UnknownTransport(t *testing.T) {
_, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, Transport("xyz"), "x.", dns.TypeA)
if err == nil || !strings.Contains(err.Error(), "unknown transport") {
t.Errorf("want unknown transport error, got %v", err)
}
}
func TestQueryResolver_MissingDoTEndpoint(t *testing.T) {
_, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoT, "x.", dns.TypeA)
if err == nil || !strings.Contains(err.Error(), "no DoT endpoint") {
t.Errorf("want missing DoT err, got %v", err)
}
}
func TestQueryResolver_MissingDoHEndpoint(t *testing.T) {
_, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoH, "x.", dns.TypeA)
if err == nil || !strings.Contains(err.Error(), "no DoH endpoint") {
t.Errorf("want missing DoH err, got %v", err)
}
}
func TestRunProbe_TransportError(t *testing.T) {
// Missing DoT host on the resolver: queryResolver returns an error,
// runProbe converts it into RRProbe.Error.
p := runProbe(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoT, "ex.", dns.TypeA)
if p.Error == "" {
t.Errorf("expected error for missing DoT host")
}
if p.Transport != TransportDoT {
t.Errorf("transport = %v", p.Transport)
}
}
func TestQueryAuthoritative(t *testing.T) {
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(m)
resp.Authoritative = true
resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 5.6.7.8")}
_ = w.WriteMsg(resp)
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA)
if e == nil {
t.Fatal("nil entry")
}
if e.sig != "5.6.7.8" {
t.Errorf("sig = %q", e.sig)
}
}
func TestQueryAuthoritative_NotAuthoritative(t *testing.T) {
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(m)
resp.Authoritative = false
resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 5.6.7.8")}
_ = w.WriteMsg(resp)
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA); e != nil {
t.Errorf("non-authoritative answer should be ignored, got %+v", e)
}
}
func TestQueryAuthoritative_NXDOMAIN(t *testing.T) {
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(m)
resp.Authoritative = true
resp.Rcode = dns.RcodeNameError
_ = w.WriteMsg(resp)
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA)
if e == nil {
t.Fatal("want non-nil entry for NXDOMAIN")
}
if e.sig != "" {
t.Errorf("NXDOMAIN should give empty sig: %q", e.sig)
}
}
func TestExchangeUDP_TruncationFallsBackToTCP(t *testing.T) {
// UDP returns truncated; we also start a TCP listener that returns the full
// answer. miekg/dns ServeMux supports both via a single Server, but we
// keep it explicit here.
pcUDP, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("udp listen: %v", err)
}
defer pcUDP.Close()
addr := pcUDP.LocalAddr().String()
host, port, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("split: %v", err)
}
// TCP needs to share the same port; bind a TCP listener on it.
tcpL, err := net.Listen("tcp", net.JoinHostPort(host, port))
if err != nil {
t.Fatalf("tcp listen: %v", err)
}
defer tcpL.Close()
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(m)
resp.Truncated = true
_ = w.WriteMsg(resp)
})
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(m)
resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 1.2.3.4")}
_ = w.WriteMsg(resp)
})
udpSrv := &dns.Server{PacketConn: pcUDP, Handler: udpHandler}
tcpSrv := &dns.Server{Listener: tcpL, Handler: tcpHandler}
var wg sync.WaitGroup
wg.Add(2)
go func() { defer wg.Done(); _ = udpSrv.ActivateAndServe() }()
go func() { defer wg.Done(); _ = tcpSrv.ActivateAndServe() }()
defer func() {
_ = udpSrv.Shutdown()
_ = tcpSrv.Shutdown()
wg.Wait()
}()
time.Sleep(30 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
m := new(dns.Msg)
m.SetQuestion("ex.", dns.TypeA)
res, err := exchangeUDPOrTCP(ctx, m, addr, "udp")
if err != nil {
t.Fatalf("exchange: %v", err)
}
if len(res.Answer) != 1 {
t.Fatalf("expected TCP fallback to populate answer, got %v", res.Answer)
}
}