Harden contract validation, STARTTLS edge cases, and rule output

This commit is contained in:
nemunaire 2026-04-26 16:39:22 +07:00
commit fa212f0fae
9 changed files with 104 additions and 39 deletions

View file

@ -68,7 +68,7 @@ func FetchChain(ctx context.Context, host string, port uint16, starttls string,
tlsConn := tls.Client(conn, &tls.Config{ tlsConn := tls.Client(conn, &tls.Config{
ServerName: host, ServerName: host,
InsecureSkipVerify: true, InsecureSkipVerify: true, // #nosec G402 -- intentional: caller receives the chain even when PKIX rejects it
}) })
if err := tlsConn.HandshakeContext(dialCtx); err != nil { if err := tlsConn.HandshakeContext(dialCtx); err != nil {
return nil, fmt.Errorf("tls handshake: %w", err) return nil, fmt.Errorf("tls handshake: %w", err)

View file

@ -172,7 +172,7 @@ func probe(ctx context.Context, ep contract.TLSEndpoint, timeout time.Duration)
func handshake(conn net.Conn, ep contract.TLSEndpoint, sni string) (*tls.Conn, error) { func handshake(conn net.Conn, ep contract.TLSEndpoint, sni string) (*tls.Conn, error) {
cfg := &tls.Config{ cfg := &tls.Config{
ServerName: sni, ServerName: sni,
InsecureSkipVerify: true, InsecureSkipVerify: true, // #nosec G402 -- intentional: chain verified separately in probe()
} }
if ep.STARTTLS == "" { if ep.STARTTLS == "" {
@ -198,7 +198,7 @@ func handshake(conn net.Conn, ep contract.TLSEndpoint, sni string) (*tls.Conn, e
} }
var ( var (
errStartTLSNotOffered = errors.New("starttls not advertised by server") errStartTLSNotOffered = errors.New("starttls not advertised by server")
errUnsupportedStartTLSProto = errors.New("unsupported starttls protocol") errUnsupportedStartTLSProto = errors.New("unsupported starttls protocol")
) )

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"sort"
"strings"
sdk "git.happydns.org/checker-sdk-go/checker" sdk "git.happydns.org/checker-sdk-go/checker"
) )
@ -81,25 +83,38 @@ func (r *cipherSuiteRule) Evaluate(ctx context.Context, obs sdk.ObservationGette
return []sdk.CheckState{emptyCaseState("tls.cipher_suite.no_endpoints")} return []sdk.CheckState{emptyCaseState("tls.cipher_suite.no_endpoints")}
} }
var out []sdk.CheckState // Collapse per-endpoint cipher suites into a single info state. One
// row per endpoint drowns out actionable rules in the UI on domains
// with many endpoints; an aggregated list is enough for visibility.
suites := map[string]int{}
endpoints := map[string][]string{}
for _, ref := range sortedRefs(data) { for _, ref := range sortedRefs(data) {
p := data.Probes[ref] p := data.Probes[ref]
if p.CipherSuite == "" { if p.CipherSuite == "" {
continue continue
} }
out = append(out, sdk.CheckState{ suites[p.CipherSuite]++
Status: sdk.StatusInfo, endpoints[p.CipherSuite] = append(endpoints[p.CipherSuite], p.Endpoint)
Code: "tls.cipher_suite.negotiated",
Subject: subjectOf(p),
Message: fmt.Sprintf("Cipher suite %s negotiated.", p.CipherSuite),
Meta: metaOf(p),
})
} }
if len(out) == 0 { if len(suites) == 0 {
return []sdk.CheckState{unknownState( return []sdk.CheckState{unknownState(
"tls.cipher_suite.skipped", "tls.cipher_suite.skipped",
"No endpoint completed a TLS handshake.", "No endpoint completed a TLS handshake.",
)} )}
} }
return out names := make([]string, 0, len(suites))
for s := range suites {
names = append(names, s)
}
sort.Strings(names)
parts := make([]string, 0, len(names))
for _, n := range names {
parts = append(parts, fmt.Sprintf("%s (%d)", n, suites[n]))
}
return []sdk.CheckState{{
Status: sdk.StatusInfo,
Code: "tls.cipher_suite.negotiated",
Message: "Negotiated cipher suites: " + strings.Join(parts, ", "),
Meta: map[string]any{"suites": endpoints},
}}
} }

View file

@ -36,6 +36,10 @@ func starttlsIMAP(conn net.Conn, sni string) error {
supportsSTARTTLS = true supportsSTARTTLS = true
} }
if strings.HasPrefix(line, "A001 ") { if strings.HasPrefix(line, "A001 ") {
rest := strings.TrimSpace(line[len("A001 "):])
if !strings.HasPrefix(strings.ToUpper(rest), "OK") {
return fmt.Errorf("CAPABILITY rejected by server: %s", rest)
}
break break
} }
} }

