checker-dane/checker/collect_test.go

226 lines
6.5 KiB
Go

package checker
import (
"context"
"encoding/json"
"testing"
sdk "git.happydns.org/checker-sdk-go/checker"
)
func makeOpts(t *testing.T, apex, subdomain string, records []map[string]any, starttls map[string]string) sdk.CheckerOptions {
t.Helper()
svc := map[string]any{
"_svctype": serviceType,
"_domain": apex,
"Service": map[string]any{"tlsa": records},
}
opts := sdk.CheckerOptions{
OptionDomain: apex,
OptionService: svc,
}
if subdomain != "" {
opts[OptionSubdomain] = subdomain
}
if starttls != nil {
opts[OptionSTARTTLS] = starttls
}
return opts
}
func tlsaRR(owner string, usage, selector, mtype int, cert string) map[string]any {
return map[string]any{
"Hdr": map[string]any{"Name": owner},
"Usage": usage,
"Selector": selector,
"MatchingType": mtype,
"Certificate": cert,
}
}
func TestCollect_GroupsByEndpoint(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com.", "", []map[string]any{
tlsaRR("_443._tcp.example.com.", 3, 1, 1, "AABB"),
tlsaRR("_443._tcp.example.com.", 3, 1, 1, "CCDD"),
tlsaRR("_25._tcp.mail.example.com.", 3, 1, 1, "EEFF"),
}, nil)
p := &daneProvider{}
out, err := p.Collect(context.Background(), opts)
if err != nil {
t.Fatalf("err=%v", err)
}
d := out.(*DANEData)
if len(d.Targets) != 2 {
t.Fatalf("targets=%d want 2", len(d.Targets))
}
// Sorted by base alphabetically: example.com < mail.example.com.
if d.Targets[0].Host != "example.com" || d.Targets[0].Port != 443 {
t.Errorf("sort[0]: %+v", d.Targets[0])
}
if d.Targets[1].Host != "mail.example.com" || d.Targets[1].Port != 25 {
t.Errorf("sort[1]: %+v", d.Targets[1])
}
// Two records on the 443 endpoint
if len(d.Targets[0].Records) != 2 {
t.Errorf("443 records=%d want 2", len(d.Targets[0].Records))
}
// Certificate hex was lowercased
if d.Targets[0].Records[0].Certificate != "aabb" {
t.Errorf("expected lowercased cert, got %q", d.Targets[0].Records[0].Certificate)
}
}
func TestCollect_DefaultSTARTTLS(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com", "", []map[string]any{
tlsaRR("_25._tcp.mail.example.com", 3, 1, 1, "00"),
tlsaRR("_443._tcp.example.com", 3, 1, 1, "00"),
tlsaRR("_587._tcp.mail.example.com", 3, 1, 1, "00"),
}, nil)
out, err := (&daneProvider{}).Collect(context.Background(), opts)
if err != nil {
t.Fatal(err)
}
d := out.(*DANEData)
got := map[uint16]string{}
for _, t := range d.Targets {
got[t.Port] = t.STARTTLS
}
if got[25] != "smtp" {
t.Errorf("port 25 starttls=%q want smtp", got[25])
}
if got[443] != "" {
t.Errorf("port 443 starttls=%q want empty (direct TLS)", got[443])
}
if got[587] != "submission" {
t.Errorf("port 587 starttls=%q want submission", got[587])
}
}
func TestCollect_STARTTLSOverride(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com", "", []map[string]any{
tlsaRR("_25._tcp.mail.example.com", 3, 1, 1, "00"),
}, map[string]string{"25/tcp": "lmtp"})
out, err := (&daneProvider{}).Collect(context.Background(), opts)
if err != nil {
t.Fatal(err)
}
d := out.(*DANEData)
if d.Targets[0].STARTTLS != "lmtp" {
t.Errorf("override: starttls=%q want lmtp", d.Targets[0].STARTTLS)
}
}
func TestCollect_MalformedOwnerSurfaced(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com", "", []map[string]any{
tlsaRR("totally-invalid", 3, 1, 1, "00"),
tlsaRR("_99999._tcp.example.com", 3, 1, 1, "00"), // port > 65535
tlsaRR("_443._tcp.example.com", 3, 1, 1, "AA"),
}, nil)
out, err := (&daneProvider{}).Collect(context.Background(), opts)
if err != nil {
t.Fatal(err)
}
d := out.(*DANEData)
if len(d.Targets) != 1 {
t.Errorf("expected one well-formed target, got %d", len(d.Targets))
}
if len(d.Invalid) != 2 {
t.Errorf("expected 2 invalid entries, got %d (%+v)", len(d.Invalid), d.Invalid)
}
}
func TestCollect_BaseRelativeToSubdomain(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com", "mail", []map[string]any{
// Owner has no base, so the records live on the subdomain itself.
tlsaRR("_25._tcp", 3, 1, 1, "AA"),
}, nil)
out, err := (&daneProvider{}).Collect(context.Background(), opts)
if err != nil {
t.Fatal(err)
}
d := out.(*DANEData)
if len(d.Targets) != 1 {
t.Fatalf("targets=%d", len(d.Targets))
}
if d.Targets[0].Host != "mail.example.com" {
t.Errorf("host=%q want mail.example.com", d.Targets[0].Host)
}
if d.Targets[0].Owner != "_25._tcp.mail.example.com" {
t.Errorf("owner=%q", d.Targets[0].Owner)
}
}
func TestCollect_WrongServiceType(t *testing.T) {
t.Parallel()
svc := map[string]any{
"_svctype": "svcs.NotTLSAs",
"Service": map[string]any{"tlsa": []any{}},
}
opts := sdk.CheckerOptions{OptionDomain: "example.com", OptionService: svc}
if _, err := (&daneProvider{}).Collect(context.Background(), opts); err == nil {
t.Error("expected error on wrong service type")
}
}
func TestCollect_MissingService(t *testing.T) {
t.Parallel()
opts := sdk.CheckerOptions{OptionDomain: "example.com"}
if _, err := (&daneProvider{}).Collect(context.Background(), opts); err == nil {
t.Error("expected error on missing service")
}
}
func TestCollect_DiscoverEntries(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com", "", []map[string]any{
tlsaRR("_443._tcp.example.com", 3, 1, 1, "AA"),
tlsaRR("_25._tcp.mail.example.com", 3, 1, 1, "BB"),
}, nil)
p := &daneProvider{}
data, err := p.Collect(context.Background(), opts)
if err != nil {
t.Fatal(err)
}
entries, err := p.DiscoverEntries(data)
if err != nil {
t.Fatalf("err=%v", err)
}
if len(entries) != 2 {
t.Errorf("entries=%d want 2", len(entries))
}
// Nil/wrong type returns nil, nil (defensive).
if got, err := p.DiscoverEntries(nil); err != nil || got != nil {
t.Errorf("nil: got=%v err=%v", got, err)
}
if got, err := p.DiscoverEntries("not a *DANEData"); err != nil || got != nil {
t.Errorf("wrong type: got=%v err=%v", got, err)
}
}
func TestCollect_DeterministicOutput(t *testing.T) {
t.Parallel()
opts := makeOpts(t, "example.com", "", []map[string]any{
tlsaRR("_25._tcp.b.example.com", 3, 1, 1, "AA"),
tlsaRR("_25._tcp.a.example.com", 3, 1, 1, "BB"),
tlsaRR("_443._tcp.a.example.com", 3, 1, 1, "CC"),
}, nil)
var prev []byte
for i := range 3 {
out, err := (&daneProvider{}).Collect(context.Background(), opts)
if err != nil {
t.Fatal(err)
}
// Compare only Targets: CollectedAt is a wall-clock timestamp.
b, _ := json.Marshal(out.(*DANEData).Targets)
if i > 0 && string(b) != string(prev) {
t.Errorf("non-deterministic targets:\n%s\nvs\n%s", prev, b)
}
prev = b
}
}