checker-sip/checker/collect.go

422 lines
12 KiB
Go

package checker
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
// Collect runs the full SIP probe against a domain.
func (p *sipProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
domain, _ := sdk.GetOption[string](opts, "domain")
domain = strings.TrimSuffix(strings.TrimSpace(domain), ".")
if domain == "" {
return nil, fmt.Errorf("domain is required")
}
timeoutSecs := sdk.GetFloatOption(opts, "timeout", 5)
if timeoutSecs < 1 {
timeoutSecs = 5
}
perEndpoint := time.Duration(timeoutSecs * float64(time.Second))
probeUDP := sdk.GetBoolOption(opts, "probeUDP", true)
probeTCP := sdk.GetBoolOption(opts, "probeTCP", true)
probeTLS := sdk.GetBoolOption(opts, "probeTLS", true)
data := &SIPData{
Domain: domain,
RunAt: time.Now().UTC().Format(time.RFC3339),
SRV: SRVLookup{Errors: map[string]string{}},
}
resolver := net.DefaultResolver
// NAPTR lookup, best-effort, failures become an info issue.
if naptr, err := lookupNAPTR(ctx, domain); err != nil {
data.SRV.Errors["naptr"] = err.Error()
} else {
data.NAPTR = naptr
}
// SRV lookups (per transport). Errors are kept per-prefix; "not
// found" is normalised to nil by lookupSRV.
type srvSet struct {
prefix string
want bool
dst *[]SRVRecord
}
sets := []srvSet{
{"_sip._udp.", probeUDP, &data.SRV.UDP},
{"_sip._tcp.", probeTCP, &data.SRV.TCP},
{"_sips._tcp.", probeTLS, &data.SRV.SIPS},
}
for _, s := range sets {
if !s.want {
continue
}
recs, err := lookupSRV(ctx, resolver, s.prefix, domain)
if err != nil {
data.SRV.Errors[s.prefix] = err.Error()
continue
}
if recs != nil {
*s.dst = recs
}
}
// Fallback when no SRV at all: synthesize a single target on each
// enabled transport against the bare domain.
total := len(data.SRV.UDP) + len(data.SRV.TCP) + len(data.SRV.SIPS)
if total == 0 {
data.SRV.FallbackProbed = true
if probeUDP {
data.SRV.UDP = []SRVRecord{{Target: domain, Port: 5060}}
}
if probeTCP {
data.SRV.TCP = []SRVRecord{{Target: domain, Port: 5060}}
}
if probeTLS {
data.SRV.SIPS = []SRVRecord{{Target: domain, Port: 5061}}
}
}
type transportJob struct {
records []SRVRecord
prefix string
t Transport
}
jobs := []transportJob{
{data.SRV.UDP, "_sip._udp.", TransportUDP},
{data.SRV.TCP, "_sip._tcp.", TransportTCP},
{data.SRV.SIPS, "_sips._tcp.", TransportTLS},
}
var wg sync.WaitGroup
var mu sync.Mutex
for _, job := range jobs {
wg.Add(1)
go func(j transportJob) {
defer wg.Done()
resolveAllInto(ctx, resolver, j.records)
eps := probeSet(ctx, j.prefix, j.t, j.records, perEndpoint)
mu.Lock()
data.Endpoints = append(data.Endpoints, eps...)
mu.Unlock()
}(job)
}
wg.Wait()
return data, nil
}
// ─── DNS ──────────────────────────────────────────────────────────────
func lookupSRV(ctx context.Context, r *net.Resolver, prefix, domain string) ([]SRVRecord, error) {
name := prefix + dns.Fqdn(domain)
_, records, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return nil, nil
}
return nil, err
}
// RFC 2782 null-target: single "." record with port 0 means
// "service explicitly unavailable". The Go resolver normalises to ".",
// but we also accept "" defensively.
if len(records) == 1 && records[0].Port == 0 && (records[0].Target == "." || records[0].Target == "") {
return nil, nil
}
out := make([]SRVRecord, 0, len(records))
for _, r := range records {
out = append(out, SRVRecord{
Target: strings.TrimSuffix(r.Target, "."),
Port: r.Port,
Priority: r.Priority,
Weight: r.Weight,
})
}
return out, nil
}
func lookupNAPTR(ctx context.Context, domain string) ([]NAPTRRecord, error) {
cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil || cfg == nil || len(cfg.Servers) == 0 {
log.Printf("checker-sip: /etc/resolv.conf unusable (%v), falling back to public resolvers 1.1.1.1/8.8.8.8 for NAPTR lookup of %s", err, domain)
cfg = &dns.ClientConfig{Servers: []string{"1.1.1.1", "8.8.8.8"}, Port: "53"}
}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeNAPTR)
m.RecursionDesired = true
// Ask a validating resolver to perform DNSSEC validation and signal
// the result via the AD bit. EDNS0 with DO=1 is required for the
// resolver to honour AD on the response.
m.AuthenticatedData = true
m.SetEdns0(4096, true)
c := new(dns.Client)
// Split the caller's deadline across the configured resolvers so a
// single slow server can't consume the whole context budget. Falls
// back to 3s per server when ctx has no deadline.
perServer := 3 * time.Second
if dl, ok := ctx.Deadline(); ok {
if remaining := time.Until(dl); remaining > 0 {
perServer = remaining / time.Duration(len(cfg.Servers))
}
}
var lastErr error
for _, srv := range cfg.Servers {
qctx, cancel := context.WithTimeout(ctx, perServer)
addr := net.JoinHostPort(srv, cfg.Port)
in, _, err := c.ExchangeContext(qctx, m, addr)
cancel()
if err != nil {
lastErr = err
continue
}
if in.Rcode == dns.RcodeServerFailure {
lastErr = fmt.Errorf("SERVFAIL from %s (possible DNSSEC validation failure)", srv)
continue
}
if in.Rcode == dns.RcodeNameError {
return nil, nil
}
if in.Rcode != dns.RcodeSuccess {
lastErr = fmt.Errorf("rcode %s", dns.RcodeToString[in.Rcode])
continue
}
var out []NAPTRRecord
for _, rr := range in.Answer {
n, ok := rr.(*dns.NAPTR)
if !ok {
continue
}
if !strings.HasPrefix(strings.ToUpper(n.Service), "SIP+") && !strings.HasPrefix(strings.ToUpper(n.Service), "SIPS+") {
continue
}
out = append(out, NAPTRRecord{
Service: n.Service,
Regexp: n.Regexp,
Replacement: strings.TrimSuffix(n.Replacement, "."),
Flags: n.Flags,
Order: n.Order,
Preference: n.Preference,
})
}
return out, nil
}
return nil, lastErr
}
func resolveAllInto(ctx context.Context, r *net.Resolver, records []SRVRecord) {
for i := range records {
ips, err := r.LookupIPAddr(ctx, records[i].Target)
if err != nil {
continue
}
for _, ip := range ips {
if v4 := ip.IP.To4(); v4 != nil {
records[i].IPv4 = append(records[i].IPv4, v4.String())
} else {
records[i].IPv6 = append(records[i].IPv6, ip.IP.String())
}
}
}
}
// ─── Probing ──────────────────────────────────────────────────────────
func probeSet(ctx context.Context, prefix string, t Transport, records []SRVRecord, timeout time.Duration) []EndpointProbe {
var eps []EndpointProbe
for _, rec := range records {
addrs := allAddrs(rec)
if len(addrs) == 0 {
eps = append(eps, EndpointProbe{
Transport: t,
SRVPrefix: prefix,
Target: rec.Target,
Port: rec.Port,
Error: "no A/AAAA records for target",
})
continue
}
for _, a := range addrs {
eps = append(eps, probeEndpoint(ctx, t, prefix, rec, a, timeout))
}
}
return eps
}
type probeAddr struct {
ip string
isV6 bool
}
func allAddrs(r SRVRecord) []probeAddr {
out := make([]probeAddr, 0, len(r.IPv4)+len(r.IPv6))
for _, ip := range r.IPv4 {
out = append(out, probeAddr{ip: ip, isV6: false})
}
for _, ip := range r.IPv6 {
out = append(out, probeAddr{ip: ip, isV6: true})
}
return out
}
func probeEndpoint(ctx context.Context, t Transport, prefix string, rec SRVRecord, a probeAddr, timeout time.Duration) (ep EndpointProbe) {
start := time.Now()
addrPort := net.JoinHostPort(a.ip, strconv.Itoa(int(rec.Port)))
ep = EndpointProbe{
Transport: t,
SRVPrefix: prefix,
Target: rec.Target,
Port: rec.Port,
Address: addrPort,
IsIPv6: a.isV6,
}
defer func() { ep.ElapsedMS = time.Since(start).Milliseconds() }()
ua := "happyDomain-checker-sip/" + Version
switch t {
case TransportUDP:
probeUDP(ctx, &ep, rec.Target, ua, timeout)
case TransportTCP:
probeTCP(ctx, &ep, rec.Target, ua, timeout)
case TransportTLS:
probeTLSConn(ctx, &ep, rec.Target, ua, timeout)
}
return
}
func probeUDP(ctx context.Context, ep *EndpointProbe, target, ua string, timeout time.Duration) {
deadline := time.Now().Add(timeout)
d := net.Dialer{Deadline: deadline}
conn, err := d.DialContext(ctx, "udp", ep.Address)
if err != nil {
ep.ReachableErr = err.Error()
ep.Error = "udp dial: " + err.Error()
return
}
defer conn.Close()
runOptionsExchange(ep, conn, deadline, target, ua, TransportUDP, func(c net.Conn) (*sipResponse, error) {
buf := make([]byte, 8192)
n, err := c.Read(buf)
if err != nil {
return nil, err
}
return parseSIPResponse(bytes.NewReader(buf[:n]))
})
}
func probeTCP(ctx context.Context, ep *EndpointProbe, target, ua string, timeout time.Duration) {
deadline := time.Now().Add(timeout)
d := net.Dialer{Deadline: deadline}
conn, err := d.DialContext(ctx, "tcp", ep.Address)
if err != nil {
ep.ReachableErr = err.Error()
ep.Error = "tcp dial: " + err.Error()
return
}
defer conn.Close()
runOptionsExchange(ep, conn, deadline, target, ua, TransportTCP, func(c net.Conn) (*sipResponse, error) {
return parseSIPResponse(c)
})
}
func probeTLSConn(ctx context.Context, ep *EndpointProbe, target, ua string, timeout time.Duration) {
deadline := time.Now().Add(timeout)
d := net.Dialer{Deadline: deadline}
raw, err := d.DialContext(ctx, "tcp", ep.Address)
if err != nil {
ep.ReachableErr = err.Error()
ep.Error = "tcp dial: " + err.Error()
return
}
// We deliberately skip cert verification, checker-tls is the
// source of truth for TLS posture. We just want to reach SIP over
// TLS.
cfg := &tls.Config{
InsecureSkipVerify: true, //nolint:gosec
ServerName: target,
}
conn := tls.Client(raw, cfg)
// SetDeadline only fails on a closed/invalid socket; the next handshake
// or I/O call will surface that with a clearer error.
_ = raw.SetDeadline(deadline)
if err := conn.HandshakeContext(ctx); err != nil {
_ = raw.Close()
ep.Error = "tls handshake: " + err.Error()
return
}
defer conn.Close()
state := conn.ConnectionState()
ep.TLSVersion = tls.VersionName(state.Version)
ep.TLSCipher = tls.CipherSuiteName(state.CipherSuite)
runOptionsExchange(ep, conn, deadline, target, ua, TransportTLS, func(c net.Conn) (*sipResponse, error) {
return parseSIPResponse(c)
})
}
// runOptionsExchange performs the post-dial OPTIONS round-trip shared by
// every transport: mark reachable, set the deadline, send the request,
// read the reply via the transport-specific reader, and fold the result
// onto ep. The transport name is used as the prefix for error strings.
func runOptionsExchange(
ep *EndpointProbe,
conn net.Conn,
deadline time.Time,
target, ua string,
t Transport,
readResp func(net.Conn) (*sipResponse, error),
) {
ep.Reachable = true
// SetDeadline only fails on a closed/invalid socket; the next I/O call
// will surface that with a clearer error.
_ = conn.SetDeadline(deadline)
prefix := string(t)
req := buildOptionsRequest(target, ep.Port, t, localAddrFor(conn), ua)
sent := time.Now()
if _, err := conn.Write([]byte(req)); err != nil {
ep.Error = prefix + " write: " + err.Error()
return
}
ep.OptionsSent = true
resp, err := readResp(conn)
if err != nil {
ep.Error = "no " + prefix + " response: " + err.Error()
return
}
applyResponse(ep, resp, sent)
}
func applyResponse(ep *EndpointProbe, resp *sipResponse, sent time.Time) {
ep.OptionsRawCode = resp.StatusCode
ep.OptionsStatus = fmt.Sprintf("%d %s", resp.StatusCode, strings.TrimSpace(resp.StatusPhrase))
ep.OptionsRTTMs = time.Since(sent).Milliseconds()
ep.ServerHeader = resp.Server
ep.UserAgent = resp.UserAgent
ep.AllowMethods = resp.Allow
ep.ContactURI = resp.Contact
}