303 lines
8.2 KiB
Go
303 lines
8.2 KiB
Go
package checker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
|
|
sdk "git.happydns.org/checker-sdk-go/checker"
|
|
)
|
|
|
|
func (p *dnssecProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
|
|
domain, _ := sdk.GetOption[string](opts, "domain_name")
|
|
domain = strings.TrimSuffix(strings.TrimSpace(domain), ".")
|
|
if domain == "" {
|
|
return nil, fmt.Errorf("missing 'domain_name' option")
|
|
}
|
|
if err := validateDomainName(domain); err != nil {
|
|
return nil, err
|
|
}
|
|
zone := lowerFQDN(domain)
|
|
|
|
resolver, _ := sdk.GetOption[string](opts, "resolver")
|
|
if resolver == "" {
|
|
resolver = systemResolver()
|
|
}
|
|
|
|
data := &DNSSECData{
|
|
Domain: strings.TrimSuffix(zone, "."),
|
|
CollectedAt: time.Now().UTC(),
|
|
Servers: map[string]PerServerView{},
|
|
}
|
|
|
|
hosts, addrs, nsErrors, err := resolveAuthNS(ctx, zone, resolver)
|
|
if err != nil {
|
|
data.Errors = append(data.Errors, err.Error())
|
|
return data, nil
|
|
}
|
|
data.NameServers = hosts
|
|
data.Errors = append(data.Errors, nsErrors...)
|
|
|
|
data.HasDS = hasParentDS(ctx, zone, resolver)
|
|
|
|
// Per-server collection runs in parallel; each goroutine writes to its
|
|
// own slot and a final pass copies it into the result map under the lock.
|
|
views := make([]PerServerView, len(addrs))
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(addrs))
|
|
for i, addr := range addrs {
|
|
go func() {
|
|
defer wg.Done()
|
|
views[i] = collectFromServer(ctx, addr, zone)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
for _, v := range views {
|
|
data.Servers[v.Server] = v
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
func collectFromServer(ctx context.Context, server, zone string) PerServerView {
|
|
view := PerServerView{Server: server}
|
|
|
|
dnskeyResp := authQuery(ctx, server, zone, dns.TypeDNSKEY, &view, true)
|
|
if dnskeyResp != nil {
|
|
for _, rr := range dnskeyResp.Answer {
|
|
switch v := rr.(type) {
|
|
case *dns.DNSKEY:
|
|
rec := DNSKEYRecord{
|
|
Flags: v.Flags,
|
|
Protocol: v.Protocol,
|
|
Algorithm: v.Algorithm,
|
|
PublicKey: v.PublicKey,
|
|
KeyTag: v.KeyTag(),
|
|
KeySize: estimateKeySize(v),
|
|
IsKSK: v.Flags&0x0001 != 0, // SEP bit
|
|
}
|
|
view.DNSKEYs = append(view.DNSKEYs, rec)
|
|
if view.DNSKEYTTL == 0 || v.Hdr.Ttl < view.DNSKEYTTL {
|
|
view.DNSKEYTTL = v.Hdr.Ttl
|
|
}
|
|
case *dns.RRSIG:
|
|
if v.TypeCovered == dns.TypeDNSKEY {
|
|
view.DNSKEYRRSIGs = append(view.DNSKEYRRSIGs, rrsigOf(v))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
soaResp := authQuery(ctx, server, zone, dns.TypeSOA, &view, true)
|
|
if soaResp != nil {
|
|
for _, rr := range soaResp.Answer {
|
|
switch v := rr.(type) {
|
|
case *dns.SOA:
|
|
view.SOA = &SOAObservation{
|
|
Serial: v.Serial,
|
|
Minimum: v.Minttl,
|
|
MName: v.Ns,
|
|
TTL: v.Hdr.Ttl,
|
|
}
|
|
case *dns.RRSIG:
|
|
if v.TypeCovered == dns.TypeSOA {
|
|
view.SOARRSIGs = append(view.SOARRSIGs, rrsigOf(v))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
nsec3pResp := authQuery(ctx, server, zone, dns.TypeNSEC3PARAM, &view, true)
|
|
if nsec3pResp != nil {
|
|
for _, rr := range nsec3pResp.Answer {
|
|
if v, ok := rr.(*dns.NSEC3PARAM); ok {
|
|
view.NSEC3PARAM = &NSEC3ParamObservation{
|
|
HashAlgorithm: v.Hash,
|
|
Flags: v.Flags,
|
|
Iterations: v.Iterations,
|
|
SaltLength: v.SaltLength,
|
|
Salt: strings.ToLower(v.Salt),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
probe := randomLabel() + "." + zone
|
|
view.ProbeName = strings.TrimSuffix(probe, ".")
|
|
if probeResp := authQuery(ctx, server, probe, dns.TypeA, &view, true); probeResp != nil {
|
|
view.DenialKind, view.DenialRecords = classifyDenial(probeResp, view.NSEC3PARAM)
|
|
} else if len(view.DNSKEYs) == 0 {
|
|
view.DenialKind = DenialNone
|
|
}
|
|
|
|
if cdsResp := authQuery(ctx, server, zone, dns.TypeCDS, &view, true); cdsResp != nil {
|
|
for _, rr := range cdsResp.Answer {
|
|
if v, ok := rr.(*dns.CDS); ok {
|
|
view.CDS = append(view.CDS, DSRecord{
|
|
KeyTag: v.KeyTag,
|
|
Algorithm: v.Algorithm,
|
|
DigestType: v.DigestType,
|
|
Digest: strings.ToLower(v.Digest),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
if cdkResp := authQuery(ctx, server, zone, dns.TypeCDNSKEY, &view, true); cdkResp != nil {
|
|
for _, rr := range cdkResp.Answer {
|
|
if v, ok := rr.(*dns.CDNSKEY); ok {
|
|
view.CDNSKEY = append(view.CDNSKEY, DNSKEYRecord{
|
|
Flags: v.Flags,
|
|
Protocol: v.Protocol,
|
|
Algorithm: v.Algorithm,
|
|
PublicKey: v.PublicKey,
|
|
KeyTag: v.KeyTag(),
|
|
IsKSK: v.Flags&0x0001 != 0,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return view
|
|
}
|
|
|
|
// authQuery sends q to the auth server with DO=1 and RD=0, retries over TCP
|
|
// on truncation, and records the first error in the per-server view so the
|
|
// report can show which probes failed without aborting the rest.
|
|
func authQuery(ctx context.Context, server, name string, qtype uint16, view *PerServerView, dnssec bool) *dns.Msg {
|
|
q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET}
|
|
r, err := dnsExchange(ctx, "", server, q, false, dnssec)
|
|
if err != nil {
|
|
if view.UDPError == "" {
|
|
view.UDPError = fmt.Sprintf("%s %s: %v", dns.TypeToString[qtype], name, err)
|
|
}
|
|
return nil
|
|
}
|
|
if r != nil && r.Truncated {
|
|
r2, err2 := dnsExchange(ctx, "tcp", server, q, false, dnssec)
|
|
if err2 != nil {
|
|
if view.TCPError == "" {
|
|
view.TCPError = fmt.Sprintf("%s %s (TCP): %v", dns.TypeToString[qtype], name, err2)
|
|
}
|
|
return r // fall back to the truncated answer rather than nothing
|
|
}
|
|
return r2
|
|
}
|
|
return r
|
|
}
|
|
|
|
// classifyDenial inspects the Authority section of a NXDOMAIN-ish response
|
|
// and maps it to NSEC / NSEC3 / OPT-OUT. NoData responses (NOERROR with NSEC
|
|
// proofs in Authority) are classified the same way: from the operator's POV,
|
|
// the negative-answer scheme is what matters.
|
|
func classifyDenial(r *dns.Msg, nsec3p *NSEC3ParamObservation) (DenialKind, []string) {
|
|
var dump []string
|
|
hasNSEC, hasNSEC3 := false, false
|
|
for _, rr := range r.Ns {
|
|
switch rr.(type) {
|
|
case *dns.NSEC:
|
|
hasNSEC = true
|
|
dump = append(dump, rr.String())
|
|
case *dns.NSEC3:
|
|
hasNSEC3 = true
|
|
dump = append(dump, rr.String())
|
|
}
|
|
}
|
|
switch {
|
|
case hasNSEC3:
|
|
if nsec3p != nil && nsec3p.Flags&0x01 != 0 {
|
|
return DenialOptOut, dump
|
|
}
|
|
return DenialNSEC3, dump
|
|
case hasNSEC:
|
|
return DenialNSEC, dump
|
|
default:
|
|
return DenialNone, dump
|
|
}
|
|
}
|
|
|
|
func rrsigOf(v *dns.RRSIG) RRSIGObservation {
|
|
return RRSIGObservation{
|
|
TypeCovered: v.TypeCovered,
|
|
Algorithm: v.Algorithm,
|
|
Labels: v.Labels,
|
|
OrigTTL: v.OrigTtl,
|
|
Inception: v.Inception,
|
|
Expiration: v.Expiration,
|
|
KeyTag: v.KeyTag,
|
|
SignerName: v.SignerName,
|
|
}
|
|
}
|
|
|
|
// estimateKeySize returns the modulus size in bits for RSA-family keys and
|
|
// the curve size for ECDSA / EdDSA. Best-effort: an unparsable PublicKey
|
|
// yields 0 so rules that care about size can skip rather than mis-judge.
|
|
func estimateKeySize(k *dns.DNSKEY) int {
|
|
switch k.Algorithm {
|
|
case dns.RSAMD5, dns.RSASHA1, dns.RSASHA1NSEC3SHA1, dns.RSASHA256, dns.RSASHA512:
|
|
raw, err := base64.StdEncoding.DecodeString(k.PublicKey)
|
|
if err != nil || len(raw) < 3 {
|
|
return 0
|
|
}
|
|
// RFC 3110: 1-byte exponent length OR 1-byte 0 + 2-byte length, then
|
|
// the exponent, then the modulus. We only need the modulus length.
|
|
var explen int
|
|
var off int
|
|
if raw[0] == 0 {
|
|
if len(raw) < 3 {
|
|
return 0
|
|
}
|
|
explen = int(raw[1])<<8 | int(raw[2])
|
|
off = 3
|
|
} else {
|
|
explen = int(raw[0])
|
|
off = 1
|
|
}
|
|
modOff := off + explen
|
|
if modOff >= len(raw) {
|
|
return 0
|
|
}
|
|
return (len(raw) - modOff) * 8
|
|
case dns.ECDSAP256SHA256:
|
|
return 256
|
|
case dns.ECDSAP384SHA384:
|
|
return 384
|
|
case dns.ED25519:
|
|
return 256
|
|
case dns.ED448:
|
|
return 456
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// validateDomainName enforces RFC 1035 limits on a trimmed domain (no trailing
|
|
// dot): up to 253 octets total, each label 1..63 octets and made of letters,
|
|
// digits, hyphens or underscores (the latter is permitted to keep the checker
|
|
// usable on zones that publish _-prefixed labels such as _dmarc).
|
|
func validateDomainName(d string) error {
|
|
if len(d) > 253 {
|
|
return fmt.Errorf("domain name too long (%d > 253 octets)", len(d))
|
|
}
|
|
for _, label := range strings.Split(d, ".") {
|
|
if l := len(label); l == 0 || l > 63 {
|
|
return fmt.Errorf("invalid label length in domain name")
|
|
}
|
|
for i := 0; i < len(label); i++ {
|
|
c := label[i]
|
|
switch {
|
|
case c >= 'a' && c <= 'z':
|
|
case c >= 'A' && c <= 'Z':
|
|
case c >= '0' && c <= '9':
|
|
case c == '-' || c == '_':
|
|
default:
|
|
return fmt.Errorf("invalid character %q in domain name", c)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|