checker-xmpp/checker/collect.go

511 lines
14 KiB
Go

package checker
import (
"context"
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
const (
streamsNS = "http://etherx.jabber.org/streams"
clientNS = "jabber:client"
serverNS = "jabber:server"
tlsNS = "urn:ietf:params:xml:ns:xmpp-tls"
)
// tlsProbeConfig returns a deliberately permissive TLS config for probing.
//
// InsecureSkipVerify is intentional: certificate chain and hostname validation
// is the TLS checker's responsibility. This checker only observes which TLS
// versions and cipher suites a server accepts, then hands the endpoints to
// checker-tls for the actual certificate audit.
//
// MinVersion is set to TLS 1.0 so we can observe whether a server still
// accepts deprecated protocol versions: that is exactly what we want to
// report. A strict client config would prevent us from reaching those servers
// at all.
func tlsProbeConfig(serverName string) *tls.Config {
return &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true, //nolint:gosec
MinVersion: tls.VersionTLS10,
}
}
// Collect runs the full XMPP probe for a domain.
func (p *xmppProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
domain, _ := sdk.GetOption[string](opts, "domain")
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return nil, fmt.Errorf("domain is required")
}
if err := validateDomain(domain); err != nil {
return nil, err
}
mode, _ := sdk.GetOption[string](opts, "mode")
if mode == "" {
mode = "both"
}
timeoutSecs := sdk.GetFloatOption(opts, "timeout", 10)
if timeoutSecs < 1 {
timeoutSecs = 10
}
perEndpoint := time.Duration(timeoutSecs * float64(time.Second))
wantC2S := mode != "s2s"
wantS2S := mode != "c2s"
data := &XMPPData{
Domain: domain,
RunAt: time.Now().UTC().Format(time.RFC3339),
SRV: SRVLookup{Errors: map[string]string{}},
}
resolver := net.DefaultResolver
lookupSets := []struct {
prefix string
want bool
dst *[]SRVRecord
}{
{"_xmpp-client._tcp.", wantC2S, &data.SRV.Client},
{"_xmpp-server._tcp.", wantS2S, &data.SRV.Server},
{"_xmpps-client._tcp.", wantC2S, &data.SRV.ClientSecure},
{"_xmpps-server._tcp.", wantS2S, &data.SRV.ServerSecure},
{"_jabber._tcp.", wantC2S, &data.SRV.Jabber},
}
for _, ls := range lookupSets {
if !ls.want {
continue
}
records, err := lookupSRV(ctx, resolver, ls.prefix, domain)
if err != nil {
data.SRV.Errors[ls.prefix] = err.Error()
continue
}
*ls.dst = records
}
totalSRV := len(data.SRV.Client) + len(data.SRV.Server) + len(data.SRV.ClientSecure) + len(data.SRV.ServerSecure)
if totalSRV == 0 {
data.SRV.FallbackProbed = true
if wantC2S {
data.SRV.Client = []SRVRecord{{Target: domain, Port: 5222}}
}
if wantS2S {
data.SRV.Server = []SRVRecord{{Target: domain, Port: 5269}}
}
}
resolveAllInto(ctx, resolver, data.SRV.Client)
resolveAllInto(ctx, resolver, data.SRV.Server)
resolveAllInto(ctx, resolver, data.SRV.ClientSecure)
resolveAllInto(ctx, resolver, data.SRV.ServerSecure)
probeSet(ctx, data, domain, ModeClient, "_xmpp-client._tcp", data.SRV.Client, false, perEndpoint)
probeSet(ctx, data, domain, ModeServer, "_xmpp-server._tcp", data.SRV.Server, false, perEndpoint)
probeSet(ctx, data, domain, ModeClient, "_xmpps-client._tcp", data.SRV.ClientSecure, true, perEndpoint)
probeSet(ctx, data, domain, ModeServer, "_xmpps-server._tcp", data.SRV.ServerSecure, true, perEndpoint)
computeCoverage(data)
// Collect intentionally does not populate data.Issues; judging the raw
// payload is the job of the CheckRules (see rules.go).
return data, nil
}
func probeSet(ctx context.Context, data *XMPPData, domain string, mode XMPPMode, prefix string, records []SRVRecord, directTLS bool, timeout time.Duration) {
for _, rec := range records {
addrs := addressesForProbe(rec)
if len(addrs) == 0 {
ep := EndpointProbe{
Mode: mode,
SRVPrefix: prefix,
Target: rec.Target,
Port: rec.Port,
DirectTLS: directTLS,
Error: "no A/AAAA records for target",
}
data.Endpoints = append(data.Endpoints, ep)
continue
}
for _, a := range addrs {
ep := probeEndpoint(ctx, domain, mode, prefix, rec, a.ip, a.isV6, directTLS, timeout)
data.Endpoints = append(data.Endpoints, ep)
}
}
}
type probeAddr struct {
ip string
isV6 bool
}
func addressesForProbe(rec SRVRecord) []probeAddr {
var out []probeAddr
for _, ip := range rec.IPv4 {
out = append(out, probeAddr{ip: ip, isV6: false})
}
for _, ip := range rec.IPv6 {
out = append(out, probeAddr{ip: ip, isV6: true})
}
return out
}
func probeEndpoint(ctx context.Context, domain string, mode XMPPMode, prefix string, rec SRVRecord, ip string, isV6, directTLS bool, timeout time.Duration) EndpointProbe {
start := time.Now()
result := EndpointProbe{
Mode: mode,
SRVPrefix: prefix,
Target: rec.Target,
Port: rec.Port,
Address: net.JoinHostPort(ip, strconv.Itoa(int(rec.Port))),
IsIPv6: isV6,
DirectTLS: directTLS,
}
defer func() { result.ElapsedMS = time.Since(start).Milliseconds() }()
ns := clientNS
if mode == ModeServer {
ns = serverNS
}
dialCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
dialer := &net.Dialer{}
rawConn, err := dialer.DialContext(dialCtx, "tcp", result.Address)
if err != nil {
result.Error = "tcp: " + err.Error()
return result
}
result.TCPConnected = true
defer rawConn.Close()
_ = rawConn.SetDeadline(time.Now().Add(timeout))
var conn net.Conn = rawConn
if directTLS {
tlsConn := tls.Client(rawConn, tlsProbeConfig(domain))
if err := tlsConn.Handshake(); err != nil {
result.Error = "tls-handshake: " + err.Error()
return result
}
result.STARTTLSUpgraded = true
state := tlsConn.ConnectionState()
result.TLSVersion = tls.VersionName(state.Version)
result.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
conn = tlsConn
feats, from, err := openStreamAndReadFeatures(conn, domain, ns, mode == ModeServer)
if err != nil {
result.Error = "stream: " + err.Error()
return result
}
result.StreamOpened = true
result.StreamFrom = from
applyFeatures(&result, feats)
return result
}
dec, from, err := openStream(conn, domain, ns, mode == ModeServer)
if err != nil {
result.Error = "stream: " + err.Error()
return result
}
result.StreamOpened = true
result.StreamFrom = from
feats, err := readFeatures(dec)
if err != nil {
result.Error = "features: " + err.Error()
return result
}
result.STARTTLSOffered = feats.StartTLS != nil
if feats.StartTLS != nil && feats.StartTLS.Required != nil {
result.STARTTLSRequired = true
}
if !result.STARTTLSOffered {
// Record any features seen in plaintext, but do not proceed; we
// intentionally refuse to send SASL over a non-TLS channel.
applyFeatures(&result, feats)
return result
}
if _, err := io.WriteString(conn, `<starttls xmlns='`+tlsNS+`'/>`); err != nil {
result.Error = "starttls-write: " + err.Error()
return result
}
if err := expectProceed(dec); err != nil {
result.Error = "starttls-proceed: " + err.Error()
return result
}
tlsConn := tls.Client(rawConn, tlsProbeConfig(domain))
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
if err := tlsConn.Handshake(); err != nil {
result.Error = "tls-handshake: " + err.Error()
return result
}
result.STARTTLSUpgraded = true
state := tlsConn.ConnectionState()
result.TLSVersion = tls.VersionName(state.Version)
result.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
feats2, _, err := openStreamAndReadFeatures(tlsConn, domain, ns, mode == ModeServer)
if err != nil {
result.Error = "post-tls stream: " + err.Error()
return result
}
applyFeatures(&result, feats2)
return result
}
// applyFeatures copies parsed stream features into the probe result.
func applyFeatures(ep *EndpointProbe, feats *streamFeatures) {
if feats == nil {
return
}
ep.FeaturesRead = true
if feats.Mechanisms != nil {
ep.SASLMechanisms = append(ep.SASLMechanisms, feats.Mechanisms.Mechanism...)
for _, m := range feats.Mechanisms.Mechanism {
if strings.EqualFold(m, "EXTERNAL") {
ep.SASLExternal = true
}
}
}
if feats.Dialback != nil {
ep.DialbackOffered = true
}
}
type streamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
StartTLS *startTLSEl
Mechanisms *mechanismsEl
Dialback *struct{} `xml:"urn:xmpp:features:dialback dialback"`
}
type startTLSEl struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required *struct{} `xml:"required"`
}
type mechanismsEl struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// openStreamAndReadFeatures performs the stream header exchange and parses
// <stream:features>. Used both for the initial open and for the post-TLS
// stream restart.
func openStreamAndReadFeatures(conn io.ReadWriter, domain, ns string, server bool) (*streamFeatures, string, error) {
dec, from, err := openStream(conn, domain, ns, server)
if err != nil {
return nil, "", err
}
feats, err := readFeatures(dec)
if err != nil {
return nil, from, err
}
return feats, from, nil
}
func openStream(conn io.ReadWriter, domain, ns string, server bool) (*xml.Decoder, string, error) {
var header string
if server {
header = fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='%s' xmlns:db='jabber:server:dialback' version='1.0' to='%s'>`, ns, streamsNS, domain)
} else {
header = fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='%s' version='1.0' to='%s'>`, ns, streamsNS, domain)
}
if _, err := io.WriteString(conn, header); err != nil {
return nil, "", fmt.Errorf("write header: %w", err)
}
dec := xml.NewDecoder(conn)
for {
tok, err := dec.Token()
if err != nil {
return nil, "", fmt.Errorf("read header: %w", err)
}
switch t := tok.(type) {
case xml.StartElement:
if t.Name.Space == streamsNS && t.Name.Local == "stream" {
var from string
for _, a := range t.Attr {
if a.Name.Local == "from" {
from = a.Value
}
}
return dec, from, nil
}
if t.Name.Space == streamsNS && t.Name.Local == "error" {
_ = dec.Skip()
return nil, "", errors.New("server returned stream:error on open")
}
return nil, "", fmt.Errorf("unexpected element %s", t.Name.Local)
}
}
}
func readFeatures(dec *xml.Decoder) (*streamFeatures, error) {
for {
tok, err := dec.Token()
if err != nil {
return nil, fmt.Errorf("read features: %w", err)
}
se, ok := tok.(xml.StartElement)
if !ok {
continue
}
if se.Name.Space == streamsNS && se.Name.Local == "features" {
var feats streamFeatures
if err := dec.DecodeElement(&feats, &se); err != nil {
return nil, fmt.Errorf("decode features: %w", err)
}
return &feats, nil
}
if se.Name.Space == streamsNS && se.Name.Local == "error" {
_ = dec.Skip()
return nil, errors.New("stream:error before features")
}
}
}
func expectProceed(dec *xml.Decoder) error {
for {
tok, err := dec.Token()
if err != nil {
return fmt.Errorf("read proceed: %w", err)
}
se, ok := tok.(xml.StartElement)
if !ok {
continue
}
if se.Name.Space == tlsNS {
switch se.Name.Local {
case "proceed":
_ = dec.Skip()
return nil
case "failure":
_ = dec.Skip()
return errors.New("server refused STARTTLS (<failure/>)")
}
}
}
}
// validateDomain enforces RFC 1123 hostname rules before the value is used in
// DNS lookups and embedded in the XMPP stream header.
func validateDomain(domain string) error {
if len(domain) > 253 {
return fmt.Errorf("domain name too long (max 253 characters, got %d)", len(domain))
}
for _, label := range strings.Split(domain, ".") {
if len(label) == 0 {
return fmt.Errorf("domain contains an empty label")
}
if len(label) > 63 {
return fmt.Errorf("domain label %q exceeds 63 characters", label)
}
if label[0] == '-' || label[len(label)-1] == '-' {
return fmt.Errorf("domain label %q must not start or end with a hyphen", label)
}
for _, c := range label {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') {
return fmt.Errorf("domain label %q contains invalid character %q", label, c)
}
}
}
return nil
}
func lookupSRV(ctx context.Context, r *net.Resolver, prefix, domain string) ([]SRVRecord, error) {
name := prefix + dns.Fqdn(domain)
_, records, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
// Distinguish NXDOMAIN / no records from real errors.
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && (dnsErr.IsNotFound) {
return nil, nil
}
return nil, err
}
// RFC 2782: single record "." with port 0 means "service explicitly not
// available at this domain". We treat that as "no records" for probing.
if len(records) == 1 && (records[0].Target == "." || records[0].Target == "") && records[0].Port == 0 {
return nil, nil
}
out := make([]SRVRecord, 0, len(records))
for _, r := range records {
out = append(out, SRVRecord{
Target: strings.TrimSuffix(r.Target, "."),
Port: r.Port,
Priority: r.Priority,
Weight: r.Weight,
})
}
return out, nil
}
func resolveAllInto(ctx context.Context, r *net.Resolver, records []SRVRecord) {
for i := range records {
ips, err := r.LookupIPAddr(ctx, records[i].Target)
if err != nil {
continue
}
for _, ip := range ips {
if v4 := ip.IP.To4(); v4 != nil {
records[i].IPv4 = append(records[i].IPv4, v4.String())
} else {
records[i].IPv6 = append(records[i].IPv6, ip.IP.String())
}
}
}
}
// computeCoverage walks raw endpoints and fills in the ReachabilitySpan
// aggregate. It is still part of Collect because coverage is a raw summary
// of what was actually reached, not a judgment (it has no severity).
func computeCoverage(data *XMPPData) {
for _, ep := range data.Endpoints {
if ep.TCPConnected {
if ep.IsIPv6 {
data.Coverage.HasIPv6 = true
} else {
data.Coverage.HasIPv4 = true
}
}
if !ep.STARTTLSUpgraded {
continue
}
switch ep.Mode {
case ModeClient:
// c2s is reachable if SASL was advertised OR if STARTTLS
// completed but features couldn't be read (benign for probes).
if len(ep.SASLMechanisms) > 0 || !ep.FeaturesRead {
data.Coverage.WorkingC2S = true
}
case ModeServer:
// s2s reachable if TLS completed; the dialback/EXTERNAL
// posture judgment is expressed by a rule, not here.
data.Coverage.WorkingS2S = true
}
}
}