diff --git a/checker/starttls.go b/checker/starttls.go index 8fb1edd..892e533 100644 --- a/checker/starttls.go +++ b/checker/starttls.go @@ -1,6 +1,41 @@ package checker -import "net" +import ( + "bufio" + "fmt" + "io" + "net" +) + +// maxSTARTTLSLineBytes caps the length of a single line read from a STARTTLS +// peer. Real banners and CAPABILITY responses are well under 1 KiB; this +// bound prevents a malicious or buggy server from exhausting memory by +// withholding the line terminator. +const maxSTARTTLSLineBytes = 8 * 1024 + +// readLineLimited reads bytes from r up to and including the next '\n', or +// until maxSTARTTLSLineBytes have been read without one (in which case it +// returns an error). The returned string keeps the trailing '\n' so callers +// can use the same parsing logic as bufio.Reader.ReadString('\n'). +func readLineLimited(r *bufio.Reader) (string, error) { + out := make([]byte, 0, 128) + for { + b, err := r.ReadByte() + if err != nil { + if err == io.EOF && len(out) > 0 { + return string(out), io.ErrUnexpectedEOF + } + return string(out), err + } + out = append(out, b) + if b == '\n' { + return string(out), nil + } + if len(out) >= maxSTARTTLSLineBytes { + return string(out), fmt.Errorf("line exceeds %d bytes without terminator", maxSTARTTLSLineBytes) + } + } +} // starttlsUpgrader performs the plaintext portion of a STARTTLS upgrade on // conn, leaving conn ready for tls.Client(conn, …).Handshake(). On success diff --git a/checker/starttls_imap.go b/checker/starttls_imap.go index 7a7a6fb..777e38d 100644 --- a/checker/starttls_imap.go +++ b/checker/starttls_imap.go @@ -15,7 +15,7 @@ func init() { func starttlsIMAP(conn net.Conn, sni string) error { rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - if _, err := rw.ReadString('\n'); err != nil { + if _, err := readLineLimited(rw.Reader); err != nil { return fmt.Errorf("read greeting: %w", err) } @@ -23,12 +23,12 @@ func starttlsIMAP(conn net.Conn, sni string) error { return fmt.Errorf("write CAPABILITY: %w", err) } if err := rw.Flush(); err != nil { - return err + return fmt.Errorf("flush CAPABILITY: %w", err) } supportsSTARTTLS := false for { - line, err := rw.ReadString('\n') + line, err := readLineLimited(rw.Reader) if err != nil { return fmt.Errorf("read CAPABILITY: %w", err) } @@ -44,13 +44,13 @@ func starttlsIMAP(conn net.Conn, sni string) error { } if _, err := rw.WriteString("A002 STARTTLS\r\n"); err != nil { - return err + return fmt.Errorf("write STARTTLS: %w", err) } if err := rw.Flush(); err != nil { - return err + return fmt.Errorf("flush STARTTLS: %w", err) } for { - line, err := rw.ReadString('\n') + line, err := readLineLimited(rw.Reader) if err != nil { return fmt.Errorf("read STARTTLS response: %w", err) } diff --git a/checker/starttls_pop3.go b/checker/starttls_pop3.go index 887933a..f46414c 100644 --- a/checker/starttls_pop3.go +++ b/checker/starttls_pop3.go @@ -15,7 +15,7 @@ func init() { func starttlsPOP3(conn net.Conn, sni string) error { rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - greeting, err := rw.ReadString('\n') + greeting, err := readLineLimited(rw.Reader) if err != nil { return fmt.Errorf("read greeting: %w", err) } @@ -24,19 +24,19 @@ func starttlsPOP3(conn net.Conn, sni string) error { } if _, err := rw.WriteString("CAPA\r\n"); err != nil { - return err + return fmt.Errorf("write CAPA: %w", err) } if err := rw.Flush(); err != nil { - return err + return fmt.Errorf("flush CAPA: %w", err) } - first, err := rw.ReadString('\n') + first, err := readLineLimited(rw.Reader) if err != nil { return fmt.Errorf("read CAPA: %w", err) } supportsSTLS := false if strings.HasPrefix(first, "+OK") { for { - line, err := rw.ReadString('\n') + line, err := readLineLimited(rw.Reader) if err != nil { return fmt.Errorf("read CAPA body: %w", err) } @@ -54,12 +54,12 @@ func starttlsPOP3(conn net.Conn, sni string) error { } if _, err := rw.WriteString("STLS\r\n"); err != nil { - return err + return fmt.Errorf("write STLS: %w", err) } if err := rw.Flush(); err != nil { - return err + return fmt.Errorf("flush STLS: %w", err) } - resp, err := rw.ReadString('\n') + resp, err := readLineLimited(rw.Reader) if err != nil { return fmt.Errorf("read STLS response: %w", err) } diff --git a/checker/starttls_smtp.go b/checker/starttls_smtp.go index 39db327..dfbaa19 100644 --- a/checker/starttls_smtp.go +++ b/checker/starttls_smtp.go @@ -60,7 +60,7 @@ func readSMTPGreeting(r *bufio.Reader) error { func readSMTPResponse(r *bufio.Reader) ([]string, error) { var out []string for { - line, err := r.ReadString('\n') + line, err := readLineLimited(r) if err != nil { return out, err } diff --git a/checker/starttls_test.go b/checker/starttls_test.go new file mode 100644 index 0000000..9d05e8f --- /dev/null +++ b/checker/starttls_test.go @@ -0,0 +1,274 @@ +package checker + +import ( + "bufio" + "errors" + "io" + "net" + "strings" + "testing" + "time" +) + +// runStartTLS drives upgrader against a fake server. The server callback runs +// on the peer end of an in-memory pipe and may read/write the plaintext +// dialect transcript. The test deadline guards both ends from hanging. +func runStartTLS(t *testing.T, upgrader func(net.Conn, string) error, sni string, server func(net.Conn) error) error { + t.Helper() + clientConn, serverConn := net.Pipe() + deadline := time.Now().Add(2 * time.Second) + _ = clientConn.SetDeadline(deadline) + _ = serverConn.SetDeadline(deadline) + + srvErr := make(chan error, 1) + go func() { + defer serverConn.Close() + srvErr <- server(serverConn) + }() + + clientErr := upgrader(clientConn, sni) + clientConn.Close() + + if err := <-srvErr; err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) { + t.Logf("server side returned: %v", err) + } + return clientErr +} + +// readLineCRLF reads one CRLF-terminated line. +func readLineCRLF(r *bufio.Reader) (string, error) { + line, err := r.ReadString('\n') + return strings.TrimRight(line, "\r\n"), err +} + +func TestStartTLS_SMTP_OK(t *testing.T) { + err := runStartTLS(t, starttlsSMTP, "mail.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + if _, err := io.WriteString(c, "220 mail.example.com ESMTP\r\n"); err != nil { + return err + } + ehlo, err := readLineCRLF(br) + if err != nil { + return err + } + if !strings.HasPrefix(ehlo, "EHLO ") { + return errors.New("expected EHLO") + } + if _, err := io.WriteString(c, "250-mail.example.com\r\n250-SIZE 10485760\r\n250 STARTTLS\r\n"); err != nil { + return err + } + stls, err := readLineCRLF(br) + if err != nil { + return err + } + if stls != "STARTTLS" { + return errors.New("expected STARTTLS") + } + _, err = io.WriteString(c, "220 ready\r\n") + return err + }) + if err != nil { + t.Fatalf("expected success, got: %v", err) + } +} + +func TestStartTLS_SMTP_NotAdvertised(t *testing.T) { + err := runStartTLS(t, starttlsSMTP, "mail.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + _, _ = io.WriteString(c, "220 mail.example.com ESMTP\r\n") + if _, err := readLineCRLF(br); err != nil { + return err + } + _, err := io.WriteString(c, "250-mail.example.com\r\n250 SIZE 10485760\r\n") + return err + }) + if !errors.Is(err, errStartTLSNotOffered) { + t.Fatalf("expected errStartTLSNotOffered, got: %v", err) + } +} + +func TestStartTLS_SMTP_Refused(t *testing.T) { + err := runStartTLS(t, starttlsSMTP, "mail.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + _, _ = io.WriteString(c, "220 mail.example.com ESMTP\r\n") + _, _ = readLineCRLF(br) + _, _ = io.WriteString(c, "250-mail.example.com\r\n250 STARTTLS\r\n") + _, _ = readLineCRLF(br) + _, err := io.WriteString(c, "454 TLS not available\r\n") + return err + }) + if err == nil { + t.Fatal("expected refusal error") + } + if errors.Is(err, errStartTLSNotOffered) { + t.Fatalf("refusal should not be classified as not-offered: %v", err) + } +} + +func TestStartTLS_IMAP_OK(t *testing.T) { + err := runStartTLS(t, starttlsIMAP, "imap.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + _, _ = io.WriteString(c, "* OK IMAP4rev1 ready\r\n") + cap1, err := readLineCRLF(br) + if err != nil { + return err + } + if !strings.HasSuffix(cap1, "CAPABILITY") { + return errors.New("expected CAPABILITY") + } + _, _ = io.WriteString(c, "* CAPABILITY IMAP4rev1 STARTTLS LOGINDISABLED\r\nA001 OK CAPABILITY completed\r\n") + stls, err := readLineCRLF(br) + if err != nil { + return err + } + if !strings.HasSuffix(stls, "STARTTLS") { + return errors.New("expected STARTTLS") + } + _, err = io.WriteString(c, "A002 OK Begin TLS\r\n") + return err + }) + if err != nil { + t.Fatalf("expected success, got: %v", err) + } +} + +func TestStartTLS_IMAP_NotAdvertised(t *testing.T) { + err := runStartTLS(t, starttlsIMAP, "imap.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + _, _ = io.WriteString(c, "* OK IMAP4rev1 ready\r\n") + _, _ = readLineCRLF(br) + _, err := io.WriteString(c, "* CAPABILITY IMAP4rev1 LOGINDISABLED\r\nA001 OK CAPABILITY completed\r\n") + return err + }) + if !errors.Is(err, errStartTLSNotOffered) { + t.Fatalf("expected errStartTLSNotOffered, got: %v", err) + } +} + +func TestStartTLS_POP3_OK(t *testing.T) { + err := runStartTLS(t, starttlsPOP3, "pop.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + _, _ = io.WriteString(c, "+OK POP3 ready\r\n") + capa, err := readLineCRLF(br) + if err != nil { + return err + } + if capa != "CAPA" { + return errors.New("expected CAPA") + } + _, _ = io.WriteString(c, "+OK capa list\r\nUSER\r\nSTLS\r\n.\r\n") + stls, err := readLineCRLF(br) + if err != nil { + return err + } + if stls != "STLS" { + return errors.New("expected STLS") + } + _, err = io.WriteString(c, "+OK begin TLS\r\n") + return err + }) + if err != nil { + t.Fatalf("expected success, got: %v", err) + } +} + +func TestStartTLS_POP3_NotAdvertised(t *testing.T) { + err := runStartTLS(t, starttlsPOP3, "pop.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + _, _ = io.WriteString(c, "+OK POP3 ready\r\n") + _, _ = readLineCRLF(br) + _, err := io.WriteString(c, "+OK capa list\r\nUSER\r\n.\r\n") + return err + }) + if !errors.Is(err, errStartTLSNotOffered) { + t.Fatalf("expected errStartTLSNotOffered, got: %v", err) + } +} + +func TestStartTLS_XMPP_OK(t *testing.T) { + err := runStartTLS(t, starttlsXMPPClient, "xmpp.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + // Read the client's stream header (one line is enough for our writer). + buf := make([]byte, 1024) + if _, err := br.Read(buf); err != nil { + return err + } + _, _ = io.WriteString(c, + ``+ + ``) + // Read the request from the client. + if _, err := br.Read(buf); err != nil { + return err + } + _, err := io.WriteString(c, ``) + return err + }) + if err != nil { + t.Fatalf("expected success, got: %v", err) + } +} + +func TestStartTLS_XMPP_NotAdvertised(t *testing.T) { + err := runStartTLS(t, starttlsXMPPClient, "xmpp.example.com", func(c net.Conn) error { + br := bufio.NewReader(c) + buf := make([]byte, 1024) + if _, err := br.Read(buf); err != nil { + return err + } + _, err := io.WriteString(c, + ``+ + `PLAIN`) + return err + }) + if !errors.Is(err, errStartTLSNotOffered) { + t.Fatalf("expected errStartTLSNotOffered, got: %v", err) + } +} + +func TestStartTLS_LDAP_OK(t *testing.T) { + err := runStartTLS(t, starttlsLDAP, "ldap.example.com", func(c net.Conn) error { + // Drain the StartTLS request (fixed 31 bytes: 0x30 0x1d + 29 bytes). + req := make([]byte, 31) + if _, err := io.ReadFull(c, req); err != nil { + return err + } + // Build a minimal ExtendedResponse with resultCode=0. + // LDAPMessage SEQUENCE { messageID INTEGER 1, [APPLICATION 24] SEQUENCE { resultCode ENUMERATED 0, matchedDN "", diagnosticMessage "" } } + resp := []byte{ + 0x30, 0x0c, // SEQUENCE, length 12 + 0x02, 0x01, 0x01, // messageID = 1 + 0x78, 0x07, // [APPLICATION 24], length 7 + 0x0a, 0x01, 0x00, // resultCode ENUMERATED 0 + 0x04, 0x00, // matchedDN "" + 0x04, 0x00, // diagnosticMessage "" + } + _, err := c.Write(resp) + return err + }) + if err != nil { + t.Fatalf("expected success, got: %v", err) + } +} + +func TestStartTLS_LDAP_Refused(t *testing.T) { + err := runStartTLS(t, starttlsLDAP, "ldap.example.com", func(c net.Conn) error { + req := make([]byte, 31) + if _, err := io.ReadFull(c, req); err != nil { + return err + } + // resultCode = 53 (unwillingToPerform) -> classified as not-offered. + resp := []byte{ + 0x30, 0x0c, + 0x02, 0x01, 0x01, + 0x78, 0x07, + 0x0a, 0x01, 0x35, + 0x04, 0x00, + 0x04, 0x00, + } + _, err := c.Write(resp) + return err + }) + if !errors.Is(err, errStartTLSNotOffered) { + t.Fatalf("expected errStartTLSNotOffered for resultCode 53, got: %v", err) + } +} diff --git a/checker/starttls_xmpp.go b/checker/starttls_xmpp.go index d810654..dfed8f2 100644 --- a/checker/starttls_xmpp.go +++ b/checker/starttls_xmpp.go @@ -32,6 +32,7 @@ func starttlsXMPP(conn net.Conn, sni, ns string) error { // Read the inbound opening and its . hasStartTLS := false +outer: for { tok, err := dec.Token() if err != nil { @@ -50,17 +51,18 @@ func starttlsXMPP(conn net.Conn, sni, ns string) error { if ee.Name.Local == "starttls" { hasStartTLS = true } - _ = dec.Skip() + if err := dec.Skip(); err != nil { + return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err) + } case xml.EndElement: if ee.Name.Local == "features" { - goto doneFeatures + break outer } } } } } } -doneFeatures: if !hasStartTLS { return fmt.Errorf("%w: XMPP features did not advertise starttls", errStartTLSNotOffered) }