package checker import ( "context" "encoding/base64" "fmt" "strings" "sync" "time" "github.com/miekg/dns" sdk "git.happydns.org/checker-sdk-go/checker" ) func (p *dnssecProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) { domain, _ := sdk.GetOption[string](opts, "domain_name") domain = strings.TrimSuffix(strings.TrimSpace(domain), ".") if domain == "" { return nil, fmt.Errorf("missing 'domain_name' option") } if err := validateDomainName(domain); err != nil { return nil, err } zone := lowerFQDN(domain) resolver, _ := sdk.GetOption[string](opts, "resolver") if resolver == "" { resolver = systemResolver() } data := &DNSSECData{ Domain: strings.TrimSuffix(zone, "."), CollectedAt: time.Now().UTC(), Servers: map[string]PerServerView{}, } hosts, addrs, nsErrors, err := resolveAuthNS(ctx, zone, resolver) if err != nil { data.Errors = append(data.Errors, err.Error()) return data, nil } data.NameServers = hosts data.Errors = append(data.Errors, nsErrors...) data.HasDS = hasParentDS(ctx, zone, resolver) // Per-server collection runs in parallel; each goroutine writes to its // own slot and a final pass copies it into the result map under the lock. views := make([]PerServerView, len(addrs)) var wg sync.WaitGroup wg.Add(len(addrs)) for i, addr := range addrs { go func() { defer wg.Done() views[i] = collectFromServer(ctx, addr, zone) }() } wg.Wait() for _, v := range views { data.Servers[v.Server] = v } return data, nil } func collectFromServer(ctx context.Context, server, zone string) PerServerView { view := PerServerView{Server: server} dnskeyResp := authQuery(ctx, server, zone, dns.TypeDNSKEY, &view, true) if dnskeyResp != nil { for _, rr := range dnskeyResp.Answer { switch v := rr.(type) { case *dns.DNSKEY: rec := DNSKEYRecord{ Flags: v.Flags, Protocol: v.Protocol, Algorithm: v.Algorithm, PublicKey: v.PublicKey, KeyTag: v.KeyTag(), KeySize: estimateKeySize(v), IsKSK: v.Flags&0x0001 != 0, // SEP bit } view.DNSKEYs = append(view.DNSKEYs, rec) if view.DNSKEYTTL == 0 || v.Hdr.Ttl < view.DNSKEYTTL { view.DNSKEYTTL = v.Hdr.Ttl } case *dns.RRSIG: if v.TypeCovered == dns.TypeDNSKEY { view.DNSKEYRRSIGs = append(view.DNSKEYRRSIGs, rrsigOf(v)) } } } } soaResp := authQuery(ctx, server, zone, dns.TypeSOA, &view, true) if soaResp != nil { for _, rr := range soaResp.Answer { switch v := rr.(type) { case *dns.SOA: view.SOA = &SOAObservation{ Serial: v.Serial, Minimum: v.Minttl, MName: v.Ns, TTL: v.Hdr.Ttl, } case *dns.RRSIG: if v.TypeCovered == dns.TypeSOA { view.SOARRSIGs = append(view.SOARRSIGs, rrsigOf(v)) } } } } nsec3pResp := authQuery(ctx, server, zone, dns.TypeNSEC3PARAM, &view, true) if nsec3pResp != nil { for _, rr := range nsec3pResp.Answer { if v, ok := rr.(*dns.NSEC3PARAM); ok { view.NSEC3PARAM = &NSEC3ParamObservation{ HashAlgorithm: v.Hash, Flags: v.Flags, Iterations: v.Iterations, SaltLength: v.SaltLength, Salt: strings.ToLower(v.Salt), } } } } probe := randomLabel() + "." + zone view.ProbeName = strings.TrimSuffix(probe, ".") if probeResp := authQuery(ctx, server, probe, dns.TypeA, &view, true); probeResp != nil { view.DenialKind, view.DenialRecords = classifyDenial(probeResp, view.NSEC3PARAM) } else if len(view.DNSKEYs) == 0 { view.DenialKind = DenialNone } if cdsResp := authQuery(ctx, server, zone, dns.TypeCDS, &view, true); cdsResp != nil { for _, rr := range cdsResp.Answer { if v, ok := rr.(*dns.CDS); ok { view.CDS = append(view.CDS, DSRecord{ KeyTag: v.KeyTag, Algorithm: v.Algorithm, DigestType: v.DigestType, Digest: strings.ToLower(v.Digest), }) } } } if cdkResp := authQuery(ctx, server, zone, dns.TypeCDNSKEY, &view, true); cdkResp != nil { for _, rr := range cdkResp.Answer { if v, ok := rr.(*dns.CDNSKEY); ok { view.CDNSKEY = append(view.CDNSKEY, DNSKEYRecord{ Flags: v.Flags, Protocol: v.Protocol, Algorithm: v.Algorithm, PublicKey: v.PublicKey, KeyTag: v.KeyTag(), IsKSK: v.Flags&0x0001 != 0, }) } } } return view } // authQuery sends q to the auth server with DO=1 and RD=0, retries over TCP // on truncation, and records the first error in the per-server view so the // report can show which probes failed without aborting the rest. func authQuery(ctx context.Context, server, name string, qtype uint16, view *PerServerView, dnssec bool) *dns.Msg { q := dns.Question{Name: dns.Fqdn(name), Qtype: qtype, Qclass: dns.ClassINET} r, err := dnsExchange(ctx, "", server, q, false, dnssec) if err != nil { if view.UDPError == "" { view.UDPError = fmt.Sprintf("%s %s: %v", dns.TypeToString[qtype], name, err) } return nil } if r != nil && r.Truncated { r2, err2 := dnsExchange(ctx, "tcp", server, q, false, dnssec) if err2 != nil { if view.TCPError == "" { view.TCPError = fmt.Sprintf("%s %s (TCP): %v", dns.TypeToString[qtype], name, err2) } return r // fall back to the truncated answer rather than nothing } return r2 } return r } // classifyDenial inspects the Authority section of a NXDOMAIN-ish response // and maps it to NSEC / NSEC3 / OPT-OUT. NoData responses (NOERROR with NSEC // proofs in Authority) are classified the same way: from the operator's POV, // the negative-answer scheme is what matters. func classifyDenial(r *dns.Msg, nsec3p *NSEC3ParamObservation) (DenialKind, []string) { var dump []string hasNSEC, hasNSEC3 := false, false for _, rr := range r.Ns { switch rr.(type) { case *dns.NSEC: hasNSEC = true dump = append(dump, rr.String()) case *dns.NSEC3: hasNSEC3 = true dump = append(dump, rr.String()) } } switch { case hasNSEC3: if nsec3p != nil && nsec3p.Flags&0x01 != 0 { return DenialOptOut, dump } return DenialNSEC3, dump case hasNSEC: return DenialNSEC, dump default: return DenialNone, dump } } func rrsigOf(v *dns.RRSIG) RRSIGObservation { return RRSIGObservation{ TypeCovered: v.TypeCovered, Algorithm: v.Algorithm, Labels: v.Labels, OrigTTL: v.OrigTtl, Inception: v.Inception, Expiration: v.Expiration, KeyTag: v.KeyTag, SignerName: v.SignerName, } } // estimateKeySize returns the modulus size in bits for RSA-family keys and // the curve size for ECDSA / EdDSA. Best-effort: an unparsable PublicKey // yields 0 so rules that care about size can skip rather than mis-judge. func estimateKeySize(k *dns.DNSKEY) int { switch k.Algorithm { case dns.RSAMD5, dns.RSASHA1, dns.RSASHA1NSEC3SHA1, dns.RSASHA256, dns.RSASHA512: raw, err := base64.StdEncoding.DecodeString(k.PublicKey) if err != nil || len(raw) < 3 { return 0 } // RFC 3110: 1-byte exponent length OR 1-byte 0 + 2-byte length, then // the exponent, then the modulus. We only need the modulus length. var explen int var off int if raw[0] == 0 { if len(raw) < 3 { return 0 } explen = int(raw[1])<<8 | int(raw[2]) off = 3 } else { explen = int(raw[0]) off = 1 } modOff := off + explen if modOff >= len(raw) { return 0 } return (len(raw) - modOff) * 8 case dns.ECDSAP256SHA256: return 256 case dns.ECDSAP384SHA384: return 384 case dns.ED25519: return 256 case dns.ED448: return 456 } return 0 } // validateDomainName enforces RFC 1035 limits on a trimmed domain (no trailing // dot): up to 253 octets total, each label 1..63 octets and made of letters, // digits, hyphens or underscores (the latter is permitted to keep the checker // usable on zones that publish _-prefixed labels such as _dmarc). func validateDomainName(d string) error { if len(d) > 253 { return fmt.Errorf("domain name too long (%d > 253 octets)", len(d)) } for _, label := range strings.Split(d, ".") { if l := len(label); l == 0 || l > 63 { return fmt.Errorf("invalid label length in domain name") } for i := 0; i < len(label); i++ { c := label[i] switch { case c >= 'a' && c <= 'z': case c >= 'A' && c <= 'Z': case c >= '0' && c <= '9': case c == '-' || c == '_': default: return fmt.Errorf("invalid character %q in domain name", c) } } } return nil }