276 lines
8.4 KiB
Go
276 lines
8.4 KiB
Go
package checker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"testing"
|
|
|
|
sdk "git.happydns.org/checker-sdk-go/checker"
|
|
tls "git.happydns.org/checker-tls/checker"
|
|
tlscontract "git.happydns.org/checker-tls/contract"
|
|
)
|
|
|
|
// mockObs is a lightweight ObservationGetter for rule unit tests.
|
|
type mockObs struct {
|
|
dane *DANEData
|
|
daneErr error
|
|
probes map[string]tls.TLSProbe
|
|
relatedErr error
|
|
}
|
|
|
|
func (m *mockObs) Get(_ context.Context, key sdk.ObservationKey, dest any) error {
|
|
if m.daneErr != nil {
|
|
return m.daneErr
|
|
}
|
|
if key != ObservationKeyDANE || m.dane == nil {
|
|
return errors.New("not found")
|
|
}
|
|
b, err := json.Marshal(m.dane)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return json.Unmarshal(b, dest)
|
|
}
|
|
|
|
func (m *mockObs) GetRelated(_ context.Context, key sdk.ObservationKey) ([]sdk.RelatedObservation, error) {
|
|
if m.relatedErr != nil {
|
|
return nil, m.relatedErr
|
|
}
|
|
if key != tls.ObservationKeyTLSProbes || m.probes == nil {
|
|
return nil, nil
|
|
}
|
|
payload := struct {
|
|
Probes map[string]tls.TLSProbe `json:"probes"`
|
|
}{Probes: m.probes}
|
|
b, _ := json.Marshal(payload)
|
|
return []sdk.RelatedObservation{{
|
|
CheckerID: "tls",
|
|
Key: tls.ObservationKeyTLSProbes,
|
|
Data: b,
|
|
}}, nil
|
|
}
|
|
|
|
func makeTarget(host string, port uint16, recs []TLSARecord) TargetResult {
|
|
t := TargetResult{
|
|
Owner: tlsaOwnerName(port, "tcp", host),
|
|
Host: host,
|
|
Port: port,
|
|
Proto: "tcp",
|
|
Records: recs,
|
|
}
|
|
t.Ref = tlscontract.Ref(tlscontract.TLSEndpoint{Host: host, Port: port, SNI: host})
|
|
return t
|
|
}
|
|
|
|
func TestHasRecordsRule(t *testing.T) {
|
|
t.Parallel()
|
|
r := &hasRecordsRule{}
|
|
|
|
// No records, no invalid → unknown
|
|
obs := &mockObs{dane: &DANEData{}}
|
|
st := r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_no_records" {
|
|
t.Errorf("no records: %+v", st)
|
|
}
|
|
|
|
// Records present → ok
|
|
obs = &mockObs{dane: &DANEData{Targets: []TargetResult{makeTarget("a.example.com", 443, []TLSARecord{{}})}}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_has_records_ok" {
|
|
t.Errorf("ok: %+v", st)
|
|
}
|
|
|
|
// Invalid records, no targets → error states
|
|
obs = &mockObs{dane: &DANEData{Invalid: []InvalidRecord{{Owner: "_x._tcp", Reason: "bad port"}}}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) < 2 {
|
|
t.Fatalf("expected per-record + aggregate, got %+v", st)
|
|
}
|
|
if st[0].Code != "dane_invalid_owner" || st[len(st)-1].Code != "dane_no_usable_records" {
|
|
t.Errorf("invalid only: %+v", st)
|
|
}
|
|
|
|
// Observation read error
|
|
obs = &mockObs{daneErr: errors.New("boom")}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_observation_error" {
|
|
t.Errorf("err: %+v", st)
|
|
}
|
|
}
|
|
|
|
func TestProbeAvailableRule(t *testing.T) {
|
|
t.Parallel()
|
|
r := &probeAvailableRule{}
|
|
tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsageDANEEE}})
|
|
|
|
// Probe present
|
|
leaf := fakeCert([]byte("l"), []byte("s"))
|
|
obs := &mockObs{
|
|
dane: &DANEData{Targets: []TargetResult{tgt}},
|
|
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}},
|
|
}
|
|
st := r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_probe_available_ok" {
|
|
t.Errorf("ok: %+v", st)
|
|
}
|
|
|
|
// Probe absent
|
|
obs.probes = map[string]tls.TLSProbe{}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_no_probe" {
|
|
t.Errorf("missing: %+v", st)
|
|
}
|
|
|
|
// No targets at all
|
|
obs = &mockObs{dane: &DANEData{}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_probe_available_skipped" {
|
|
t.Errorf("empty: %+v", st)
|
|
}
|
|
|
|
// Related-fetch error surfaces as warning state.
|
|
obs = &mockObs{dane: &DANEData{Targets: []TargetResult{tgt}}, relatedErr: errors.New("upstream down")}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_observation_warning" {
|
|
t.Errorf("relatedErr: %+v", st)
|
|
}
|
|
}
|
|
|
|
func TestHandshakeOKRule(t *testing.T) {
|
|
t.Parallel()
|
|
r := &handshakeOKRule{}
|
|
tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsageDANEEE}})
|
|
leaf := fakeCert([]byte("l"), []byte("s"))
|
|
|
|
// All good.
|
|
obs := &mockObs{
|
|
dane: &DANEData{Targets: []TargetResult{tgt}},
|
|
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}},
|
|
}
|
|
st := r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_handshake_ok" {
|
|
t.Errorf("ok: %+v", st)
|
|
}
|
|
|
|
// Handshake failed.
|
|
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Error: "tls: bad cert"}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_handshake_failed" {
|
|
t.Errorf("failed: %+v", st)
|
|
}
|
|
}
|
|
|
|
func TestRecordsMatchChainRule(t *testing.T) {
|
|
t.Parallel()
|
|
r := &recordsMatchChainRule{}
|
|
leaf := fakeCert([]byte("leaf"), []byte("ls"))
|
|
|
|
tgt := makeTarget("a.example.com", 443, []TLSARecord{
|
|
{Usage: UsageDANEEE, Selector: SelectorSPKI, MatchingType: MatchingSHA256, Certificate: leaf.SPKISHA256},
|
|
})
|
|
obs := &mockObs{
|
|
dane: &DANEData{Targets: []TargetResult{tgt}},
|
|
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}},
|
|
}
|
|
st := r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_match_ok" {
|
|
t.Errorf("match ok: %+v", st)
|
|
}
|
|
|
|
// Same target, wrong cert hash → no match (crit).
|
|
tgt.Records[0].Certificate = "deadbeef"
|
|
obs.dane = &DANEData{Targets: []TargetResult{tgt}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_no_match" {
|
|
t.Errorf("no match: %+v", st)
|
|
}
|
|
|
|
// No probe usable → skipped.
|
|
obs.probes = map[string]tls.TLSProbe{}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_records_match_chain_skipped" {
|
|
t.Errorf("skipped: %+v", st)
|
|
}
|
|
}
|
|
|
|
func TestPKIXChainValidRule(t *testing.T) {
|
|
t.Parallel()
|
|
r := &pkixChainValidRule{}
|
|
leaf := fakeCert([]byte("l"), []byte("s"))
|
|
bTrue, bFalse := true, false
|
|
|
|
// PKIX usage + valid chain → ok.
|
|
tgt := makeTarget("a.example.com", 443, []TLSARecord{{Usage: UsagePKIXEE}})
|
|
obs := &mockObs{
|
|
dane: &DANEData{Targets: []TargetResult{tgt}},
|
|
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}, ChainValid: &bTrue}},
|
|
}
|
|
st := r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_pkix_chain_valid_ok" {
|
|
t.Errorf("ok: %+v", st)
|
|
}
|
|
|
|
// PKIX usage + invalid chain → crit.
|
|
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}, ChainValid: &bFalse}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_pkix_chain_invalid" {
|
|
t.Errorf("invalid: %+v", st)
|
|
}
|
|
|
|
// DANE-only usages → skipped (rule does not apply).
|
|
tgt.Records = []TLSARecord{{Usage: UsageDANEEE}}
|
|
obs.dane = &DANEData{Targets: []TargetResult{tgt}}
|
|
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_pkix_chain_valid_skipped" {
|
|
t.Errorf("skipped: %+v", st)
|
|
}
|
|
}
|
|
|
|
func TestUsageCoherentRule(t *testing.T) {
|
|
t.Parallel()
|
|
r := &usageCoherentRule{}
|
|
leaf := fakeCert([]byte("l"), []byte("ls"))
|
|
mid := fakeCert([]byte("m"), []byte("ms"))
|
|
|
|
// EE record whose hash matches the intermediate → warn.
|
|
tgt := makeTarget("a.example.com", 443, []TLSARecord{{
|
|
Usage: UsageDANEEE, Selector: SelectorSPKI, MatchingType: MatchingSHA256,
|
|
Certificate: mid.SPKISHA256,
|
|
}})
|
|
obs := &mockObs{
|
|
dane: &DANEData{Targets: []TargetResult{tgt}},
|
|
probes: map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf, mid}}},
|
|
}
|
|
st := r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_usage_incoherent" {
|
|
t.Errorf("incoherent: %+v", st)
|
|
}
|
|
|
|
// EE matching leaf → ok.
|
|
tgt.Records[0].Certificate = leaf.SPKISHA256
|
|
obs.dane = &DANEData{Targets: []TargetResult{tgt}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_usage_coherent_ok" {
|
|
t.Errorf("coherent ok: %+v", st)
|
|
}
|
|
|
|
// Single-cert chain → skipped.
|
|
obs.probes = map[string]tls.TLSProbe{tgt.Ref: {Chain: []tls.CertInfo{leaf}}}
|
|
st = r.Evaluate(context.Background(), obs, nil)
|
|
if len(st) != 1 || st[0].Code != "dane_usage_coherent_skipped" {
|
|
t.Errorf("skipped: %+v", st)
|
|
}
|
|
}
|
|
|
|
func TestRules_ObservationError(t *testing.T) {
|
|
t.Parallel()
|
|
obs := &mockObs{daneErr: errors.New("read failed")}
|
|
for _, rule := range Rules() {
|
|
st := rule.Evaluate(context.Background(), obs, nil)
|
|
if len(st) == 0 || st[0].Code != "dane_observation_error" {
|
|
t.Errorf("%s: expected observation_error, got %+v", rule.Name(), st)
|
|
}
|
|
}
|
|
}
|