From 505cbae9af9b3e8128edbe1424a68e6f23979f44 Mon Sep 17 00:00:00 2001 From: Pierre-Olivier Mercier Date: Wed, 15 Oct 2025 15:03:39 +0700 Subject: [PATCH] Build DNS validation module (MX, SPF, DKIM, DMARC records) --- internal/analyzer/dns.go | 566 ++++++++++++++++++++++++++++++ internal/analyzer/dns_test.go | 633 ++++++++++++++++++++++++++++++++++ 2 files changed, 1199 insertions(+) create mode 100644 internal/analyzer/dns.go create mode 100644 internal/analyzer/dns_test.go diff --git a/internal/analyzer/dns.go b/internal/analyzer/dns.go new file mode 100644 index 0000000..07c0346 --- /dev/null +++ b/internal/analyzer/dns.go @@ -0,0 +1,566 @@ +// This file is part of the happyDeliver (R) project. +// Copyright (c) 2025 happyDomain +// Authors: Pierre-Olivier Mercier, et al. +// +// This program is offered under a commercial and under the AGPL license. +// For commercial licensing, contact us at . +// +// For AGPL licensing: +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package analyzer + +import ( + "context" + "fmt" + "net" + "regexp" + "strings" + "time" + + "git.happydns.org/happyDeliver/internal/api" +) + +// DNSAnalyzer analyzes DNS records for email domains +type DNSAnalyzer struct { + Timeout time.Duration + resolver *net.Resolver +} + +// NewDNSAnalyzer creates a new DNS analyzer with configurable timeout +func NewDNSAnalyzer(timeout time.Duration) *DNSAnalyzer { + if timeout == 0 { + timeout = 10 * time.Second // Default timeout + } + return &DNSAnalyzer{ + Timeout: timeout, + resolver: &net.Resolver{ + PreferGo: true, + }, + } +} + +// DNSResults represents DNS validation results for an email +type DNSResults struct { + Domain string + MXRecords []MXRecord + SPFRecord *SPFRecord + DKIMRecords []DKIMRecord + DMARCRecord *DMARCRecord + Errors []string +} + +// MXRecord represents an MX record +type MXRecord struct { + Host string + Priority uint16 + Valid bool + Error string +} + +// SPFRecord represents an SPF record +type SPFRecord struct { + Record string + Valid bool + Error string +} + +// DKIMRecord represents a DKIM record +type DKIMRecord struct { + Selector string + Domain string + Record string + Valid bool + Error string +} + +// DMARCRecord represents a DMARC record +type DMARCRecord struct { + Record string + Policy string // none, quarantine, reject + Valid bool + Error string +} + +// AnalyzeDNS performs DNS validation for the email's domain +func (d *DNSAnalyzer) AnalyzeDNS(email *EmailMessage, authResults *api.AuthenticationResults) *DNSResults { + // Extract domain from From address + domain := d.extractDomain(email) + if domain == "" { + return &DNSResults{ + Errors: []string{"Unable to extract domain from email"}, + } + } + + results := &DNSResults{ + Domain: domain, + } + + // Check MX records + results.MXRecords = d.checkMXRecords(domain) + + // Check SPF record + results.SPFRecord = d.checkSPFRecord(domain) + + // Check DKIM records (from authentication results) + if authResults != nil && authResults.Dkim != nil { + for _, dkim := range *authResults.Dkim { + if dkim.Domain != nil && dkim.Selector != nil { + dkimRecord := d.checkDKIMRecord(*dkim.Domain, *dkim.Selector) + if dkimRecord != nil { + results.DKIMRecords = append(results.DKIMRecords, *dkimRecord) + } + } + } + } + + // Check DMARC record + results.DMARCRecord = d.checkDMARCRecord(domain) + + return results +} + +// extractDomain extracts the domain from the email's From address +func (d *DNSAnalyzer) extractDomain(email *EmailMessage) string { + if email.From != nil && email.From.Address != "" { + parts := strings.Split(email.From.Address, "@") + if len(parts) == 2 { + return strings.ToLower(strings.TrimSpace(parts[1])) + } + } + return "" +} + +// checkMXRecords looks up MX records for a domain +func (d *DNSAnalyzer) checkMXRecords(domain string) []MXRecord { + ctx, cancel := context.WithTimeout(context.Background(), d.Timeout) + defer cancel() + + mxRecords, err := d.resolver.LookupMX(ctx, domain) + if err != nil { + return []MXRecord{ + { + Valid: false, + Error: fmt.Sprintf("Failed to lookup MX records: %v", err), + }, + } + } + + if len(mxRecords) == 0 { + return []MXRecord{ + { + Valid: false, + Error: "No MX records found", + }, + } + } + + var results []MXRecord + for _, mx := range mxRecords { + results = append(results, MXRecord{ + Host: mx.Host, + Priority: mx.Pref, + Valid: true, + }) + } + + return results +} + +// checkSPFRecord looks up and validates SPF record for a domain +func (d *DNSAnalyzer) checkSPFRecord(domain string) *SPFRecord { + ctx, cancel := context.WithTimeout(context.Background(), d.Timeout) + defer cancel() + + txtRecords, err := d.resolver.LookupTXT(ctx, domain) + if err != nil { + return &SPFRecord{ + Valid: false, + Error: fmt.Sprintf("Failed to lookup TXT records: %v", err), + } + } + + // Find SPF record (starts with "v=spf1") + var spfRecord string + spfCount := 0 + for _, txt := range txtRecords { + if strings.HasPrefix(txt, "v=spf1") { + spfRecord = txt + spfCount++ + } + } + + if spfCount == 0 { + return &SPFRecord{ + Valid: false, + Error: "No SPF record found", + } + } + + if spfCount > 1 { + return &SPFRecord{ + Record: spfRecord, + Valid: false, + Error: "Multiple SPF records found (RFC violation)", + } + } + + // Basic validation + if !d.validateSPF(spfRecord) { + return &SPFRecord{ + Record: spfRecord, + Valid: false, + Error: "SPF record appears malformed", + } + } + + return &SPFRecord{ + Record: spfRecord, + Valid: true, + } +} + +// validateSPF performs basic SPF record validation +func (d *DNSAnalyzer) validateSPF(record string) bool { + // Must start with v=spf1 + if !strings.HasPrefix(record, "v=spf1") { + return false + } + + // Check for common syntax issues + // Should have a final mechanism (all, +all, -all, ~all, ?all) + validEndings := []string{" all", " +all", " -all", " ~all", " ?all"} + hasValidEnding := false + for _, ending := range validEndings { + if strings.HasSuffix(record, ending) { + hasValidEnding = true + break + } + } + + return hasValidEnding +} + +// checkDKIMRecord looks up and validates DKIM record for a domain and selector +func (d *DNSAnalyzer) checkDKIMRecord(domain, selector string) *DKIMRecord { + // DKIM records are at: selector._domainkey.domain + dkimDomain := fmt.Sprintf("%s._domainkey.%s", selector, domain) + + ctx, cancel := context.WithTimeout(context.Background(), d.Timeout) + defer cancel() + + txtRecords, err := d.resolver.LookupTXT(ctx, dkimDomain) + if err != nil { + return &DKIMRecord{ + Selector: selector, + Domain: domain, + Valid: false, + Error: fmt.Sprintf("Failed to lookup DKIM record: %v", err), + } + } + + if len(txtRecords) == 0 { + return &DKIMRecord{ + Selector: selector, + Domain: domain, + Valid: false, + Error: "No DKIM record found", + } + } + + // Concatenate all TXT record parts (DKIM can be split) + dkimRecord := strings.Join(txtRecords, "") + + // Basic validation - should contain "v=DKIM1" and "p=" (public key) + if !d.validateDKIM(dkimRecord) { + return &DKIMRecord{ + Selector: selector, + Domain: domain, + Record: dkimRecord, + Valid: false, + Error: "DKIM record appears malformed", + } + } + + return &DKIMRecord{ + Selector: selector, + Domain: domain, + Record: dkimRecord, + Valid: true, + } +} + +// validateDKIM performs basic DKIM record validation +func (d *DNSAnalyzer) validateDKIM(record string) bool { + // Should contain p= tag (public key) + if !strings.Contains(record, "p=") { + return false + } + + // Often contains v=DKIM1 but not required + // If v= is present, it should be DKIM1 + if strings.Contains(record, "v=") && !strings.Contains(record, "v=DKIM1") { + return false + } + + return true +} + +// checkDMARCRecord looks up and validates DMARC record for a domain +func (d *DNSAnalyzer) checkDMARCRecord(domain string) *DMARCRecord { + // DMARC records are at: _dmarc.domain + dmarcDomain := fmt.Sprintf("_dmarc.%s", domain) + + ctx, cancel := context.WithTimeout(context.Background(), d.Timeout) + defer cancel() + + txtRecords, err := d.resolver.LookupTXT(ctx, dmarcDomain) + if err != nil { + return &DMARCRecord{ + Valid: false, + Error: fmt.Sprintf("Failed to lookup DMARC record: %v", err), + } + } + + // Find DMARC record (starts with "v=DMARC1") + var dmarcRecord string + for _, txt := range txtRecords { + if strings.HasPrefix(txt, "v=DMARC1") { + dmarcRecord = txt + break + } + } + + if dmarcRecord == "" { + return &DMARCRecord{ + Valid: false, + Error: "No DMARC record found", + } + } + + // Extract policy + policy := d.extractDMARCPolicy(dmarcRecord) + + // Basic validation + if !d.validateDMARC(dmarcRecord) { + return &DMARCRecord{ + Record: dmarcRecord, + Policy: policy, + Valid: false, + Error: "DMARC record appears malformed", + } + } + + return &DMARCRecord{ + Record: dmarcRecord, + Policy: policy, + Valid: true, + } +} + +// extractDMARCPolicy extracts the policy from a DMARC record +func (d *DNSAnalyzer) extractDMARCPolicy(record string) string { + // Look for p=none, p=quarantine, or p=reject + re := regexp.MustCompile(`p=(none|quarantine|reject)`) + matches := re.FindStringSubmatch(record) + if len(matches) > 1 { + return matches[1] + } + return "unknown" +} + +// validateDMARC performs basic DMARC record validation +func (d *DNSAnalyzer) validateDMARC(record string) bool { + // Must start with v=DMARC1 + if !strings.HasPrefix(record, "v=DMARC1") { + return false + } + + // Must have a policy tag + if !strings.Contains(record, "p=") { + return false + } + + return true +} + +// GenerateDNSChecks generates check results for DNS validation +func (d *DNSAnalyzer) GenerateDNSChecks(results *DNSResults) []api.Check { + var checks []api.Check + + if results == nil { + return checks + } + + // MX record check + checks = append(checks, d.generateMXCheck(results)) + + // SPF record check + if results.SPFRecord != nil { + checks = append(checks, d.generateSPFCheck(results.SPFRecord)) + } + + // DKIM record checks + for _, dkim := range results.DKIMRecords { + checks = append(checks, d.generateDKIMCheck(&dkim)) + } + + // DMARC record check + if results.DMARCRecord != nil { + checks = append(checks, d.generateDMARCCheck(results.DMARCRecord)) + } + + return checks +} + +// generateMXCheck creates a check for MX records +func (d *DNSAnalyzer) generateMXCheck(results *DNSResults) api.Check { + check := api.Check{ + Category: api.Dns, + Name: "MX Records", + } + + if len(results.MXRecords) == 0 || !results.MXRecords[0].Valid { + check.Status = api.CheckStatusFail + check.Score = 0.0 + check.Severity = api.PtrTo(api.Critical) + + if len(results.MXRecords) > 0 && results.MXRecords[0].Error != "" { + check.Message = results.MXRecords[0].Error + } else { + check.Message = "No valid MX records found" + } + check.Advice = api.PtrTo("Configure MX records for your domain to receive email") + } else { + check.Status = api.CheckStatusPass + check.Score = 1.0 + check.Severity = api.PtrTo(api.Info) + check.Message = fmt.Sprintf("Found %d valid MX record(s)", len(results.MXRecords)) + + // Add details about MX records + var mxList []string + for _, mx := range results.MXRecords { + mxList = append(mxList, fmt.Sprintf("%s (priority %d)", mx.Host, mx.Priority)) + } + details := strings.Join(mxList, ", ") + check.Details = &details + check.Advice = api.PtrTo("Your MX records are properly configured") + } + + return check +} + +// generateSPFCheck creates a check for SPF records +func (d *DNSAnalyzer) generateSPFCheck(spf *SPFRecord) api.Check { + check := api.Check{ + Category: api.Dns, + Name: "SPF Record", + } + + if !spf.Valid { + // If no record exists at all, it's a failure + if spf.Record == "" { + check.Status = api.CheckStatusFail + check.Score = 0.0 + check.Message = spf.Error + check.Severity = api.PtrTo(api.High) + check.Advice = api.PtrTo("Configure an SPF record for your domain to improve deliverability") + } else { + // If record exists but is invalid, it's a warning + check.Status = api.CheckStatusWarn + check.Score = 0.5 + check.Message = "SPF record found but appears invalid" + check.Severity = api.PtrTo(api.Medium) + check.Advice = api.PtrTo("Review and fix your SPF record syntax") + check.Details = &spf.Record + } + } else { + check.Status = api.CheckStatusPass + check.Score = 1.0 + check.Message = "Valid SPF record found" + check.Severity = api.PtrTo(api.Info) + check.Details = &spf.Record + check.Advice = api.PtrTo("Your SPF record is properly configured") + } + + return check +} + +// generateDKIMCheck creates a check for DKIM records +func (d *DNSAnalyzer) generateDKIMCheck(dkim *DKIMRecord) api.Check { + check := api.Check{ + Category: api.Dns, + Name: fmt.Sprintf("DKIM Record (%s)", dkim.Selector), + } + + if !dkim.Valid { + check.Status = api.CheckStatusFail + check.Score = 0.0 + check.Message = fmt.Sprintf("DKIM record not found or invalid: %s", dkim.Error) + check.Severity = api.PtrTo(api.High) + check.Advice = api.PtrTo("Ensure DKIM record is published in DNS for the selector used") + details := fmt.Sprintf("Selector: %s, Domain: %s", dkim.Selector, dkim.Domain) + check.Details = &details + } else { + check.Status = api.CheckStatusPass + check.Score = 1.0 + check.Message = "Valid DKIM record found" + check.Severity = api.PtrTo(api.Info) + details := fmt.Sprintf("Selector: %s, Domain: %s", dkim.Selector, dkim.Domain) + check.Details = &details + check.Advice = api.PtrTo("Your DKIM record is properly published") + } + + return check +} + +// generateDMARCCheck creates a check for DMARC records +func (d *DNSAnalyzer) generateDMARCCheck(dmarc *DMARCRecord) api.Check { + check := api.Check{ + Category: api.Dns, + Name: "DMARC Record", + } + + if !dmarc.Valid { + check.Status = api.CheckStatusFail + check.Score = 0.0 + check.Message = dmarc.Error + check.Severity = api.PtrTo(api.High) + check.Advice = api.PtrTo("Configure a DMARC record for your domain to improve deliverability and prevent spoofing") + } else { + check.Status = api.CheckStatusPass + check.Score = 1.0 + check.Message = fmt.Sprintf("Valid DMARC record found with policy: %s", dmarc.Policy) + check.Severity = api.PtrTo(api.Info) + check.Details = &dmarc.Record + + // Provide advice based on policy + switch dmarc.Policy { + case "none": + advice := "DMARC policy is set to 'none' (monitoring only). Consider upgrading to 'quarantine' or 'reject' for better protection" + check.Advice = &advice + case "quarantine": + advice := "DMARC policy is set to 'quarantine'. This provides good protection" + check.Advice = &advice + case "reject": + advice := "DMARC policy is set to 'reject'. This provides the strongest protection" + check.Advice = &advice + default: + advice := "Your DMARC record is properly configured" + check.Advice = &advice + } + } + + return check +} diff --git a/internal/analyzer/dns_test.go b/internal/analyzer/dns_test.go new file mode 100644 index 0000000..fe501d5 --- /dev/null +++ b/internal/analyzer/dns_test.go @@ -0,0 +1,633 @@ +// This file is part of the happyDeliver (R) project. +// Copyright (c) 2025 happyDomain +// Authors: Pierre-Olivier Mercier, et al. +// +// This program is offered under a commercial and under the AGPL license. +// For commercial licensing, contact us at . +// +// For AGPL licensing: +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package analyzer + +import ( + "net/mail" + "strings" + "testing" + "time" + + "git.happydns.org/happyDeliver/internal/api" +) + +func TestNewDNSAnalyzer(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + expectedTimeout time.Duration + }{ + { + name: "Default timeout", + timeout: 0, + expectedTimeout: 10 * time.Second, + }, + { + name: "Custom timeout", + timeout: 5 * time.Second, + expectedTimeout: 5 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + analyzer := NewDNSAnalyzer(tt.timeout) + if analyzer.Timeout != tt.expectedTimeout { + t.Errorf("Timeout = %v, want %v", analyzer.Timeout, tt.expectedTimeout) + } + if analyzer.resolver == nil { + t.Error("Resolver should not be nil") + } + }) + } +} + +func TestExtractDomain(t *testing.T) { + tests := []struct { + name string + fromAddress string + expectedDomain string + }{ + { + name: "Valid email", + fromAddress: "user@example.com", + expectedDomain: "example.com", + }, + { + name: "Email with subdomain", + fromAddress: "user@mail.example.com", + expectedDomain: "mail.example.com", + }, + { + name: "Email with uppercase", + fromAddress: "User@Example.COM", + expectedDomain: "example.com", + }, + { + name: "Invalid email (no @)", + fromAddress: "invalid-email", + expectedDomain: "", + }, + { + name: "Empty email", + fromAddress: "", + expectedDomain: "", + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + email := &EmailMessage{ + Header: make(mail.Header), + } + if tt.fromAddress != "" { + email.From = &mail.Address{ + Address: tt.fromAddress, + } + } + + domain := analyzer.extractDomain(email) + if domain != tt.expectedDomain { + t.Errorf("extractDomain() = %q, want %q", domain, tt.expectedDomain) + } + }) + } +} + +func TestValidateSPF(t *testing.T) { + tests := []struct { + name string + record string + expected bool + }{ + { + name: "Valid SPF with -all", + record: "v=spf1 include:_spf.example.com -all", + expected: true, + }, + { + name: "Valid SPF with ~all", + record: "v=spf1 ip4:192.0.2.0/24 ~all", + expected: true, + }, + { + name: "Valid SPF with +all", + record: "v=spf1 +all", + expected: true, + }, + { + name: "Valid SPF with ?all", + record: "v=spf1 mx ?all", + expected: true, + }, + { + name: "Invalid SPF - no version", + record: "include:_spf.example.com -all", + expected: false, + }, + { + name: "Invalid SPF - no all mechanism", + record: "v=spf1 include:_spf.example.com", + expected: false, + }, + { + name: "Invalid SPF - wrong version", + record: "v=spf2 include:_spf.example.com -all", + expected: false, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := analyzer.validateSPF(tt.record) + if result != tt.expected { + t.Errorf("validateSPF(%q) = %v, want %v", tt.record, result, tt.expected) + } + }) + } +} + +func TestValidateDKIM(t *testing.T) { + tests := []struct { + name string + record string + expected bool + }{ + { + name: "Valid DKIM with version", + record: "v=DKIM1; k=rsa; p=MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQ...", + expected: true, + }, + { + name: "Valid DKIM without version", + record: "k=rsa; p=MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQ...", + expected: true, + }, + { + name: "Invalid DKIM - no public key", + record: "v=DKIM1; k=rsa", + expected: false, + }, + { + name: "Invalid DKIM - wrong version", + record: "v=DKIM2; k=rsa; p=MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQ...", + expected: false, + }, + { + name: "Invalid DKIM - empty", + record: "", + expected: false, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := analyzer.validateDKIM(tt.record) + if result != tt.expected { + t.Errorf("validateDKIM(%q) = %v, want %v", tt.record, result, tt.expected) + } + }) + } +} + +func TestExtractDMARCPolicy(t *testing.T) { + tests := []struct { + name string + record string + expectedPolicy string + }{ + { + name: "Policy none", + record: "v=DMARC1; p=none; rua=mailto:dmarc@example.com", + expectedPolicy: "none", + }, + { + name: "Policy quarantine", + record: "v=DMARC1; p=quarantine; pct=100", + expectedPolicy: "quarantine", + }, + { + name: "Policy reject", + record: "v=DMARC1; p=reject; sp=reject", + expectedPolicy: "reject", + }, + { + name: "No policy", + record: "v=DMARC1", + expectedPolicy: "unknown", + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := analyzer.extractDMARCPolicy(tt.record) + if result != tt.expectedPolicy { + t.Errorf("extractDMARCPolicy(%q) = %q, want %q", tt.record, result, tt.expectedPolicy) + } + }) + } +} + +func TestValidateDMARC(t *testing.T) { + tests := []struct { + name string + record string + expected bool + }{ + { + name: "Valid DMARC", + record: "v=DMARC1; p=quarantine; rua=mailto:dmarc@example.com", + expected: true, + }, + { + name: "Valid DMARC minimal", + record: "v=DMARC1; p=none", + expected: true, + }, + { + name: "Invalid DMARC - no version", + record: "p=quarantine", + expected: false, + }, + { + name: "Invalid DMARC - no policy", + record: "v=DMARC1", + expected: false, + }, + { + name: "Invalid DMARC - wrong version", + record: "v=DMARC2; p=reject", + expected: false, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := analyzer.validateDMARC(tt.record) + if result != tt.expected { + t.Errorf("validateDMARC(%q) = %v, want %v", tt.record, result, tt.expected) + } + }) + } +} + +func TestGenerateMXCheck(t *testing.T) { + tests := []struct { + name string + results *DNSResults + expectedStatus api.CheckStatus + expectedScore float32 + }{ + { + name: "Valid MX records", + results: &DNSResults{ + Domain: "example.com", + MXRecords: []MXRecord{ + {Host: "mail.example.com", Priority: 10, Valid: true}, + {Host: "mail2.example.com", Priority: 20, Valid: true}, + }, + }, + expectedStatus: api.CheckStatusPass, + expectedScore: 1.0, + }, + { + name: "No MX records", + results: &DNSResults{ + Domain: "example.com", + MXRecords: []MXRecord{ + {Valid: false, Error: "No MX records found"}, + }, + }, + expectedStatus: api.CheckStatusFail, + expectedScore: 0.0, + }, + { + name: "MX lookup failed", + results: &DNSResults{ + Domain: "example.com", + MXRecords: []MXRecord{ + {Valid: false, Error: "DNS lookup failed"}, + }, + }, + expectedStatus: api.CheckStatusFail, + expectedScore: 0.0, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + check := analyzer.generateMXCheck(tt.results) + + if check.Status != tt.expectedStatus { + t.Errorf("Status = %v, want %v", check.Status, tt.expectedStatus) + } + if check.Score != tt.expectedScore { + t.Errorf("Score = %v, want %v", check.Score, tt.expectedScore) + } + if check.Category != api.Dns { + t.Errorf("Category = %v, want %v", check.Category, api.Dns) + } + }) + } +} + +func TestGenerateSPFCheck(t *testing.T) { + tests := []struct { + name string + spf *SPFRecord + expectedStatus api.CheckStatus + expectedScore float32 + }{ + { + name: "Valid SPF", + spf: &SPFRecord{ + Record: "v=spf1 include:_spf.example.com -all", + Valid: true, + }, + expectedStatus: api.CheckStatusPass, + expectedScore: 1.0, + }, + { + name: "Invalid SPF", + spf: &SPFRecord{ + Record: "v=spf1 invalid syntax", + Valid: false, + Error: "SPF record appears malformed", + }, + expectedStatus: api.CheckStatusWarn, + expectedScore: 0.5, + }, + { + name: "No SPF record", + spf: &SPFRecord{ + Valid: false, + Error: "No SPF record found", + }, + expectedStatus: api.CheckStatusFail, + expectedScore: 0.0, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + check := analyzer.generateSPFCheck(tt.spf) + + if check.Status != tt.expectedStatus { + t.Errorf("Status = %v, want %v", check.Status, tt.expectedStatus) + } + if check.Score != tt.expectedScore { + t.Errorf("Score = %v, want %v", check.Score, tt.expectedScore) + } + if check.Category != api.Dns { + t.Errorf("Category = %v, want %v", check.Category, api.Dns) + } + }) + } +} + +func TestGenerateDKIMCheck(t *testing.T) { + tests := []struct { + name string + dkim *DKIMRecord + expectedStatus api.CheckStatus + expectedScore float32 + }{ + { + name: "Valid DKIM", + dkim: &DKIMRecord{ + Selector: "default", + Domain: "example.com", + Record: "v=DKIM1; k=rsa; p=MIGfMA0...", + Valid: true, + }, + expectedStatus: api.CheckStatusPass, + expectedScore: 1.0, + }, + { + name: "Invalid DKIM", + dkim: &DKIMRecord{ + Selector: "default", + Domain: "example.com", + Valid: false, + Error: "No DKIM record found", + }, + expectedStatus: api.CheckStatusFail, + expectedScore: 0.0, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + check := analyzer.generateDKIMCheck(tt.dkim) + + if check.Status != tt.expectedStatus { + t.Errorf("Status = %v, want %v", check.Status, tt.expectedStatus) + } + if check.Score != tt.expectedScore { + t.Errorf("Score = %v, want %v", check.Score, tt.expectedScore) + } + if check.Category != api.Dns { + t.Errorf("Category = %v, want %v", check.Category, api.Dns) + } + if !strings.Contains(check.Name, tt.dkim.Selector) { + t.Errorf("Check name should contain selector %s", tt.dkim.Selector) + } + }) + } +} + +func TestGenerateDMARCCheck(t *testing.T) { + tests := []struct { + name string + dmarc *DMARCRecord + expectedStatus api.CheckStatus + expectedScore float32 + }{ + { + name: "Valid DMARC - reject", + dmarc: &DMARCRecord{ + Record: "v=DMARC1; p=reject", + Policy: "reject", + Valid: true, + }, + expectedStatus: api.CheckStatusPass, + expectedScore: 1.0, + }, + { + name: "Valid DMARC - quarantine", + dmarc: &DMARCRecord{ + Record: "v=DMARC1; p=quarantine", + Policy: "quarantine", + Valid: true, + }, + expectedStatus: api.CheckStatusPass, + expectedScore: 1.0, + }, + { + name: "Valid DMARC - none", + dmarc: &DMARCRecord{ + Record: "v=DMARC1; p=none", + Policy: "none", + Valid: true, + }, + expectedStatus: api.CheckStatusPass, + expectedScore: 1.0, + }, + { + name: "No DMARC record", + dmarc: &DMARCRecord{ + Valid: false, + Error: "No DMARC record found", + }, + expectedStatus: api.CheckStatusFail, + expectedScore: 0.0, + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + check := analyzer.generateDMARCCheck(tt.dmarc) + + if check.Status != tt.expectedStatus { + t.Errorf("Status = %v, want %v", check.Status, tt.expectedStatus) + } + if check.Score != tt.expectedScore { + t.Errorf("Score = %v, want %v", check.Score, tt.expectedScore) + } + if check.Category != api.Dns { + t.Errorf("Category = %v, want %v", check.Category, api.Dns) + } + + // Check that advice mentions policy for valid DMARC + if tt.dmarc.Valid && check.Advice != nil { + if tt.dmarc.Policy == "none" && !strings.Contains(*check.Advice, "none") { + t.Error("Advice should mention 'none' policy") + } + } + }) + } +} + +func TestGenerateDNSChecks(t *testing.T) { + tests := []struct { + name string + results *DNSResults + minChecks int + }{ + { + name: "Nil results", + results: nil, + minChecks: 0, + }, + { + name: "Complete results", + results: &DNSResults{ + Domain: "example.com", + MXRecords: []MXRecord{ + {Host: "mail.example.com", Priority: 10, Valid: true}, + }, + SPFRecord: &SPFRecord{ + Record: "v=spf1 include:_spf.example.com -all", + Valid: true, + }, + DKIMRecords: []DKIMRecord{ + { + Selector: "default", + Domain: "example.com", + Valid: true, + }, + }, + DMARCRecord: &DMARCRecord{ + Record: "v=DMARC1; p=quarantine", + Policy: "quarantine", + Valid: true, + }, + }, + minChecks: 4, // MX, SPF, DKIM, DMARC + }, + { + name: "Partial results", + results: &DNSResults{ + Domain: "example.com", + MXRecords: []MXRecord{ + {Host: "mail.example.com", Priority: 10, Valid: true}, + }, + }, + minChecks: 1, // Only MX + }, + } + + analyzer := NewDNSAnalyzer(5 * time.Second) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checks := analyzer.GenerateDNSChecks(tt.results) + + if len(checks) < tt.minChecks { + t.Errorf("Got %d checks, want at least %d", len(checks), tt.minChecks) + } + + // Verify all checks have the DNS category + for _, check := range checks { + if check.Category != api.Dns { + t.Errorf("Check %s has category %v, want %v", check.Name, check.Category, api.Dns) + } + } + }) + } +} + +func TestAnalyzeDNS_NoDomain(t *testing.T) { + analyzer := NewDNSAnalyzer(5 * time.Second) + email := &EmailMessage{ + Header: make(mail.Header), + // No From address + } + + results := analyzer.AnalyzeDNS(email, nil) + + if results == nil { + t.Fatal("Expected results, got nil") + } + + if len(results.Errors) == 0 { + t.Error("Expected error when no domain can be extracted") + } +}