243 lines
7.1 KiB
Go
243 lines
7.1 KiB
Go
package checker
|
|
|
|
import (
|
|
"encoding/json"
|
|
"reflect"
|
|
"sort"
|
|
"testing"
|
|
|
|
"github.com/miekg/dns"
|
|
|
|
sdk "git.happydns.org/checker-sdk-go/checker"
|
|
)
|
|
|
|
func TestParseCSV(t *testing.T) {
|
|
cases := []struct {
|
|
in string
|
|
want []string
|
|
}{
|
|
{"", nil},
|
|
{"a", []string{"a"}},
|
|
{"a,b,c", []string{"a", "b", "c"}},
|
|
{" a , ,b ,", []string{"a", "b"}},
|
|
{",,,", []string{}},
|
|
}
|
|
for _, c := range cases {
|
|
got := parseCSV(c.in)
|
|
if len(got) != len(c.want) {
|
|
t.Errorf("parseCSV(%q) len = %d, want %d", c.in, len(got), len(c.want))
|
|
continue
|
|
}
|
|
for i := range got {
|
|
if got[i] != c.want[i] {
|
|
t.Errorf("parseCSV(%q)[%d] = %q, want %q", c.in, i, got[i], c.want[i])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseQTypes(t *testing.T) {
|
|
got := parseQTypes("a,aaaa,MX,TxT,bogus,A") // A duplicated; bogus skipped
|
|
want := []uint16{dns.TypeA, dns.TypeAAAA, dns.TypeMX, dns.TypeTXT}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("parseQTypes = %v, want %v", got, want)
|
|
}
|
|
|
|
if got := parseQTypes(""); got != nil {
|
|
t.Errorf("parseQTypes(\"\") = %v, want nil", got)
|
|
}
|
|
if got := parseQTypes("nope,onlybad"); got != nil {
|
|
t.Errorf("parseQTypes(bad) = %v, want nil", got)
|
|
}
|
|
}
|
|
|
|
func TestQtypeNames(t *testing.T) {
|
|
got := qtypeNames([]uint16{dns.TypeA, dns.TypeMX})
|
|
want := []string{"A", "MX"}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("qtypeNames = %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestJoinSubdomain(t *testing.T) {
|
|
cases := []struct {
|
|
sd, zone, want string
|
|
}{
|
|
{"", "example.com", "example.com."},
|
|
{"@", "example.com.", "example.com."},
|
|
{"www", "example.com", "www.example.com."},
|
|
{"WWW", "Example.Com", "www.example.com."},
|
|
{"foo.example.org.", "example.com", "foo.example.org."}, // already FQDN: used as-is
|
|
{" www ", "example.com", "www.example.com."},
|
|
}
|
|
for _, c := range cases {
|
|
if got := joinSubdomain(c.sd, c.zone); got != c.want {
|
|
t.Errorf("joinSubdomain(%q,%q) = %q, want %q", c.sd, c.zone, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExtractSerial(t *testing.T) {
|
|
cases := []struct {
|
|
in []string
|
|
want uint32
|
|
}{
|
|
{nil, 0},
|
|
{[]string{"ns. hostmaster. 2024010101 7200 3600 1209600 3600"}, 2024010101},
|
|
{[]string{"too few fields"}, 0},
|
|
{[]string{"ns. hm. notanumber 1 2 3 4"}, 0},
|
|
{[]string{"ns. hm. 99999999999999999 1 2 3 4"}, 0}, // overflow uint32
|
|
}
|
|
for _, c := range cases {
|
|
if got := extractSerial(c.in); got != c.want {
|
|
t.Errorf("extractSerial(%v) = %d, want %d", c.in, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFirstN(t *testing.T) {
|
|
if got := firstN([]string{"a", "b"}, 5); got != "a, b" {
|
|
t.Errorf("under: %q", got)
|
|
}
|
|
if got := firstN([]string{"a", "b", "c", "d"}, 2); got != "a, b (+2 more)" {
|
|
t.Errorf("over: %q", got)
|
|
}
|
|
if got := firstN(nil, 3); got != "" {
|
|
t.Errorf("nil: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestIsValidatingResolver(t *testing.T) {
|
|
for _, id := range []string{"cloudflare", "google", "quad9", "adguard"} {
|
|
if !isValidatingResolver(id) {
|
|
t.Errorf("%s should validate", id)
|
|
}
|
|
}
|
|
for _, id := range []string{"opendns", "yandex", "ntt-jp", ""} {
|
|
if isValidatingResolver(id) {
|
|
t.Errorf("%s should NOT validate", id)
|
|
}
|
|
}
|
|
// transport-suffixed IDs (e.g. "cloudflare|tcp") should still match.
|
|
if !isValidatingResolver("cloudflare|tcp") {
|
|
t.Errorf("transport-suffixed ID should still validate")
|
|
}
|
|
}
|
|
|
|
func TestComputeBasicStats(t *testing.T) {
|
|
data := &ResolverPropagationData{
|
|
Resolvers: map[string]*ResolverView{
|
|
"a": {Region: "eu", Reachable: true},
|
|
"b": {Region: "eu", Reachable: false, Filtered: true},
|
|
"c": {Region: "global", Reachable: true},
|
|
"d": {Region: "na", Reachable: true, Filtered: true},
|
|
},
|
|
}
|
|
s := computeBasicStats(data)
|
|
if s.TotalResolvers != 4 {
|
|
t.Errorf("total = %d", s.TotalResolvers)
|
|
}
|
|
if s.ReachableResolvers != 3 {
|
|
t.Errorf("reachable = %d", s.ReachableResolvers)
|
|
}
|
|
if s.FilteredProbed != 2 || s.UnfilteredProbed != 2 {
|
|
t.Errorf("split filtered=%d unfiltered=%d", s.FilteredProbed, s.UnfilteredProbed)
|
|
}
|
|
if s.CountriesCovered != 3 {
|
|
t.Errorf("regions = %d", s.CountriesCovered)
|
|
}
|
|
}
|
|
|
|
func TestGetStringOpt(t *testing.T) {
|
|
opts := sdk.CheckerOptions{"a": "x", "b": ""}
|
|
if got := getStringOpt(opts, "a", "d"); got != "x" {
|
|
t.Errorf("a = %q", got)
|
|
}
|
|
if got := getStringOpt(opts, "b", "d"); got != "d" {
|
|
t.Errorf("b = %q", got)
|
|
}
|
|
if got := getStringOpt(opts, "missing", "d"); got != "d" {
|
|
t.Errorf("missing = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestLoadService(t *testing.T) {
|
|
// Missing service: tolerated (standalone / interactive use). Returns
|
|
// an empty payload so collectExpected falls back to the system resolver.
|
|
if svc, err := loadService(sdk.CheckerOptions{}); err != nil {
|
|
t.Errorf("unexpected error for missing service: %v", err)
|
|
} else if svc == nil || svc.SOA != nil || len(svc.NameServers) != 0 {
|
|
t.Errorf("want empty service, got %+v", svc)
|
|
}
|
|
|
|
// Wrong type.
|
|
bad := serviceMessage{Type: "abstract.NotOrigin", Service: json.RawMessage(`{}`)}
|
|
if _, err := loadService(sdk.CheckerOptions{"service": bad}); err == nil {
|
|
t.Errorf("want error for wrong service type")
|
|
}
|
|
|
|
// Valid Origin payload.
|
|
msg := serviceMessage{
|
|
Type: "abstract.Origin",
|
|
Service: json.RawMessage(`{"soa":{"Hdr":{"Name":"example.com.","Rrtype":6,"Class":1,"Ttl":3600},"Ns":"ns.example.com.","Mbox":"hm.example.com.","Serial":42,"Refresh":3600,"Retry":600,"Expire":86400,"Minttl":300}}`),
|
|
}
|
|
svc, err := loadService(sdk.CheckerOptions{"service": msg})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if svc.SOA == nil || svc.SOA.Serial != 42 {
|
|
t.Errorf("got SOA = %+v", svc.SOA)
|
|
}
|
|
|
|
// Empty type is accepted.
|
|
emptyType := serviceMessage{Type: "", Service: json.RawMessage(`{}`)}
|
|
if _, err := loadService(sdk.CheckerOptions{"service": emptyType}); err != nil {
|
|
t.Errorf("empty type should be allowed: %v", err)
|
|
}
|
|
|
|
// Malformed JSON in Service.
|
|
bad2 := serviceMessage{Type: "abstract.Origin", Service: json.RawMessage(`not-json`)}
|
|
if _, err := loadService(sdk.CheckerOptions{"service": bad2}); err == nil {
|
|
t.Errorf("want decode error")
|
|
}
|
|
}
|
|
|
|
func TestLoadZone(t *testing.T) {
|
|
// From explicit option.
|
|
z, err := loadZone(sdk.CheckerOptions{"domain_name": "example.com"}, &originService{})
|
|
if err != nil || z != "example.com." {
|
|
t.Errorf("explicit: %q %v", z, err)
|
|
}
|
|
|
|
// Fallback to SOA header.
|
|
soa := &dns.SOA{Hdr: dns.RR_Header{Name: "fallback.test."}}
|
|
z, err = loadZone(sdk.CheckerOptions{}, &originService{SOA: soa})
|
|
if err != nil || z != "fallback.test." {
|
|
t.Errorf("fallback: %q %v", z, err)
|
|
}
|
|
|
|
// No source available.
|
|
if _, err := loadZone(sdk.CheckerOptions{}, &originService{}); err == nil {
|
|
t.Errorf("want error when nothing supplies a zone")
|
|
}
|
|
}
|
|
|
|
func TestNamesAreDeduplicated(t *testing.T) {
|
|
// Smoke test for the dedup loop in Collect: build the same names slice
|
|
// the way Collect does and confirm extras don't double-up.
|
|
zone := dns.Fqdn("example.com")
|
|
names := []string{zone}
|
|
seen := map[string]bool{names[0]: true}
|
|
for _, sd := range []string{"@", "www", "www", "mail"} {
|
|
full := joinSubdomain(sd, zone)
|
|
if !seen[full] {
|
|
seen[full] = true
|
|
names = append(names, full)
|
|
}
|
|
}
|
|
sort.Strings(names)
|
|
want := []string{"example.com.", "mail.example.com.", "www.example.com."}
|
|
if !reflect.DeepEqual(names, want) {
|
|
t.Errorf("names = %v, want %v", names, want)
|
|
}
|
|
}
|