checker-resolver-propagation/checker/collect_test.go

243 lines
7.1 KiB
Go

package checker
import (
"encoding/json"
"reflect"
"sort"
"testing"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
func TestParseCSV(t *testing.T) {
cases := []struct {
in string
want []string
}{
{"", nil},
{"a", []string{"a"}},
{"a,b,c", []string{"a", "b", "c"}},
{" a , ,b ,", []string{"a", "b"}},
{",,,", []string{}},
}
for _, c := range cases {
got := parseCSV(c.in)
if len(got) != len(c.want) {
t.Errorf("parseCSV(%q) len = %d, want %d", c.in, len(got), len(c.want))
continue
}
for i := range got {
if got[i] != c.want[i] {
t.Errorf("parseCSV(%q)[%d] = %q, want %q", c.in, i, got[i], c.want[i])
}
}
}
}
func TestParseQTypes(t *testing.T) {
got := parseQTypes("a,aaaa,MX,TxT,bogus,A") // A duplicated; bogus skipped
want := []uint16{dns.TypeA, dns.TypeAAAA, dns.TypeMX, dns.TypeTXT}
if !reflect.DeepEqual(got, want) {
t.Errorf("parseQTypes = %v, want %v", got, want)
}
if got := parseQTypes(""); got != nil {
t.Errorf("parseQTypes(\"\") = %v, want nil", got)
}
if got := parseQTypes("nope,onlybad"); got != nil {
t.Errorf("parseQTypes(bad) = %v, want nil", got)
}
}
func TestQtypeNames(t *testing.T) {
got := qtypeNames([]uint16{dns.TypeA, dns.TypeMX})
want := []string{"A", "MX"}
if !reflect.DeepEqual(got, want) {
t.Errorf("qtypeNames = %v, want %v", got, want)
}
}
func TestJoinSubdomain(t *testing.T) {
cases := []struct {
sd, zone, want string
}{
{"", "example.com", "example.com."},
{"@", "example.com.", "example.com."},
{"www", "example.com", "www.example.com."},
{"WWW", "Example.Com", "www.example.com."},
{"foo.example.org.", "example.com", "foo.example.org."}, // already FQDN: used as-is
{" www ", "example.com", "www.example.com."},
}
for _, c := range cases {
if got := joinSubdomain(c.sd, c.zone); got != c.want {
t.Errorf("joinSubdomain(%q,%q) = %q, want %q", c.sd, c.zone, got, c.want)
}
}
}
func TestExtractSerial(t *testing.T) {
cases := []struct {
in []string
want uint32
}{
{nil, 0},
{[]string{"ns. hostmaster. 2024010101 7200 3600 1209600 3600"}, 2024010101},
{[]string{"too few fields"}, 0},
{[]string{"ns. hm. notanumber 1 2 3 4"}, 0},
{[]string{"ns. hm. 99999999999999999 1 2 3 4"}, 0}, // overflow uint32
}
for _, c := range cases {
if got := extractSerial(c.in); got != c.want {
t.Errorf("extractSerial(%v) = %d, want %d", c.in, got, c.want)
}
}
}
func TestFirstN(t *testing.T) {
if got := firstN([]string{"a", "b"}, 5); got != "a, b" {
t.Errorf("under: %q", got)
}
if got := firstN([]string{"a", "b", "c", "d"}, 2); got != "a, b (+2 more)" {
t.Errorf("over: %q", got)
}
if got := firstN(nil, 3); got != "" {
t.Errorf("nil: %q", got)
}
}
func TestIsValidatingResolver(t *testing.T) {
for _, id := range []string{"cloudflare", "google", "quad9", "adguard"} {
if !isValidatingResolver(id) {
t.Errorf("%s should validate", id)
}
}
for _, id := range []string{"opendns", "yandex", "ntt-jp", ""} {
if isValidatingResolver(id) {
t.Errorf("%s should NOT validate", id)
}
}
// transport-suffixed IDs (e.g. "cloudflare|tcp") should still match.
if !isValidatingResolver("cloudflare|tcp") {
t.Errorf("transport-suffixed ID should still validate")
}
}
func TestComputeBasicStats(t *testing.T) {
data := &ResolverPropagationData{
Resolvers: map[string]*ResolverView{
"a": {Region: "eu", Reachable: true},
"b": {Region: "eu", Reachable: false, Filtered: true},
"c": {Region: "global", Reachable: true},
"d": {Region: "na", Reachable: true, Filtered: true},
},
}
s := computeBasicStats(data)
if s.TotalResolvers != 4 {
t.Errorf("total = %d", s.TotalResolvers)
}
if s.ReachableResolvers != 3 {
t.Errorf("reachable = %d", s.ReachableResolvers)
}
if s.FilteredProbed != 2 || s.UnfilteredProbed != 2 {
t.Errorf("split filtered=%d unfiltered=%d", s.FilteredProbed, s.UnfilteredProbed)
}
if s.CountriesCovered != 3 {
t.Errorf("regions = %d", s.CountriesCovered)
}
}
func TestGetStringOpt(t *testing.T) {
opts := sdk.CheckerOptions{"a": "x", "b": ""}
if got := getStringOpt(opts, "a", "d"); got != "x" {
t.Errorf("a = %q", got)
}
if got := getStringOpt(opts, "b", "d"); got != "d" {
t.Errorf("b = %q", got)
}
if got := getStringOpt(opts, "missing", "d"); got != "d" {
t.Errorf("missing = %q", got)
}
}
func TestLoadService(t *testing.T) {
// Missing service: tolerated (standalone / interactive use). Returns
// an empty payload so collectExpected falls back to the system resolver.
if svc, err := loadService(sdk.CheckerOptions{}); err != nil {
t.Errorf("unexpected error for missing service: %v", err)
} else if svc == nil || svc.SOA != nil || len(svc.NameServers) != 0 {
t.Errorf("want empty service, got %+v", svc)
}
// Wrong type.
bad := serviceMessage{Type: "abstract.NotOrigin", Service: json.RawMessage(`{}`)}
if _, err := loadService(sdk.CheckerOptions{"service": bad}); err == nil {
t.Errorf("want error for wrong service type")
}
// Valid Origin payload.
msg := serviceMessage{
Type: "abstract.Origin",
Service: json.RawMessage(`{"soa":{"Hdr":{"Name":"example.com.","Rrtype":6,"Class":1,"Ttl":3600},"Ns":"ns.example.com.","Mbox":"hm.example.com.","Serial":42,"Refresh":3600,"Retry":600,"Expire":86400,"Minttl":300}}`),
}
svc, err := loadService(sdk.CheckerOptions{"service": msg})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if svc.SOA == nil || svc.SOA.Serial != 42 {
t.Errorf("got SOA = %+v", svc.SOA)
}
// Empty type is accepted.
emptyType := serviceMessage{Type: "", Service: json.RawMessage(`{}`)}
if _, err := loadService(sdk.CheckerOptions{"service": emptyType}); err != nil {
t.Errorf("empty type should be allowed: %v", err)
}
// Malformed JSON in Service.
bad2 := serviceMessage{Type: "abstract.Origin", Service: json.RawMessage(`not-json`)}
if _, err := loadService(sdk.CheckerOptions{"service": bad2}); err == nil {
t.Errorf("want decode error")
}
}
func TestLoadZone(t *testing.T) {
// From explicit option.
z, err := loadZone(sdk.CheckerOptions{"domain_name": "example.com"}, &originService{})
if err != nil || z != "example.com." {
t.Errorf("explicit: %q %v", z, err)
}
// Fallback to SOA header.
soa := &dns.SOA{Hdr: dns.RR_Header{Name: "fallback.test."}}
z, err = loadZone(sdk.CheckerOptions{}, &originService{SOA: soa})
if err != nil || z != "fallback.test." {
t.Errorf("fallback: %q %v", z, err)
}
// No source available.
if _, err := loadZone(sdk.CheckerOptions{}, &originService{}); err == nil {
t.Errorf("want error when nothing supplies a zone")
}
}
func TestNamesAreDeduplicated(t *testing.T) {
// Smoke test for the dedup loop in Collect: build the same names slice
// the way Collect does and confirm extras don't double-up.
zone := dns.Fqdn("example.com")
names := []string{zone}
seen := map[string]bool{names[0]: true}
for _, sd := range []string{"@", "www", "www", "mail"} {
full := joinSubdomain(sd, zone)
if !seen[full] {
seen[full] = true
names = append(names, full)
}
}
sort.Strings(names)
want := []string{"example.com.", "mail.example.com.", "www.example.com."}
if !reflect.DeepEqual(names, want) {
t.Errorf("names = %v, want %v", names, want)
}
}