diff --git a/pkg/analyzer/dns_dmarc.go b/pkg/analyzer/dns_dmarc.go index 28548ea..b89500b 100644 --- a/pkg/analyzer/dns_dmarc.go +++ b/pkg/analyzer/dns_dmarc.go @@ -25,13 +25,15 @@ import ( "context" "fmt" "net" - "regexp" + "strconv" "strings" "git.happydns.org/happyDeliver/internal/model" "git.happydns.org/happyDeliver/internal/utils" ) +var dmarcPolicyStrength = map[string]int{"none": 0, "quarantine": 1, "reject": 2} + // lookupDMARCAt queries _dmarc. and returns the raw DMARC1 TXT record. // notFound=true means no record exists (NXDOMAIN or empty); false means a real DNS error occurred. func (d *DNSAnalyzer) lookupDMARCAt(domain string) (record string, notFound bool, err error) { @@ -56,17 +58,62 @@ func (d *DNSAnalyzer) lookupDMARCAt(domain string) (record string, notFound bool // parseDMARCRecord parses a raw DMARC TXT record into a DMARCRecord model. func (d *DNSAnalyzer) parseDMARCRecord(foundDomain, rawRecord string) *model.DMARCRecord { - policy := d.extractDMARCPolicy(rawRecord) - subdomainPolicy := d.extractDMARCSubdomainPolicy(rawRecord) - nonexistentSubdomainPolicy := d.extractDMARCNonexistentSubdomainPolicy(rawRecord) - percentage := d.extractDMARCPercentage(rawRecord) - testMode := d.extractDMARCTestMode(rawRecord) - psd := d.extractDMARCPSD(rawRecord) - spfAlignment := d.extractDMARCSPFAlignment(rawRecord) - dkimAlignment := d.extractDMARCDKIMAlignment(rawRecord) - deprecatedPct := percentage != nil - deprecatedRf := d.hasDMARCTag(rawRecord, "rf") - deprecatedRi := d.hasDMARCTag(rawRecord, "ri") + tags := parseDKIMTags(rawRecord) + + // Policy + policy := "unknown" + switch tags["p"] { + case "none", "quarantine", "reject": + policy = tags["p"] + } + + // SPF alignment (default: relaxed) + spfAlignment := utils.PtrTo(model.DMARCRecordSpfAlignmentRelaxed) + if tags["aspf"] == "s" { + spfAlignment = utils.PtrTo(model.DMARCRecordSpfAlignmentStrict) + } + + // DKIM alignment (default: relaxed) + dkimAlignment := utils.PtrTo(model.DMARCRecordDkimAlignmentRelaxed) + if tags["adkim"] == "s" { + dkimAlignment = utils.PtrTo(model.DMARCRecordDkimAlignmentStrict) + } + + // Subdomain policy + var subdomainPolicy *model.DMARCRecordSubdomainPolicy + switch tags["sp"] { + case "none", "quarantine", "reject": + subdomainPolicy = utils.PtrTo(model.DMARCRecordSubdomainPolicy(tags["sp"])) + } + + // Non-existent subdomain policy (DMARCbis np=) + var nonexistentSubdomainPolicy *model.DMARCRecordNonexistentSubdomainPolicy + switch tags["np"] { + case "none", "quarantine", "reject": + nonexistentSubdomainPolicy = utils.PtrTo(model.DMARCRecordNonexistentSubdomainPolicy(tags["np"])) + } + + // Percentage (pct=, deprecated in DMARCbis) + var percentage *int + if pctStr, ok := tags["pct"]; ok { + if pct, err := strconv.Atoi(pctStr); err == nil && pct >= 0 && pct <= 100 { + percentage = &pct + } + } + + // Test mode (DMARCbis t=) + var testMode *bool + if t, ok := tags["t"]; ok { + v := t == "y" + testMode = &v + } + + // PSD (DMARCbis psd=) + var psd *model.DMARCRecordPsd + switch tags["psd"] { + case "y", "n", "u": + psd = utils.PtrTo(model.DMARCRecordPsd(tags["psd"])) + } rec := &model.DMARCRecord{ Domain: &foundDomain, @@ -80,13 +127,13 @@ func (d *DNSAnalyzer) parseDMARCRecord(foundDomain, rawRecord string) *model.DMA SpfAlignment: spfAlignment, DkimAlignment: dkimAlignment, } - if deprecatedPct { + if percentage != nil { rec.DeprecatedPct = utils.PtrTo(true) } - if deprecatedRf { + if _, ok := tags["rf"]; ok { rec.DeprecatedRf = utils.PtrTo(true) } - if deprecatedRi { + if _, ok := tags["ri"]; ok { rec.DeprecatedRi = utils.PtrTo(true) } @@ -158,129 +205,17 @@ func (d *DNSAnalyzer) checkDMARCRecord(domain string) *model.DMARCRecord { return d.parseDMARCRecord(foundDomain, raw) } -// extractDMARCPolicy extracts the policy from a DMARC record -func (d *DNSAnalyzer) extractDMARCPolicy(record string) string { - // Look for p=none, p=quarantine, or p=reject - re := regexp.MustCompile(`p=(none|quarantine|reject)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - return matches[1] - } - return "unknown" -} - -// extractDMARCSPFAlignment extracts SPF alignment mode from a DMARC record -// Returns "relaxed" (default) or "strict" -func (d *DNSAnalyzer) extractDMARCSPFAlignment(record string) *model.DMARCRecordSpfAlignment { - // Look for aspf=s (strict) or aspf=r (relaxed) - re := regexp.MustCompile(`aspf=(r|s)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - if matches[1] == "s" { - return utils.PtrTo(model.DMARCRecordSpfAlignmentStrict) - } - return utils.PtrTo(model.DMARCRecordSpfAlignmentRelaxed) - } - // Default is relaxed if not specified - return utils.PtrTo(model.DMARCRecordSpfAlignmentRelaxed) -} - -// extractDMARCDKIMAlignment extracts DKIM alignment mode from a DMARC record -// Returns "relaxed" (default) or "strict" -func (d *DNSAnalyzer) extractDMARCDKIMAlignment(record string) *model.DMARCRecordDkimAlignment { - // Look for adkim=s (strict) or adkim=r (relaxed) - re := regexp.MustCompile(`adkim=(r|s)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - if matches[1] == "s" { - return utils.PtrTo(model.DMARCRecordDkimAlignmentStrict) - } - return utils.PtrTo(model.DMARCRecordDkimAlignmentRelaxed) - } - // Default is relaxed if not specified - return utils.PtrTo(model.DMARCRecordDkimAlignmentRelaxed) -} - -// extractDMARCSubdomainPolicy extracts subdomain policy from a DMARC record -// Returns the sp tag value or nil if not specified (defaults to main policy) -func (d *DNSAnalyzer) extractDMARCSubdomainPolicy(record string) *model.DMARCRecordSubdomainPolicy { - // Look for sp=none, sp=quarantine, or sp=reject - re := regexp.MustCompile(`sp=(none|quarantine|reject)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - return utils.PtrTo(model.DMARCRecordSubdomainPolicy(matches[1])) - } - // If sp is not specified, it defaults to the main policy (p tag) - // Return nil to indicate it's using the default - return nil -} - -// extractDMARCNonexistentSubdomainPolicy extracts non-existent subdomain policy from a DMARC record. -// Returns the np tag value or nil if not specified (defaults to effective sp/p policy). -// The np= tag is introduced by DMARCbis (draft-ietf-dmarc-dmarcbis). -func (d *DNSAnalyzer) extractDMARCNonexistentSubdomainPolicy(record string) *model.DMARCRecordNonexistentSubdomainPolicy { - re := regexp.MustCompile(`np=(none|quarantine|reject)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - return utils.PtrTo(model.DMARCRecordNonexistentSubdomainPolicy(matches[1])) - } - return nil -} - -// extractDMARCPercentage extracts the percentage from a DMARC record. -// Returns the pct tag value or nil if not specified (defaults to 100). -// Note: pct= is deprecated in DMARCbis; use t= (test_mode) instead. -func (d *DNSAnalyzer) extractDMARCPercentage(record string) *int { - re := regexp.MustCompile(`pct=(\d+)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - var pct int - fmt.Sscanf(matches[1], "%d", &pct) - if pct >= 0 && pct <= 100 { - return &pct - } - } - return nil -} - -// extractDMARCTestMode extracts the DMARCbis t= tag (test mode). -// Returns true for t=y, false for t=n, nil if absent (defaults to false / full enforcement). -func (d *DNSAnalyzer) extractDMARCTestMode(record string) *bool { - re := regexp.MustCompile(`(?:^|;)\s*t=(y|n)(?:;|$|\s)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - v := matches[1] == "y" - return &v - } - return nil -} - -// extractDMARCPSD extracts the DMARCbis psd= tag value as a typed enum. -// Returns nil if the tag is absent (defaults to "u" / unknown). -func (d *DNSAnalyzer) extractDMARCPSD(record string) *model.DMARCRecordPsd { - v := d.extractDMARCPSDValue(record) - if v == "" { - return nil - } - return utils.PtrTo(model.DMARCRecordPsd(v)) -} - -// extractDMARCPSDValue returns the raw string value of psd= ("y", "n", "u") or "". +// extractDMARCPSDValue returns the raw psd= value ("y", "n", "u") or "" if absent. +// Used during DNS Tree Walk before full record parsing. func (d *DNSAnalyzer) extractDMARCPSDValue(record string) string { - re := regexp.MustCompile(`(?:^|;)\s*psd=(y|n|u)(?:;|$|\s)`) - matches := re.FindStringSubmatch(record) - if len(matches) > 1 { - return matches[1] + v := parseDKIMTags(record)["psd"] + switch v { + case "y", "n", "u": + return v } return "" } -// hasDMARCTag reports whether the given tag name appears in the record. -func (d *DNSAnalyzer) hasDMARCTag(record, tag string) bool { - re := regexp.MustCompile(`(?:^|;)\s*` + regexp.QuoteMeta(tag) + `=`) - return re.MatchString(record) -} - // validateDMARC performs basic DMARC record validation. // Per DMARCbis, p= is now RECOMMENDED (not required): a record with a valid // rua= but no p= is treated as p=none and considered valid. @@ -343,12 +278,10 @@ func (d *DNSAnalyzer) calculateDMARCScore(results *model.DNSResults) (score int) score += 5 } - policyStrength := map[string]int{"none": 0, "quarantine": 1, "reject": 2} - // Subdomain policy scoring (sp tag): +15 for equal-or-stricter, -15 for weaker if results.DmarcRecord.SubdomainPolicy != nil { subPolicy := string(*results.DmarcRecord.SubdomainPolicy) - if policyStrength[subPolicy] >= policyStrength[effectivePolicy] { + if dmarcPolicyStrength[subPolicy] >= dmarcPolicyStrength[effectivePolicy] { score += 15 } else { score -= 15 @@ -357,19 +290,17 @@ func (d *DNSAnalyzer) calculateDMARCScore(results *model.DNSResults) (score int) score += 15 // inherits main policy — good default } - // Non-existent subdomain policy scoring (np tag, DMARCbis) - score -= 15 + // Non-existent subdomain policy scoring (np tag, DMARCbis): +15 for equal-or-stricter, -15 for weaker effectiveSubPolicy := effectivePolicy if results.DmarcRecord.SubdomainPolicy != nil { effectiveSubPolicy = string(*results.DmarcRecord.SubdomainPolicy) } if results.DmarcRecord.NonexistentSubdomainPolicy == nil { score += 15 // inherits subdomain/main policy — good default + } else if dmarcPolicyStrength[string(*results.DmarcRecord.NonexistentSubdomainPolicy)] >= dmarcPolicyStrength[effectiveSubPolicy] { + score += 15 } else { - npStrength := policyStrength[string(*results.DmarcRecord.NonexistentSubdomainPolicy)] - if npStrength >= policyStrength[effectiveSubPolicy] { - score += 15 - } + score -= 15 } // pct= scaling (deprecated in DMARCbis, kept for backward compatibility). diff --git a/pkg/analyzer/dns_dmarc_test.go b/pkg/analyzer/dns_dmarc_test.go index 46a3518..5c34a32 100644 --- a/pkg/analyzer/dns_dmarc_test.go +++ b/pkg/analyzer/dns_dmarc_test.go @@ -221,7 +221,7 @@ func containsStr(s, sub string) bool { return false } -func TestExtractDMARCPolicy(t *testing.T) { +func TestParseDMARCRecordPolicy(t *testing.T) { tests := []struct { name string record string @@ -253,15 +253,18 @@ func TestExtractDMARCPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCPolicy(tt.record) - if result != tt.expectedPolicy { - t.Errorf("extractDMARCPolicy(%q) = %q, want %q", tt.record, result, tt.expectedPolicy) + rec := analyzer.parseDMARCRecord("example.com", tt.record) + if rec.Policy == nil { + t.Fatalf("parseDMARCRecord(%q).Policy = nil", tt.record) + } + if string(*rec.Policy) != tt.expectedPolicy { + t.Errorf("parseDMARCRecord(%q).Policy = %q, want %q", tt.record, string(*rec.Policy), tt.expectedPolicy) } }) } } -func TestExtractDMARCTestMode(t *testing.T) { +func TestParseDMARCRecordTestMode(t *testing.T) { tests := []struct { name string record string @@ -288,24 +291,24 @@ func TestExtractDMARCTestMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCTestMode(tt.record) + result := analyzer.parseDMARCRecord("example.com", tt.record).TestMode if tt.wantMode == nil { if result != nil { - t.Errorf("extractDMARCTestMode(%q) = %v, want nil", tt.record, *result) + t.Errorf("parseDMARCRecord(%q).TestMode = %v, want nil", tt.record, *result) } } else { if result == nil { - t.Fatalf("extractDMARCTestMode(%q) = nil, want %v", tt.record, *tt.wantMode) + t.Fatalf("parseDMARCRecord(%q).TestMode = nil, want %v", tt.record, *tt.wantMode) } if *result != *tt.wantMode { - t.Errorf("extractDMARCTestMode(%q) = %v, want %v", tt.record, *result, *tt.wantMode) + t.Errorf("parseDMARCRecord(%q).TestMode = %v, want %v", tt.record, *result, *tt.wantMode) } } }) } } -func TestExtractDMARCPSD(t *testing.T) { +func TestParseDMARCRecordPSD(t *testing.T) { tests := []struct { name string record string @@ -337,43 +340,48 @@ func TestExtractDMARCPSD(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCPSD(tt.record) + result := analyzer.parseDMARCRecord("example.com", tt.record).Psd if tt.wantPSD == nil { if result != nil { - t.Errorf("extractDMARCPSD(%q) = %v, want nil", tt.record, *result) + t.Errorf("parseDMARCRecord(%q).Psd = %v, want nil", tt.record, *result) } } else { if result == nil { - t.Fatalf("extractDMARCPSD(%q) = nil, want %q", tt.record, *tt.wantPSD) + t.Fatalf("parseDMARCRecord(%q).Psd = nil, want %q", tt.record, *tt.wantPSD) } if string(*result) != *tt.wantPSD { - t.Errorf("extractDMARCPSD(%q) = %q, want %q", tt.record, string(*result), *tt.wantPSD) + t.Errorf("parseDMARCRecord(%q).Psd = %q, want %q", tt.record, string(*result), *tt.wantPSD) } } }) } } -func TestHasDMARCTag(t *testing.T) { +func TestParseDMARCRecordDeprecatedTags(t *testing.T) { tests := []struct { - name string - record string - tag string - want bool + name string + record string + wantRf bool + wantRi bool }{ - {name: "rf tag present", record: "v=DMARC1; p=none; rf=afrf", tag: "rf", want: true}, - {name: "ri tag present", record: "v=DMARC1; p=none; ri=86400", tag: "ri", want: true}, - {name: "rf tag absent", record: "v=DMARC1; p=quarantine; rua=mailto:x@example.com", tag: "rf", want: false}, - {name: "ri tag absent", record: "v=DMARC1; p=quarantine", tag: "ri", want: false}, + {name: "rf tag present", record: "v=DMARC1; p=none; rf=afrf", wantRf: true, wantRi: false}, + {name: "ri tag present", record: "v=DMARC1; p=none; ri=86400", wantRf: false, wantRi: true}, + {name: "rf tag absent", record: "v=DMARC1; p=quarantine; rua=mailto:x@example.com", wantRf: false, wantRi: false}, + {name: "ri tag absent", record: "v=DMARC1; p=quarantine", wantRf: false, wantRi: false}, } analyzer := NewDNSAnalyzer(5 * time.Second) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.hasDMARCTag(tt.record, tt.tag) - if result != tt.want { - t.Errorf("hasDMARCTag(%q, %q) = %v, want %v", tt.record, tt.tag, result, tt.want) + rec := analyzer.parseDMARCRecord("example.com", tt.record) + gotRf := rec.DeprecatedRf != nil && *rec.DeprecatedRf + gotRi := rec.DeprecatedRi != nil && *rec.DeprecatedRi + if gotRf != tt.wantRf { + t.Errorf("parseDMARCRecord(%q).DeprecatedRf = %v, want %v", tt.record, gotRf, tt.wantRf) + } + if gotRi != tt.wantRi { + t.Errorf("parseDMARCRecord(%q).DeprecatedRi = %v, want %v", tt.record, gotRi, tt.wantRi) } }) } @@ -429,142 +437,36 @@ func TestValidateDMARC(t *testing.T) { } } -func TestExtractDMARCSPFAlignment(t *testing.T) { - tests := []struct { - name string - record string - expectedAlignment string - }{ - { - name: "SPF alignment - strict", - record: "v=DMARC1; p=quarantine; aspf=s", - expectedAlignment: "strict", - }, - { - name: "SPF alignment - relaxed (explicit)", - record: "v=DMARC1; p=quarantine; aspf=r", - expectedAlignment: "relaxed", - }, - { - name: "SPF alignment - relaxed (default, not specified)", - record: "v=DMARC1; p=quarantine", - expectedAlignment: "relaxed", - }, - { - name: "Both alignments specified - check SPF strict", - record: "v=DMARC1; p=quarantine; aspf=s; adkim=r", - expectedAlignment: "strict", - }, - { - name: "Both alignments specified - check SPF relaxed", - record: "v=DMARC1; p=quarantine; aspf=r; adkim=s", - expectedAlignment: "relaxed", - }, - { - name: "Complex record with SPF strict", - record: "v=DMARC1; p=reject; rua=mailto:dmarc@example.com; aspf=s; adkim=s; pct=100", - expectedAlignment: "strict", - }, - } - - analyzer := NewDNSAnalyzer(5 * time.Second) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCSPFAlignment(tt.record) - if result == nil { - t.Fatalf("extractDMARCSPFAlignment(%q) returned nil, expected non-nil", tt.record) - } - if string(*result) != tt.expectedAlignment { - t.Errorf("extractDMARCSPFAlignment(%q) = %q, want %q", tt.record, string(*result), tt.expectedAlignment) - } - }) - } -} - -func TestExtractDMARCDKIMAlignment(t *testing.T) { - tests := []struct { - name string - record string - expectedAlignment string - }{ - { - name: "DKIM alignment - strict", - record: "v=DMARC1; p=reject; adkim=s", - expectedAlignment: "strict", - }, - { - name: "DKIM alignment - relaxed (explicit)", - record: "v=DMARC1; p=reject; adkim=r", - expectedAlignment: "relaxed", - }, - { - name: "DKIM alignment - relaxed (default, not specified)", - record: "v=DMARC1; p=none", - expectedAlignment: "relaxed", - }, - { - name: "Both alignments specified - check DKIM strict", - record: "v=DMARC1; p=quarantine; aspf=r; adkim=s", - expectedAlignment: "strict", - }, - { - name: "Both alignments specified - check DKIM relaxed", - record: "v=DMARC1; p=quarantine; aspf=s; adkim=r", - expectedAlignment: "relaxed", - }, - { - name: "Complex record with DKIM strict", - record: "v=DMARC1; p=reject; rua=mailto:dmarc@example.com; aspf=r; adkim=s; pct=100", - expectedAlignment: "strict", - }, - } - - analyzer := NewDNSAnalyzer(5 * time.Second) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCDKIMAlignment(tt.record) - if result == nil { - t.Fatalf("extractDMARCDKIMAlignment(%q) returned nil, expected non-nil", tt.record) - } - if string(*result) != tt.expectedAlignment { - t.Errorf("extractDMARCDKIMAlignment(%q) = %q, want %q", tt.record, string(*result), tt.expectedAlignment) - } - }) - } -} - -func TestExtractDMARCSubdomainPolicy(t *testing.T) { +func TestParseDMARCRecordAlignment(t *testing.T) { tests := []struct { name string record string - expectedPolicy *string + expectedSPF string + expectedDKIM string }{ { - name: "Subdomain policy - none", - record: "v=DMARC1; p=quarantine; sp=none", - expectedPolicy: utils.PtrTo("none"), + name: "SPF strict, DKIM relaxed", + record: "v=DMARC1; p=quarantine; aspf=s; adkim=r", + expectedSPF: "strict", + expectedDKIM: "relaxed", }, { - name: "Subdomain policy - quarantine", - record: "v=DMARC1; p=reject; sp=quarantine", - expectedPolicy: utils.PtrTo("quarantine"), + name: "SPF relaxed explicit, DKIM strict", + record: "v=DMARC1; p=quarantine; aspf=r; adkim=s", + expectedSPF: "relaxed", + expectedDKIM: "strict", }, { - name: "Subdomain policy - reject", - record: "v=DMARC1; p=quarantine; sp=reject", - expectedPolicy: utils.PtrTo("reject"), + name: "Defaults when neither specified", + record: "v=DMARC1; p=quarantine", + expectedSPF: "relaxed", + expectedDKIM: "relaxed", }, { - name: "No subdomain policy specified (defaults to main policy)", - record: "v=DMARC1; p=quarantine", - expectedPolicy: nil, - }, - { - name: "Complex record with subdomain policy", - record: "v=DMARC1; p=reject; sp=quarantine; rua=mailto:dmarc@example.com; pct=100", - expectedPolicy: utils.PtrTo("quarantine"), + name: "Both strict in complex record", + record: "v=DMARC1; p=reject; rua=mailto:dmarc@example.com; aspf=s; adkim=s; pct=100", + expectedSPF: "strict", + expectedDKIM: "strict", }, } @@ -572,53 +474,53 @@ func TestExtractDMARCSubdomainPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCSubdomainPolicy(tt.record) - if tt.expectedPolicy == nil { - if result != nil { - t.Errorf("extractDMARCSubdomainPolicy(%q) = %v, want nil", tt.record, result) - } - } else { - if result == nil { - t.Fatalf("extractDMARCSubdomainPolicy(%q) returned nil, expected %q", tt.record, *tt.expectedPolicy) - } - if string(*result) != *tt.expectedPolicy { - t.Errorf("extractDMARCSubdomainPolicy(%q) = %q, want %q", tt.record, string(*result), *tt.expectedPolicy) - } + rec := analyzer.parseDMARCRecord("example.com", tt.record) + if rec.SpfAlignment == nil { + t.Fatalf("parseDMARCRecord(%q).SpfAlignment = nil", tt.record) + } + if string(*rec.SpfAlignment) != tt.expectedSPF { + t.Errorf("SpfAlignment = %q, want %q", string(*rec.SpfAlignment), tt.expectedSPF) + } + if rec.DkimAlignment == nil { + t.Fatalf("parseDMARCRecord(%q).DkimAlignment = nil", tt.record) + } + if string(*rec.DkimAlignment) != tt.expectedDKIM { + t.Errorf("DkimAlignment = %q, want %q", string(*rec.DkimAlignment), tt.expectedDKIM) } }) } } -func TestExtractDMARCNonexistentSubdomainPolicy(t *testing.T) { +func TestParseDMARCRecordSubdomainPolicy(t *testing.T) { tests := []struct { name string record string - expectedPolicy *string + expectedSP *string + expectedNP *string }{ { - name: "Non-existent subdomain policy - none", - record: "v=DMARC1; p=quarantine; np=none", - expectedPolicy: utils.PtrTo("none"), + name: "sp=none, no np", + record: "v=DMARC1; p=quarantine; sp=none", + expectedSP: utils.PtrTo("none"), + expectedNP: nil, }, { - name: "Non-existent subdomain policy - quarantine", - record: "v=DMARC1; p=reject; np=quarantine", - expectedPolicy: utils.PtrTo("quarantine"), + name: "sp=reject, np=reject", + record: "v=DMARC1; p=reject; sp=quarantine; np=reject; rua=mailto:dmarc@example.com; pct=100", + expectedSP: utils.PtrTo("quarantine"), + expectedNP: utils.PtrTo("reject"), }, { - name: "Non-existent subdomain policy - reject", - record: "v=DMARC1; p=quarantine; np=reject", - expectedPolicy: utils.PtrTo("reject"), + name: "No sp or np (both default)", + record: "v=DMARC1; p=quarantine", + expectedSP: nil, + expectedNP: nil, }, { - name: "No np tag (defaults to effective sp/p policy)", - record: "v=DMARC1; p=quarantine", - expectedPolicy: nil, - }, - { - name: "Complex record with np and sp tags", - record: "v=DMARC1; p=reject; sp=quarantine; np=reject; rua=mailto:dmarc@example.com; pct=100", - expectedPolicy: utils.PtrTo("reject"), + name: "np=quarantine, no sp", + record: "v=DMARC1; p=reject; np=quarantine", + expectedSP: nil, + expectedNP: utils.PtrTo("quarantine"), }, } @@ -626,86 +528,63 @@ func TestExtractDMARCNonexistentSubdomainPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCNonexistentSubdomainPolicy(tt.record) - if tt.expectedPolicy == nil { - if result != nil { - t.Errorf("extractDMARCNonexistentSubdomainPolicy(%q) = %v, want nil", tt.record, result) + rec := analyzer.parseDMARCRecord("example.com", tt.record) + if tt.expectedSP == nil { + if rec.SubdomainPolicy != nil { + t.Errorf("parseDMARCRecord(%q).SubdomainPolicy = %v, want nil", tt.record, *rec.SubdomainPolicy) } } else { - if result == nil { - t.Fatalf("extractDMARCNonexistentSubdomainPolicy(%q) returned nil, expected %q", tt.record, *tt.expectedPolicy) + if rec.SubdomainPolicy == nil { + t.Fatalf("parseDMARCRecord(%q).SubdomainPolicy = nil, want %q", tt.record, *tt.expectedSP) } - if string(*result) != *tt.expectedPolicy { - t.Errorf("extractDMARCNonexistentSubdomainPolicy(%q) = %q, want %q", tt.record, string(*result), *tt.expectedPolicy) + if string(*rec.SubdomainPolicy) != *tt.expectedSP { + t.Errorf("SubdomainPolicy = %q, want %q", string(*rec.SubdomainPolicy), *tt.expectedSP) + } + } + if tt.expectedNP == nil { + if rec.NonexistentSubdomainPolicy != nil { + t.Errorf("parseDMARCRecord(%q).NonexistentSubdomainPolicy = %v, want nil", tt.record, *rec.NonexistentSubdomainPolicy) + } + } else { + if rec.NonexistentSubdomainPolicy == nil { + t.Fatalf("parseDMARCRecord(%q).NonexistentSubdomainPolicy = nil, want %q", tt.record, *tt.expectedNP) + } + if string(*rec.NonexistentSubdomainPolicy) != *tt.expectedNP { + t.Errorf("NonexistentSubdomainPolicy = %q, want %q", string(*rec.NonexistentSubdomainPolicy), *tt.expectedNP) } } }) } } -func TestExtractDMARCPercentage(t *testing.T) { +func TestParseDMARCRecordPercentage(t *testing.T) { tests := []struct { name string record string expectedPercentage *int }{ - { - name: "Percentage - 100", - record: "v=DMARC1; p=quarantine; pct=100", - expectedPercentage: utils.PtrTo(100), - }, - { - name: "Percentage - 50", - record: "v=DMARC1; p=quarantine; pct=50", - expectedPercentage: utils.PtrTo(50), - }, - { - name: "Percentage - 25", - record: "v=DMARC1; p=reject; pct=25", - expectedPercentage: utils.PtrTo(25), - }, - { - name: "Percentage - 0", - record: "v=DMARC1; p=none; pct=0", - expectedPercentage: utils.PtrTo(0), - }, - { - name: "No percentage specified (defaults to 100)", - record: "v=DMARC1; p=quarantine", - expectedPercentage: nil, - }, - { - name: "Complex record with percentage", - record: "v=DMARC1; p=reject; sp=quarantine; rua=mailto:dmarc@example.com; pct=75", - expectedPercentage: utils.PtrTo(75), - }, - { - name: "Invalid percentage > 100 (ignored)", - record: "v=DMARC1; p=quarantine; pct=150", - expectedPercentage: nil, - }, - { - name: "Invalid percentage < 0 (ignored)", - record: "v=DMARC1; p=quarantine; pct=-10", - expectedPercentage: nil, - }, + {name: "pct=100", record: "v=DMARC1; p=quarantine; pct=100", expectedPercentage: utils.PtrTo(100)}, + {name: "pct=50", record: "v=DMARC1; p=quarantine; pct=50", expectedPercentage: utils.PtrTo(50)}, + {name: "pct=0", record: "v=DMARC1; p=none; pct=0", expectedPercentage: utils.PtrTo(0)}, + {name: "no pct", record: "v=DMARC1; p=quarantine", expectedPercentage: nil}, + {name: "pct=150 ignored", record: "v=DMARC1; p=quarantine; pct=150", expectedPercentage: nil}, } analyzer := NewDNSAnalyzer(5 * time.Second) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := analyzer.extractDMARCPercentage(tt.record) + result := analyzer.parseDMARCRecord("example.com", tt.record).Percentage if tt.expectedPercentage == nil { if result != nil { - t.Errorf("extractDMARCPercentage(%q) = %v, want nil", tt.record, *result) + t.Errorf("parseDMARCRecord(%q).Percentage = %d, want nil", tt.record, *result) } } else { if result == nil { - t.Fatalf("extractDMARCPercentage(%q) returned nil, expected %d", tt.record, *tt.expectedPercentage) + t.Fatalf("parseDMARCRecord(%q).Percentage = nil, want %d", tt.record, *tt.expectedPercentage) } if *result != *tt.expectedPercentage { - t.Errorf("extractDMARCPercentage(%q) = %d, want %d", tt.record, *result, *tt.expectedPercentage) + t.Errorf("parseDMARCRecord(%q).Percentage = %d, want %d", tt.record, *result, *tt.expectedPercentage) } } })