100 lines
2.7 KiB
Go
100 lines
2.7 KiB
Go
package checker
|
|
|
|
import (
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
)
|
|
|
|
func init() {
|
|
registerStartTLS("xmpp-client", starttlsXMPPClient)
|
|
registerStartTLS("xmpp-server", starttlsXMPPServer)
|
|
}
|
|
|
|
// starttlsXMPPClient implements RFC 6120 STARTTLS for c2s streams.
|
|
func starttlsXMPPClient(conn net.Conn, sni string) error {
|
|
return starttlsXMPP(conn, sni, "jabber:client")
|
|
}
|
|
|
|
// starttlsXMPPServer implements RFC 6120 STARTTLS for s2s streams.
|
|
func starttlsXMPPServer(conn net.Conn, sni string) error {
|
|
return starttlsXMPP(conn, sni, "jabber:server")
|
|
}
|
|
|
|
func starttlsXMPP(conn net.Conn, sni, ns string) error {
|
|
header := fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='http://etherx.jabber.org/streams' version='1.0' to='%s'>`, ns, sni)
|
|
if _, err := io.WriteString(conn, header); err != nil {
|
|
return fmt.Errorf("write stream header: %w", err)
|
|
}
|
|
|
|
dec := xml.NewDecoder(conn)
|
|
|
|
// Read the inbound <stream:stream> opening and its <stream:features>.
|
|
// A peer that opens with <stream:error/> (or anything other than features)
|
|
// is not going to advertise STARTTLS: surface that immediately rather
|
|
// than spinning on tokens until the deadline fires.
|
|
hasStartTLS := false
|
|
outer:
|
|
for {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return fmt.Errorf("read stream features: %w", err)
|
|
}
|
|
se, ok := tok.(xml.StartElement)
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch se.Name.Local {
|
|
case "stream":
|
|
// Outer <stream:stream> opening. Continue reading children.
|
|
continue
|
|
case "features":
|
|
for {
|
|
t2, err := dec.Token()
|
|
if err != nil {
|
|
return fmt.Errorf("read features body: %w", err)
|
|
}
|
|
switch ee := t2.(type) {
|
|
case xml.StartElement:
|
|
if ee.Name.Local == "starttls" {
|
|
hasStartTLS = true
|
|
}
|
|
if err := dec.Skip(); err != nil {
|
|
return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err)
|
|
}
|
|
case xml.EndElement:
|
|
if ee.Name.Local == "features" {
|
|
break outer
|
|
}
|
|
}
|
|
}
|
|
case "error":
|
|
return fmt.Errorf("server returned <stream:error/> before features")
|
|
default:
|
|
return fmt.Errorf("%w: unexpected element %q before features", errStartTLSNotOffered, se.Name.Local)
|
|
}
|
|
}
|
|
if !hasStartTLS {
|
|
return fmt.Errorf("%w: XMPP features did not advertise starttls", errStartTLSNotOffered)
|
|
}
|
|
|
|
if _, err := io.WriteString(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`); err != nil {
|
|
return fmt.Errorf("write starttls: %w", err)
|
|
}
|
|
|
|
for {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return fmt.Errorf("read proceed: %w", err)
|
|
}
|
|
if se, ok := tok.(xml.StartElement); ok {
|
|
switch se.Name.Local {
|
|
case "proceed":
|
|
return nil
|
|
case "failure":
|
|
return fmt.Errorf("server refused STARTTLS (<failure/>)")
|
|
}
|
|
}
|
|
}
|
|
}
|