checker-dane/checker/rules_test.go

276 lines
8.4 KiB
Go

package checker
import (
"context"
"encoding/json"
"errors"
"testing"
sdk "git.happydns.org/checker-sdk-go/checker"
tls "git.happydns.org/checker-tls/checker"
tlscontract "git.happydns.org/checker-tls/contract"
)
// mockObs is a lightweight ObservationGetter for rule unit tests.
type mockObs struct {
dane *DANEData
daneErr error
probes map[string]tls.TLSProbe
relatedErr error
}
func (m *mockObs) Get(_ context.Context, key sdk.ObservationKey, dest any) error {
if m.daneErr != nil {
return m.daneErr
}
if key != ObservationKeyDANE || m.dane == nil {
return errors.New("not found")
}
b, err := json.Marshal(m.dane)
if err != nil {
return err
}
return json.Unmarshal(b, dest)
}
func (m *mockObs) GetRelated(_ context.Context, key sdk.ObservationKey) ([]sdk.RelatedObservation, error) {
if m.relatedErr != nil {
return nil, m.relatedErr
}
if key != tls.ObservationKeyTLSProbes || m.probes == nil {
return nil, nil
}
payload := struct {
Probes map[string]tls.TLSProbe `json:"probes"`
}{Probes: m.probes}
b, _ := json.Marshal(payload)
return []sdk.RelatedObservation{{
CheckerID: "tls",
Key: tls.ObservationKeyTLSProbes,
Data: b,
}}, nil
}
func makeTarget(host string, port uint16, recs []TLSARecord) TargetResult {
t := TargetResult{
Owner: tlsaOwnerName(port, "tcp", host),
Host: host,
Port: port,
Proto: "tcp",
Records: recs,
}
t.Ref = tlscontract.Ref(tlscontract.TLSEndpoint{Host: host, Port: port, SNI: host})
return t
}
func TestHasRecordsRule(t *testing.T) {
t.Parallel()
r := &hasRecordsRule{}
// No records, no invalid → unknown
obs := &mockObs{dane: &DANEData{}}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_no_records" {
t.Errorf("no records: %+v", st)
}
// Records present → ok
obs = &mockObs{dane: &DANEData{Targets: []TargetResult{makeTarget("a.example.com", 443, []TLSARecord{{}})}}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_has_records_ok" {
t.Errorf("ok: %+v", st)
}
// Invalid records, no targets → error states
obs = &mockObs{dane: &DANEData{Invalid: []InvalidRecord{{Owner: "_x._tcp", Reason: "bad port"}}}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) < 2 {
t.Fatalf("expected per-record + aggregate, got %+v", st)
}
if st[0].Code != "dane_invalid_owner" || st[len(st)-1].Code != "dane_no_usable_records" {
t.Errorf("invalid only: %+v", st)
}
// Observation read error
obs = &mockObs{daneErr: errors.New("boom")}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_observation_error" {
t.Errorf("err: %+v", st)
}
}
func TestProbeAvailableRule(t *testing.T) {
t.Parallel()
r := &probeAvailableRule{}
tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsageDANEEE}})
// Probe present
leaf := fakeCert([]byte("l"), []byte("s"))
obs := &mockObs{
dane: &DANEData{Targets: []TargetResult{tgt}},
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}},
}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_probe_available_ok" {
t.Errorf("ok: %+v", st)
}
// Probe absent
obs.probes = map[string]tls.TLSProbe{}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_no_probe" {
t.Errorf("missing: %+v", st)
}
// No targets at all
obs = &mockObs{dane: &DANEData{}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_probe_available_skipped" {
t.Errorf("empty: %+v", st)
}
// Related-fetch error surfaces as warning state.
obs = &mockObs{dane: &DANEData{Targets: []TargetResult{tgt}}, relatedErr: errors.New("upstream down")}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_observation_warning" {
t.Errorf("relatedErr: %+v", st)
}
}
func TestHandshakeOKRule(t *testing.T) {
t.Parallel()
r := &handshakeOKRule{}
tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsageDANEEE}})
leaf := fakeCert([]byte("l"), []byte("s"))
// All good.
obs := &mockObs{
dane: &DANEData{Targets: []TargetResult{tgt}},
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}},
}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_handshake_ok" {
t.Errorf("ok: %+v", st)
}
// Handshake failed.
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Error: "tls: bad cert"}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_handshake_failed" {
t.Errorf("failed: %+v", st)
}
}
func TestRecordsMatchChainRule(t *testing.T) {
t.Parallel()
r := &recordsMatchChainRule{}
leaf := fakeCert([]byte("leaf"), []byte("ls"))
tgt := makeTarget("a.example.com", 443, []TLSARecord{
{Usage: UsageDANEEE, Selector: SelectorSPKI, MatchingType: MatchingSHA256, Certificate: leaf.SPKISHA256},
})
obs := &mockObs{
dane: &DANEData{Targets: []TargetResult{tgt}},
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}},
}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_match_ok" {
t.Errorf("match ok: %+v", st)
}
// Same target, wrong cert hash → no match (crit).
tgt.Records[0].Certificate = "deadbeef"
obs.dane = &DANEData{Targets: []TargetResult{tgt}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_no_match" {
t.Errorf("no match: %+v", st)
}
// No probe usable → skipped.
obs.probes = map[string]tls.TLSProbe{}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_records_match_chain_skipped" {
t.Errorf("skipped: %+v", st)
}
}
func TestPKIXChainValidRule(t *testing.T) {
t.Parallel()
r := &pkixChainValidRule{}
leaf := fakeCert([]byte("l"), []byte("s"))
bTrue, bFalse := true, false
// PKIX usage + valid chain → ok.
tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsagePKIXEE}})
obs := &mockObs{
dane: &DANEData{Targets: []TargetResult{tgt}},
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}, ChainValid: &bTrue}},
}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_pkix_chain_valid_ok" {
t.Errorf("ok: %+v", st)
}
// PKIX usage + invalid chain → crit.
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}, ChainValid: &bFalse}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_pkix_chain_invalid" {
t.Errorf("invalid: %+v", st)
}
// DANE-only usages → skipped (rule does not apply).
tgt.Records = []TLSARecord{{Usage: UsageDANEEE}}
obs.dane = &DANEData{Targets: []TargetResult{tgt}}
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_pkix_chain_valid_skipped" {
t.Errorf("skipped: %+v", st)
}
}
func TestUsageCoherentRule(t *testing.T) {
t.Parallel()
r := &usageCoherentRule{}
leaf := fakeCert([]byte("l"), []byte("ls"))
mid := fakeCert([]byte("m"), []byte("ms"))
// EE record whose hash matches the intermediate → warn.
tgt := makeTarget("a.example.com", 443, []TLSARecord{{
Usage: UsageDANEEE, Selector: SelectorSPKI, MatchingType: MatchingSHA256,
Certificate: mid.SPKISHA256,
}})
obs := &mockObs{
dane: &DANEData{Targets: []TargetResult{tgt}},
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf, mid}}},
}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_usage_incoherent" {
t.Errorf("incoherent: %+v", st)
}
// EE matching leaf → ok.
tgt.Records[0].Certificate = leaf.SPKISHA256
obs.dane = &DANEData{Targets: []TargetResult{tgt}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_usage_coherent_ok" {
t.Errorf("coherent ok: %+v", st)
}
// Single-cert chain → skipped.
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}
st = r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Code != "dane_usage_coherent_skipped" {
t.Errorf("skipped: %+v", st)
}
}
func TestRules_ObservationError(t *testing.T) {
t.Parallel()
obs := &mockObs{daneErr: errors.New("read failed")}
for _, rule := range Rules() {
st := rule.Evaluate(context.Background(), obs, nil)
if len(st) == 0 || st[0].Code != "dane_observation_error" {
t.Errorf("%s: expected observation_error, got %+v", rule.Name(), st)
}
}
}