checker-happydeliver/checker/collect_test.go

691 lines
20 KiB
Go

package checker
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"net/mail"
"net/url"
"strconv"
"strings"
"testing"
"time"
sdk "git.happydns.org/checker-sdk-go/checker"
)
func baseValidOptions() sdk.CheckerOptions {
return sdk.CheckerOptions{
"happydeliver_url": "https://deliver.example.org",
"happydeliver_token": "tok",
"smtp_host": "smtp.example.org",
"smtp_port": float64(587),
"smtp_tls": "starttls",
"from_address": "test@example.org",
}
}
func TestLoadConfigDefaults(t *testing.T) {
cfg, err := loadConfig(baseValidOptions())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.SMTPPort != 587 {
t.Errorf("SMTPPort = %d, want 587", cfg.SMTPPort)
}
if cfg.SMTPTLS != "starttls" {
t.Errorf("SMTPTLS = %q, want starttls", cfg.SMTPTLS)
}
if cfg.Subject != defaultSubject {
t.Errorf("Subject = %q, want default", cfg.Subject)
}
if cfg.BodyText != defaultBodyText {
t.Errorf("BodyText not defaulted")
}
if cfg.BodyHTML != defaultBodyHTML {
t.Errorf("BodyHTML not defaulted")
}
if cfg.WaitTimeout != 900*time.Second {
t.Errorf("WaitTimeout = %v, want 900s", cfg.WaitTimeout)
}
if cfg.PollInterval != 5*time.Second {
t.Errorf("PollInterval = %v, want 5s", cfg.PollInterval)
}
if cfg.FromAddress != "test@example.org" {
t.Errorf("FromAddress = %q", cfg.FromAddress)
}
if cfg.FromHeader != "<test@example.org>" {
t.Errorf("FromHeader = %q", cfg.FromHeader)
}
}
func TestLoadConfigPollClamping(t *testing.T) {
for _, tc := range []struct {
in, want int
}{
{0, 2}, {1, 2}, {2, 2}, {30, 30}, {60, 60}, {120, 60},
} {
opts := baseValidOptions()
opts["poll_interval"] = float64(tc.in)
cfg, err := loadConfig(opts)
if err != nil {
t.Fatalf("in=%d: %v", tc.in, err)
}
if cfg.PollInterval != time.Duration(tc.want)*time.Second {
t.Errorf("in=%d: got %v, want %ds", tc.in, cfg.PollInterval, tc.want)
}
}
}
func TestLoadConfigValidationErrors(t *testing.T) {
cases := []struct {
name string
mutate func(sdk.CheckerOptions)
want string
}{
{"missing url", func(o sdk.CheckerOptions) { delete(o, "happydeliver_url") }, "happydeliver_url is required"},
{"bad url", func(o sdk.CheckerOptions) { o["happydeliver_url"] = "not a url" }, "invalid happydeliver_url"},
{"non-http scheme", func(o sdk.CheckerOptions) { o["happydeliver_url"] = "ftp://x.test" }, "must use http or https"},
{"missing host", func(o sdk.CheckerOptions) { o["happydeliver_url"] = "https://" }, "missing a host"},
{"missing smtp_host", func(o sdk.CheckerOptions) { delete(o, "smtp_host") }, "smtp_host is required"},
{"missing from", func(o sdk.CheckerOptions) { delete(o, "from_address") }, "from_address is required"},
{"bad from", func(o sdk.CheckerOptions) { o["from_address"] = "not-an-address" }, "invalid from_address"},
{"bad tls mode", func(o sdk.CheckerOptions) { o["smtp_tls"] = "weird" }, "smtp_tls must be one of"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
opts := baseValidOptions()
tc.mutate(opts)
_, err := loadConfig(opts)
if err == nil {
t.Fatalf("expected error containing %q, got nil", tc.want)
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("err = %v, want substring %q", err, tc.want)
}
})
}
}
func TestLoadConfigAcceptsDisplayNameFrom(t *testing.T) {
opts := baseValidOptions()
opts["from_address"] = "Alice <alice@example.org>"
cfg, err := loadConfig(opts)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.FromAddress != "alice@example.org" {
t.Errorf("FromAddress should be the bare address, got %q", cfg.FromAddress)
}
if !strings.Contains(cfg.FromHeader, "<alice@example.org>") {
t.Errorf("FromHeader should keep display form, got %q", cfg.FromHeader)
}
}
func TestLoadConfigTrimsURL(t *testing.T) {
opts := baseValidOptions()
opts["happydeliver_url"] = " https://deliver.example.org "
cfg, err := loadConfig(opts)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.HappyDeliverURL != "https://deliver.example.org" {
t.Errorf("URL not trimmed: %q", cfg.HappyDeliverURL)
}
}
func TestJoinURL(t *testing.T) {
cases := map[string]string{
"https://x": "https://x/api",
"https://x/": "https://x/api",
"https://x///": "https://x/api",
"https://x/foo": "https://x/foo/api",
}
for base, want := range cases {
if got := joinURL(base, "/api"); got != want {
t.Errorf("joinURL(%q) = %q, want %q", base, got, want)
}
}
}
func TestHostFromAddress(t *testing.T) {
cases := map[string]string{
"user@example.org": "example.org",
"a@b@example.org": "example.org",
"no-at-sign": "localhost",
"": "localhost",
}
for in, want := range cases {
if got := hostFromAddress(in); got != want {
t.Errorf("hostFromAddress(%q) = %q, want %q", in, got, want)
}
}
}
func TestStringOpt(t *testing.T) {
opts := sdk.CheckerOptions{"k": "v", "n": float64(1)}
if stringOpt(opts, "k") != "v" {
t.Error("expected v")
}
if stringOpt(opts, "missing") != "" {
t.Error("missing key should give empty string")
}
if stringOpt(opts, "n") != "" {
t.Error("non-string value should give empty string")
}
}
func TestBuildMessageStructure(t *testing.T) {
cfg := &runConfig{
FromAddress: "alice@example.org",
FromHeader: "Alice <alice@example.org>",
Subject: "Hello: accents",
BodyText: "plain body",
BodyHTML: "<p>html body</p>",
}
raw := buildMessage(cfg, "rcpt@deliver.test")
msg, err := mail.ReadMessage(strings.NewReader(string(raw)))
if err != nil {
t.Fatalf("not a parseable RFC 5322 message: %v\n--\n%s", err, raw)
}
if got := msg.Header.Get("From"); !strings.Contains(got, "alice@example.org") {
t.Errorf("From header = %q", got)
}
if got := msg.Header.Get("To"); got != "rcpt@deliver.test" {
t.Errorf("To header = %q", got)
}
// Subject: Q-encoded UTF-8.
dec := new(mime.WordDecoder)
subj, err := dec.DecodeHeader(msg.Header.Get("Subject"))
if err != nil {
t.Fatalf("subject decode: %v", err)
}
if subj != "Hello: accents" {
t.Errorf("decoded subject = %q", subj)
}
if msg.Header.Get("MIME-Version") != "1.0" {
t.Errorf("missing MIME-Version")
}
if mid := msg.Header.Get("Message-ID"); !strings.Contains(mid, "@example.org>") {
t.Errorf("Message-ID = %q", mid)
}
mediaType, params, err := mime.ParseMediaType(msg.Header.Get("Content-Type"))
if err != nil {
t.Fatalf("content-type: %v", err)
}
if mediaType != "multipart/alternative" {
t.Errorf("media type = %q", mediaType)
}
if params["boundary"] == "" {
t.Fatal("missing boundary")
}
mr := multipart.NewReader(msg.Body, params["boundary"])
var seenText, seenHTML bool
for {
p, err := mr.NextPart()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("next part: %v", err)
}
if cte := p.Header.Get("Content-Transfer-Encoding"); cte != "8bit" {
t.Errorf("part CTE = %q, want 8bit", cte)
}
body, _ := io.ReadAll(p)
ct := p.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "text/plain"):
seenText = true
if !strings.Contains(string(body), "plain body") {
t.Errorf("plain part body = %q", body)
}
case strings.HasPrefix(ct, "text/html"):
seenHTML = true
if !strings.Contains(string(body), "html body") {
t.Errorf("html part body = %q", body)
}
default:
t.Errorf("unexpected part Content-Type %q", ct)
}
}
if !seenText || !seenHTML {
t.Errorf("missing parts: text=%v html=%v", seenText, seenHTML)
}
}
func TestBuildMessageBodyTextNormalisation(t *testing.T) {
cfg := &runConfig{
FromAddress: "a@b.test", FromHeader: "<a@b.test>",
Subject: "s", BodyText: "no newline", BodyHTML: "<p>x</p>",
}
raw := string(buildMessage(cfg, "r@x.test"))
// The plain body must be CRLF-terminated before the next boundary line.
if !strings.Contains(raw, "no newline\r\n--") {
t.Errorf("plain body was not CRLF-normalised before boundary:\n%s", raw)
}
}
// --- HTTP client tests ---------------------------------------------------
func TestAllocateTestSuccess(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost || r.URL.Path != "/api/test" {
t.Errorf("unexpected request %s %s", r.Method, r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer hunter2" {
t.Errorf("auth header = %q", got)
}
_, _ = io.WriteString(w, `{"id":"abc","email":"abc@deliver.test"}`)
}))
defer srv.Close()
cfg := &runConfig{HappyDeliverURL: srv.URL, HappyDeliverToken: "hunter2"}
tr, err := allocateTest(context.Background(), srv.Client(), cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if tr.ID != "abc" || tr.Email != "abc@deliver.test" {
t.Errorf("got %+v", tr)
}
}
func TestAllocateTestNon2xx(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "nope", http.StatusServiceUnavailable)
}))
defer srv.Close()
_, err := allocateTest(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL})
if err == nil || !strings.Contains(err.Error(), "503") {
t.Fatalf("expected 503 error, got %v", err)
}
}
func TestAllocateTestEmptyAllocation(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"id":"","email":""}`)
}))
defer srv.Close()
_, err := allocateTest(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL})
if err == nil || !strings.Contains(err.Error(), "empty test allocation") {
t.Fatalf("got %v", err)
}
}
func TestAllocateTestNoToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization should be empty, got %q", got)
}
_, _ = io.WriteString(w, `{"id":"x","email":"x@y"}`)
}))
defer srv.Close()
if _, err := allocateTest(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}); err != nil {
t.Fatal(err)
}
}
func TestGetTestStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/test/abc" {
t.Errorf("path = %q", r.URL.Path)
}
_, _ = io.WriteString(w, `{"status":"analyzed"}`)
}))
defer srv.Close()
st, err := getTestStatus(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}, "abc")
if err != nil {
t.Fatal(err)
}
if st != "analyzed" {
t.Errorf("status = %q", st)
}
}
func TestFetchReport(t *testing.T) {
body := `{"score":42}`
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/report/xyz" {
t.Errorf("path = %q", r.URL.Path)
}
_, _ = io.WriteString(w, body)
}))
defer srv.Close()
raw, err := fetchReport(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}, "xyz")
if err != nil {
t.Fatal(err)
}
if string(raw) != body {
t.Errorf("body = %q", raw)
}
}
func TestWaitForAnalysisPolls(t *testing.T) {
var hits int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
if hits < 3 {
_, _ = io.WriteString(w, `{"status":"pending"}`)
return
}
_, _ = io.WriteString(w, `{"status":"analyzed"}`)
}))
defer srv.Close()
cfg := &runConfig{HappyDeliverURL: srv.URL, WaitTimeout: 5 * time.Second, PollInterval: 10 * time.Millisecond}
if err := waitForAnalysis(context.Background(), srv.Client(), cfg, "x"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if hits < 3 {
t.Errorf("expected at least 3 hits, got %d", hits)
}
}
func TestWaitForAnalysisTimeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"status":"pending"}`)
}))
defer srv.Close()
cfg := &runConfig{HappyDeliverURL: srv.URL, WaitTimeout: 25 * time.Millisecond, PollInterval: 10 * time.Millisecond}
err := waitForAnalysis(context.Background(), srv.Client(), cfg, "x")
if !errors.Is(err, errTimeout) {
t.Fatalf("expected errTimeout, got %v", err)
}
}
func TestWaitForAnalysisContextCancelled(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"status":"pending"}`)
}))
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
cfg := &runConfig{HappyDeliverURL: srv.URL, WaitTimeout: time.Hour, PollInterval: 10 * time.Millisecond}
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
err := waitForAnalysis(ctx, srv.Client(), cfg, "x")
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
// --- End-to-end Collect with httptest + in-process SMTP ------------------
func TestCollectHappyPath(t *testing.T) {
smtpAddr, smtpReq := startFakeSMTP(t)
const reportJSON = `{
"score": 88, "grade": "B",
"summary": {
"dns_score": 90, "dns_grade": "A",
"authentication_score": 85, "authentication_grade": "B",
"spam_score": 80, "spam_grade": "B",
"blacklist_score": 100, "blacklist_grade": "A",
"header_score": 75, "header_grade": "C",
"content_score": 70, "content_grade": "C"
}
}`
mux := http.NewServeMux()
mux.HandleFunc("/api/test", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"id":"id1","email":"id1@deliver.test"}`)
})
mux.HandleFunc("/api/test/id1", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"status":"analyzed"}`)
})
mux.HandleFunc("/api/report/id1", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, reportJSON)
})
srv := httptest.NewServer(mux)
defer srv.Close()
host, port, _ := net.SplitHostPort(smtpAddr)
p, _ := strconv.Atoi(port)
provider := &happyDeliverProvider{}
out, err := provider.Collect(context.Background(), sdk.CheckerOptions{
"happydeliver_url": srv.URL,
"smtp_host": host,
"smtp_port": float64(p),
"smtp_tls": "none",
"from_address": "alice@example.org",
"poll_interval": float64(2),
})
if err != nil {
t.Fatalf("Collect returned err: %v", err)
}
data, ok := out.(*HappyDeliverData)
if !ok {
t.Fatalf("Collect returned %T, want *HappyDeliverData", out)
}
if data.Phase != "ok" {
t.Errorf("Phase = %q (Error=%q), want ok", data.Phase, data.Error)
}
if data.Scores[SectionOverall] != 88 {
t.Errorf("overall score = %d", data.Scores[SectionOverall])
}
if data.RecipientEmail != "id1@deliver.test" {
t.Errorf("RecipientEmail = %q", data.RecipientEmail)
}
if data.LatencySeconds < 0 {
t.Errorf("LatencySeconds should be non-negative, got %v", data.LatencySeconds)
}
// Verify the SMTP server received a sane envelope.
select {
case got := <-smtpReq:
if got.from != "alice@example.org" {
t.Errorf("MAIL FROM = %q", got.from)
}
if got.to != "id1@deliver.test" {
t.Errorf("RCPT TO = %q", got.to)
}
if !strings.Contains(got.data, "Subject:") {
t.Errorf("data missing headers: %s", got.data)
}
case <-time.After(2 * time.Second):
t.Fatal("SMTP server never received a message")
}
}
func TestCollectAllocateFailure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "boom", http.StatusInternalServerError)
}))
defer srv.Close()
provider := &happyDeliverProvider{}
out, err := provider.Collect(context.Background(), sdk.CheckerOptions{
"happydeliver_url": srv.URL,
"smtp_host": "irrelevant",
"smtp_tls": "none",
"from_address": "a@b.test",
})
if err != nil {
t.Fatalf("Collect should swallow error, got %v", err)
}
d := out.(*HappyDeliverData)
if d.Phase != "allocate" {
t.Errorf("phase = %q, want allocate", d.Phase)
}
if d.Error == "" {
t.Error("expected an error message")
}
}
func TestCollectParseFailure(t *testing.T) {
smtpAddr, _ := startFakeSMTP(t)
mux := http.NewServeMux()
mux.HandleFunc("/api/test", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"id":"i","email":"i@deliver.test"}`)
})
mux.HandleFunc("/api/test/i", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{"status":"analyzed"}`)
})
mux.HandleFunc("/api/report/i", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `not json`)
})
srv := httptest.NewServer(mux)
defer srv.Close()
host, port, _ := net.SplitHostPort(smtpAddr)
p, _ := strconv.Atoi(port)
provider := &happyDeliverProvider{}
out, err := provider.Collect(context.Background(), sdk.CheckerOptions{
"happydeliver_url": srv.URL,
"smtp_host": host,
"smtp_port": float64(p),
"smtp_tls": "none",
"from_address": "alice@example.org",
"poll_interval": float64(2),
})
if err != nil {
t.Fatal(err)
}
d := out.(*HappyDeliverData)
if d.Phase != "parse" {
t.Errorf("phase = %q, want parse (Error=%q)", d.Phase, d.Error)
}
if !strings.HasPrefix(d.Error, "parse:") {
t.Errorf("error = %q", d.Error)
}
}
// --- minimal in-process SMTP server -------------------------------------
type smtpReceived struct{ from, to, data string }
func startFakeSMTP(t *testing.T) (string, <-chan smtpReceived) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { _ = ln.Close() })
out := make(chan smtpReceived, 4)
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go handleSMTPConn(conn, out)
}
}()
return ln.Addr().String(), out
}
func handleSMTPConn(conn net.Conn, out chan<- smtpReceived) {
defer conn.Close()
br := bufio.NewReader(conn)
write := func(s string) { _, _ = io.WriteString(conn, s) }
write("220 fake.test ESMTP\r\n")
var rec smtpReceived
for {
line, err := br.ReadString('\n')
if err != nil {
return
}
line = strings.TrimRight(line, "\r\n")
switch {
case strings.HasPrefix(strings.ToUpper(line), "EHLO"), strings.HasPrefix(strings.ToUpper(line), "HELO"):
write("250-fake.test\r\n250 OK\r\n")
case strings.HasPrefix(strings.ToUpper(line), "MAIL FROM:"):
rec.from = extractAngleAddr(line)
write("250 OK\r\n")
case strings.HasPrefix(strings.ToUpper(line), "RCPT TO:"):
rec.to = extractAngleAddr(line)
write("250 OK\r\n")
case strings.ToUpper(line) == "DATA":
write("354 send data\r\n")
var b strings.Builder
for {
dl, err := br.ReadString('\n')
if err != nil {
return
}
if dl == ".\r\n" || strings.TrimRight(dl, "\r\n") == "." {
break
}
b.WriteString(dl)
}
rec.data = b.String()
write("250 OK\r\n")
out <- rec
rec = smtpReceived{}
case strings.ToUpper(line) == "QUIT":
write("221 bye\r\n")
return
case strings.ToUpper(line) == "RSET":
write("250 OK\r\n")
case strings.ToUpper(line) == "NOOP":
write("250 OK\r\n")
default:
write("250 OK\r\n")
}
}
}
func extractAngleAddr(line string) string {
i := strings.Index(line, "<")
j := strings.Index(line, ">")
if i >= 0 && j > i {
return line[i+1 : j]
}
if k := strings.Index(line, ":"); k > 0 {
return strings.TrimSpace(line[k+1:])
}
return ""
}
// --- ensure URLs we build actually parse ---------------------------------
func TestJoinURLProducesParseableURL(t *testing.T) {
u, err := url.Parse(joinURL("https://x.test/", "/api/foo"))
if err != nil {
t.Fatal(err)
}
if u.Path != "/api/foo" {
t.Errorf("path = %q", u.Path)
}
}
// sanity: errTimeout text must remain stable for log scrapers.
func TestErrTimeoutMessage(t *testing.T) {
if errTimeout.Error() != "timeout waiting for analysis" {
t.Errorf("errTimeout text changed: %q", errTimeout.Error())
}
if !errors.Is(fmt.Errorf("wrap: %w", errTimeout), errTimeout) {
t.Error("errTimeout not unwrappable")
}
// Make sure JSON marshalling of HappyDeliverData round-trips.
d := HappyDeliverData{Phase: "ok", Scores: map[string]int{SectionOverall: 1}}
raw, err := json.Marshal(d)
if err != nil {
t.Fatal(err)
}
var back HappyDeliverData
if err := json.Unmarshal(raw, &back); err != nil {
t.Fatal(err)
}
if back.Scores[SectionOverall] != 1 {
t.Errorf("round-trip lost scores: %+v", back)
}
}