package checker import ( "context" "encoding/json" "errors" "testing" sdk "git.happydns.org/checker-sdk-go/checker" tls "git.happydns.org/checker-tls/checker" tlscontract "git.happydns.org/checker-tls/contract" ) // mockObs is a lightweight ObservationGetter for rule unit tests. type mockObs struct { dane *DANEData daneErr error probes map[string]tls.TLSProbe relatedErr error } func (m *mockObs) Get(_ context.Context, key sdk.ObservationKey, dest any) error { if m.daneErr != nil { return m.daneErr } if key != ObservationKeyDANE || m.dane == nil { return errors.New("not found") } b, err := json.Marshal(m.dane) if err != nil { return err } return json.Unmarshal(b, dest) } func (m *mockObs) GetRelated(_ context.Context, key sdk.ObservationKey) ([]sdk.RelatedObservation, error) { if m.relatedErr != nil { return nil, m.relatedErr } if key != tls.ObservationKeyTLSProbes || m.probes == nil { return nil, nil } payload := struct { Probes map[string]tls.TLSProbe `json:"probes"` }{Probes: m.probes} b, _ := json.Marshal(payload) return []sdk.RelatedObservation{{ CheckerID: "tls", Key: tls.ObservationKeyTLSProbes, Data: b, }}, nil } func makeTarget(host string, port uint16, recs []TLSARecord) TargetResult { t := TargetResult{ Owner: tlsaOwnerName(port, "tcp", host), Host: host, Port: port, Proto: "tcp", Records: recs, } t.Ref = tlscontract.Ref(tlscontract.TLSEndpoint{Host: host, Port: port, SNI: host}) return t } func TestHasRecordsRule(t *testing.T) { t.Parallel() r := &hasRecordsRule{} // No records, no invalid → unknown obs := &mockObs{dane: &DANEData{}} st := r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_no_records" { t.Errorf("no records: %+v", st) } // Records present → ok obs = &mockObs{dane: &DANEData{Targets: []TargetResult{makeTarget("a.example.com", 443, []TLSARecord{{}})}}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_has_records_ok" { t.Errorf("ok: %+v", st) } // Invalid records, no targets → error states obs = &mockObs{dane: &DANEData{Invalid: []InvalidRecord{{Owner: "_x._tcp", Reason: "bad port"}}}} st = r.Evaluate(context.Background(), obs, nil) if len(st) < 2 { t.Fatalf("expected per-record + aggregate, got %+v", st) } if st[0].Code != "dane_invalid_owner" || st[len(st)-1].Code != "dane_no_usable_records" { t.Errorf("invalid only: %+v", st) } // Observation read error obs = &mockObs{daneErr: errors.New("boom")} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_observation_error" { t.Errorf("err: %+v", st) } } func TestProbeAvailableRule(t *testing.T) { t.Parallel() r := &probeAvailableRule{} tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsageDANEEE}}) // Probe present leaf := fakeCert([]byte("l"), []byte("s")) obs := &mockObs{ dane: &DANEData{Targets: []TargetResult{tgt}}, probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}, } st := r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_probe_available_ok" { t.Errorf("ok: %+v", st) } // Probe absent obs.probes = map[string]tls.TLSProbe{} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_no_probe" { t.Errorf("missing: %+v", st) } // No targets at all obs = &mockObs{dane: &DANEData{}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_probe_available_skipped" { t.Errorf("empty: %+v", st) } // Related-fetch error surfaces as warning state. obs = &mockObs{dane: &DANEData{Targets: []TargetResult{tgt}}, relatedErr: errors.New("upstream down")} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_observation_warning" { t.Errorf("relatedErr: %+v", st) } } func TestHandshakeOKRule(t *testing.T) { t.Parallel() r := &handshakeOKRule{} tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsageDANEEE}}) leaf := fakeCert([]byte("l"), []byte("s")) // All good. obs := &mockObs{ dane: &DANEData{Targets: []TargetResult{tgt}}, probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}, } st := r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_handshake_ok" { t.Errorf("ok: %+v", st) } // Handshake failed. obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Error: "tls: bad cert"}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_handshake_failed" { t.Errorf("failed: %+v", st) } } func TestRecordsMatchChainRule(t *testing.T) { t.Parallel() r := &recordsMatchChainRule{} leaf := fakeCert([]byte("leaf"), []byte("ls")) tgt := makeTarget("a.example.com", 443, []TLSARecord{ {Usage: UsageDANEEE, Selector: SelectorSPKI, MatchingType: MatchingSHA256, Certificate: leaf.SPKISHA256}, }) obs := &mockObs{ dane: &DANEData{Targets: []TargetResult{tgt}}, probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}, } st := r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_match_ok" { t.Errorf("match ok: %+v", st) } // Same target, wrong cert hash → no match (crit). tgt.Records[0].Certificate = "deadbeef" obs.dane = &DANEData{Targets: []TargetResult{tgt}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_no_match" { t.Errorf("no match: %+v", st) } // No probe usable → skipped. obs.probes = map[string]tls.TLSProbe{} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_records_match_chain_skipped" { t.Errorf("skipped: %+v", st) } } func TestPKIXChainValidRule(t *testing.T) { t.Parallel() r := &pkixChainValidRule{} leaf := fakeCert([]byte("l"), []byte("s")) bTrue, bFalse := true, false // PKIX usage + valid chain → ok. tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsagePKIXEE}}) obs := &mockObs{ dane: &DANEData{Targets: []TargetResult{tgt}}, probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}, ChainValid: &bTrue}}, } st := r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_pkix_chain_valid_ok" { t.Errorf("ok: %+v", st) } // PKIX usage + invalid chain → crit. obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}, ChainValid: &bFalse}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_pkix_chain_invalid" { t.Errorf("invalid: %+v", st) } // DANE-only usages → skipped (rule does not apply). tgt.Records = []TLSARecord{{Usage: UsageDANEEE}} obs.dane = &DANEData{Targets: []TargetResult{tgt}} obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_pkix_chain_valid_skipped" { t.Errorf("skipped: %+v", st) } } func TestUsageCoherentRule(t *testing.T) { t.Parallel() r := &usageCoherentRule{} leaf := fakeCert([]byte("l"), []byte("ls")) mid := fakeCert([]byte("m"), []byte("ms")) // EE record whose hash matches the intermediate → warn. tgt := makeTarget("a.example.com", 443, []TLSARecord{{ Usage: UsageDANEEE, Selector: SelectorSPKI, MatchingType: MatchingSHA256, Certificate: mid.SPKISHA256, }}) obs := &mockObs{ dane: &DANEData{Targets: []TargetResult{tgt}}, probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf, mid}}}, } st := r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_usage_incoherent" { t.Errorf("incoherent: %+v", st) } // EE matching leaf → ok. tgt.Records[0].Certificate = leaf.SPKISHA256 obs.dane = &DANEData{Targets: []TargetResult{tgt}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_usage_coherent_ok" { t.Errorf("coherent ok: %+v", st) } // Single-cert chain → skipped. obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}} st = r.Evaluate(context.Background(), obs, nil) if len(st) != 1 || st[0].Code != "dane_usage_coherent_skipped" { t.Errorf("skipped: %+v", st) } } func TestRules_ObservationError(t *testing.T) { t.Parallel() obs := &mockObs{daneErr: errors.New("read failed")} for _, rule := range Rules() { st := rule.Evaluate(context.Background(), obs, nil) if len(st) == 0 || st[0].Code != "dane_observation_error" { t.Errorf("%s: expected observation_error, got %+v", rule.Name(), st) } } }