checker-xmpp/checker/collect.go

664 lines
19 KiB
Go

package checker
import (
"context"
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
const (
streamsNS = "http://etherx.jabber.org/streams"
clientNS = "jabber:client"
serverNS = "jabber:server"
tlsNS = "urn:ietf:params:xml:ns:xmpp-tls"
)
func tlsProbeConfig(serverName string) *tls.Config {
return &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true, //nolint:gosec: cert validation is the TLS checker's job
MinVersion: tls.VersionTLS10,
}
}
// Collect runs the full XMPP probe for a domain.
func (p *xmppProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
domain, _ := sdk.GetOption[string](opts, "domain")
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return nil, fmt.Errorf("domain is required")
}
mode, _ := sdk.GetOption[string](opts, "mode")
if mode == "" {
mode = "both"
}
timeoutSecs := sdk.GetFloatOption(opts, "timeout", 10)
if timeoutSecs < 1 {
timeoutSecs = 10
}
perEndpoint := time.Duration(timeoutSecs * float64(time.Second))
wantC2S := mode != "s2s"
wantS2S := mode != "c2s"
data := &XMPPData{
Domain: domain,
RunAt: time.Now().UTC().Format(time.RFC3339),
SRV: SRVLookup{Errors: map[string]string{}},
}
resolver := net.DefaultResolver
lookupSets := []struct {
prefix string
want bool
dst *[]SRVRecord
}{
{"_xmpp-client._tcp.", wantC2S, &data.SRV.Client},
{"_xmpp-server._tcp.", wantS2S, &data.SRV.Server},
{"_xmpps-client._tcp.", wantC2S, &data.SRV.ClientSecure},
{"_xmpps-server._tcp.", wantS2S, &data.SRV.ServerSecure},
{"_jabber._tcp.", wantC2S, &data.SRV.Jabber},
}
for _, ls := range lookupSets {
if !ls.want {
continue
}
records, err := lookupSRV(ctx, resolver, ls.prefix, domain)
if err != nil {
data.SRV.Errors[ls.prefix] = err.Error()
continue
}
*ls.dst = records
}
totalSRV := len(data.SRV.Client) + len(data.SRV.Server) + len(data.SRV.ClientSecure) + len(data.SRV.ServerSecure)
if totalSRV == 0 {
data.SRV.FallbackProbed = true
if wantC2S {
data.SRV.Client = []SRVRecord{{Target: domain, Port: 5222}}
}
if wantS2S {
data.SRV.Server = []SRVRecord{{Target: domain, Port: 5269}}
}
}
resolveAllInto(ctx, resolver, data.SRV.Client)
resolveAllInto(ctx, resolver, data.SRV.Server)
resolveAllInto(ctx, resolver, data.SRV.ClientSecure)
resolveAllInto(ctx, resolver, data.SRV.ServerSecure)
probeSet(ctx, data, domain, ModeClient, "_xmpp-client._tcp", data.SRV.Client, false, perEndpoint)
probeSet(ctx, data, domain, ModeServer, "_xmpp-server._tcp", data.SRV.Server, false, perEndpoint)
probeSet(ctx, data, domain, ModeClient, "_xmpps-client._tcp", data.SRV.ClientSecure, true, perEndpoint)
probeSet(ctx, data, domain, ModeServer, "_xmpps-server._tcp", data.SRV.ServerSecure, true, perEndpoint)
computeCoverage(data)
data.Issues = deriveIssues(data, wantC2S, wantS2S)
return data, nil
}
func probeSet(ctx context.Context, data *XMPPData, domain string, mode XMPPMode, prefix string, records []SRVRecord, directTLS bool, timeout time.Duration) {
for _, rec := range records {
addrs := addressesForProbe(rec)
if len(addrs) == 0 {
ep := EndpointProbe{
Mode: mode,
SRVPrefix: prefix,
Target: rec.Target,
Port: rec.Port,
DirectTLS: directTLS,
Error: "no A/AAAA records for target",
}
data.Endpoints = append(data.Endpoints, ep)
continue
}
for _, a := range addrs {
ep := probeEndpoint(ctx, domain, mode, prefix, rec, a.ip, a.isV6, directTLS, timeout)
data.Endpoints = append(data.Endpoints, ep)
}
}
}
type probeAddr struct {
ip string
isV6 bool
}
func addressesForProbe(rec SRVRecord) []probeAddr {
var out []probeAddr
for _, ip := range rec.IPv4 {
out = append(out, probeAddr{ip: ip, isV6: false})
}
for _, ip := range rec.IPv6 {
out = append(out, probeAddr{ip: ip, isV6: true})
}
return out
}
func probeEndpoint(ctx context.Context, domain string, mode XMPPMode, prefix string, rec SRVRecord, ip string, isV6, directTLS bool, timeout time.Duration) EndpointProbe {
start := time.Now()
result := EndpointProbe{
Mode: mode,
SRVPrefix: prefix,
Target: rec.Target,
Port: rec.Port,
Address: net.JoinHostPort(ip, strconv.Itoa(int(rec.Port))),
IsIPv6: isV6,
DirectTLS: directTLS,
}
defer func() { result.ElapsedMS = time.Since(start).Milliseconds() }()
ns := clientNS
if mode == ModeServer {
ns = serverNS
}
dialCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
dialer := &net.Dialer{}
rawConn, err := dialer.DialContext(dialCtx, "tcp", result.Address)
if err != nil {
result.Error = "tcp: " + err.Error()
return result
}
result.TCPConnected = true
defer rawConn.Close()
_ = rawConn.SetDeadline(time.Now().Add(timeout))
var conn net.Conn = rawConn
if directTLS {
tlsConn := tls.Client(rawConn, tlsProbeConfig(domain))
if err := tlsConn.Handshake(); err != nil {
result.Error = "tls-handshake: " + err.Error()
return result
}
result.STARTTLSUpgraded = true
state := tlsConn.ConnectionState()
result.TLSVersion = tls.VersionName(state.Version)
result.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
conn = tlsConn
feats, from, err := openStreamAndReadFeatures(conn, domain, ns, mode == ModeServer)
if err != nil {
result.Error = "stream: " + err.Error()
return result
}
result.StreamOpened = true
result.StreamFrom = from
applyFeatures(&result, feats)
return result
}
dec, from, err := openStream(conn, domain, ns, mode == ModeServer)
if err != nil {
result.Error = "stream: " + err.Error()
return result
}
result.StreamOpened = true
result.StreamFrom = from
feats, err := readFeatures(dec)
if err != nil {
result.Error = "features: " + err.Error()
return result
}
result.STARTTLSOffered = feats.StartTLS != nil
if feats.StartTLS != nil && feats.StartTLS.Required != nil {
result.STARTTLSRequired = true
}
if !result.STARTTLSOffered {
// Record any features seen in plaintext, but do not proceed; we
// intentionally refuse to send SASL over a non-TLS channel.
applyFeatures(&result, feats)
return result
}
if _, err := io.WriteString(conn, `<starttls xmlns='`+tlsNS+`'/>`); err != nil {
result.Error = "starttls-write: " + err.Error()
return result
}
if err := expectProceed(dec); err != nil {
result.Error = "starttls-proceed: " + err.Error()
return result
}
tlsConn := tls.Client(rawConn, tlsProbeConfig(domain))
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
if err := tlsConn.Handshake(); err != nil {
result.Error = "tls-handshake: " + err.Error()
return result
}
result.STARTTLSUpgraded = true
state := tlsConn.ConnectionState()
result.TLSVersion = tls.VersionName(state.Version)
result.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
feats2, _, err := openStreamAndReadFeatures(tlsConn, domain, ns, mode == ModeServer)
if err != nil {
result.Error = "post-tls stream: " + err.Error()
return result
}
applyFeatures(&result, feats2)
return result
}
// applyFeatures copies parsed stream features into the probe result.
func applyFeatures(ep *EndpointProbe, feats *streamFeatures) {
if feats == nil {
return
}
ep.FeaturesRead = true
if feats.Mechanisms != nil {
ep.SASLMechanisms = append(ep.SASLMechanisms, feats.Mechanisms.Mechanism...)
for _, m := range feats.Mechanisms.Mechanism {
if strings.EqualFold(m, "EXTERNAL") {
ep.SASLExternal = true
}
}
}
if feats.Dialback != nil {
ep.DialbackOffered = true
}
}
type streamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
StartTLS *startTLSEl
Mechanisms *mechanismsEl
Dialback *struct{} `xml:"urn:xmpp:features:dialback dialback"`
}
type startTLSEl struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required *struct{} `xml:"required"`
}
type mechanismsEl struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// openStreamAndReadFeatures performs the stream header exchange and parses
// <stream:features>. Used both for the initial open and for the post-TLS
// stream restart.
func openStreamAndReadFeatures(conn io.ReadWriter, domain, ns string, server bool) (*streamFeatures, string, error) {
dec, from, err := openStream(conn, domain, ns, server)
if err != nil {
return nil, "", err
}
feats, err := readFeatures(dec)
if err != nil {
return nil, from, err
}
return feats, from, nil
}
func openStream(conn io.ReadWriter, domain, ns string, server bool) (*xml.Decoder, string, error) {
var header string
if server {
header = fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='%s' xmlns:db='jabber:server:dialback' version='1.0' to='%s'>`, ns, streamsNS, domain)
} else {
header = fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='%s' version='1.0' to='%s'>`, ns, streamsNS, domain)
}
if _, err := io.WriteString(conn, header); err != nil {
return nil, "", fmt.Errorf("write header: %w", err)
}
dec := xml.NewDecoder(conn)
for {
tok, err := dec.Token()
if err != nil {
return nil, "", fmt.Errorf("read header: %w", err)
}
switch t := tok.(type) {
case xml.StartElement:
if t.Name.Space == streamsNS && t.Name.Local == "stream" {
var from string
for _, a := range t.Attr {
if a.Name.Local == "from" {
from = a.Value
}
}
return dec, from, nil
}
if t.Name.Space == streamsNS && t.Name.Local == "error" {
_ = dec.Skip()
return nil, "", errors.New("server returned stream:error on open")
}
return nil, "", fmt.Errorf("unexpected element %s", t.Name.Local)
}
}
}
func readFeatures(dec *xml.Decoder) (*streamFeatures, error) {
for {
tok, err := dec.Token()
if err != nil {
return nil, fmt.Errorf("read features: %w", err)
}
se, ok := tok.(xml.StartElement)
if !ok {
continue
}
if se.Name.Space == streamsNS && se.Name.Local == "features" {
var feats streamFeatures
if err := dec.DecodeElement(&feats, &se); err != nil {
return nil, fmt.Errorf("decode features: %w", err)
}
return &feats, nil
}
if se.Name.Space == streamsNS && se.Name.Local == "error" {
_ = dec.Skip()
return nil, errors.New("stream:error before features")
}
}
}
func expectProceed(dec *xml.Decoder) error {
for {
tok, err := dec.Token()
if err != nil {
return fmt.Errorf("read proceed: %w", err)
}
se, ok := tok.(xml.StartElement)
if !ok {
continue
}
if se.Name.Space == tlsNS {
switch se.Name.Local {
case "proceed":
_ = dec.Skip()
return nil
case "failure":
_ = dec.Skip()
return errors.New("server refused STARTTLS (<failure/>)")
}
}
}
}
func lookupSRV(ctx context.Context, r *net.Resolver, prefix, domain string) ([]SRVRecord, error) {
name := prefix + dns.Fqdn(domain)
_, records, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
// Distinguish NXDOMAIN / no records from real errors.
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && (dnsErr.IsNotFound) {
return nil, nil
}
return nil, err
}
// RFC 2782: single record "." with port 0 means "service explicitly not
// available at this domain". We treat that as "no records" for probing.
if len(records) == 1 && (records[0].Target == "." || records[0].Target == "") && records[0].Port == 0 {
return nil, nil
}
out := make([]SRVRecord, 0, len(records))
for _, r := range records {
out = append(out, SRVRecord{
Target: strings.TrimSuffix(r.Target, "."),
Port: r.Port,
Priority: r.Priority,
Weight: r.Weight,
})
}
return out, nil
}
func resolveAllInto(ctx context.Context, r *net.Resolver, records []SRVRecord) {
for i := range records {
ips, err := r.LookupIPAddr(ctx, records[i].Target)
if err != nil {
continue
}
for _, ip := range ips {
if v4 := ip.IP.To4(); v4 != nil {
records[i].IPv4 = append(records[i].IPv4, v4.String())
} else {
records[i].IPv6 = append(records[i].IPv6, ip.IP.String())
}
}
}
}
func computeCoverage(data *XMPPData) {
for _, ep := range data.Endpoints {
if ep.TCPConnected {
if ep.IsIPv6 {
data.Coverage.HasIPv6 = true
} else {
data.Coverage.HasIPv4 = true
}
}
if !ep.STARTTLSUpgraded {
continue
}
switch ep.Mode {
case ModeClient:
// We consider c2s working if SASL was advertised, OR if STARTTLS
// completed but features couldn't be read (benign for probes).
if len(ep.SASLMechanisms) > 0 || !ep.FeaturesRead {
data.Coverage.WorkingC2S = true
}
case ModeServer:
// Similarly, s2s is "working" if TLS completed. A misconfigured
// server that advertised TLS but no dialback/EXTERNAL is reported
// via the xmpp.s2s.no_auth issue, not via coverage.
data.Coverage.WorkingS2S = true
}
}
}
func deriveIssues(data *XMPPData, wantC2S, _ bool) []Issue {
var issues []Issue
// 1. No SRV published.
if data.SRV.FallbackProbed {
issues = append(issues, Issue{
Code: CodeNoSRV,
Severity: SeverityCrit,
Message: "No XMPP SRV records found for " + data.Domain + ".",
Fix: "Publish _xmpp-client._tcp." + data.Domain + " and _xmpp-server._tcp." + data.Domain + " SRV records.",
})
}
// 2. Legacy _jabber.
if len(data.SRV.Jabber) > 0 {
issues = append(issues, Issue{
Code: CodeLegacyJabber,
Severity: SeverityWarn,
Message: "Obsolete _jabber._tcp SRV record still published.",
Fix: "Remove _jabber._tcp records; _xmpp-client._tcp supersedes them.",
})
}
// 3. SRV lookup errors (real DNS failures, not NXDOMAIN).
for prefix, msg := range data.SRV.Errors {
issues = append(issues, Issue{
Code: CodeSRVServfail,
Severity: SeverityWarn,
Message: "DNS lookup failed for " + prefix + data.Domain + ": " + msg,
Fix: "Check the authoritative DNS servers for this domain.",
})
}
// 4. Endpoint-level issues.
allDown := true
sawSCRAM := map[XMPPMode]bool{}
sawSCRAMPlus := map[XMPPMode]bool{}
sawPlainOnly := map[XMPPMode]bool{}
sawAnyWorking := map[XMPPMode]bool{}
for _, ep := range data.Endpoints {
if ep.TCPConnected && ep.STARTTLSUpgraded {
allDown = false
sawAnyWorking[ep.Mode] = true
}
if ep.TCPConnected && ep.StreamOpened && !ep.DirectTLS {
if !ep.STARTTLSOffered {
issues = append(issues, Issue{
Code: CodeStartTLSMissing,
Severity: SeverityCrit,
Message: "STARTTLS not advertised on " + ep.Address + " (" + ep.SRVPrefix + ").",
Fix: "Enable STARTTLS in the XMPP server configuration and require it for all connections.",
Endpoint: ep.Address,
})
} else if !ep.STARTTLSRequired {
issues = append(issues, Issue{
Code: CodeStartTLSNotRequired,
Severity: SeverityWarn,
Message: "STARTTLS offered but not <required/> on " + ep.Address + ".",
Fix: "Set the server to require TLS (e.g. `c2s_require_encryption = true` in Prosody, `starttls_required` in ejabberd).",
Endpoint: ep.Address,
})
}
}
if ep.TCPConnected && !ep.STARTTLSUpgraded && ep.STARTTLSOffered && ep.Error != "" {
issues = append(issues, Issue{
Code: CodeStartTLSFailed,
Severity: SeverityCrit,
Message: "STARTTLS handshake failed on " + ep.Address + ": " + ep.Error + ".",
Fix: "Run the TLS checker on this port for cert and cipher details.",
Endpoint: ep.Address,
})
}
if !ep.TCPConnected && ep.Error != "" {
issues = append(issues, Issue{
Code: CodeTCPUnreachable,
Severity: SeverityWarn,
Message: "Cannot reach " + ep.Address + ": " + ep.Error + ".",
Fix: "Verify firewall rules and that the XMPP server is listening on this address.",
Endpoint: ep.Address,
})
}
// SASL posture (c2s only).
if ep.Mode == ModeClient && ep.STARTTLSUpgraded && len(ep.SASLMechanisms) > 0 {
hasSCRAM := false
hasSCRAMPlus := false
hasPlain := false
nonPlain := false
for _, m := range ep.SASLMechanisms {
u := strings.ToUpper(m)
if strings.HasPrefix(u, "SCRAM-") {
hasSCRAM = true
if strings.HasSuffix(u, "-PLUS") {
hasSCRAMPlus = true
}
}
if u == "PLAIN" {
hasPlain = true
} else {
nonPlain = true
}
}
if hasSCRAM {
sawSCRAM[ep.Mode] = true
}
if hasSCRAMPlus {
sawSCRAMPlus[ep.Mode] = true
}
if hasPlain && !nonPlain {
sawPlainOnly[ep.Mode] = true
}
}
// S2S auth posture, only meaningful if we actually parsed the
// post-TLS features. Many public servers don't respond fully to
// anonymous s2s probes; in that case we emit a probe_incomplete
// info instead of falsely asserting "no auth".
if ep.Mode == ModeServer && ep.STARTTLSUpgraded {
if !ep.FeaturesRead {
issues = append(issues, Issue{
Code: CodeS2SProbeIncomplete,
Severity: SeverityInfo,
Message: "Could not read post-TLS stream features on " + ep.Address + "; server may require an authenticated origin for s2s.",
Fix: "This is often benign for well-run public servers. Try from a real federating host if in doubt.",
Endpoint: ep.Address,
})
} else if !ep.DialbackOffered && !ep.SASLExternal {
issues = append(issues, Issue{
Code: CodeS2SNoAuth,
Severity: SeverityCrit,
Message: "No dialback or SASL EXTERNAL advertised on " + ep.Address + " after TLS; federation will fail.",
Fix: "Enable server-to-server dialback, or provision a cert usable for SASL EXTERNAL.",
Endpoint: ep.Address,
})
}
}
}
if len(data.Endpoints) > 0 && allDown {
issues = append(issues, Issue{
Code: CodeAllEndpointsDown,
Severity: SeverityCrit,
Message: "None of the XMPP endpoints could complete STARTTLS.",
Fix: "Verify the server is running and reachable on the published SRV ports.",
})
}
if wantC2S && sawAnyWorking[ModeClient] {
if !sawSCRAM[ModeClient] {
issues = append(issues, Issue{
Code: CodeSASLNoSCRAM,
Severity: SeverityWarn,
Message: "No SCRAM-SHA-* SASL mechanism offered on c2s.",
Fix: "Enable SCRAM-SHA-256 (and SCRAM-SHA-1 for compatibility).",
})
}
if !sawSCRAMPlus[ModeClient] {
issues = append(issues, Issue{
Code: CodeSASLNoSCRAMPlus,
Severity: SeverityInfo,
Message: "No SCRAM-SHA-*-PLUS offered (channel binding).",
Fix: "Enable SCRAM-SHA-256-PLUS to protect against TLS MITM.",
})
}
if sawPlainOnly[ModeClient] {
issues = append(issues, Issue{
Code: CodeSASLPlainOnly,
Severity: SeverityCrit,
Message: "Only SASL PLAIN is offered on c2s.",
Fix: "Enable SCRAM-SHA-256 so credentials are not sent as a password-equivalent hash.",
})
}
}
// IPv6 coverage.
if data.Coverage.HasIPv4 && !data.Coverage.HasIPv6 {
issues = append(issues, Issue{
Code: CodeNoIPv6,
Severity: SeverityInfo,
Message: "No IPv6 endpoint reachable.",
Fix: "Publish AAAA records for the SRV targets.",
})
}
// XEP-0368 direct TLS coverage.
if wantC2S && sawAnyWorking[ModeClient] && len(data.SRV.ClientSecure) == 0 {
issues = append(issues, Issue{
Code: CodeNoDirectTLS,
Severity: SeverityInfo,
Message: "No XEP-0368 direct-TLS SRV record (_xmpps-client._tcp) published.",
Fix: "Publish _xmpps-client._tcp SRV records pointing at port 5223 to allow TLS from the first byte.",
})
}
return issues
}