checker-smtp/checker/rule_test.go

294 lines
9.5 KiB
Go

package checker
import (
"context"
"encoding/json"
"errors"
"reflect"
"testing"
"time"
sdk "git.happydns.org/checker-sdk-go/checker"
)
func mustJSONForRule(t *testing.T, v any) json.RawMessage {
t.Helper()
b, err := json.Marshal(v)
if err != nil {
t.Fatalf("marshal: %v", err)
}
return b
}
// stubObs is a minimal sdk.ObservationGetter for the rule tests. It is
// keyed by ObservationKey so a single instance can serve a Get and any
// number of GetRelated lookups.
type stubObs struct {
data *SMTPData
getErr error
related map[sdk.ObservationKey][]sdk.RelatedObservation
}
func (s *stubObs) Get(_ context.Context, _ sdk.ObservationKey, dest any) error {
if s.getErr != nil {
return s.getErr
}
if s.data == nil {
return errors.New("no data")
}
b, err := json.Marshal(s.data)
if err != nil {
return err
}
return json.Unmarshal(b, dest)
}
func (s *stubObs) GetRelated(_ context.Context, key sdk.ObservationKey) ([]sdk.RelatedObservation, error) {
return s.related[key], nil
}
func TestSeverityToStatus(t *testing.T) {
cases := []struct {
sev string
want sdk.Status
}{
{SeverityCrit, sdk.StatusCrit},
{SeverityWarn, sdk.StatusWarn},
{SeverityInfo, sdk.StatusInfo},
{"", sdk.StatusOK},
{"bogus", sdk.StatusOK},
}
for _, c := range cases {
if got := severityToStatus(c.sev); got != c.want {
t.Errorf("%q → %v, want %v", c.sev, got, c.want)
}
}
}
func TestPassAndNotTestedStates(t *testing.T) {
p := passState("c.ok", "fine")
if p.Status != sdk.StatusOK || p.Code != "c.ok" || p.Message != "fine" {
t.Errorf("passState: %+v", p)
}
n := notTestedState("c.skip", "n/a")
if n.Status != sdk.StatusUnknown || n.Code != "c.skip" {
t.Errorf("notTestedState: %+v", n)
}
}
func TestIssueToState(t *testing.T) {
is := Issue{
Code: "x", Severity: SeverityWarn, Message: "m", Fix: "do",
Endpoint: "1.2.3.4:25", Target: "mx",
}
st := issueToState(is)
if st.Status != sdk.StatusWarn {
t.Errorf("status: %v", st.Status)
}
if st.Subject != "1.2.3.4:25" {
t.Errorf("subject (endpoint preferred): %q", st.Subject)
}
if st.Meta["fix"] != "do" || st.Meta["endpoint"] != "1.2.3.4:25" || st.Meta["target"] != "mx" {
t.Errorf("meta: %+v", st.Meta)
}
}
func TestIssueToState_TargetFallbackSubject(t *testing.T) {
is := Issue{Code: "x", Severity: SeverityCrit, Target: "mx"}
st := issueToState(is)
if st.Subject != "mx" {
t.Errorf("expected target subject, got %q", st.Subject)
}
}
func TestIssueToState_NoMeta(t *testing.T) {
is := Issue{Code: "x", Severity: SeverityInfo}
st := issueToState(is)
if st.Meta != nil {
t.Errorf("meta should be nil when no fields set, got %+v", st.Meta)
}
}
func TestStatesFromIssues(t *testing.T) {
issues := []Issue{
{Code: "a", Severity: SeverityCrit},
{Code: "b", Severity: SeverityWarn},
}
states := statesFromIssues(issues)
if len(states) != 2 {
t.Fatalf("got %d", len(states))
}
if states[0].Code != "a" || states[1].Code != "b" {
t.Errorf("order not preserved: %+v", states)
}
}
func TestIssuesByCodes_FiltersAndKeepsOrder(t *testing.T) {
d := &SMTPData{
Domain: "x",
MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}},
Endpoints: []EndpointProbe{{Target: "mx.x", IP: "1.2.3.4", Address: "1.2.3.4:25", Error: "boom"}},
}
got := issuesByCodes(d, CodeTCPUnreachable, CodeAllEndpointsDown, "smtp.does-not-exist")
if len(got) != 2 {
t.Fatalf("want 2 issues, got %d", len(got))
}
codes := []string{got[0].Code, got[1].Code}
want := []string{CodeTCPUnreachable, CodeAllEndpointsDown}
if !reflect.DeepEqual(codes, want) {
t.Errorf("order: got %v, want %v", codes, want)
}
}
func TestIssuesByCodes_EmptyCodes(t *testing.T) {
if got := issuesByCodes(&SMTPData{}); got != nil {
t.Errorf("expected nil, got %+v", got)
}
}
func TestRules_ContainsAllExpectedNames(t *testing.T) {
rules := Rules()
got := map[string]bool{}
for _, r := range rules {
got[r.Name()] = true
if r.Description() == "" {
t.Errorf("%s: empty description", r.Name())
}
}
want := []string{
"smtp.null_mx", "smtp.mx_present", "smtp.mx_sanity",
"smtp.endpoint_reachable", "smtp.banner_sanity", "smtp.ehlo_supported",
"smtp.starttls_offered", "smtp.starttls_handshake", "smtp.auth_posture",
"smtp.reverse_dns", "smtp.null_sender", "smtp.postmaster",
"smtp.open_relay", "smtp.extension_posture", "smtp.ipv6_reachable",
"smtp.tls_quality",
}
for _, n := range want {
if !got[n] {
t.Errorf("missing rule %q", n)
}
}
}
func TestNullMXRule_Detected(t *testing.T) {
obs := &stubObs{data: &SMTPData{MX: MXLookup{NullMX: true}}}
st := (&nullMXRule{}).Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusInfo || st[0].Code != CodeNullMX {
t.Errorf("got %+v", st)
}
}
func TestNullMXRule_NotNull(t *testing.T) {
obs := &stubObs{data: &SMTPData{Domain: "x", MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}}}}
st := (&nullMXRule{}).Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusOK {
t.Errorf("expected pass, got %+v", st)
}
}
func TestNullMXRule_LoadError(t *testing.T) {
obs := &stubObs{getErr: errors.New("boom")}
st := (&nullMXRule{}).Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusError {
t.Errorf("expected error, got %+v", st)
}
}
func TestSimpleConcernRule_PassWhenNoMatchingIssues(t *testing.T) {
yes := true
obs := &stubObs{data: &SMTPData{
Domain: "x",
MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}},
Endpoints: []EndpointProbe{{Target: "mx.x", IP: "1.2.3.4", Address: "1.2.3.4:25", TCPConnected: true, BannerReceived: true, BannerCode: 220, EHLOReceived: true, STARTTLSOffered: true, STARTTLSUpgraded: true, NullSenderAccepted: &yes, PostmasterAccepted: &yes, PTR: "mx.x", FCrDNSPass: true, HasPipelining: true, Has8BITMIME: true}},
}}
r := &simpleConcernRule{name: "smtp.endpoint_reachable", codes: []string{CodeTCPUnreachable}, passCode: "smtp.endpoint_reachable.ok", passMessage: "ok"}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusOK {
t.Errorf("expected single pass state, got %+v", st)
}
}
func TestSimpleConcernRule_EmitsMatchingIssues(t *testing.T) {
obs := &stubObs{data: &SMTPData{
Domain: "x",
MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}},
Endpoints: []EndpointProbe{{Target: "mx.x", IP: "1.2.3.4", Address: "1.2.3.4:25", Error: "boom"}},
}}
r := &simpleConcernRule{name: "smtp.endpoint_reachable", codes: []string{CodeTCPUnreachable, CodeAllEndpointsDown}, passCode: "smtp.endpoint_reachable.ok", passMessage: "ok"}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 2 {
t.Fatalf("want 2 states, got %d (%+v)", len(st), st)
}
if st[0].Status != sdk.StatusCrit {
t.Errorf("expected crit status, got %v", st[0].Status)
}
}
func TestSimpleConcernRule_NullMXSkipped(t *testing.T) {
obs := &stubObs{data: &SMTPData{MX: MXLookup{NullMX: true}}}
r := &simpleConcernRule{name: "smtp.starttls_offered", codes: []string{CodeSTARTTLSMissing}, passCode: "smtp.starttls_offered.ok"}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusUnknown {
t.Errorf("null MX should yield not-tested, got %+v", st)
}
}
func TestSimpleConcernRule_LoadError(t *testing.T) {
obs := &stubObs{getErr: errors.New("nope")}
r := &simpleConcernRule{name: "smtp.x", codes: []string{CodeTCPUnreachable}, passCode: "ok"}
st := r.Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusError {
t.Errorf("got %+v", st)
}
}
func TestTLSQualityRule_NoRelated(t *testing.T) {
obs := &stubObs{data: &SMTPData{Domain: "x", MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}}}}
st := (&tlsQualityRule{}).Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusUnknown {
t.Errorf("expected not-tested, got %+v", st)
}
}
func TestTLSQualityRule_NullMXSkipped(t *testing.T) {
obs := &stubObs{data: &SMTPData{MX: MXLookup{NullMX: true}}}
st := (&tlsQualityRule{}).Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusUnknown {
t.Errorf("got %+v", st)
}
}
func TestTLSQualityRule_PassWhenRelatedClean(t *testing.T) {
yes := true
notAfter := time.Now().Add(365 * 24 * time.Hour)
payload := map[string]any{"host": "mx.x", "port": 25, "chain_valid": yes, "hostname_match": yes, "not_after": notAfter}
related := map[sdk.ObservationKey][]sdk.RelatedObservation{
TLSRelatedKey: {{Data: mustJSONForRule(t, payload)}},
}
obs := &stubObs{
data: &SMTPData{Domain: "x", MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}}},
related: related,
}
st := (&tlsQualityRule{}).Evaluate(context.Background(), obs, nil)
if len(st) != 1 || st[0].Status != sdk.StatusOK {
t.Errorf("expected ok pass, got %+v", st)
}
}
func TestTLSQualityRule_RelatedIssuesFlow(t *testing.T) {
payload := map[string]any{
"host": "mx.x", "port": 25,
"issues": []map[string]any{{"code": "cert.expired", "severity": "crit", "message": "expired"}},
}
related := map[sdk.ObservationKey][]sdk.RelatedObservation{
TLSRelatedKey: {{Data: mustJSONForRule(t, payload)}},
}
obs := &stubObs{
data: &SMTPData{Domain: "x", MX: MXLookup{Records: []MXRecord{{Target: "mx.x", IPv4: []string{"1.2.3.4"}}}}},
related: related,
}
st := (&tlsQualityRule{}).Evaluate(context.Background(), obs, nil)
if len(st) == 0 || st[0].Status != sdk.StatusCrit {
t.Errorf("expected crit, got %+v", st)
}
}