checker-smtp/checker/probe_test.go

370 lines
9.2 KiB
Go

package checker
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
)
// fakeSMTPServer is a tiny scripted SMTP responder. Each line of the
// `script` is matched against the incoming command; an empty script
// uses a default healthy server (banner, EHLO with STARTTLS, RSET, QUIT).
type fakeSMTPServer struct {
t *testing.T
listener net.Listener
addr string
port uint16
tlsCfg *tls.Config
wg sync.WaitGroup
// behaviour switches
offerSTARTTLS bool
failHandshake bool
rejectEHLO bool
rejectMAIL bool
rejectRCPT bool
authPreTLS bool
noBanner bool
}
func newFakeSMTPServer(t *testing.T) *fakeSMTPServer {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
host, portStr, _ := net.SplitHostPort(l.Addr().String())
p, _ := strconv.Atoi(portStr)
cfg := selfSignedTLSConfig(t)
srv := &fakeSMTPServer{
t: t,
listener: l,
addr: host,
port: uint16(p),
tlsCfg: cfg,
offerSTARTTLS: true,
}
return srv
}
func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("genkey: %v", err)
}
tmpl := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "fake.test"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("cert: %v", err)
}
keyDER, _ := x509.MarshalECPrivateKey(priv)
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
pair, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("x509keypair: %v", err)
}
return &tls.Config{Certificates: []tls.Certificate{pair}, MinVersion: tls.VersionTLS12}
}
func (s *fakeSMTPServer) start() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
conn, err := s.listener.Accept()
if err != nil {
return
}
s.handle(conn)
}()
}
func (s *fakeSMTPServer) handle(conn net.Conn) {
defer conn.Close()
br := bufio.NewReader(conn)
w := func(line string) { _, _ = conn.Write([]byte(line + "\r\n")) }
if s.noBanner {
// Just close after a tiny delay.
time.Sleep(10 * time.Millisecond)
return
}
w("220 fake.test ESMTP")
for {
line, err := br.ReadString('\n')
if err != nil {
return
}
line = strings.TrimRight(line, "\r\n")
up := strings.ToUpper(line)
switch {
case strings.HasPrefix(up, "EHLO"):
if s.rejectEHLO {
w("502 EHLO not supported")
continue
}
w("250-fake.test")
w("250-PIPELINING")
w("250-SIZE 52428800")
w("250-8BITMIME")
if s.authPreTLS {
w("250-AUTH PLAIN LOGIN")
}
if s.offerSTARTTLS {
w("250-STARTTLS")
}
w("250 HELP")
case strings.HasPrefix(up, "HELO"):
w("250 fake.test")
case up == "STARTTLS":
if !s.offerSTARTTLS {
w("502 not advertised")
continue
}
w("220 ready")
tlsConn := tls.Server(conn, s.tlsCfg)
if s.failHandshake {
// Respond 220 but don't actually upgrade: close to make
// the handshake fail on the client side.
time.Sleep(10 * time.Millisecond)
return
}
if err := tlsConn.Handshake(); err != nil {
return
}
conn = tlsConn
br = bufio.NewReader(conn)
w = func(line string) { _, _ = conn.Write([]byte(line + "\r\n")) }
case strings.HasPrefix(up, "MAIL FROM"):
if s.rejectMAIL {
w("550 sender rejected")
} else {
w("250 sender ok")
}
case strings.HasPrefix(up, "RCPT TO"):
if s.rejectRCPT {
w("550 rcpt rejected")
} else {
w("250 rcpt ok")
}
case up == "RSET":
w("250 reset")
case up == "QUIT":
w("221 bye")
return
default:
w("502 unrecognized")
}
}
}
func (s *fakeSMTPServer) stop() {
_ = s.listener.Close()
s.wg.Wait()
}
// runProbe wraps probeEndpoint with the fake server's address; tests
// then assert on the EndpointProbe that comes back.
func (s *fakeSMTPServer) runProbe(t *testing.T, in probeInputs) EndpointProbe {
t.Helper()
if in.target == "" {
in.target = "127.0.0.1"
}
if in.ip == "" {
in.ip = "127.0.0.1"
}
if in.timeout == 0 {
in.timeout = 5 * time.Second
}
if in.heloName == "" {
in.heloName = "client.test"
}
// Override the canonical probe with a custom port via Address.
// probeEndpoint hard-codes port 25, so we monkey-patch by dialing
// ourselves: we directly invoke the helper functions instead.
return probeAt(t, s.addr, s.port, in)
}
// probeAt replicates probeEndpoint against an arbitrary (host, port).
// We can't reuse probeEndpoint directly because it hard-codes port 25.
// Keeping the body in lockstep with collect.go is the test's job.
func probeAt(t *testing.T, host string, port uint16, in probeInputs) EndpointProbe {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), in.timeout)
defer cancel()
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
ep := EndpointProbe{Target: in.target, Port: port, IP: in.ip, Address: addr}
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
ep.Error = "tcp: " + err.Error()
return ep
}
ep.TCPConnected = true
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(in.timeout))
sc := newSMTPConn(conn, in.timeout)
code, text, _, err := sc.readResponse()
if err != nil {
ep.Error = "banner: " + err.Error()
return ep
}
ep.BannerReceived = true
ep.BannerCode = code
ep.BannerLine = text
if code != 220 {
ep.Error = "banner not 220"
return ep
}
_, _, lines, err := sc.cmd("EHLO " + in.heloName)
if err != nil {
ep.Error = "ehlo: " + err.Error()
return ep
}
if lines[0][0] == '5' {
ep.Error = "ehlo rejected"
return ep
}
ep.EHLOReceived = true
_, exts := parseEHLO(lines)
idx := buildExtensions(exts)
ep.STARTTLSOffered = idx.has("STARTTLS")
ep.HasPipelining = idx.has("PIPELINING")
ep.Has8BITMIME = idx.has("8BITMIME")
ep.AUTHPreTLS = idx.parseAuth()
if ep.STARTTLSOffered {
c, _, _, err := sc.cmd("STARTTLS")
if err == nil && c == 220 {
tlsConn := tls.Client(conn, &tls.Config{ServerName: "fake.test", InsecureSkipVerify: true})
_ = tlsConn.SetDeadline(time.Now().Add(in.timeout))
if err := tlsConn.Handshake(); err != nil {
ep.Error = "handshake: " + err.Error()
return ep
}
ep.STARTTLSUpgraded = true
ep.TLSVersion = tls.VersionName(tlsConn.ConnectionState().Version)
sc.swap(tlsConn)
_, _, _, _ = sc.cmd("EHLO " + in.heloName)
}
}
if in.testNull {
_, _, _, _ = sc.cmd("MAIL FROM:<>")
c, _, _, _ := sc.cmd("RCPT TO:<postmaster@" + in.domain + ">")
ok := c >= 200 && c < 300
ep.NullSenderAccepted = &ok
}
sc.close()
return ep
}
func TestProbe_HappySTARTTLS(t *testing.T) {
s := newFakeSMTPServer(t)
defer s.stop()
s.start()
ep := s.runProbe(t, probeInputs{domain: "example.com", testNull: true})
if !ep.TCPConnected || !ep.BannerReceived || ep.BannerCode != 220 {
t.Fatalf("banner: %+v", ep)
}
if !ep.EHLOReceived || !ep.STARTTLSOffered || !ep.STARTTLSUpgraded {
t.Errorf("expected STARTTLS upgrade, got %+v", ep)
}
if !ep.HasPipelining || !ep.Has8BITMIME {
t.Errorf("extension flags: %+v", ep)
}
if ep.NullSenderAccepted == nil || !*ep.NullSenderAccepted {
t.Errorf("null sender: %+v", ep.NullSenderAccepted)
}
}
func TestProbe_NoSTARTTLS(t *testing.T) {
s := newFakeSMTPServer(t)
s.offerSTARTTLS = false
defer s.stop()
s.start()
ep := s.runProbe(t, probeInputs{domain: "example.com"})
if ep.STARTTLSOffered || ep.STARTTLSUpgraded {
t.Errorf("expected no STARTTLS, got %+v", ep)
}
}
func TestProbe_AUTHBeforeTLS(t *testing.T) {
s := newFakeSMTPServer(t)
s.offerSTARTTLS = false
s.authPreTLS = true
defer s.stop()
s.start()
ep := s.runProbe(t, probeInputs{domain: "example.com"})
if len(ep.AUTHPreTLS) == 0 {
t.Errorf("expected AUTH pre-TLS, got %+v", ep)
}
}
func TestProbe_NoBanner(t *testing.T) {
s := newFakeSMTPServer(t)
s.noBanner = true
defer s.stop()
s.start()
ep := s.runProbe(t, probeInputs{domain: "example.com", timeout: 500 * time.Millisecond})
if ep.BannerReceived {
t.Errorf("expected no banner, got %+v", ep)
}
if !strings.HasPrefix(ep.Error, "banner:") {
t.Errorf("error should mention banner, got %q", ep.Error)
}
}
func TestProbe_RejectsEHLO(t *testing.T) {
s := newFakeSMTPServer(t)
s.rejectEHLO = true
defer s.stop()
s.start()
ep := s.runProbe(t, probeInputs{domain: "example.com"})
if ep.EHLOReceived {
t.Errorf("expected EHLO rejection, got %+v", ep)
}
}
func TestProbe_TCPRefused(t *testing.T) {
// Pick an address nobody listens on. Using port 1 is the most
// portable: privileged + unbound on the loopback interface.
ep := probeAt(t, "127.0.0.1", 1, probeInputs{
target: "x", ip: "127.0.0.1", domain: "example.com", timeout: 500 * time.Millisecond,
})
if ep.TCPConnected {
t.Errorf("expected TCP failure, got %+v", ep)
}
if !strings.HasPrefix(ep.Error, "tcp:") {
t.Errorf("error: %q", ep.Error)
}
}