511 lines
14 KiB
Go
511 lines
14 KiB
Go
package checker
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
|
|
sdk "git.happydns.org/checker-sdk-go/checker"
|
|
)
|
|
|
|
const (
|
|
streamsNS = "http://etherx.jabber.org/streams"
|
|
clientNS = "jabber:client"
|
|
serverNS = "jabber:server"
|
|
tlsNS = "urn:ietf:params:xml:ns:xmpp-tls"
|
|
)
|
|
|
|
// tlsProbeConfig returns a deliberately permissive TLS config for probing.
|
|
//
|
|
// InsecureSkipVerify is intentional: certificate chain and hostname validation
|
|
// is the TLS checker's responsibility. This checker only observes which TLS
|
|
// versions and cipher suites a server accepts, then hands the endpoints to
|
|
// checker-tls for the actual certificate audit.
|
|
//
|
|
// MinVersion is set to TLS 1.0 so we can observe whether a server still
|
|
// accepts deprecated protocol versions: that is exactly what we want to
|
|
// report. A strict client config would prevent us from reaching those servers
|
|
// at all.
|
|
func tlsProbeConfig(serverName string) *tls.Config {
|
|
return &tls.Config{
|
|
ServerName: serverName,
|
|
InsecureSkipVerify: true, //nolint:gosec
|
|
MinVersion: tls.VersionTLS10,
|
|
}
|
|
}
|
|
|
|
// Collect runs the full XMPP probe for a domain.
|
|
func (p *xmppProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
|
|
domain, _ := sdk.GetOption[string](opts, "domain")
|
|
domain = strings.TrimSuffix(domain, ".")
|
|
if domain == "" {
|
|
return nil, fmt.Errorf("domain is required")
|
|
}
|
|
if err := validateDomain(domain); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mode, _ := sdk.GetOption[string](opts, "mode")
|
|
if mode == "" {
|
|
mode = "both"
|
|
}
|
|
timeoutSecs := sdk.GetFloatOption(opts, "timeout", 10)
|
|
if timeoutSecs < 1 {
|
|
timeoutSecs = 10
|
|
}
|
|
perEndpoint := time.Duration(timeoutSecs * float64(time.Second))
|
|
|
|
wantC2S := mode != "s2s"
|
|
wantS2S := mode != "c2s"
|
|
|
|
data := &XMPPData{
|
|
Domain: domain,
|
|
RunAt: time.Now().UTC().Format(time.RFC3339),
|
|
SRV: SRVLookup{Errors: map[string]string{}},
|
|
}
|
|
|
|
resolver := net.DefaultResolver
|
|
|
|
lookupSets := []struct {
|
|
prefix string
|
|
want bool
|
|
dst *[]SRVRecord
|
|
}{
|
|
{"_xmpp-client._tcp.", wantC2S, &data.SRV.Client},
|
|
{"_xmpp-server._tcp.", wantS2S, &data.SRV.Server},
|
|
{"_xmpps-client._tcp.", wantC2S, &data.SRV.ClientSecure},
|
|
{"_xmpps-server._tcp.", wantS2S, &data.SRV.ServerSecure},
|
|
{"_jabber._tcp.", wantC2S, &data.SRV.Jabber},
|
|
}
|
|
for _, ls := range lookupSets {
|
|
if !ls.want {
|
|
continue
|
|
}
|
|
records, err := lookupSRV(ctx, resolver, ls.prefix, domain)
|
|
if err != nil {
|
|
data.SRV.Errors[ls.prefix] = err.Error()
|
|
continue
|
|
}
|
|
*ls.dst = records
|
|
}
|
|
|
|
totalSRV := len(data.SRV.Client) + len(data.SRV.Server) + len(data.SRV.ClientSecure) + len(data.SRV.ServerSecure)
|
|
if totalSRV == 0 {
|
|
data.SRV.FallbackProbed = true
|
|
if wantC2S {
|
|
data.SRV.Client = []SRVRecord{{Target: domain, Port: 5222}}
|
|
}
|
|
if wantS2S {
|
|
data.SRV.Server = []SRVRecord{{Target: domain, Port: 5269}}
|
|
}
|
|
}
|
|
|
|
resolveAllInto(ctx, resolver, data.SRV.Client)
|
|
resolveAllInto(ctx, resolver, data.SRV.Server)
|
|
resolveAllInto(ctx, resolver, data.SRV.ClientSecure)
|
|
resolveAllInto(ctx, resolver, data.SRV.ServerSecure)
|
|
|
|
probeSet(ctx, data, domain, ModeClient, "_xmpp-client._tcp", data.SRV.Client, false, perEndpoint)
|
|
probeSet(ctx, data, domain, ModeServer, "_xmpp-server._tcp", data.SRV.Server, false, perEndpoint)
|
|
probeSet(ctx, data, domain, ModeClient, "_xmpps-client._tcp", data.SRV.ClientSecure, true, perEndpoint)
|
|
probeSet(ctx, data, domain, ModeServer, "_xmpps-server._tcp", data.SRV.ServerSecure, true, perEndpoint)
|
|
|
|
computeCoverage(data)
|
|
|
|
// Collect intentionally does not populate data.Issues; judging the raw
|
|
// payload is the job of the CheckRules (see rules.go).
|
|
|
|
return data, nil
|
|
}
|
|
|
|
func probeSet(ctx context.Context, data *XMPPData, domain string, mode XMPPMode, prefix string, records []SRVRecord, directTLS bool, timeout time.Duration) {
|
|
for _, rec := range records {
|
|
addrs := addressesForProbe(rec)
|
|
if len(addrs) == 0 {
|
|
ep := EndpointProbe{
|
|
Mode: mode,
|
|
SRVPrefix: prefix,
|
|
Target: rec.Target,
|
|
Port: rec.Port,
|
|
DirectTLS: directTLS,
|
|
Error: "no A/AAAA records for target",
|
|
}
|
|
data.Endpoints = append(data.Endpoints, ep)
|
|
continue
|
|
}
|
|
for _, a := range addrs {
|
|
ep := probeEndpoint(ctx, domain, mode, prefix, rec, a.ip, a.isV6, directTLS, timeout)
|
|
data.Endpoints = append(data.Endpoints, ep)
|
|
}
|
|
}
|
|
}
|
|
|
|
type probeAddr struct {
|
|
ip string
|
|
isV6 bool
|
|
}
|
|
|
|
func addressesForProbe(rec SRVRecord) []probeAddr {
|
|
var out []probeAddr
|
|
for _, ip := range rec.IPv4 {
|
|
out = append(out, probeAddr{ip: ip, isV6: false})
|
|
}
|
|
for _, ip := range rec.IPv6 {
|
|
out = append(out, probeAddr{ip: ip, isV6: true})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func probeEndpoint(ctx context.Context, domain string, mode XMPPMode, prefix string, rec SRVRecord, ip string, isV6, directTLS bool, timeout time.Duration) EndpointProbe {
|
|
start := time.Now()
|
|
result := EndpointProbe{
|
|
Mode: mode,
|
|
SRVPrefix: prefix,
|
|
Target: rec.Target,
|
|
Port: rec.Port,
|
|
Address: net.JoinHostPort(ip, strconv.Itoa(int(rec.Port))),
|
|
IsIPv6: isV6,
|
|
DirectTLS: directTLS,
|
|
}
|
|
defer func() { result.ElapsedMS = time.Since(start).Milliseconds() }()
|
|
|
|
ns := clientNS
|
|
if mode == ModeServer {
|
|
ns = serverNS
|
|
}
|
|
|
|
dialCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
dialer := &net.Dialer{}
|
|
rawConn, err := dialer.DialContext(dialCtx, "tcp", result.Address)
|
|
if err != nil {
|
|
result.Error = "tcp: " + err.Error()
|
|
return result
|
|
}
|
|
result.TCPConnected = true
|
|
defer rawConn.Close()
|
|
_ = rawConn.SetDeadline(time.Now().Add(timeout))
|
|
|
|
var conn net.Conn = rawConn
|
|
|
|
if directTLS {
|
|
tlsConn := tls.Client(rawConn, tlsProbeConfig(domain))
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
result.Error = "tls-handshake: " + err.Error()
|
|
return result
|
|
}
|
|
result.STARTTLSUpgraded = true
|
|
state := tlsConn.ConnectionState()
|
|
result.TLSVersion = tls.VersionName(state.Version)
|
|
result.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
|
|
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
|
|
conn = tlsConn
|
|
|
|
feats, from, err := openStreamAndReadFeatures(conn, domain, ns, mode == ModeServer)
|
|
if err != nil {
|
|
result.Error = "stream: " + err.Error()
|
|
return result
|
|
}
|
|
result.StreamOpened = true
|
|
result.StreamFrom = from
|
|
applyFeatures(&result, feats)
|
|
return result
|
|
}
|
|
|
|
dec, from, err := openStream(conn, domain, ns, mode == ModeServer)
|
|
if err != nil {
|
|
result.Error = "stream: " + err.Error()
|
|
return result
|
|
}
|
|
result.StreamOpened = true
|
|
result.StreamFrom = from
|
|
|
|
feats, err := readFeatures(dec)
|
|
if err != nil {
|
|
result.Error = "features: " + err.Error()
|
|
return result
|
|
}
|
|
result.STARTTLSOffered = feats.StartTLS != nil
|
|
if feats.StartTLS != nil && feats.StartTLS.Required != nil {
|
|
result.STARTTLSRequired = true
|
|
}
|
|
|
|
if !result.STARTTLSOffered {
|
|
// Record any features seen in plaintext, but do not proceed; we
|
|
// intentionally refuse to send SASL over a non-TLS channel.
|
|
applyFeatures(&result, feats)
|
|
return result
|
|
}
|
|
|
|
if _, err := io.WriteString(conn, `<starttls xmlns='`+tlsNS+`'/>`); err != nil {
|
|
result.Error = "starttls-write: " + err.Error()
|
|
return result
|
|
}
|
|
if err := expectProceed(dec); err != nil {
|
|
result.Error = "starttls-proceed: " + err.Error()
|
|
return result
|
|
}
|
|
|
|
tlsConn := tls.Client(rawConn, tlsProbeConfig(domain))
|
|
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
result.Error = "tls-handshake: " + err.Error()
|
|
return result
|
|
}
|
|
result.STARTTLSUpgraded = true
|
|
state := tlsConn.ConnectionState()
|
|
result.TLSVersion = tls.VersionName(state.Version)
|
|
result.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
|
|
|
|
_ = tlsConn.SetDeadline(time.Now().Add(timeout))
|
|
feats2, _, err := openStreamAndReadFeatures(tlsConn, domain, ns, mode == ModeServer)
|
|
if err != nil {
|
|
result.Error = "post-tls stream: " + err.Error()
|
|
return result
|
|
}
|
|
applyFeatures(&result, feats2)
|
|
return result
|
|
}
|
|
|
|
// applyFeatures copies parsed stream features into the probe result.
|
|
func applyFeatures(ep *EndpointProbe, feats *streamFeatures) {
|
|
if feats == nil {
|
|
return
|
|
}
|
|
ep.FeaturesRead = true
|
|
if feats.Mechanisms != nil {
|
|
ep.SASLMechanisms = append(ep.SASLMechanisms, feats.Mechanisms.Mechanism...)
|
|
for _, m := range feats.Mechanisms.Mechanism {
|
|
if strings.EqualFold(m, "EXTERNAL") {
|
|
ep.SASLExternal = true
|
|
}
|
|
}
|
|
}
|
|
if feats.Dialback != nil {
|
|
ep.DialbackOffered = true
|
|
}
|
|
}
|
|
|
|
type streamFeatures struct {
|
|
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
|
|
StartTLS *startTLSEl
|
|
Mechanisms *mechanismsEl
|
|
Dialback *struct{} `xml:"urn:xmpp:features:dialback dialback"`
|
|
}
|
|
|
|
type startTLSEl struct {
|
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
|
Required *struct{} `xml:"required"`
|
|
}
|
|
|
|
type mechanismsEl struct {
|
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
|
|
Mechanism []string `xml:"mechanism"`
|
|
}
|
|
|
|
// openStreamAndReadFeatures performs the stream header exchange and parses
|
|
// <stream:features>. Used both for the initial open and for the post-TLS
|
|
// stream restart.
|
|
func openStreamAndReadFeatures(conn io.ReadWriter, domain, ns string, server bool) (*streamFeatures, string, error) {
|
|
dec, from, err := openStream(conn, domain, ns, server)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
feats, err := readFeatures(dec)
|
|
if err != nil {
|
|
return nil, from, err
|
|
}
|
|
return feats, from, nil
|
|
}
|
|
|
|
func openStream(conn io.ReadWriter, domain, ns string, server bool) (*xml.Decoder, string, error) {
|
|
var header string
|
|
if server {
|
|
header = fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='%s' xmlns:db='jabber:server:dialback' version='1.0' to='%s'>`, ns, streamsNS, domain)
|
|
} else {
|
|
header = fmt.Sprintf(`<?xml version='1.0'?><stream:stream xmlns='%s' xmlns:stream='%s' version='1.0' to='%s'>`, ns, streamsNS, domain)
|
|
}
|
|
if _, err := io.WriteString(conn, header); err != nil {
|
|
return nil, "", fmt.Errorf("write header: %w", err)
|
|
}
|
|
dec := xml.NewDecoder(conn)
|
|
for {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("read header: %w", err)
|
|
}
|
|
switch t := tok.(type) {
|
|
case xml.StartElement:
|
|
if t.Name.Space == streamsNS && t.Name.Local == "stream" {
|
|
var from string
|
|
for _, a := range t.Attr {
|
|
if a.Name.Local == "from" {
|
|
from = a.Value
|
|
}
|
|
}
|
|
return dec, from, nil
|
|
}
|
|
if t.Name.Space == streamsNS && t.Name.Local == "error" {
|
|
_ = dec.Skip()
|
|
return nil, "", errors.New("server returned stream:error on open")
|
|
}
|
|
return nil, "", fmt.Errorf("unexpected element %s", t.Name.Local)
|
|
}
|
|
}
|
|
}
|
|
|
|
func readFeatures(dec *xml.Decoder) (*streamFeatures, error) {
|
|
for {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read features: %w", err)
|
|
}
|
|
se, ok := tok.(xml.StartElement)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if se.Name.Space == streamsNS && se.Name.Local == "features" {
|
|
var feats streamFeatures
|
|
if err := dec.DecodeElement(&feats, &se); err != nil {
|
|
return nil, fmt.Errorf("decode features: %w", err)
|
|
}
|
|
return &feats, nil
|
|
}
|
|
if se.Name.Space == streamsNS && se.Name.Local == "error" {
|
|
_ = dec.Skip()
|
|
return nil, errors.New("stream:error before features")
|
|
}
|
|
}
|
|
}
|
|
|
|
func expectProceed(dec *xml.Decoder) error {
|
|
for {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return fmt.Errorf("read proceed: %w", err)
|
|
}
|
|
se, ok := tok.(xml.StartElement)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if se.Name.Space == tlsNS {
|
|
switch se.Name.Local {
|
|
case "proceed":
|
|
_ = dec.Skip()
|
|
return nil
|
|
case "failure":
|
|
_ = dec.Skip()
|
|
return errors.New("server refused STARTTLS (<failure/>)")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// validateDomain enforces RFC 1123 hostname rules before the value is used in
|
|
// DNS lookups and embedded in the XMPP stream header.
|
|
func validateDomain(domain string) error {
|
|
if len(domain) > 253 {
|
|
return fmt.Errorf("domain name too long (max 253 characters, got %d)", len(domain))
|
|
}
|
|
for _, label := range strings.Split(domain, ".") {
|
|
if len(label) == 0 {
|
|
return fmt.Errorf("domain contains an empty label")
|
|
}
|
|
if len(label) > 63 {
|
|
return fmt.Errorf("domain label %q exceeds 63 characters", label)
|
|
}
|
|
if label[0] == '-' || label[len(label)-1] == '-' {
|
|
return fmt.Errorf("domain label %q must not start or end with a hyphen", label)
|
|
}
|
|
for _, c := range label {
|
|
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') {
|
|
return fmt.Errorf("domain label %q contains invalid character %q", label, c)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func lookupSRV(ctx context.Context, r *net.Resolver, prefix, domain string) ([]SRVRecord, error) {
|
|
name := prefix + dns.Fqdn(domain)
|
|
_, records, err := r.LookupSRV(ctx, "", "", name)
|
|
if err != nil {
|
|
// Distinguish NXDOMAIN / no records from real errors.
|
|
var dnsErr *net.DNSError
|
|
if errors.As(err, &dnsErr) && (dnsErr.IsNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
// RFC 2782: single record "." with port 0 means "service explicitly not
|
|
// available at this domain". We treat that as "no records" for probing.
|
|
if len(records) == 1 && (records[0].Target == "." || records[0].Target == "") && records[0].Port == 0 {
|
|
return nil, nil
|
|
}
|
|
out := make([]SRVRecord, 0, len(records))
|
|
for _, r := range records {
|
|
out = append(out, SRVRecord{
|
|
Target: strings.TrimSuffix(r.Target, "."),
|
|
Port: r.Port,
|
|
Priority: r.Priority,
|
|
Weight: r.Weight,
|
|
})
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func resolveAllInto(ctx context.Context, r *net.Resolver, records []SRVRecord) {
|
|
for i := range records {
|
|
ips, err := r.LookupIPAddr(ctx, records[i].Target)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
for _, ip := range ips {
|
|
if v4 := ip.IP.To4(); v4 != nil {
|
|
records[i].IPv4 = append(records[i].IPv4, v4.String())
|
|
} else {
|
|
records[i].IPv6 = append(records[i].IPv6, ip.IP.String())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// computeCoverage walks raw endpoints and fills in the ReachabilitySpan
|
|
// aggregate. It is still part of Collect because coverage is a raw summary
|
|
// of what was actually reached, not a judgment (it has no severity).
|
|
func computeCoverage(data *XMPPData) {
|
|
for _, ep := range data.Endpoints {
|
|
if ep.TCPConnected {
|
|
if ep.IsIPv6 {
|
|
data.Coverage.HasIPv6 = true
|
|
} else {
|
|
data.Coverage.HasIPv4 = true
|
|
}
|
|
}
|
|
if !ep.STARTTLSUpgraded {
|
|
continue
|
|
}
|
|
switch ep.Mode {
|
|
case ModeClient:
|
|
// c2s is reachable if SASL was advertised OR if STARTTLS
|
|
// completed but features couldn't be read (benign for probes).
|
|
if len(ep.SASLMechanisms) > 0 || !ep.FeaturesRead {
|
|
data.Coverage.WorkingC2S = true
|
|
}
|
|
case ModeServer:
|
|
// s2s reachable if TLS completed; the dialback/EXTERNAL
|
|
// posture judgment is expressed by a rule, not here.
|
|
data.Coverage.WorkingS2S = true
|
|
}
|
|
}
|
|
}
|