package checker import ( "bufio" "context" "encoding/json" "errors" "fmt" "io" "mime" "mime/multipart" "net" "net/http" "net/http/httptest" "net/mail" "net/url" "strconv" "strings" "testing" "time" sdk "git.happydns.org/checker-sdk-go/checker" ) func baseValidOptions() sdk.CheckerOptions { return sdk.CheckerOptions{ "happydeliver_url": "https://deliver.example.org", "happydeliver_token": "tok", "smtp_host": "smtp.example.org", "smtp_port": float64(587), "smtp_tls": "starttls", "from_address": "test@example.org", } } func TestLoadConfigDefaults(t *testing.T) { cfg, err := loadConfig(baseValidOptions()) if err != nil { t.Fatalf("unexpected error: %v", err) } if cfg.SMTPPort != 587 { t.Errorf("SMTPPort = %d, want 587", cfg.SMTPPort) } if cfg.SMTPTLS != "starttls" { t.Errorf("SMTPTLS = %q, want starttls", cfg.SMTPTLS) } if cfg.Subject != defaultSubject { t.Errorf("Subject = %q, want default", cfg.Subject) } if cfg.BodyText != defaultBodyText { t.Errorf("BodyText not defaulted") } if cfg.BodyHTML != defaultBodyHTML { t.Errorf("BodyHTML not defaulted") } if cfg.WaitTimeout != 900*time.Second { t.Errorf("WaitTimeout = %v, want 900s", cfg.WaitTimeout) } if cfg.PollInterval != 5*time.Second { t.Errorf("PollInterval = %v, want 5s", cfg.PollInterval) } if cfg.FromAddress != "test@example.org" { t.Errorf("FromAddress = %q", cfg.FromAddress) } if cfg.FromHeader != "" { t.Errorf("FromHeader = %q", cfg.FromHeader) } } func TestLoadConfigPollClamping(t *testing.T) { for _, tc := range []struct { in, want int }{ {0, 2}, {1, 2}, {2, 2}, {30, 30}, {60, 60}, {120, 60}, } { opts := baseValidOptions() opts["poll_interval"] = float64(tc.in) cfg, err := loadConfig(opts) if err != nil { t.Fatalf("in=%d: %v", tc.in, err) } if cfg.PollInterval != time.Duration(tc.want)*time.Second { t.Errorf("in=%d: got %v, want %ds", tc.in, cfg.PollInterval, tc.want) } } } func TestLoadConfigValidationErrors(t *testing.T) { cases := []struct { name string mutate func(sdk.CheckerOptions) want string }{ {"missing url", func(o sdk.CheckerOptions) { delete(o, "happydeliver_url") }, "happydeliver_url is required"}, {"bad url", func(o sdk.CheckerOptions) { o["happydeliver_url"] = "not a url" }, "invalid happydeliver_url"}, {"non-http scheme", func(o sdk.CheckerOptions) { o["happydeliver_url"] = "ftp://x.test" }, "must use http or https"}, {"missing host", func(o sdk.CheckerOptions) { o["happydeliver_url"] = "https://" }, "missing a host"}, {"missing smtp_host", func(o sdk.CheckerOptions) { delete(o, "smtp_host") }, "smtp_host is required"}, {"missing from", func(o sdk.CheckerOptions) { delete(o, "from_address") }, "from_address is required"}, {"bad from", func(o sdk.CheckerOptions) { o["from_address"] = "not-an-address" }, "invalid from_address"}, {"bad tls mode", func(o sdk.CheckerOptions) { o["smtp_tls"] = "weird" }, "smtp_tls must be one of"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { opts := baseValidOptions() tc.mutate(opts) _, err := loadConfig(opts) if err == nil { t.Fatalf("expected error containing %q, got nil", tc.want) } if !strings.Contains(err.Error(), tc.want) { t.Fatalf("err = %v, want substring %q", err, tc.want) } }) } } func TestLoadConfigAcceptsDisplayNameFrom(t *testing.T) { opts := baseValidOptions() opts["from_address"] = "Alice " cfg, err := loadConfig(opts) if err != nil { t.Fatalf("unexpected error: %v", err) } if cfg.FromAddress != "alice@example.org" { t.Errorf("FromAddress should be the bare address, got %q", cfg.FromAddress) } if !strings.Contains(cfg.FromHeader, "") { t.Errorf("FromHeader should keep display form, got %q", cfg.FromHeader) } } func TestLoadConfigTrimsURL(t *testing.T) { opts := baseValidOptions() opts["happydeliver_url"] = " https://deliver.example.org " cfg, err := loadConfig(opts) if err != nil { t.Fatalf("unexpected error: %v", err) } if cfg.HappyDeliverURL != "https://deliver.example.org" { t.Errorf("URL not trimmed: %q", cfg.HappyDeliverURL) } } func TestJoinURL(t *testing.T) { cases := map[string]string{ "https://x": "https://x/api", "https://x/": "https://x/api", "https://x///": "https://x/api", "https://x/foo": "https://x/foo/api", } for base, want := range cases { if got := joinURL(base, "/api"); got != want { t.Errorf("joinURL(%q) = %q, want %q", base, got, want) } } } func TestHostFromAddress(t *testing.T) { cases := map[string]string{ "user@example.org": "example.org", "a@b@example.org": "example.org", "no-at-sign": "localhost", "": "localhost", } for in, want := range cases { if got := hostFromAddress(in); got != want { t.Errorf("hostFromAddress(%q) = %q, want %q", in, got, want) } } } func TestStringOpt(t *testing.T) { opts := sdk.CheckerOptions{"k": "v", "n": float64(1)} if stringOpt(opts, "k") != "v" { t.Error("expected v") } if stringOpt(opts, "missing") != "" { t.Error("missing key should give empty string") } if stringOpt(opts, "n") != "" { t.Error("non-string value should give empty string") } } func TestBuildMessageStructure(t *testing.T) { cfg := &runConfig{ FromAddress: "alice@example.org", FromHeader: "Alice ", Subject: "Hello: accents", BodyText: "plain body", BodyHTML: "

html body

", } raw := buildMessage(cfg, "rcpt@deliver.test") msg, err := mail.ReadMessage(strings.NewReader(string(raw))) if err != nil { t.Fatalf("not a parseable RFC 5322 message: %v\n--\n%s", err, raw) } if got := msg.Header.Get("From"); !strings.Contains(got, "alice@example.org") { t.Errorf("From header = %q", got) } if got := msg.Header.Get("To"); got != "rcpt@deliver.test" { t.Errorf("To header = %q", got) } // Subject: Q-encoded UTF-8. dec := new(mime.WordDecoder) subj, err := dec.DecodeHeader(msg.Header.Get("Subject")) if err != nil { t.Fatalf("subject decode: %v", err) } if subj != "Hello: accents" { t.Errorf("decoded subject = %q", subj) } if msg.Header.Get("MIME-Version") != "1.0" { t.Errorf("missing MIME-Version") } if mid := msg.Header.Get("Message-ID"); !strings.Contains(mid, "@example.org>") { t.Errorf("Message-ID = %q", mid) } mediaType, params, err := mime.ParseMediaType(msg.Header.Get("Content-Type")) if err != nil { t.Fatalf("content-type: %v", err) } if mediaType != "multipart/alternative" { t.Errorf("media type = %q", mediaType) } if params["boundary"] == "" { t.Fatal("missing boundary") } mr := multipart.NewReader(msg.Body, params["boundary"]) var seenText, seenHTML bool for { p, err := mr.NextPart() if err == io.EOF { break } if err != nil { t.Fatalf("next part: %v", err) } if cte := p.Header.Get("Content-Transfer-Encoding"); cte != "8bit" { t.Errorf("part CTE = %q, want 8bit", cte) } body, _ := io.ReadAll(p) ct := p.Header.Get("Content-Type") switch { case strings.HasPrefix(ct, "text/plain"): seenText = true if !strings.Contains(string(body), "plain body") { t.Errorf("plain part body = %q", body) } case strings.HasPrefix(ct, "text/html"): seenHTML = true if !strings.Contains(string(body), "html body") { t.Errorf("html part body = %q", body) } default: t.Errorf("unexpected part Content-Type %q", ct) } } if !seenText || !seenHTML { t.Errorf("missing parts: text=%v html=%v", seenText, seenHTML) } } func TestBuildMessageBodyTextNormalisation(t *testing.T) { cfg := &runConfig{ FromAddress: "a@b.test", FromHeader: "", Subject: "s", BodyText: "no newline", BodyHTML: "

x

", } raw := string(buildMessage(cfg, "r@x.test")) // The plain body must be CRLF-terminated before the next boundary line. if !strings.Contains(raw, "no newline\r\n--") { t.Errorf("plain body was not CRLF-normalised before boundary:\n%s", raw) } } // --- HTTP client tests --------------------------------------------------- func TestAllocateTestSuccess(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/api/test" { t.Errorf("unexpected request %s %s", r.Method, r.URL.Path) } if got := r.Header.Get("Authorization"); got != "Bearer hunter2" { t.Errorf("auth header = %q", got) } _, _ = io.WriteString(w, `{"id":"abc","email":"abc@deliver.test"}`) })) defer srv.Close() cfg := &runConfig{HappyDeliverURL: srv.URL, HappyDeliverToken: "hunter2"} tr, err := allocateTest(context.Background(), srv.Client(), cfg) if err != nil { t.Fatalf("unexpected error: %v", err) } if tr.ID != "abc" || tr.Email != "abc@deliver.test" { t.Errorf("got %+v", tr) } } func TestAllocateTestNon2xx(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "nope", http.StatusServiceUnavailable) })) defer srv.Close() _, err := allocateTest(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}) if err == nil || !strings.Contains(err.Error(), "503") { t.Fatalf("expected 503 error, got %v", err) } } func TestAllocateTestEmptyAllocation(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"id":"","email":""}`) })) defer srv.Close() _, err := allocateTest(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}) if err == nil || !strings.Contains(err.Error(), "empty test allocation") { t.Fatalf("got %v", err) } } func TestAllocateTestNoToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Authorization"); got != "" { t.Errorf("Authorization should be empty, got %q", got) } _, _ = io.WriteString(w, `{"id":"x","email":"x@y"}`) })) defer srv.Close() if _, err := allocateTest(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}); err != nil { t.Fatal(err) } } func TestGetTestStatus(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/test/abc" { t.Errorf("path = %q", r.URL.Path) } _, _ = io.WriteString(w, `{"status":"analyzed"}`) })) defer srv.Close() st, err := getTestStatus(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}, "abc") if err != nil { t.Fatal(err) } if st != "analyzed" { t.Errorf("status = %q", st) } } func TestFetchReport(t *testing.T) { body := `{"score":42}` srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/report/xyz" { t.Errorf("path = %q", r.URL.Path) } _, _ = io.WriteString(w, body) })) defer srv.Close() raw, err := fetchReport(context.Background(), srv.Client(), &runConfig{HappyDeliverURL: srv.URL}, "xyz") if err != nil { t.Fatal(err) } if string(raw) != body { t.Errorf("body = %q", raw) } } func TestWaitForAnalysisPolls(t *testing.T) { var hits int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits++ if hits < 3 { _, _ = io.WriteString(w, `{"status":"pending"}`) return } _, _ = io.WriteString(w, `{"status":"analyzed"}`) })) defer srv.Close() cfg := &runConfig{HappyDeliverURL: srv.URL, WaitTimeout: 5 * time.Second, PollInterval: 10 * time.Millisecond} if err := waitForAnalysis(context.Background(), srv.Client(), cfg, "x"); err != nil { t.Fatalf("unexpected error: %v", err) } if hits < 3 { t.Errorf("expected at least 3 hits, got %d", hits) } } func TestWaitForAnalysisTimeout(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"status":"pending"}`) })) defer srv.Close() cfg := &runConfig{HappyDeliverURL: srv.URL, WaitTimeout: 25 * time.Millisecond, PollInterval: 10 * time.Millisecond} err := waitForAnalysis(context.Background(), srv.Client(), cfg, "x") if !errors.Is(err, errTimeout) { t.Fatalf("expected errTimeout, got %v", err) } } func TestWaitForAnalysisContextCancelled(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"status":"pending"}`) })) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) cfg := &runConfig{HappyDeliverURL: srv.URL, WaitTimeout: time.Hour, PollInterval: 10 * time.Millisecond} go func() { time.Sleep(20 * time.Millisecond) cancel() }() err := waitForAnalysis(ctx, srv.Client(), cfg, "x") if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } } // --- End-to-end Collect with httptest + in-process SMTP ------------------ func TestCollectHappyPath(t *testing.T) { smtpAddr, smtpReq := startFakeSMTP(t) const reportJSON = `{ "score": 88, "grade": "B", "summary": { "dns_score": 90, "dns_grade": "A", "authentication_score": 85, "authentication_grade": "B", "spam_score": 80, "spam_grade": "B", "blacklist_score": 100, "blacklist_grade": "A", "header_score": 75, "header_grade": "C", "content_score": 70, "content_grade": "C" } }` mux := http.NewServeMux() mux.HandleFunc("/api/test", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"id":"id1","email":"id1@deliver.test"}`) }) mux.HandleFunc("/api/test/id1", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"status":"analyzed"}`) }) mux.HandleFunc("/api/report/id1", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, reportJSON) }) srv := httptest.NewServer(mux) defer srv.Close() host, port, _ := net.SplitHostPort(smtpAddr) p, _ := strconv.Atoi(port) provider := &happyDeliverProvider{} out, err := provider.Collect(context.Background(), sdk.CheckerOptions{ "happydeliver_url": srv.URL, "smtp_host": host, "smtp_port": float64(p), "smtp_tls": "none", "from_address": "alice@example.org", "poll_interval": float64(2), }) if err != nil { t.Fatalf("Collect returned err: %v", err) } data, ok := out.(*HappyDeliverData) if !ok { t.Fatalf("Collect returned %T, want *HappyDeliverData", out) } if data.Phase != "ok" { t.Errorf("Phase = %q (Error=%q), want ok", data.Phase, data.Error) } if data.Scores[SectionOverall] != 88 { t.Errorf("overall score = %d", data.Scores[SectionOverall]) } if data.RecipientEmail != "id1@deliver.test" { t.Errorf("RecipientEmail = %q", data.RecipientEmail) } if data.LatencySeconds < 0 { t.Errorf("LatencySeconds should be non-negative, got %v", data.LatencySeconds) } // Verify the SMTP server received a sane envelope. select { case got := <-smtpReq: if got.from != "alice@example.org" { t.Errorf("MAIL FROM = %q", got.from) } if got.to != "id1@deliver.test" { t.Errorf("RCPT TO = %q", got.to) } if !strings.Contains(got.data, "Subject:") { t.Errorf("data missing headers: %s", got.data) } case <-time.After(2 * time.Second): t.Fatal("SMTP server never received a message") } } func TestCollectAllocateFailure(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "boom", http.StatusInternalServerError) })) defer srv.Close() provider := &happyDeliverProvider{} out, err := provider.Collect(context.Background(), sdk.CheckerOptions{ "happydeliver_url": srv.URL, "smtp_host": "irrelevant", "smtp_tls": "none", "from_address": "a@b.test", }) if err != nil { t.Fatalf("Collect should swallow error, got %v", err) } d := out.(*HappyDeliverData) if d.Phase != "allocate" { t.Errorf("phase = %q, want allocate", d.Phase) } if d.Error == "" { t.Error("expected an error message") } } func TestCollectParseFailure(t *testing.T) { smtpAddr, _ := startFakeSMTP(t) mux := http.NewServeMux() mux.HandleFunc("/api/test", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"id":"i","email":"i@deliver.test"}`) }) mux.HandleFunc("/api/test/i", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{"status":"analyzed"}`) }) mux.HandleFunc("/api/report/i", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `not json`) }) srv := httptest.NewServer(mux) defer srv.Close() host, port, _ := net.SplitHostPort(smtpAddr) p, _ := strconv.Atoi(port) provider := &happyDeliverProvider{} out, err := provider.Collect(context.Background(), sdk.CheckerOptions{ "happydeliver_url": srv.URL, "smtp_host": host, "smtp_port": float64(p), "smtp_tls": "none", "from_address": "alice@example.org", "poll_interval": float64(2), }) if err != nil { t.Fatal(err) } d := out.(*HappyDeliverData) if d.Phase != "parse" { t.Errorf("phase = %q, want parse (Error=%q)", d.Phase, d.Error) } if !strings.HasPrefix(d.Error, "parse:") { t.Errorf("error = %q", d.Error) } } // --- minimal in-process SMTP server ------------------------------------- type smtpReceived struct{ from, to, data string } func startFakeSMTP(t *testing.T) (string, <-chan smtpReceived) { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } t.Cleanup(func() { _ = ln.Close() }) out := make(chan smtpReceived, 4) go func() { for { conn, err := ln.Accept() if err != nil { return } go handleSMTPConn(conn, out) } }() return ln.Addr().String(), out } func handleSMTPConn(conn net.Conn, out chan<- smtpReceived) { defer conn.Close() br := bufio.NewReader(conn) write := func(s string) { _, _ = io.WriteString(conn, s) } write("220 fake.test ESMTP\r\n") var rec smtpReceived for { line, err := br.ReadString('\n') if err != nil { return } line = strings.TrimRight(line, "\r\n") switch { case strings.HasPrefix(strings.ToUpper(line), "EHLO"), strings.HasPrefix(strings.ToUpper(line), "HELO"): write("250-fake.test\r\n250 OK\r\n") case strings.HasPrefix(strings.ToUpper(line), "MAIL FROM:"): rec.from = extractAngleAddr(line) write("250 OK\r\n") case strings.HasPrefix(strings.ToUpper(line), "RCPT TO:"): rec.to = extractAngleAddr(line) write("250 OK\r\n") case strings.ToUpper(line) == "DATA": write("354 send data\r\n") var b strings.Builder for { dl, err := br.ReadString('\n') if err != nil { return } if dl == ".\r\n" || strings.TrimRight(dl, "\r\n") == "." { break } b.WriteString(dl) } rec.data = b.String() write("250 OK\r\n") out <- rec rec = smtpReceived{} case strings.ToUpper(line) == "QUIT": write("221 bye\r\n") return case strings.ToUpper(line) == "RSET": write("250 OK\r\n") case strings.ToUpper(line) == "NOOP": write("250 OK\r\n") default: write("250 OK\r\n") } } } func extractAngleAddr(line string) string { i := strings.Index(line, "<") j := strings.Index(line, ">") if i >= 0 && j > i { return line[i+1 : j] } if k := strings.Index(line, ":"); k > 0 { return strings.TrimSpace(line[k+1:]) } return "" } // --- ensure URLs we build actually parse --------------------------------- func TestJoinURLProducesParseableURL(t *testing.T) { u, err := url.Parse(joinURL("https://x.test/", "/api/foo")) if err != nil { t.Fatal(err) } if u.Path != "/api/foo" { t.Errorf("path = %q", u.Path) } } // sanity: errTimeout text must remain stable for log scrapers. func TestErrTimeoutMessage(t *testing.T) { if errTimeout.Error() != "timeout waiting for analysis" { t.Errorf("errTimeout text changed: %q", errTimeout.Error()) } if !errors.Is(fmt.Errorf("wrap: %w", errTimeout), errTimeout) { t.Error("errTimeout not unwrappable") } // Make sure JSON marshalling of HappyDeliverData round-trips. d := HappyDeliverData{Phase: "ok", Scores: map[string]int{SectionOverall: 1}} raw, err := json.Marshal(d) if err != nil { t.Fatal(err) } var back HappyDeliverData if err := json.Unmarshal(raw, &back); err != nil { t.Fatal(err) } if back.Scores[SectionOverall] != 1 { t.Errorf("round-trip lost scores: %+v", back) } }