checker-authoritative-consi.../checker/rules_consistency.go

291 lines
8.9 KiB
Go

package checker
import (
"context"
"fmt"
"sort"
"strings"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
)
type serialConsistencyRule struct{}
func (r *serialConsistencyRule) Name() string { return "authoritative_consistency.serial_consistency" }
func (r *serialConsistencyRule) Description() string {
return "Verifies that every authoritative name server returns the same SOA serial (detects incomplete zone transfer)."
}
func (r *serialConsistencyRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter, _ sdk.CheckerOptions) []sdk.CheckState {
data, errSt := loadObservation(ctx, obs)
if errSt != nil {
return []sdk.CheckState{*errSt}
}
if !data.HasSOA {
return []sdk.CheckState{notTestedState("authoritative_consistency.serial_consistency.skipped", "Zone does not declare a SOA record.")}
}
findings := collectSerialDrift(data)
if len(findings) == 0 {
return []sdk.CheckState{passState("authoritative_consistency.serial_consistency.ok", "Every authoritative name server returns the same SOA serial.")}
}
return findingsToStates(findings)
}
func collectSerialDrift(data *ObservationData) []Finding {
bySerial := map[uint32][]string{}
for _, ns := range data.Probed {
r := data.Results[ns]
if r == nil || !r.Authoritative || r.SOA == nil {
continue
}
bySerial[r.Serial] = append(bySerial[r.Serial], ns)
}
if len(bySerial) < 2 {
return nil
}
var pairs []string
serials := make([]uint32, 0, len(bySerial))
for s := range bySerial {
serials = append(serials, s)
}
sort.Slice(serials, func(i, j int) bool { return serials[i] < serials[j] })
for _, s := range serials {
servers := bySerial[s]
sort.Strings(servers)
pairs = append(pairs, fmt.Sprintf("serial %d: %s", s, strings.Join(servers, ", ")))
}
return []Finding{{
Code: CodeSerialDrift,
Severity: SeverityCrit,
Message: "SOA serial drift between authoritative servers: " + strings.Join(pairs, "; "),
}}
}
type serialVsSavedRule struct{}
func (r *serialVsSavedRule) Name() string { return "authoritative_consistency.serial_vs_saved" }
func (r *serialVsSavedRule) Description() string {
return "Compares the live SOA serial with the one saved in happyDomain (detects un-pushed edits and out-of-band changes)."
}
func (r *serialVsSavedRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter, opts sdk.CheckerOptions) []sdk.CheckState {
data, errSt := loadObservation(ctx, obs)
if errSt != nil {
return []sdk.CheckState{*errSt}
}
if !data.HasSOA || data.DeclaredSerial == 0 {
return []sdk.CheckState{notTestedState("authoritative_consistency.serial_vs_saved.skipped", "No saved serial to compare against.")}
}
warnOnStale := sdk.GetBoolOption(opts, "warnOnStaleSaved", true)
findings := collectSerialVsSaved(data, warnOnStale)
if len(findings) == 0 {
return []sdk.CheckState{passState("authoritative_consistency.serial_vs_saved.ok", fmt.Sprintf("Live serials match the saved value %d.", data.DeclaredSerial))}
}
return findingsToStates(findings)
}
func collectSerialVsSaved(data *ObservationData, warn bool) []Finding {
saved := data.DeclaredSerial
if saved == 0 {
return nil
}
var below, above []string
for _, ns := range data.Probed {
r := data.Results[ns]
if r == nil || !r.Authoritative || r.SOA == nil {
continue
}
switch {
case serialLess(r.Serial, saved):
below = append(below, ns)
case serialLess(saved, r.Serial):
above = append(above, ns)
}
}
var out []Finding
if len(below) > 0 && warn {
sort.Strings(below)
out = append(out, Finding{
Code: CodeSerialStaleVsSaved,
Severity: SeverityWarn,
Message: fmt.Sprintf(
"saved serial %d is newer than live serial on %s; changes have not propagated yet or have not been applied to the provider",
saved, strings.Join(below, ", "),
),
})
}
if len(above) > 0 {
sort.Strings(above)
out = append(out, Finding{
Code: CodeSerialAheadOfSaved,
Severity: SeverityInfo,
Message: fmt.Sprintf(
"live serial on %s is ahead of the saved serial %d; the zone was modified outside happyDomain",
strings.Join(above, ", "), saved,
),
})
}
return out
}
type soaFieldsConsistencyRule struct{}
func (r *soaFieldsConsistencyRule) Name() string {
return "authoritative_consistency.soa_fields_consistency"
}
func (r *soaFieldsConsistencyRule) Description() string {
return "Verifies that every authoritative name server returns the same SOA RDATA (MNAME, RNAME, refresh, retry, expire, minimum)."
}
func (r *soaFieldsConsistencyRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter, _ sdk.CheckerOptions) []sdk.CheckState {
data, errSt := loadObservation(ctx, obs)
if errSt != nil {
return []sdk.CheckState{*errSt}
}
if !data.HasSOA {
return []sdk.CheckState{notTestedState("authoritative_consistency.soa_fields_consistency.skipped", "Zone does not declare a SOA record.")}
}
findings := collectSOAFieldsDrift(data)
if len(findings) == 0 {
return []sdk.CheckState{passState("authoritative_consistency.soa_fields_consistency.ok", "Every authoritative name server returns the same SOA RDATA.")}
}
return findingsToStates(findings)
}
func collectSOAFieldsDrift(data *ObservationData) []Finding {
type soaSig struct {
mname, rname string
refresh, retry uint32
expire, minimum, serial uint32
}
groups := map[soaSig][]string{}
sig := func(s *dns.SOA) soaSig {
return soaSig{
mname: strings.ToLower(strings.TrimSuffix(s.Ns, ".")),
rname: strings.ToLower(strings.TrimSuffix(s.Mbox, ".")),
refresh: s.Refresh,
retry: s.Retry,
expire: s.Expire,
minimum: s.Minttl,
serial: s.Serial,
}
}
for _, ns := range data.Probed {
r := data.Results[ns]
if r == nil || r.SOA == nil {
continue
}
k := sig(r.SOA)
k.serial = 0 // serial drift is reported separately
groups[k] = append(groups[k], ns)
}
if len(groups) < 2 {
return nil
}
var lines []string
keys := make([]soaSig, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool { return len(groups[keys[i]]) > len(groups[keys[j]]) })
for _, k := range keys {
srv := groups[k]
sort.Strings(srv)
lines = append(lines, fmt.Sprintf(
"mname=%s rname=%s refresh=%d retry=%d expire=%d minimum=%d → %s",
k.mname, k.rname, k.refresh, k.retry, k.expire, k.minimum, strings.Join(srv, ", "),
))
}
return []Finding{{
Code: CodeSOAFieldsDrift,
Severity: SeverityWarn,
Message: "SOA fields differ between authoritative servers: " + strings.Join(lines, "; "),
}}
}
type nsRRsetConsistencyRule struct{}
func (r *nsRRsetConsistencyRule) Name() string {
return "authoritative_consistency.ns_rrset_consistency"
}
func (r *nsRRsetConsistencyRule) Description() string {
return "Verifies every authoritative name server returns the same NS RRset, and that this RRset matches the NS declared in the service."
}
func (r *nsRRsetConsistencyRule) Evaluate(ctx context.Context, obs sdk.ObservationGetter, _ sdk.CheckerOptions) []sdk.CheckState {
data, errSt := loadObservation(ctx, obs)
if errSt != nil {
return []sdk.CheckState{*errSt}
}
findings := collectNSRRsetDrift(data)
if len(findings) == 0 {
return []sdk.CheckState{passState("authoritative_consistency.ns_rrset_consistency.ok", "NS RRset is consistent across authoritative servers and matches the declared list.")}
}
return findingsToStates(findings)
}
func collectNSRRsetDrift(data *ObservationData) []Finding {
groups := map[string][]string{}
for _, ns := range data.Probed {
r := data.Results[ns]
if r == nil || !r.Authoritative || len(r.NSRRset) == 0 {
continue
}
k := strings.Join(r.NSRRset, "|")
groups[k] = append(groups[k], ns)
}
if len(groups) == 0 {
return nil
}
var findings []Finding
if len(groups) > 1 {
var lines []string
keys := make([]string, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool { return len(groups[keys[i]]) > len(groups[keys[j]]) })
for _, k := range keys {
srv := groups[k]
sort.Strings(srv)
lines = append(lines, fmt.Sprintf("NS RRset [%s] → %s", strings.ReplaceAll(k, "|", ", "), strings.Join(srv, ", ")))
}
findings = append(findings, Finding{
Code: CodeNSRRsetDrift,
Severity: SeverityWarn,
Message: "NS RRset differs between authoritative servers: " + strings.Join(lines, "; "),
})
}
if len(data.DeclaredNS) == 0 {
return findings
}
var majority []string
var majorityCount int
for k, servers := range groups {
if len(servers) > majorityCount {
majority = strings.Split(k, "|")
majorityCount = len(servers)
}
}
if len(majority) == 0 {
return findings
}
missing, extra := diffStringSets(data.DeclaredNS, majority)
if len(missing) > 0 || len(extra) > 0 {
findings = append(findings, Finding{
Code: CodeNSRRsetMismatchConfig,
Severity: SeverityWarn,
Message: fmt.Sprintf(
"NS RRset served by authoritative servers does not match declared service: missing=%v extra=%v",
missing, extra,
),
})
}
return findings
}