package checker import ( "encoding/xml" "fmt" "io" "net" ) func init() { registerStartTLS("xmpp-client", starttlsXMPPClient) registerStartTLS("xmpp-server", starttlsXMPPServer) } // starttlsXMPPClient implements RFC 6120 STARTTLS for c2s streams. func starttlsXMPPClient(conn net.Conn, sni string) error { return starttlsXMPP(conn, sni, "jabber:client") } // starttlsXMPPServer implements RFC 6120 STARTTLS for s2s streams. func starttlsXMPPServer(conn net.Conn, sni string) error { return starttlsXMPP(conn, sni, "jabber:server") } func starttlsXMPP(conn net.Conn, sni, ns string) error { header := fmt.Sprintf(``, ns, sni) if _, err := io.WriteString(conn, header); err != nil { return fmt.Errorf("write stream header: %w", err) } dec := xml.NewDecoder(conn) // Read the inbound opening and its . // A peer that opens with (or anything other than features) // is not going to advertise STARTTLS: surface that immediately rather // than spinning on tokens until the deadline fires. hasStartTLS := false outer: for { tok, err := dec.Token() if err != nil { return fmt.Errorf("read stream features: %w", err) } se, ok := tok.(xml.StartElement) if !ok { continue } switch se.Name.Local { case "stream": // Outer opening. Continue reading children. continue case "features": for { t2, err := dec.Token() if err != nil { return fmt.Errorf("read features body: %w", err) } switch ee := t2.(type) { case xml.StartElement: if ee.Name.Local == "starttls" { hasStartTLS = true } if err := dec.Skip(); err != nil { return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err) } case xml.EndElement: if ee.Name.Local == "features" { break outer } } } case "error": return fmt.Errorf("server returned before features") default: return fmt.Errorf("%w: unexpected element %q before features", errStartTLSNotOffered, se.Name.Local) } } if !hasStartTLS { return fmt.Errorf("%w: XMPP features did not advertise starttls", errStartTLSNotOffered) } if _, err := io.WriteString(conn, ``); err != nil { return fmt.Errorf("write starttls: %w", err) } for { tok, err := dec.Token() if err != nil { return fmt.Errorf("read proceed: %w", err) } if se, ok := tok.(xml.StartElement); ok { switch se.Name.Local { case "proceed": return nil case "failure": return fmt.Errorf("server refused STARTTLS ()") } } } }