checker-dnssec/checker/collect.go

303 lines
8.2 KiB
Go

package checker
import (
"context"
"encoding/base64"
"fmt"
"strings"
"sync"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
func (p *dnssecProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
domain, _ := sdk.GetOption[string](opts, "domain_name")
domain = strings.TrimSuffix(strings.TrimSpace(domain), ".")
if domain == "" {
return nil, fmt.Errorf("missing 'domain_name' option")
}
if err := validateDomainName(domain); err != nil {
return nil, err
}
zone := lowerFQDN(domain)
resolver, _ := sdk.GetOption[string](opts, "resolver")
if resolver == "" {
resolver = systemResolver()
}
data := &DNSSECData{
Domain: strings.TrimSuffix(zone, "."),
CollectedAt: time.Now().UTC(),
Servers: map[string]PerServerView{},
}
hosts, addrs, nsErrors, err := resolveAuthNS(ctx, zone, resolver)
if err != nil {
data.Errors = append(data.Errors, err.Error())
return data, nil
}
data.NameServers = hosts
data.Errors = append(data.Errors, nsErrors...)
data.HasDS = hasParentDS(ctx, zone, resolver)
// Per-server collection runs in parallel; each goroutine writes to its
// own slot and a final pass copies it into the result map under the lock.
views := make([]PerServerView, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
for i, addr := range addrs {
go func() {
defer wg.Done()
views[i] = collectFromServer(ctx, addr, zone)
}()
}
wg.Wait()
for _, v := range views {
data.Servers[v.Server] = v
}
return data, nil
}
func collectFromServer(ctx context.Context, server, zone string) PerServerView {
view := PerServerView{Server: server}
dnskeyResp := authQuery(ctx, server, zone, dns.TypeDNSKEY, &view, true)
if dnskeyResp != nil {
for _, rr := range dnskeyResp.Answer {
switch v := rr.(type) {
case *dns.DNSKEY:
rec := DNSKEYRecord{
Flags: v.Flags,
Protocol: v.Protocol,
Algorithm: v.Algorithm,
PublicKey: v.PublicKey,
KeyTag: v.KeyTag(),
KeySize: estimateKeySize(v),
IsKSK: v.Flags&0x0001 != 0, // SEP bit
}
view.DNSKEYs = append(view.DNSKEYs, rec)
if view.DNSKEYTTL == 0 || v.Hdr.Ttl < view.DNSKEYTTL {
view.DNSKEYTTL = v.Hdr.Ttl
}
case *dns.RRSIG:
if v.TypeCovered == dns.TypeDNSKEY {
view.DNSKEYRRSIGs = append(view.DNSKEYRRSIGs, rrsigOf(v))
}
}
}
}
soaResp := authQuery(ctx, server, zone, dns.TypeSOA, &view, true)
if soaResp != nil {
for _, rr := range soaResp.Answer {
switch v := rr.(type) {
case *dns.SOA:
view.SOA = &SOAObservation{
Serial: v.Serial,
Minimum: v.Minttl,
MName: v.Ns,
TTL: v.Hdr.Ttl,
}
case *dns.RRSIG:
if v.TypeCovered == dns.TypeSOA {
view.SOARRSIGs = append(view.SOARRSIGs, rrsigOf(v))
}
}
}
}
nsec3pResp := authQuery(ctx, server, zone, dns.TypeNSEC3PARAM, &view, true)
if nsec3pResp != nil {
for _, rr := range nsec3pResp.Answer {
if v, ok := rr.(*dns.NSEC3PARAM); ok {
view.NSEC3PARAM = &NSEC3ParamObservation{
HashAlgorithm: v.Hash,
Flags: v.Flags,
Iterations: v.Iterations,
SaltLength: v.SaltLength,
Salt: strings.ToLower(v.Salt),
}
}
}
}
probe := randomLabel() + "." + zone
view.ProbeName = strings.TrimSuffix(probe, ".")
if probeResp := authQuery(ctx, server, probe, dns.TypeA, &view, true); probeResp != nil {
view.DenialKind, view.DenialRecords = classifyDenial(probeResp, view.NSEC3PARAM)
} else if len(view.DNSKEYs) == 0 {
view.DenialKind = DenialNone
}
if cdsResp := authQuery(ctx, server, zone, dns.TypeCDS, &view, true); cdsResp != nil {
for _, rr := range cdsResp.Answer {
if v, ok := rr.(*dns.CDS); ok {
view.CDS = append(view.CDS, DSRecord{
KeyTag: v.KeyTag,
Algorithm: v.Algorithm,
DigestType: v.DigestType,
Digest: strings.ToLower(v.Digest),
})
}
}
}
if cdkResp := authQuery(ctx, server, zone, dns.TypeCDNSKEY, &view, true); cdkResp != nil {
for _, rr := range cdkResp.Answer {
if v, ok := rr.(*dns.CDNSKEY); ok {
view.CDNSKEY = append(view.CDNSKEY, DNSKEYRecord{
Flags: v.Flags,
Protocol: v.Protocol,
Algorithm: v.Algorithm,
PublicKey: v.PublicKey,
KeyTag: v.KeyTag(),
IsKSK: v.Flags&0x0001 != 0,
})
}
}
}
return view
}
// authQuery sends q to the auth server with DO=1 and RD=0, retries over TCP
// on truncation, and records the first error in the per-server view so the
// report can show which probes failed without aborting the rest.
func authQuery(ctx context.Context, server, name string, qtype uint16, view *PerServerView, dnssec bool) *dns.Msg {
q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET}
r, err := dnsExchange(ctx, "", server, q, false, dnssec)
if err != nil {
if view.UDPError == "" {
view.UDPError = fmt.Sprintf("%s %s: %v", dns.TypeToString[qtype], name, err)
}
return nil
}
if r != nil && r.Truncated {
r2, err2 := dnsExchange(ctx, "tcp", server, q, false, dnssec)
if err2 != nil {
if view.TCPError == "" {
view.TCPError = fmt.Sprintf("%s %s (TCP): %v", dns.TypeToString[qtype], name, err2)
}
return r // fall back to the truncated answer rather than nothing
}
return r2
}
return r
}
// classifyDenial inspects the Authority section of a NXDOMAIN-ish response
// and maps it to NSEC / NSEC3 / OPT-OUT. NoData responses (NOERROR with NSEC
// proofs in Authority) are classified the same way: from the operator's POV,
// the negative-answer scheme is what matters.
func classifyDenial(r *dns.Msg, nsec3p *NSEC3ParamObservation) (DenialKind, []string) {
var dump []string
hasNSEC, hasNSEC3 := false, false
for _, rr := range r.Ns {
switch rr.(type) {
case *dns.NSEC:
hasNSEC = true
dump = append(dump, rr.String())
case *dns.NSEC3:
hasNSEC3 = true
dump = append(dump, rr.String())
}
}
switch {
case hasNSEC3:
if nsec3p != nil && nsec3p.Flags&0x01 != 0 {
return DenialOptOut, dump
}
return DenialNSEC3, dump
case hasNSEC:
return DenialNSEC, dump
default:
return DenialNone, dump
}
}
func rrsigOf(v *dns.RRSIG) RRSIGObservation {
return RRSIGObservation{
TypeCovered: v.TypeCovered,
Algorithm: v.Algorithm,
Labels: v.Labels,
OrigTTL: v.OrigTtl,
Inception: v.Inception,
Expiration: v.Expiration,
KeyTag: v.KeyTag,
SignerName: v.SignerName,
}
}
// estimateKeySize returns the modulus size in bits for RSA-family keys and
// the curve size for ECDSA / EdDSA. Best-effort: an unparsable PublicKey
// yields 0 so rules that care about size can skip rather than mis-judge.
func estimateKeySize(k *dns.DNSKEY) int {
switch k.Algorithm {
case dns.RSAMD5, dns.RSASHA1, dns.RSASHA1NSEC3SHA1, dns.RSASHA256, dns.RSASHA512:
raw, err := base64.StdEncoding.DecodeString(k.PublicKey)
if err != nil || len(raw) < 3 {
return 0
}
// RFC 3110: 1-byte exponent length OR 1-byte 0 + 2-byte length, then
// the exponent, then the modulus. We only need the modulus length.
var explen int
var off int
if raw[0] == 0 {
if len(raw) < 3 {
return 0
}
explen = int(raw[1])<<8 | int(raw[2])
off = 3
} else {
explen = int(raw[0])
off = 1
}
modOff := off + explen
if modOff >= len(raw) {
return 0
}
return (len(raw) - modOff) * 8
case dns.ECDSAP256SHA256:
return 256
case dns.ECDSAP384SHA384:
return 384
case dns.ED25519:
return 256
case dns.ED448:
return 456
}
return 0
}
// validateDomainName enforces RFC 1035 limits on a trimmed domain (no trailing
// dot): up to 253 octets total, each label 1..63 octets and made of letters,
// digits, hyphens or underscores (the latter is permitted to keep the checker
// usable on zones that publish _-prefixed labels such as _dmarc).
func validateDomainName(d string) error {
if len(d) > 253 {
return fmt.Errorf("domain name too long (%d > 253 octets)", len(d))
}
for _, label := range strings.Split(d, ".") {
if l := len(label); l == 0 || l > 63 {
return fmt.Errorf("invalid label length in domain name")
}
for i := 0; i < len(label); i++ {
c := label[i]
switch {
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
case c >= '0' && c <= '9':
case c == '-' || c == '_':
default:
return fmt.Errorf("invalid character %q in domain name", c)
}
}
}
return nil
}