checker-smtp/checker/collect_test.go

307 lines
8.8 KiB
Go

package checker
import (
"context"
"encoding/json"
"net"
"strings"
"testing"
"time"
sdk "git.happydns.org/checker-sdk-go/checker"
)
func TestIsValidHostname(t *testing.T) {
good := []string{"example.com", "mx-1.example.com", "a.b.c.d", "MX.EXAMPLE.COM", "1.2.3.4"}
for _, s := range good {
if !isValidHostname(s) {
t.Errorf("expected %q valid", s)
}
}
bad := []string{
"", "a b.com", "a\r\nb.com", "a\nb.com", "<bracket>.com",
"under_score.com", "a@b.com", "spaces .com", strings.Repeat("a", 254),
}
for _, s := range bad {
if isValidHostname(s) {
t.Errorf("expected %q invalid", s)
}
}
}
func TestIsValidMailbox(t *testing.T) {
good := []string{"a@b.com", "user.name+tag@mx.example.com", "postmaster@example.org"}
for _, s := range good {
if !isValidMailbox(s) {
t.Errorf("expected %q valid", s)
}
}
bad := []string{
"",
"@example.com",
"a@",
"a b@example.com", // space in local
"a\r\n@example.com", // CRLF
"a<b>@example.com", // bracket in local
"a,b@example.com", // comma in local
"\"quoted\"@example.com", // quoted local
"a@with space.com",
"a@<bracket>.com",
strings.Repeat("a", 65) + "@example.com", // too-long local
}
for _, s := range bad {
if isValidMailbox(s) {
t.Errorf("expected %q invalid", s)
}
}
}
func TestCollect_RejectsInvalidDomain(t *testing.T) {
p := &smtpProvider{}
_, err := p.Collect(context.Background(), sdk.CheckerOptions{"domain": "evil\r\nMAIL FROM:<>"})
if err == nil || !strings.Contains(err.Error(), "invalid domain") {
t.Errorf("expected invalid-domain error, got %v", err)
}
}
func TestCollect_RejectsInvalidHELO(t *testing.T) {
p := &smtpProvider{}
_, err := p.Collect(context.Background(), sdk.CheckerOptions{
"domain": "example.com",
"helo_name": "evil\r\nRSET",
})
if err == nil || !strings.Contains(err.Error(), "invalid helo_name") {
t.Errorf("expected invalid-helo error, got %v", err)
}
}
func TestCollect_RewritesInvalidProbeAddress(t *testing.T) {
p := &smtpProvider{}
body := json.RawMessage(`{"mx":[{"Preference":0,"Mx":"."}]}`) // null MX → returns immediately
out, err := p.Collect(context.Background(), sdk.CheckerOptions{
"domain": "example.com",
"service": body,
"timeout": 1.0,
"test_probe_address": "evil\r\nMAIL FROM:<x>",
})
if err != nil {
t.Fatalf("collect: %v", err)
}
if !out.(*SMTPData).MX.NullMX {
t.Error("expected null MX path")
}
}
func TestSplitMail_MoreCases(t *testing.T) {
cases := []struct {
in string
ok bool
domain, local string
}{
{"a@b.com", true, "b.com", "a"},
{"a@b@c.com", true, "c.com", "a@b"}, // last @ wins
{"", false, "", ""},
{"@", false, "", ""},
}
for _, c := range cases {
d, l, ok := splitMail(c.in)
if ok != c.ok || d != c.domain || l != c.local {
t.Errorf("splitMail(%q) = (%q,%q,%v), want (%q,%q,%v)", c.in, d, l, ok, c.domain, c.local, c.ok)
}
}
}
func TestComputeCoverage_Empty(t *testing.T) {
d := &SMTPData{}
computeCoverage(d)
if d.Coverage.AnyReachable {
t.Errorf("empty endpoints should not be reachable")
}
}
func TestComputeCoverage_AllPath(t *testing.T) {
yes := true
d := &SMTPData{
Endpoints: []EndpointProbe{
{IP: "1.2.3.4", IsIPv6: false, TCPConnected: true, BannerReceived: true, EHLOReceived: true, STARTTLSUpgraded: true, NullSenderAccepted: &yes, PostmasterAccepted: &yes},
{IP: "2001:db8::1", IsIPv6: true, TCPConnected: true, BannerReceived: true, EHLOReceived: true, STARTTLSUpgraded: true, NullSenderAccepted: &yes, PostmasterAccepted: &yes},
},
}
computeCoverage(d)
c := d.Coverage
if !c.HasIPv4 || !c.HasIPv6 || !c.AnyReachable || !c.AnyBanner || !c.AnyEHLO || !c.AnySTARTTLS || !c.AllSTARTTLS || !c.AllAcceptMail {
t.Errorf("expected all coverage flags set, got %+v", c)
}
}
func TestComputeCoverage_PartialSTARTTLS(t *testing.T) {
yes := true
d := &SMTPData{
Endpoints: []EndpointProbe{
{IP: "1.2.3.4", TCPConnected: true, BannerReceived: true, EHLOReceived: true, STARTTLSUpgraded: true, NullSenderAccepted: &yes, PostmasterAccepted: &yes},
{IP: "1.2.3.5", TCPConnected: true, BannerReceived: true, EHLOReceived: true, STARTTLSUpgraded: false, NullSenderAccepted: &yes, PostmasterAccepted: &yes},
},
}
computeCoverage(d)
if !d.Coverage.AnySTARTTLS {
t.Error("any STARTTLS expected")
}
if d.Coverage.AllSTARTTLS {
t.Error("not all STARTTLS")
}
}
func TestComputeCoverage_RejectedMailFlipsAccept(t *testing.T) {
no := false
yes := true
d := &SMTPData{
Endpoints: []EndpointProbe{
{IP: "1.2.3.4", TCPConnected: true, BannerReceived: true, EHLOReceived: true, STARTTLSUpgraded: true, NullSenderAccepted: &no, PostmasterAccepted: &yes},
},
}
computeCoverage(d)
if d.Coverage.AllAcceptMail {
t.Error("AllAcceptMail should be false when null sender rejected")
}
}
func TestComputeCoverage_NoEHLOFlipsAccept(t *testing.T) {
d := &SMTPData{
Endpoints: []EndpointProbe{
{IP: "1.2.3.4", TCPConnected: true, BannerReceived: true, EHLOReceived: false},
},
}
computeCoverage(d)
if d.Coverage.AllAcceptMail {
t.Error("no EHLO must drop AllAcceptMail")
}
}
func TestParseServiceBody_FullEnvelope(t *testing.T) {
type mxitem struct {
Hdr struct{ Name string } `json:"Hdr"`
Preference uint16
Mx string
}
body := struct {
MXs []mxitem `json:"mx"`
}{
MXs: []mxitem{
{Preference: 10, Mx: "mx1.example.com."},
{Preference: 20, Mx: "mx2.example.com."},
},
}
envelope := struct {
Type string `json:"_svctype"`
Service json.RawMessage `json:"Service"`
}{Type: "svcs.MXs"}
envelope.Service, _ = json.Marshal(body)
raw, _ := json.Marshal(envelope)
out := parseServiceBody(raw)
if len(out) != 2 {
t.Fatalf("got %d", len(out))
}
if out[0].Preference != 10 || out[0].Target != "mx1.example.com." {
t.Errorf("[0]: %+v", out[0])
}
}
func TestParseServiceBody_BareBody(t *testing.T) {
raw := json.RawMessage(`{"mx":[{"Preference":5,"Mx":"mx.example.com."}]}`)
out := parseServiceBody(raw)
if len(out) != 1 || out[0].Preference != 5 {
t.Errorf("got %+v", out)
}
}
func TestParseServiceBody_BadJSON(t *testing.T) {
if got := parseServiceBody(json.RawMessage(`not json`)); got != nil {
t.Errorf("expected nil, got %+v", got)
}
}
func TestParseServiceBody_NoMX(t *testing.T) {
raw := json.RawMessage(`{"mx":[]}`)
out := parseServiceBody(raw)
if out == nil {
t.Errorf("empty array should yield empty slice, not nil")
}
if len(out) != 0 {
t.Errorf("got %+v", out)
}
}
func TestLookupMX_NXDomainBecomesEmpty(t *testing.T) {
// Use a TLD-style label that fails fast and cleanly. We rely on the
// system resolver returning IsNotFound for invalid.example.invalid;
// if the local resolver is unusual, the test is skipped.
r := &net.Resolver{}
out, err := lookupMX(context.Background(), r, "invalid.example.invalid")
if err != nil {
t.Skipf("resolver returned a different error: %v", err)
}
if out != nil {
t.Errorf("expected nil for NXDOMAIN, got %+v", out)
}
}
func TestCollect_RejectsEmptyDomain(t *testing.T) {
p := &smtpProvider{}
_, err := p.Collect(context.Background(), sdk.CheckerOptions{"domain": " "})
if err == nil || !strings.Contains(err.Error(), "domain is required") {
t.Errorf("expected domain-required error, got %v", err)
}
}
func TestCollect_NullMXFromService(t *testing.T) {
p := &smtpProvider{}
body := json.RawMessage(`{"mx":[{"Preference":0,"Mx":"."}]}`)
out, err := p.Collect(context.Background(), sdk.CheckerOptions{
"domain": "example.com",
"service": body,
"timeout": 1.0,
})
if err != nil {
t.Fatalf("collect: %v", err)
}
d, ok := out.(*SMTPData)
if !ok {
t.Fatalf("type: %T", out)
}
if !d.MX.NullMX {
t.Errorf("expected NullMX=true, got %+v", d.MX)
}
if len(d.Endpoints) != 0 {
t.Errorf("null MX should not probe, got %d endpoints", len(d.Endpoints))
}
}
func TestCollect_RewritesProbeToAvoidLocalDomain(t *testing.T) {
// Use a tiny timeout so the probe attempt against the bogus IP
// fails fast; we only assert on the rewriting behavior, which
// happens before any network call.
p := &smtpProvider{}
body := json.RawMessage(`{"mx":[{"Preference":10,"Mx":"127.255.255.255"}]}`) // unlikely to be an MX target
opts := sdk.CheckerOptions{
"domain": "example.com",
"service": body,
"timeout": 1.0,
"test_probe_address": "victim@example.com", // same domain → must be rewritten
"test_open_relay": false,
"test_null_sender": false,
"test_postmaster": false,
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
out, err := p.Collect(ctx, opts)
if err != nil {
t.Fatalf("collect: %v", err)
}
d := out.(*SMTPData)
for _, ep := range d.Endpoints {
if strings.Contains(ep.OpenRelayRecipient, "example.com") {
t.Errorf("probe recipient leaked into the local domain: %q", ep.OpenRelayRecipient)
}
}
}