View file

@ -7,6 +7,11 @@ import (
"strings" "strings"
) )
// EHLOHostname is the hostname sent in the SMTP EHLO command during STARTTLS
// negotiation. Override it at startup (e.g. via -ldflags or programmatically)
// to match the identity of the host running the checker.
var EHLOHostname = "checker.localhost"
func init() { func init() {
registerStartTLS("smtp", starttlsSMTP) registerStartTLS("smtp", starttlsSMTP)
registerStartTLS("submission", starttlsSMTP) registerStartTLS("submission", starttlsSMTP)
@ -20,7 +25,7 @@ func starttlsSMTP(conn net.Conn, sni string) error {
return fmt.Errorf("read greeting: %w", err) return fmt.Errorf("read greeting: %w", err)
} }
if _, err := rw.WriteString("EHLO checker.happydomain.org\r\n"); err != nil { if _, err := fmt.Fprintf(rw, "EHLO %s\r\n", EHLOHostname); err != nil {
return fmt.Errorf("write ehlo: %w", err) return fmt.Errorf("write ehlo: %w", err)
} }
if err := rw.Flush(); err != nil { if err := rw.Flush(); err != nil {

View file

@ -31,6 +31,9 @@ func starttlsXMPP(conn net.Conn, sni, ns string) error {
dec := xml.NewDecoder(conn) dec := xml.NewDecoder(conn)
// Read the inbound <stream:stream> opening and its <stream:features>. // Read the inbound <stream:stream> opening and its <stream:features>.
// A peer that opens with <stream:error/> (or anything other than features)
// is not going to advertise STARTTLS: surface that immediately rather
// than spinning on tokens until the deadline fires.
hasStartTLS := false hasStartTLS := false
outer: outer:
for { for {
@ -38,29 +41,38 @@ outer:
if err != nil { if err != nil {
return fmt.Errorf("read stream features: %w", err) return fmt.Errorf("read stream features: %w", err)
} }
if se, ok := tok.(xml.StartElement); ok { se, ok := tok.(xml.StartElement)
if se.Name.Local == "features" { if !ok {
// Scan features children. continue
for { }
t2, err := dec.Token() switch se.Name.Local {
if err != nil { case "stream":
return fmt.Errorf("read features body: %w", err) // Outer <stream:stream> opening. Continue reading children.
continue
case "features":
for {
t2, err := dec.Token()
if err != nil {
return fmt.Errorf("read features body: %w", err)
}
switch ee := t2.(type) {
case xml.StartElement:
if ee.Name.Local == "starttls" {
hasStartTLS = true
} }
switch ee := t2.(type) { if err := dec.Skip(); err != nil {
case xml.StartElement: return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err)
if ee.Name.Local == "starttls" { }
hasStartTLS = true case xml.EndElement:
} if ee.Name.Local == "features" {
if err := dec.Skip(); err != nil { break outer
return fmt.Errorf("skip feature %q: %w", ee.Name.Local, err)
}
case xml.EndElement:
if ee.Name.Local == "features" {
break outer
}
} }
} }
} }
case "error":
return fmt.Errorf("server returned <stream:error/> before features")
default:
return fmt.Errorf("%w: unexpected element %q before features", errStartTLSNotOffered, se.Name.Local)
} }
} }
if !hasStartTLS { if !hasStartTLS {

View file

@ -78,11 +78,11 @@ type TLSProbe struct {
// no certificate. // no certificate.
NoPeerCert bool `json:"no_peer_cert,omitempty"` NoPeerCert bool `json:"no_peer_cert,omitempty"`
HostnameMatch *bool `json:"hostname_match,omitempty"` HostnameMatch *bool `json:"hostname_match,omitempty"`
ChainValid *bool `json:"chain_valid,omitempty"` ChainValid *bool `json:"chain_valid,omitempty"`
ChainVerifyErr string `json:"chain_verify_err,omitempty"` ChainVerifyErr string `json:"chain_verify_err,omitempty"`
NotAfter time.Time `json:"not_after,omitempty"` NotAfter time.Time `json:"not_after,omitempty"`
Issuer string `json:"issuer,omitempty"` Issuer string `json:"issuer,omitempty"`
// IssuerDN is the leaf's issuer as an RFC 2253 DN string, suitable for // IssuerDN is the leaf's issuer as an RFC 2253 DN string, suitable for
// matching the CCADB CAA Identifiers CSV "Subject" column when the AKI // matching the CCADB CAA Identifiers CSV "Subject" column when the AKI
// lookup misses. // lookup misses.

View file

@ -16,6 +16,7 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
sdk "git.happydns.org/checker-sdk-go/checker" sdk "git.happydns.org/checker-sdk-go/checker"
) )
@ -58,10 +59,27 @@ type TLSEndpoint struct {
RequireSTARTTLS bool `json:"require,omitempty"` RequireSTARTTLS bool `json:"require,omitempty"`
} }
// Validate rejects endpoints that cannot be probed: empty Host or zero Port.
// STARTTLS dialect is intentionally not checked here (the checker surfaces
// unsupported dialects at runtime via the tls.starttls_dialect_supported
// rule), and SNI defaults to Host downstream.
func (ep TLSEndpoint) Validate() error {
if strings.TrimSpace(strings.TrimSuffix(ep.Host, ".")) == "" {
return fmt.Errorf("contract: TLSEndpoint.Host is required")
}
if ep.Port == 0 {
return fmt.Errorf("contract: TLSEndpoint.Port must be 1-65535")
}
return nil
}
// NewEntry wraps ep in an sdk.DiscoveryEntry with Type, a deterministic Ref // NewEntry wraps ep in an sdk.DiscoveryEntry with Type, a deterministic Ref
// derived from ep, and a marshaled Payload. The returned entry can be // derived from ep, and a marshaled Payload. The returned entry can be
// returned as-is from a DiscoveryPublisher implementation. // returned as-is from a DiscoveryPublisher implementation.
func NewEntry(ep TLSEndpoint) (sdk.DiscoveryEntry, error) { func NewEntry(ep TLSEndpoint) (sdk.DiscoveryEntry, error) {
if err := ep.Validate(); err != nil {
return sdk.DiscoveryEntry{}, err
}
payload, err := json.Marshal(ep) payload, err := json.Marshal(ep)
if err != nil { if err != nil {
return sdk.DiscoveryEntry{}, fmt.Errorf("contract: marshal TLSEndpoint: %w", err) return sdk.DiscoveryEntry{}, fmt.Errorf("contract: marshal TLSEndpoint: %w", err)
@ -95,7 +113,7 @@ func Ref(ep TLSEndpoint) string {
req = "1" req = "1"
} }
canonical := fmt.Sprintf("%s|%d|%s|%s|%s", ep.Host, ep.Port, sni, ep.STARTTLS, req) canonical := fmt.Sprintf("%s|%d|%s|%s|%s", ep.Host, ep.Port, sni, ep.STARTTLS, req)
sum := sha1.Sum([]byte(canonical)) sum := sha1.Sum([]byte(canonical)) // #nosec G401 G505 -- non-cryptographic stable key; see doc comment above
return hex.EncodeToString(sum[:8]) return hex.EncodeToString(sum[:8])
} }
@ -109,6 +127,9 @@ func ParseEntry(e sdk.DiscoveryEntry) (TLSEndpoint, error) {
if err := json.Unmarshal(e.Payload, &ep); err != nil { if err := json.Unmarshal(e.Payload, &ep); err != nil {
return TLSEndpoint{}, fmt.Errorf("contract: unmarshal TLSEndpoint: %w", err) return TLSEndpoint{}, fmt.Errorf("contract: unmarshal TLSEndpoint: %w", err)
} }
if err := ep.Validate(); err != nil {
return TLSEndpoint{}, err
}
return ep, nil return ep, nil
} }

View file

@ -10,11 +10,19 @@ import (
var Version = "custom-build" var Version = "custom-build"
// EHLOHostname is set via -ldflags to identify this checker instance in SMTP
// EHLO greetings. Falls back to the package default ("checker.localhost") when
// left empty.
var EHLOHostname = ""
var listenAddr = flag.String("listen", ":8080", "HTTP listen address") var listenAddr = flag.String("listen", ":8080", "HTTP listen address")
func main() { func main() {
flag.Parse() flag.Parse()
tls.Version = Version tls.Version = Version
if EHLOHostname != "" {
tls.EHLOHostname = EHLOHostname
}
srv := server.New(tls.Provider()) srv := server.New(tls.Provider())
if err := srv.ListenAndServe(*listenAddr); err != nil { if err := srv.ListenAndServe(*listenAddr); err != nil {