package checker import ( "encoding/xml" "fmt" "io" "net" ) func init() { registerStartTLS("xmpp-client", starttlsXMPPClient) registerStartTLS("xmpp-server", starttlsXMPPServer) } // starttlsXMPPClient implements RFC 6120 STARTTLS for c2s streams. func starttlsXMPPClient(conn net.Conn, sni string) error { return starttlsXMPP(conn, sni, "jabber:client") } // starttlsXMPPServer implements RFC 6120 STARTTLS for s2s streams. func starttlsXMPPServer(conn net.Conn, sni string) error { return starttlsXMPP(conn, sni, "jabber:server") } func starttlsXMPP(conn net.Conn, sni, ns string) error { header := fmt.Sprintf(``, ns, sni) if _, err := io.WriteString(conn, header); err != nil { return fmt.Errorf("write stream header: %w", err) } dec := xml.NewDecoder(conn) // Read the inbound opening and its . hasStartTLS := false outer: for { tok, err := dec.Token() if err != nil { return fmt.Errorf("read stream features: %w", err) } if se, ok := tok.(xml.StartElement); ok { if se.Name.Local == "features" { // Scan features children. for { t2, err := dec.Token() if err != nil { return fmt.Errorf("read features body: %w", err) } switch ee := t2.(type) { case xml.StartElement: if ee.Name.Local == "starttls" { hasStartTLS = true } if err := dec.Skip(); err != nil { return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err) } case xml.EndElement: if ee.Name.Local == "features" { break outer } } } } } } if !hasStartTLS { return fmt.Errorf("%w: XMPP features did not advertise starttls", errStartTLSNotOffered) } if _, err := io.WriteString(conn, ``); err != nil { return fmt.Errorf("write starttls: %w", err) } for { tok, err := dec.Token() if err != nil { return fmt.Errorf("read proceed: %w", err) } if se, ok := tok.(xml.StartElement); ok { switch se.Name.Local { case "proceed": return nil case "failure": return fmt.Errorf("server refused STARTTLS ()") } } } }