checker-dane/checker/collect.go

282 lines
8.6 KiB
Go

package checker
import (
"context"
"encoding/json"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
"time"
sdk "git.happydns.org/checker-sdk-go/checker"
tlscontract "git.happydns.org/checker-tls/contract"
)
// tlsaOwner matches the "_<port>._<proto>.<base>" TLSA owner-name pattern.
// The base group is whatever the happyDomain analyzer bucketed the TLSAs
// under; when empty, the TLSAs live directly under the zone apex.
var tlsaOwner = regexp.MustCompile(`^_(\d+)\._(tcp|udp)(?:\.(.*))?$`)
// tlsaOwnerName builds the canonical "_<port>._<proto>.<base>" owner name.
// When base is empty (TLSA records sit directly at the zone apex of an
// otherwise-unspecified host), the trailing label is omitted so the result
// is still a syntactically valid relative name rather than "_443._tcp.".
func tlsaOwnerName(port uint16, proto, base string) string {
base = strings.TrimSuffix(base, ".")
if base == "" {
return fmt.Sprintf("_%d._%s", port, proto)
}
return fmt.Sprintf("_%d._%s.%s", port, proto, base)
}
// starttlsKey is the "<port>/<proto>" lookup key used in OptionSTARTTLS.
func starttlsKey(port uint16, proto string) string {
return fmt.Sprintf("%d/%s", port, proto)
}
// serviceMessage mirrors the on-wire happydns.ServiceMessage shape, kept
// local so this module does not depend on happyDomain core. Same pattern
// as checker-caa/checker/collect.go.
type serviceMessage struct {
Type string `json:"_svctype"`
Domain string `json:"_domain"`
Service json.RawMessage `json:"Service"`
}
// tlsasPayload mirrors the JSON shape of svcs.TLSAs (services/tlsa.go).
type tlsasPayload struct {
Records []tlsaRecord `json:"tlsa"`
}
// tlsaRecord decodes one dns.TLSA as serialized by miekg/dns. The Hdr.Name
// is how we learn which endpoint each record applies to; Certificate is
// already a lowercase-hex string as miekg/dns emits it.
type tlsaRecord struct {
Hdr struct {
Name string `json:"Name"`
} `json:"Hdr"`
Usage uint8 `json:"Usage"`
Selector uint8 `json:"Selector"`
MatchingType uint8 `json:"MatchingType"`
Certificate string `json:"Certificate"`
}
// defaultSTARTTLS maps common ports to the STARTTLS service name checker-tls
// expects. Endpoints not covered default to direct TLS; the user can override
// explicitly via the OptionSTARTTLS map.
var defaultSTARTTLS = map[uint16]string{
25: "smtp",
110: "pop3",
143: "imap",
389: "ldap",
587: "submission",
5222: "xmpp-client",
5269: "xmpp-server",
}
// Collect walks the bound TLSAs service, groups records by (port, proto,
// base), emits one tls.endpoint.v1 discovery entry per group so checker-tls
// probes each of them, and returns DANEData with the user's TLSA records.
// No TLSA matching happens here; that's the rule's job: it reads the TLS
// chain via obs.GetRelated on the next evaluation.
func (p *daneProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
svc, err := serviceFromOptions(opts)
if err != nil {
return nil, err
}
if svc.Type != serviceType {
return nil, fmt.Errorf("service is %q, expected %q", svc.Type, serviceType)
}
var pl tlsasPayload
if err := json.Unmarshal(svc.Service, &pl); err != nil {
return nil, fmt.Errorf("decode TLSAs service: %w", err)
}
apex, _ := sdk.GetOption[string](opts, OptionDomain)
apex = strings.TrimSuffix(apex, ".")
subdomain, _ := sdk.GetOption[string](opts, OptionSubdomain)
subdomain = strings.TrimSuffix(subdomain, ".")
// STARTTLS overrides: map of "port/proto" → service name.
var starttlsOverride map[string]string
if v, ok := opts[OptionSTARTTLS]; ok {
raw, _ := json.Marshal(v)
_ = json.Unmarshal(raw, &starttlsOverride)
}
// Group records by endpoint key.
type key struct {
Port uint16
Proto string
Base string // base host, fully-qualified without trailing dot
}
groups := map[key][]TLSARecord{}
var invalid []InvalidRecord
for _, r := range pl.Records {
owner := strings.TrimSuffix(r.Hdr.Name, ".")
m := tlsaOwner.FindStringSubmatch(owner)
if len(m) != 4 {
invalid = append(invalid, InvalidRecord{
Owner: owner,
Reason: "owner name does not match _<port>._<tcp|udp>[.<base>]",
})
continue
}
port64, err := strconv.ParseUint(m[1], 10, 16)
if err != nil || port64 == 0 {
invalid = append(invalid, InvalidRecord{
Owner: owner,
Reason: fmt.Sprintf("port %q out of range (1-65535)", m[1]),
})
continue
}
base := m[3]
// Resolve base relative to the apex: TLSA owners in the service
// are typically stored relative to the service's subdomain
// bucket. Fall back to the apex when unspecified.
base = joinName(base, subdomain, apex)
if base == "" {
invalid = append(invalid, InvalidRecord{
Owner: owner,
Reason: "could not resolve a host name (apex and subdomain both empty)",
})
continue
}
k := key{Port: uint16(port64), Proto: m[2], Base: base}
groups[k] = append(groups[k], TLSARecord{
Usage: r.Usage,
Selector: r.Selector,
MatchingType: r.MatchingType,
Certificate: strings.ToLower(strings.TrimSpace(r.Certificate)),
})
}
// Deterministic output ordering keeps diffs quiet across runs.
keys := make([]key, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].Base != keys[j].Base {
return keys[i].Base < keys[j].Base
}
if keys[i].Port != keys[j].Port {
return keys[i].Port < keys[j].Port
}
return keys[i].Proto < keys[j].Proto
})
targets := make([]TargetResult, 0, len(keys))
for _, k := range keys {
starttls := defaultSTARTTLS[k.Port]
if v, ok := starttlsOverride[starttlsKey(k.Port, k.Proto)]; ok {
starttls = v
}
t := TargetResult{
Owner: tlsaOwnerName(k.Port, k.Proto, k.Base),
Host: k.Base,
Port: k.Port,
Proto: k.Proto,
STARTTLS: starttls,
Records: groups[k],
}
t.Ref = tlscontract.Ref(endpointFromTarget(t))
targets = append(targets, t)
}
data := &DANEData{
Targets: targets,
Invalid: invalid,
CollectedAt: time.Now().UTC(),
}
if v, ok := opts[OptionDNSSECValidated]; ok {
if b, ok := v.(bool); ok {
data.DNSSECValidated = &b
}
}
return data, nil
}
// endpointFromTarget builds the TLSEndpoint for a collected target.
func endpointFromTarget(t TargetResult) tlscontract.TLSEndpoint {
return tlscontract.TLSEndpoint{
Host: t.Host,
Port: t.Port,
SNI: t.Host,
STARTTLS: t.STARTTLS,
// RFC 7672 §2.2: when a TLSA record exists for an SMTP service, the
// receiving MTA MUST use STARTTLS. The whole point of DANE on port 25
// is to defeat STARTTLS-stripping downgrade attacks, so the presence
// of TLSA records here flips the connection from opportunistic to
// mandatory.
RequireSTARTTLS: t.STARTTLS != "",
}
}
// DiscoverEntries publishes one tls.endpoint.v1 entry per target so
// checker-tls probes them in its next cycle. Implements sdk.DiscoveryPublisher.
func (p *daneProvider) DiscoverEntries(data any) ([]sdk.DiscoveryEntry, error) {
d, ok := data.(*DANEData)
if !ok || d == nil {
return nil, nil
}
out := make([]sdk.DiscoveryEntry, 0, len(d.Targets))
for _, t := range d.Targets {
entry, err := tlscontract.NewEntry(endpointFromTarget(t))
if err != nil {
return nil, err
}
out = append(out, entry)
}
return out, nil
}
// serviceFromOptions extracts and decodes the happyDomain service payload.
func serviceFromOptions(opts sdk.CheckerOptions) (*serviceMessage, error) {
v, ok := opts[OptionService]
if !ok {
return nil, fmt.Errorf("service option missing")
}
raw, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("marshal service option: %w", err)
}
var svc serviceMessage
if err := json.Unmarshal(raw, &svc); err != nil {
return nil, fmt.Errorf("decode service option: %w", err)
}
return &svc, nil
}
// joinName resolves a possibly-relative TLSA base name against the service's
// subdomain bucket and the zone apex, returning a fully-qualified host name
// without trailing dot. An empty base means "the subdomain/apex itself".
func joinName(base, subdomain, apex string) string {
base = strings.TrimSuffix(base, ".")
// Absolute match to apex: return apex; otherwise treat as relative.
if base == "" {
if subdomain != "" {
return strings.TrimSuffix(subdomain+"."+apex, ".")
}
return apex
}
// If base already ends with apex (fully qualified), keep as-is.
if apex != "" && (base == apex || strings.HasSuffix(base, "."+apex)) {
return base
}
// Otherwise, base is relative to the subdomain bucket (or apex).
if subdomain != "" {
return strings.TrimSuffix(base+"."+subdomain+"."+apex, ".")
}
if apex != "" {
return base + "." + apex
}
return base
}