checker-tls/checker/starttls_xmpp.go

100 lines
2.7 KiB
Go

package checker
import (
"encoding/xml"
"fmt"
"io"
"net"
)
func init() {
registerStartTLS("xmpp-client", starttlsXMPPClient)
registerStartTLS("xmpp-server", starttlsXMPPServer)
}
// starttlsXMPPClient implements RFC 6120 STARTTLS for c2s streams.
func starttlsXMPPClient(conn net.Conn, sni string) error {
return starttlsXMPP(conn, sni, "jabber:client")
}
// starttlsXMPPServer implements RFC 6120 STARTTLS for s2s streams.
func starttlsXMPPServer(conn net.Conn, sni string) error {
return starttlsXMPP(conn, sni, "jabber:server")
}
func starttlsXMPP(conn net.Conn, sni, ns string) error {
header := fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='http://etherx.jabber.org/streams' version='1.0' to='%s'>`, ns, sni)
if _, err := io.WriteString(conn, header); err != nil {
return fmt.Errorf("write stream header: %w", err)
}
dec := xml.NewDecoder(conn)
// Read the inbound <stream:stream> opening and its <stream:features>.
// A peer that opens with <stream:error/> (or anything other than features)
// is not going to advertise STARTTLS: surface that immediately rather
// than spinning on tokens until the deadline fires.
hasStartTLS := false
outer:
for {
tok, err := dec.Token()
if err != nil {
return fmt.Errorf("read stream features: %w", err)
}
se, ok := tok.(xml.StartElement)
if !ok {
continue
}
switch se.Name.Local {
case "stream":
// Outer <stream:stream> opening. Continue reading children.
continue
case "features":
for {
t2, err := dec.Token()
if err != nil {
return fmt.Errorf("read features body: %w", err)
}
switch ee := t2.(type) {
case xml.StartElement:
if ee.Name.Local == "starttls" {
hasStartTLS = true
}
if err := dec.Skip(); err != nil {
return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err)
}
case xml.EndElement:
if ee.Name.Local == "features" {
break outer
}
}
}
case "error":
return fmt.Errorf("server returned <stream:error/> before features")
default:
return fmt.Errorf("%w: unexpected element %q before features", errStartTLSNotOffered, se.Name.Local)
}
}
if !hasStartTLS {
return fmt.Errorf("%w: XMPP features did not advertise starttls", errStartTLSNotOffered)
}
if _, err := io.WriteString(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`); err != nil {
return fmt.Errorf("write starttls: %w", err)
}
for {
tok, err := dec.Token()
if err != nil {
return fmt.Errorf("read proceed: %w", err)
}
if se, ok := tok.(xml.StartElement); ok {
switch se.Name.Local {
case "proceed":
return nil
case "failure":
return fmt.Errorf("server refused STARTTLS (<failure/>)")
}
}
}
}