package checker import ( "context" "crypto/tls" "encoding/xml" "errors" "fmt" "io" "net" "strconv" "strings" "time" "github.com/miekg/dns" sdk "git.happydns.org/checker-sdk-go/checker" ) const ( streamsNS = "http://etherx.jabber.org/streams" clientNS = "jabber:client" serverNS = "jabber:server" tlsNS = "urn:ietf:params:xml:ns:xmpp-tls" ) // tlsProbeConfig returns a deliberately permissive TLS config for probing. // // InsecureSkipVerify is intentional: certificate chain and hostname validation // is the TLS checker's responsibility. This checker only observes which TLS // versions and cipher suites a server accepts, then hands the endpoints to // checker-tls for the actual certificate audit. // // MinVersion is set to TLS 1.0 so we can observe whether a server still // accepts deprecated protocol versions: that is exactly what we want to // report. A strict client config would prevent us from reaching those servers // at all. func tlsProbeConfig(serverName string) *tls.Config { return &tls.Config{ ServerName: serverName, InsecureSkipVerify: true, //nolint:gosec MinVersion: tls.VersionTLS10, } } // Collect runs the full XMPP probe for a domain. func (p *xmppProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) { domain, _ := sdk.GetOption[string](opts, "domain") domain = strings.TrimSuffix(domain, ".") if domain == "" { return nil, fmt.Errorf("domain is required") } if err := validateDomain(domain); err != nil { return nil, err } mode, _ := sdk.GetOption[string](opts, "mode") if mode == "" { mode = "both" } timeoutSecs := sdk.GetFloatOption(opts, "timeout", 10) if timeoutSecs < 1 { timeoutSecs = 10 } perEndpoint := time.Duration(timeoutSecs * float64(time.Second)) wantC2S := mode != "s2s" wantS2S := mode != "c2s" data := &XMPPData{ Domain: domain, RunAt: time.Now().UTC().Format(time.RFC3339), SRV: SRVLookup{Errors: map[string]string{}}, } resolver := net.DefaultResolver lookupSets := []struct { prefix string want bool dst *[]SRVRecord }{ {"_xmpp-client._tcp.", wantC2S, &data.SRV.Client}, {"_xmpp-server._tcp.", wantS2S, &data.SRV.Server}, {"_xmpps-client._tcp.", wantC2S, &data.SRV.ClientSecure}, {"_xmpps-server._tcp.", wantS2S, &data.SRV.ServerSecure}, {"_jabber._tcp.", wantC2S, &data.SRV.Jabber}, } for _, ls := range lookupSets { if !ls.want { continue } records, err := lookupSRV(ctx, resolver, ls.prefix, domain) if err != nil { data.SRV.Errors[ls.prefix] = err.Error() continue } *ls.dst = records } totalSRV := len(data.SRV.Client) + len(data.SRV.Server) + len(data.SRV.ClientSecure) + len(data.SRV.ServerSecure) if totalSRV == 0 { data.SRV.FallbackProbed = true if wantC2S { data.SRV.Client = []SRVRecord{{Target: domain, Port: 5222}} } if wantS2S { data.SRV.Server = []SRVRecord{{Target: domain, Port: 5269}} } } resolveAllInto(ctx, resolver, data.SRV.Client) resolveAllInto(ctx, resolver, data.SRV.Server) resolveAllInto(ctx, resolver, data.SRV.ClientSecure) resolveAllInto(ctx, resolver, data.SRV.ServerSecure) probeSet(ctx, data, domain, ModeClient, "_xmpp-client._tcp", data.SRV.Client, false, perEndpoint) probeSet(ctx, data, domain, ModeServer, "_xmpp-server._tcp", data.SRV.Server, false, perEndpoint) probeSet(ctx, data, domain, ModeClient, "_xmpps-client._tcp", data.SRV.ClientSecure, true, perEndpoint) probeSet(ctx, data, domain, ModeServer, "_xmpps-server._tcp", data.SRV.ServerSecure, true, perEndpoint) computeCoverage(data) // Collect intentionally does not populate data.Issues; judging the raw // payload is the job of the CheckRules (see rules.go). return data, nil } func probeSet(ctx context.Context, data *XMPPData, domain string, mode XMPPMode, prefix string, records []SRVRecord, directTLS bool, timeout time.Duration) { for _, rec := range records { addrs := addressesForProbe(rec) if len(addrs) == 0 { ep := EndpointProbe{ Mode: mode, SRVPrefix: prefix, Target: rec.Target, Port: rec.Port, DirectTLS: directTLS, Error: "no A/AAAA records for target", } data.Endpoints = append(data.Endpoints, ep) continue } for _, a := range addrs { ep := probeEndpoint(ctx, domain, mode, prefix, rec, a.ip, a.isV6, directTLS, timeout) data.Endpoints = append(data.Endpoints, ep) } } } type probeAddr struct { ip string isV6 bool } func addressesForProbe(rec SRVRecord) []probeAddr { var out []probeAddr for _, ip := range rec.IPv4 { out = append(out, probeAddr{ip: ip, isV6: false}) } for _, ip := range rec.IPv6 { out = append(out, probeAddr{ip: ip, isV6: true}) } return out } func probeEndpoint(ctx context.Context, domain string, mode XMPPMode, prefix string, rec SRVRecord, ip string, isV6, directTLS bool, timeout time.Duration) EndpointProbe { start := time.Now() result := EndpointProbe{ Mode: mode, SRVPrefix: prefix, Target: rec.Target, Port: rec.Port, Address: net.JoinHostPort(ip, strconv.Itoa(int(rec.Port))), IsIPv6: isV6, DirectTLS: directTLS, } defer func() { result.ElapsedMS = time.Since(start).Milliseconds() }() ns := clientNS if mode == ModeServer { ns = serverNS } dialCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() dialer := &net.Dialer{} rawConn, err := dialer.DialContext(dialCtx, "tcp", result.Address) if err != nil { result.Error = "tcp: " + err.Error() return result } result.TCPConnected = true defer rawConn.Close() _ = rawConn.SetDeadline(time.Now().Add(timeout)) var conn net.Conn = rawConn if directTLS { tlsConn := tls.Client(rawConn, tlsProbeConfig(domain)) if err := tlsConn.Handshake(); err != nil { result.Error = "tls-handshake: " + err.Error() return result } result.STARTTLSUpgraded = true state := tlsConn.ConnectionState() result.TLSVersion = tls.VersionName(state.Version) result.TLSCipher = tls.CipherSuiteName(state.CipherSuite) _ = tlsConn.SetDeadline(time.Now().Add(timeout)) conn = tlsConn feats, from, err := openStreamAndReadFeatures(conn, domain, ns, mode == ModeServer) if err != nil { result.Error = "stream: " + err.Error() return result } result.StreamOpened = true result.StreamFrom = from applyFeatures(&result, feats) return result } dec, from, err := openStream(conn, domain, ns, mode == ModeServer) if err != nil { result.Error = "stream: " + err.Error() return result } result.StreamOpened = true result.StreamFrom = from feats, err := readFeatures(dec) if err != nil { result.Error = "features: " + err.Error() return result } result.STARTTLSOffered = feats.StartTLS != nil if feats.StartTLS != nil && feats.StartTLS.Required != nil { result.STARTTLSRequired = true } if !result.STARTTLSOffered { // Record any features seen in plaintext, but do not proceed; we // intentionally refuse to send SASL over a non-TLS channel. applyFeatures(&result, feats) return result } if _, err := io.WriteString(conn, ``); err != nil { result.Error = "starttls-write: " + err.Error() return result } if err := expectProceed(dec); err != nil { result.Error = "starttls-proceed: " + err.Error() return result } tlsConn := tls.Client(rawConn, tlsProbeConfig(domain)) _ = tlsConn.SetDeadline(time.Now().Add(timeout)) if err := tlsConn.Handshake(); err != nil { result.Error = "tls-handshake: " + err.Error() return result } result.STARTTLSUpgraded = true state := tlsConn.ConnectionState() result.TLSVersion = tls.VersionName(state.Version) result.TLSCipher = tls.CipherSuiteName(state.CipherSuite) _ = tlsConn.SetDeadline(time.Now().Add(timeout)) feats2, _, err := openStreamAndReadFeatures(tlsConn, domain, ns, mode == ModeServer) if err != nil { result.Error = "post-tls stream: " + err.Error() return result } applyFeatures(&result, feats2) return result } // applyFeatures copies parsed stream features into the probe result. func applyFeatures(ep *EndpointProbe, feats *streamFeatures) { if feats == nil { return } ep.FeaturesRead = true if feats.Mechanisms != nil { ep.SASLMechanisms = append(ep.SASLMechanisms, feats.Mechanisms.Mechanism...) for _, m := range feats.Mechanisms.Mechanism { if strings.EqualFold(m, "EXTERNAL") { ep.SASLExternal = true } } } if feats.Dialback != nil { ep.DialbackOffered = true } } type streamFeatures struct { XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` StartTLS *startTLSEl Mechanisms *mechanismsEl Dialback *struct{} `xml:"urn:xmpp:features:dialback dialback"` } type startTLSEl struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` Required *struct{} `xml:"required"` } type mechanismsEl struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` Mechanism []string `xml:"mechanism"` } // openStreamAndReadFeatures performs the stream header exchange and parses // . Used both for the initial open and for the post-TLS // stream restart. func openStreamAndReadFeatures(conn io.ReadWriter, domain, ns string, server bool) (*streamFeatures, string, error) { dec, from, err := openStream(conn, domain, ns, server) if err != nil { return nil, "", err } feats, err := readFeatures(dec) if err != nil { return nil, from, err } return feats, from, nil } func openStream(conn io.ReadWriter, domain, ns string, server bool) (*xml.Decoder, string, error) { var header string if server { header = fmt.Sprintf(``, ns, streamsNS, domain) } else { header = fmt.Sprintf(``, ns, streamsNS, domain) } if _, err := io.WriteString(conn, header); err != nil { return nil, "", fmt.Errorf("write header: %w", err) } dec := xml.NewDecoder(conn) for { tok, err := dec.Token() if err != nil { return nil, "", fmt.Errorf("read header: %w", err) } switch t := tok.(type) { case xml.StartElement: if t.Name.Space == streamsNS && t.Name.Local == "stream" { var from string for _, a := range t.Attr { if a.Name.Local == "from" { from = a.Value } } return dec, from, nil } if t.Name.Space == streamsNS && t.Name.Local == "error" { _ = dec.Skip() return nil, "", errors.New("server returned stream:error on open") } return nil, "", fmt.Errorf("unexpected element %s", t.Name.Local) } } } func readFeatures(dec *xml.Decoder) (*streamFeatures, error) { for { tok, err := dec.Token() if err != nil { return nil, fmt.Errorf("read features: %w", err) } se, ok := tok.(xml.StartElement) if !ok { continue } if se.Name.Space == streamsNS && se.Name.Local == "features" { var feats streamFeatures if err := dec.DecodeElement(&feats, &se); err != nil { return nil, fmt.Errorf("decode features: %w", err) } return &feats, nil } if se.Name.Space == streamsNS && se.Name.Local == "error" { _ = dec.Skip() return nil, errors.New("stream:error before features") } } } func expectProceed(dec *xml.Decoder) error { for { tok, err := dec.Token() if err != nil { return fmt.Errorf("read proceed: %w", err) } se, ok := tok.(xml.StartElement) if !ok { continue } if se.Name.Space == tlsNS { switch se.Name.Local { case "proceed": _ = dec.Skip() return nil case "failure": _ = dec.Skip() return errors.New("server refused STARTTLS ()") } } } } // validateDomain enforces RFC 1123 hostname rules before the value is used in // DNS lookups and embedded in the XMPP stream header. func validateDomain(domain string) error { if len(domain) > 253 { return fmt.Errorf("domain name too long (max 253 characters, got %d)", len(domain)) } for _, label := range strings.Split(domain, ".") { if len(label) == 0 { return fmt.Errorf("domain contains an empty label") } if len(label) > 63 { return fmt.Errorf("domain label %q exceeds 63 characters", label) } if label[0] == '-' || label[len(label)-1] == '-' { return fmt.Errorf("domain label %q must not start or end with a hyphen", label) } for _, c := range label { if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') { return fmt.Errorf("domain label %q contains invalid character %q", label, c) } } } return nil } func lookupSRV(ctx context.Context, r *net.Resolver, prefix, domain string) ([]SRVRecord, error) { name := prefix + dns.Fqdn(domain) _, records, err := r.LookupSRV(ctx, "", "", name) if err != nil { // Distinguish NXDOMAIN / no records from real errors. var dnsErr *net.DNSError if errors.As(err, &dnsErr) && (dnsErr.IsNotFound) { return nil, nil } return nil, err } // RFC 2782: single record "." with port 0 means "service explicitly not // available at this domain". We treat that as "no records" for probing. if len(records) == 1 && (records[0].Target == "." || records[0].Target == "") && records[0].Port == 0 { return nil, nil } out := make([]SRVRecord, 0, len(records)) for _, r := range records { out = append(out, SRVRecord{ Target: strings.TrimSuffix(r.Target, "."), Port: r.Port, Priority: r.Priority, Weight: r.Weight, }) } return out, nil } func resolveAllInto(ctx context.Context, r *net.Resolver, records []SRVRecord) { for i := range records { ips, err := r.LookupIPAddr(ctx, records[i].Target) if err != nil { continue } for _, ip := range ips { if v4 := ip.IP.To4(); v4 != nil { records[i].IPv4 = append(records[i].IPv4, v4.String()) } else { records[i].IPv6 = append(records[i].IPv6, ip.IP.String()) } } } } // computeCoverage walks raw endpoints and fills in the ReachabilitySpan // aggregate. It is still part of Collect because coverage is a raw summary // of what was actually reached, not a judgment (it has no severity). func computeCoverage(data *XMPPData) { for _, ep := range data.Endpoints { if ep.TCPConnected { if ep.IsIPv6 { data.Coverage.HasIPv6 = true } else { data.Coverage.HasIPv4 = true } } if !ep.STARTTLSUpgraded { continue } switch ep.Mode { case ModeClient: // c2s is reachable if SASL was advertised OR if STARTTLS // completed but features couldn't be read (benign for probes). if len(ep.SASLMechanisms) > 0 || !ep.FeaturesRead { data.Coverage.WorkingC2S = true } case ModeServer: // s2s reachable if TLS completed; the dialback/EXTERNAL // posture judgment is expressed by a rule, not here. data.Coverage.WorkingS2S = true } } }