package checker import ( "context" "encoding/json" "testing" sdk "git.happydns.org/checker-sdk-go/checker" ) func makeOpts(t *testing.T, apex, subdomain string, records []map[string]any, starttls map[string]string) sdk.CheckerOptions { t.Helper() svc := map[string]any{ "_svctype": serviceType, "_domain": apex, "Service": map[string]any{"tlsa": records}, } opts := sdk.CheckerOptions{ OptionDomain: apex, OptionService: svc, } if subdomain != "" { opts[OptionSubdomain] = subdomain } if starttls != nil { opts[OptionSTARTTLS] = starttls } return opts } func tlsaRR(owner string, usage, selector, mtype int, cert string) map[string]any { return map[string]any{ "Hdr": map[string]any{"Name": owner}, "Usage": usage, "Selector": selector, "MatchingType": mtype, "Certificate": cert, } } func TestCollect_GroupsByEndpoint(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com.", "", []map[string]any{ tlsaRR("_443._tcp.example.com.", 3, 1, 1, "AABB"), tlsaRR("_443._tcp.example.com.", 3, 1, 1, "CCDD"), tlsaRR("_25._tcp.mail.example.com.", 3, 1, 1, "EEFF"), }, nil) p := &daneProvider{} out, err := p.Collect(context.Background(), opts) if err != nil { t.Fatalf("err=%v", err) } d := out.(*DANEData) if len(d.Targets) != 2 { t.Fatalf("targets=%d want 2", len(d.Targets)) } // Sorted by base alphabetically: example.com < mail.example.com. if d.Targets[0].Host != "example.com" || d.Targets[0].Port != 443 { t.Errorf("sort[0]: %+v", d.Targets[0]) } if d.Targets[1].Host != "mail.example.com" || d.Targets[1].Port != 25 { t.Errorf("sort[1]: %+v", d.Targets[1]) } // Two records on the 443 endpoint if len(d.Targets[0].Records) != 2 { t.Errorf("443 records=%d want 2", len(d.Targets[0].Records)) } // Certificate hex was lowercased if d.Targets[0].Records[0].Certificate != "aabb" { t.Errorf("expected lowercased cert, got %q", d.Targets[0].Records[0].Certificate) } } func TestCollect_DefaultSTARTTLS(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com", "", []map[string]any{ tlsaRR("_25._tcp.mail.example.com", 3, 1, 1, "00"), tlsaRR("_443._tcp.example.com", 3, 1, 1, "00"), tlsaRR("_587._tcp.mail.example.com", 3, 1, 1, "00"), }, nil) out, err := (&daneProvider{}).Collect(context.Background(), opts) if err != nil { t.Fatal(err) } d := out.(*DANEData) got := map[uint16]string{} for _, t := range d.Targets { got[t.Port] = t.STARTTLS } if got[25] != "smtp" { t.Errorf("port 25 starttls=%q want smtp", got[25]) } if got[443] != "" { t.Errorf("port 443 starttls=%q want empty (direct TLS)", got[443]) } if got[587] != "submission" { t.Errorf("port 587 starttls=%q want submission", got[587]) } } func TestCollect_STARTTLSOverride(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com", "", []map[string]any{ tlsaRR("_25._tcp.mail.example.com", 3, 1, 1, "00"), }, map[string]string{"25/tcp": "lmtp"}) out, err := (&daneProvider{}).Collect(context.Background(), opts) if err != nil { t.Fatal(err) } d := out.(*DANEData) if d.Targets[0].STARTTLS != "lmtp" { t.Errorf("override: starttls=%q want lmtp", d.Targets[0].STARTTLS) } } func TestCollect_MalformedOwnerSurfaced(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com", "", []map[string]any{ tlsaRR("totally-invalid", 3, 1, 1, "00"), tlsaRR("_99999._tcp.example.com", 3, 1, 1, "00"), // port > 65535 tlsaRR("_443._tcp.example.com", 3, 1, 1, "AA"), }, nil) out, err := (&daneProvider{}).Collect(context.Background(), opts) if err != nil { t.Fatal(err) } d := out.(*DANEData) if len(d.Targets) != 1 { t.Errorf("expected one well-formed target, got %d", len(d.Targets)) } if len(d.Invalid) != 2 { t.Errorf("expected 2 invalid entries, got %d (%+v)", len(d.Invalid), d.Invalid) } } func TestCollect_BaseRelativeToSubdomain(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com", "mail", []map[string]any{ // Owner has no base, so the records live on the subdomain itself. tlsaRR("_25._tcp", 3, 1, 1, "AA"), }, nil) out, err := (&daneProvider{}).Collect(context.Background(), opts) if err != nil { t.Fatal(err) } d := out.(*DANEData) if len(d.Targets) != 1 { t.Fatalf("targets=%d", len(d.Targets)) } if d.Targets[0].Host != "mail.example.com" { t.Errorf("host=%q want mail.example.com", d.Targets[0].Host) } if d.Targets[0].Owner != "_25._tcp.mail.example.com" { t.Errorf("owner=%q", d.Targets[0].Owner) } } func TestCollect_WrongServiceType(t *testing.T) { t.Parallel() svc := map[string]any{ "_svctype": "svcs.NotTLSAs", "Service": map[string]any{"tlsa": []any{}}, } opts := sdk.CheckerOptions{OptionDomain: "example.com", OptionService: svc} if _, err := (&daneProvider{}).Collect(context.Background(), opts); err == nil { t.Error("expected error on wrong service type") } } func TestCollect_MissingService(t *testing.T) { t.Parallel() opts := sdk.CheckerOptions{OptionDomain: "example.com"} if _, err := (&daneProvider{}).Collect(context.Background(), opts); err == nil { t.Error("expected error on missing service") } } func TestCollect_DiscoverEntries(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com", "", []map[string]any{ tlsaRR("_443._tcp.example.com", 3, 1, 1, "AA"), tlsaRR("_25._tcp.mail.example.com", 3, 1, 1, "BB"), }, nil) p := &daneProvider{} data, err := p.Collect(context.Background(), opts) if err != nil { t.Fatal(err) } entries, err := p.DiscoverEntries(data) if err != nil { t.Fatalf("err=%v", err) } if len(entries) != 2 { t.Errorf("entries=%d want 2", len(entries)) } // Nil/wrong type returns nil, nil (defensive). if got, err := p.DiscoverEntries(nil); err != nil || got != nil { t.Errorf("nil: got=%v err=%v", got, err) } if got, err := p.DiscoverEntries("not a *DANEData"); err != nil || got != nil { t.Errorf("wrong type: got=%v err=%v", got, err) } } func TestCollect_DeterministicOutput(t *testing.T) { t.Parallel() opts := makeOpts(t, "example.com", "", []map[string]any{ tlsaRR("_25._tcp.b.example.com", 3, 1, 1, "AA"), tlsaRR("_25._tcp.a.example.com", 3, 1, 1, "BB"), tlsaRR("_443._tcp.a.example.com", 3, 1, 1, "CC"), }, nil) var prev []byte for i := range 3 { out, err := (&daneProvider{}).Collect(context.Background(), opts) if err != nil { t.Fatal(err) } // Compare only Targets: CollectedAt is a wall-clock timestamp. b, _ := json.Marshal(out.(*DANEData).Targets) if i > 0 && string(b) != string(prev) { t.Errorf("non-deterministic targets:\n%s\nvs\n%s", prev, b) } prev = b } }