package checker import ( "errors" "net" "testing" ) // TestUpgraderFor_DirectTLS verifies that an empty dialect returns a nil // upgrader with ok=true: tlsenum's contract is that nil means "no upgrade // phase", so direct-TLS endpoints must round-trip through this branch // without producing a shim that would call into the registry. func TestUpgraderFor_DirectTLS(t *testing.T) { up, ok := upgraderFor("", "example.test") if !ok { t.Fatalf("expected ok=true for empty dialect") } if up != nil { t.Fatalf("expected nil upgrader for empty dialect, got %T", up) } } func TestUpgraderFor_UnknownDialect(t *testing.T) { up, ok := upgraderFor("totally-not-a-dialect", "example.test") if ok { t.Fatalf("expected ok=false for unknown dialect") } if up != nil { t.Fatalf("expected nil upgrader for unknown dialect, got %T", up) } } // TestUpgraderFor_KnownDialect_ForwardsSNI registers a temporary fake dialect // in the registry, asks upgraderFor for its callback, invokes the callback, // and asserts the registered upgrader received the expected SNI. We can't // reuse a real dialect for this because they all read/write protocol-specific // banners on the connection — the point of this test is the SNI plumbing in // the closure, not the dialect's own behavior. func TestUpgraderFor_KnownDialect_ForwardsSNI(t *testing.T) { const dialect = "test-fake" const wantSNI = "host.example.test" var ( gotSNI string gotConn net.Conn ) wantErr := errors.New("sentinel from fake upgrader") registerStartTLS(dialect, func(c net.Conn, sni string) error { gotConn = c gotSNI = sni return wantErr }) defer delete(starttlsUpgraders, dialect) up, ok := upgraderFor(dialect, wantSNI) if !ok || up == nil { t.Fatalf("expected non-nil upgrader and ok=true, got nil=%v ok=%v", up == nil, ok) } // Use a closed pipe end as a sentinel net.Conn — the registered upgrader // captures it without doing I/O, so a real connection is unnecessary. a, b := net.Pipe() _ = a.Close() _ = b.Close() if err := up(a); !errors.Is(err, wantErr) { t.Fatalf("expected sentinel error to propagate, got %v", err) } if gotSNI != wantSNI { t.Fatalf("registered upgrader received SNI %q, want %q", gotSNI, wantSNI) } if gotConn != a { t.Fatalf("registered upgrader received a different conn than the one passed in") } } // TestUpgraderFor_RealDialects_AllRegistered guards against silently dropping // a dialect from the registry: every protocol referenced by the contract's // STARTTLS values must resolve to a non-nil upgrader. The list mirrors the // dialects implemented in starttls_*.go. func TestUpgraderFor_RealDialects_AllRegistered(t *testing.T) { dialects := []string{"smtp", "submission", "imap", "pop3", "xmpp-client", "xmpp-server", "ldap"} for _, d := range dialects { t.Run(d, func(t *testing.T) { up, ok := upgraderFor(d, "host.example") if !ok || up == nil { t.Fatalf("dialect %q is not registered", d) } }) } }