305 lines
8.5 KiB
Go
305 lines
8.5 KiB
Go
package checker
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
func mustRR(t *testing.T, s string) dns.RR {
|
|
t.Helper()
|
|
rr, err := dns.NewRR(s)
|
|
if err != nil {
|
|
t.Fatalf("dns.NewRR(%q): %v", s, err)
|
|
}
|
|
return rr
|
|
}
|
|
|
|
func TestCanonicalRR(t *testing.T) {
|
|
if got := canonicalRR(nil); got != "" {
|
|
t.Errorf("nil RR: want empty, got %q", got)
|
|
}
|
|
|
|
cases := []struct {
|
|
rr string
|
|
want string
|
|
}{
|
|
{"example.com. 300 IN A 192.0.2.1", "192.0.2.1"},
|
|
{"Example.Com. 300 IN NS Ns1.Example.Com.", "ns1.example.com."},
|
|
{"example.com. 60 IN MX 10 mail.example.com.", "10 mail.example.com."},
|
|
{"example.com. 30 IN TXT \"v=spf1 -all\"", "\"v=spf1 -all\""},
|
|
}
|
|
for _, c := range cases {
|
|
if got := canonicalRR(mustRR(t, c.rr)); got != c.want {
|
|
t.Errorf("canonicalRR(%q) = %q, want %q", c.rr, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSignatureFromRRs(t *testing.T) {
|
|
rrs := []dns.RR{
|
|
mustRR(t, "example.com. 300 IN A 192.0.2.2"),
|
|
mustRR(t, "example.com. 60 IN A 192.0.2.1"),
|
|
mustRR(t, "example.com. 300 IN AAAA 2001:db8::1"), // wrong type
|
|
mustRR(t, "other.example.com. 300 IN A 198.51.100.1"), // wrong owner
|
|
}
|
|
sig, recs, ttl := signatureFromRRs(rrs, "example.com", dns.TypeA)
|
|
if sig != "192.0.2.1|192.0.2.2" {
|
|
t.Errorf("sig = %q", sig)
|
|
}
|
|
if len(recs) != 2 || recs[0] != "192.0.2.1" || recs[1] != "192.0.2.2" {
|
|
t.Errorf("records = %v", recs)
|
|
}
|
|
if ttl != 60 {
|
|
t.Errorf("minTTL = %d, want 60", ttl)
|
|
}
|
|
|
|
// Owner case-insensitivity.
|
|
sig2, _, _ := signatureFromRRs(rrs, "EXAMPLE.com.", dns.TypeA)
|
|
if sig2 != sig {
|
|
t.Errorf("owner case sensitivity: %q vs %q", sig2, sig)
|
|
}
|
|
|
|
// Empty input.
|
|
if s, r, ttl := signatureFromRRs(nil, "x", dns.TypeA); s != "" || r != nil || ttl != 0 {
|
|
t.Errorf("empty input: %q %v %d", s, r, ttl)
|
|
}
|
|
}
|
|
|
|
func TestSignatureDeterministic(t *testing.T) {
|
|
a := []dns.RR{
|
|
mustRR(t, "x. 30 IN A 1.1.1.1"),
|
|
mustRR(t, "x. 30 IN A 2.2.2.2"),
|
|
}
|
|
b := []dns.RR{
|
|
mustRR(t, "x. 30 IN A 2.2.2.2"),
|
|
mustRR(t, "x. 30 IN A 1.1.1.1"),
|
|
}
|
|
sa, _, _ := signatureFromRRs(a, "x", dns.TypeA)
|
|
sb, _, _ := signatureFromRRs(b, "x", dns.TypeA)
|
|
if sa != sb {
|
|
t.Errorf("ordering changed sig: %q vs %q", sa, sb)
|
|
}
|
|
}
|
|
|
|
func TestRcodeToString(t *testing.T) {
|
|
cases := []struct {
|
|
in int
|
|
want string
|
|
}{
|
|
{dns.RcodeSuccess, "NOERROR"},
|
|
{dns.RcodeNameError, "NXDOMAIN"},
|
|
{dns.RcodeServerFailure, "SERVFAIL"},
|
|
{42, "RCODE42"},
|
|
}
|
|
for _, c := range cases {
|
|
if got := rcodeToString(c.in); got != c.want {
|
|
t.Errorf("rcodeToString(%d) = %q, want %q", c.in, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// startUDPServer brings up a tiny miekg/dns UDP server bound to a free port,
|
|
// returning its address and a stop func. The handler is called for every
|
|
// query and decides what to write back.
|
|
func startUDPServer(t *testing.T, handler dns.HandlerFunc) (string, func()) {
|
|
t.Helper()
|
|
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
srv := &dns.Server{PacketConn: pc, Handler: handler}
|
|
done := make(chan struct{})
|
|
go func() {
|
|
_ = srv.ActivateAndServe()
|
|
close(done)
|
|
}()
|
|
// give the server a moment
|
|
time.Sleep(20 * time.Millisecond)
|
|
return pc.LocalAddr().String(), func() {
|
|
_ = srv.Shutdown()
|
|
<-done
|
|
}
|
|
}
|
|
|
|
func TestExchangeUDPOrTCP_Success(t *testing.T) {
|
|
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
|
|
resp := new(dns.Msg)
|
|
resp.SetReply(m)
|
|
resp.Authoritative = true
|
|
resp.Answer = []dns.RR{mustRR(t, "example.com. 60 IN A 192.0.2.10")}
|
|
_ = w.WriteMsg(resp)
|
|
})
|
|
defer stop()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
res, err := exchangeUDPOrTCP(ctx, m, addr, "udp")
|
|
if err != nil {
|
|
t.Fatalf("exchange: %v", err)
|
|
}
|
|
if res.Rcode != dns.RcodeSuccess {
|
|
t.Errorf("rcode = %d", res.Rcode)
|
|
}
|
|
if len(res.Answer) != 1 {
|
|
t.Fatalf("answers: %v", res.Answer)
|
|
}
|
|
}
|
|
|
|
func TestQueryResolver_UnknownTransport(t *testing.T) {
|
|
_, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, Transport("xyz"), "x.", dns.TypeA)
|
|
if err == nil || !strings.Contains(err.Error(), "unknown transport") {
|
|
t.Errorf("want unknown transport error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestQueryResolver_MissingDoTEndpoint(t *testing.T) {
|
|
_, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoT, "x.", dns.TypeA)
|
|
if err == nil || !strings.Contains(err.Error(), "no DoT endpoint") {
|
|
t.Errorf("want missing DoT err, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestQueryResolver_MissingDoHEndpoint(t *testing.T) {
|
|
_, err := queryResolver(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoH, "x.", dns.TypeA)
|
|
if err == nil || !strings.Contains(err.Error(), "no DoH endpoint") {
|
|
t.Errorf("want missing DoH err, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRunProbe_TransportError(t *testing.T) {
|
|
// Missing DoT host on the resolver: queryResolver returns an error,
|
|
// runProbe converts it into RRProbe.Error.
|
|
p := runProbe(context.Background(), Resolver{ID: "x", IP: "127.0.0.1"}, TransportDoT, "ex.", dns.TypeA)
|
|
if p.Error == "" {
|
|
t.Errorf("expected error for missing DoT host")
|
|
}
|
|
if p.Transport != TransportDoT {
|
|
t.Errorf("transport = %v", p.Transport)
|
|
}
|
|
}
|
|
|
|
func TestQueryAuthoritative(t *testing.T) {
|
|
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
|
|
resp := new(dns.Msg)
|
|
resp.SetReply(m)
|
|
resp.Authoritative = true
|
|
resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 5.6.7.8")}
|
|
_ = w.WriteMsg(resp)
|
|
})
|
|
defer stop()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA)
|
|
if e == nil {
|
|
t.Fatal("nil entry")
|
|
}
|
|
if e.sig != "5.6.7.8" {
|
|
t.Errorf("sig = %q", e.sig)
|
|
}
|
|
}
|
|
|
|
func TestQueryAuthoritative_NotAuthoritative(t *testing.T) {
|
|
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
|
|
resp := new(dns.Msg)
|
|
resp.SetReply(m)
|
|
resp.Authoritative = false
|
|
resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 5.6.7.8")}
|
|
_ = w.WriteMsg(resp)
|
|
})
|
|
defer stop()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
if e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA); e != nil {
|
|
t.Errorf("non-authoritative answer should be ignored, got %+v", e)
|
|
}
|
|
}
|
|
|
|
func TestQueryAuthoritative_NXDOMAIN(t *testing.T) {
|
|
addr, stop := startUDPServer(t, func(w dns.ResponseWriter, m *dns.Msg) {
|
|
resp := new(dns.Msg)
|
|
resp.SetReply(m)
|
|
resp.Authoritative = true
|
|
resp.Rcode = dns.RcodeNameError
|
|
_ = w.WriteMsg(resp)
|
|
})
|
|
defer stop()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
e := queryAuthoritative(ctx, []string{addr}, "ex.", dns.TypeA)
|
|
if e == nil {
|
|
t.Fatal("want non-nil entry for NXDOMAIN")
|
|
}
|
|
if e.sig != "" {
|
|
t.Errorf("NXDOMAIN should give empty sig: %q", e.sig)
|
|
}
|
|
}
|
|
|
|
func TestExchangeUDP_TruncationFallsBackToTCP(t *testing.T) {
|
|
// UDP returns truncated; we also start a TCP listener that returns the full
|
|
// answer. miekg/dns ServeMux supports both via a single Server, but we
|
|
// keep it explicit here.
|
|
pcUDP, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("udp listen: %v", err)
|
|
}
|
|
defer pcUDP.Close()
|
|
addr := pcUDP.LocalAddr().String()
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
t.Fatalf("split: %v", err)
|
|
}
|
|
// TCP needs to share the same port; bind a TCP listener on it.
|
|
tcpL, err := net.Listen("tcp", net.JoinHostPort(host, port))
|
|
if err != nil {
|
|
t.Fatalf("tcp listen: %v", err)
|
|
}
|
|
defer tcpL.Close()
|
|
|
|
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
|
resp := new(dns.Msg)
|
|
resp.SetReply(m)
|
|
resp.Truncated = true
|
|
_ = w.WriteMsg(resp)
|
|
})
|
|
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
|
resp := new(dns.Msg)
|
|
resp.SetReply(m)
|
|
resp.Answer = []dns.RR{mustRR(t, "ex. 60 IN A 1.2.3.4")}
|
|
_ = w.WriteMsg(resp)
|
|
})
|
|
|
|
udpSrv := &dns.Server{PacketConn: pcUDP, Handler: udpHandler}
|
|
tcpSrv := &dns.Server{Listener: tcpL, Handler: tcpHandler}
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() { defer wg.Done(); _ = udpSrv.ActivateAndServe() }()
|
|
go func() { defer wg.Done(); _ = tcpSrv.ActivateAndServe() }()
|
|
defer func() {
|
|
_ = udpSrv.Shutdown()
|
|
_ = tcpSrv.Shutdown()
|
|
wg.Wait()
|
|
}()
|
|
time.Sleep(30 * time.Millisecond)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("ex.", dns.TypeA)
|
|
res, err := exchangeUDPOrTCP(ctx, m, addr, "udp")
|
|
if err != nil {
|
|
t.Fatalf("exchange: %v", err)
|
|
}
|
|
if len(res.Answer) != 1 {
|
|
t.Fatalf("expected TCP fallback to populate answer, got %v", res.Answer)
|
|
}
|
|
}
|