package tlsenum import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" stdtls "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "errors" "fmt" "io" "math/big" "net" "os" "testing" "time" utls "github.com/refraction-networking/utls" ) // selfSignedCert returns a brand-new in-memory self-signed cert + key for // "test.local", suitable for stdlib tls.Server. func selfSignedCert() (stdtls.Certificate, error) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return stdtls.Certificate{}, err } tmpl := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "test.local"}, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), DNSNames: []string{"test.local"}, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } der, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &key.PublicKey, key) if err != nil { return stdtls.Certificate{}, err } keyDER, err := x509.MarshalECPrivateKey(key) if err != nil { return stdtls.Certificate{}, err } certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) return stdtls.X509KeyPair(certPEM, keyPEM) } // runFakeStartTLSServer accepts one connection, expects a "STARTTLS\r\n" // line, replies "OK\r\n", then runs a TLS handshake. It returns once the // handshake completes (or fails) and the connection is closed. func runFakeStartTLSServer(ln net.Listener, cert stdtls.Certificate) error { c, err := ln.Accept() if err != nil { return err } defer c.Close() buf := make([]byte, len("STARTTLS\r\n")) if _, err := io.ReadFull(c, buf); err != nil { return err } if string(buf) != "STARTTLS\r\n" { return fmt.Errorf("unexpected pre-tls line: %q", string(buf)) } if _, err := c.Write([]byte("OK\r\n")); err != nil { return err } tc := stdtls.Server(c, &stdtls.Config{ Certificates: []stdtls.Certificate{cert}, MinVersion: stdtls.VersionTLS12, }) defer tc.Close() return tc.Handshake() } // liveTarget returns a host:port to enumerate against, or skips the test if // the environment hasn't opted in. Network tests are gated behind // TLSENUM_LIVE=1 so the unit-test suite stays hermetic. func liveTarget(t *testing.T) (addr, sni string) { t.Helper() if os.Getenv("TLSENUM_LIVE") == "" { t.Skip("set TLSENUM_LIVE=1 to run live enumeration tests") } host := os.Getenv("TLSENUM_HOST") if host == "" { host = "tls-v1-2.badssl.com" } port := os.Getenv("TLSENUM_PORT") if port == "" { port = "1012" } return net.JoinHostPort(host, port), host } func TestProbe_TLS12_AESGCM(t *testing.T) { addr, sni := liveTarget(t) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() r := Probe(ctx, addr, sni, utls.VersionTLS12, 0xC02F /* ECDHE-RSA-AES128-GCM-SHA256 */, ProbeOptions{Timeout: 5 * time.Second}) if !r.Accepted { t.Fatalf("expected ECDHE-RSA-AES128-GCM-SHA256 to be accepted on TLS 1.2 target; got err=%v", r.Err) } if r.NegotiatedVersion != utls.VersionTLS12 { t.Fatalf("negotiated version = %x, want %x", r.NegotiatedVersion, utls.VersionTLS12) } } func TestEnumerate_BasicShape(t *testing.T) { addr, sni := liveTarget(t) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() res, err := Enumerate(ctx, addr, sni, EnumerateOptions{ ProbeTimeout: 5 * time.Second, }) if err != nil { t.Fatalf("Enumerate: %v", err) } if len(res.SupportedVersions) == 0 { t.Fatalf("no supported versions discovered") } for v, ciphers := range res.CiphersByVersion { if len(ciphers) == 0 { t.Errorf("version %s listed as supported but no ciphers recorded", VersionName(v)) } t.Logf("%s: %d cipher(s)", VersionName(v), len(ciphers)) } } // TestProbe_UpgraderInvoked uses a tiny in-memory STARTTLS-style server: a // goroutine listens, reads one "STARTTLS\r\n" line, replies "OK\r\n", then // performs a real Go-stdlib TLS handshake. We probe through the matching // Upgrader and assert the handshake succeeds — proving the callback runs in // the right place between dial and ClientHello. func TestProbe_UpgraderInvoked(t *testing.T) { cert, err := selfSignedCert() if err != nil { t.Fatalf("self-signed cert: %v", err) } ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() srvDone := make(chan error, 1) go func() { srvDone <- runFakeStartTLSServer(ln, cert) }() upgrader := func(c net.Conn) error { if _, err := c.Write([]byte("STARTTLS\r\n")); err != nil { return err } buf := make([]byte, 16) n, err := c.Read(buf) if err != nil { return err } if got := string(buf[:n]); got != "OK\r\n" { return fmt.Errorf("unexpected reply: %q", got) } return nil } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() r := Probe(ctx, ln.Addr().String(), "test.local", utls.VersionTLS12, 0xC02B, /* ECDHE-ECDSA-AES128-GCM-SHA256 (matches the P-256 cert) */ ProbeOptions{Timeout: 3 * time.Second, Upgrader: upgrader}) if !r.Accepted { t.Fatalf("expected handshake to succeed through upgrader; err=%v", r.Err) } if r.NegotiatedVersion != utls.VersionTLS12 { t.Fatalf("negotiated %#x, want %#x", r.NegotiatedVersion, utls.VersionTLS12) } if err := <-srvDone; err != nil { t.Logf("fake server done with: %v", err) // accept clean close from utls } } func TestProbe_UpgraderError(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() go func() { c, _ := ln.Accept() if c != nil { c.Close() } }() wantErr := errors.New("plaintext refused starttls") r := Probe(context.Background(), ln.Addr().String(), "x", utls.VersionTLS12, 0xC02F, ProbeOptions{Timeout: 2 * time.Second, Upgrader: func(net.Conn) error { return wantErr }}) if r.Accepted { t.Fatalf("expected probe to fail when upgrader returns error") } if r.Err == nil || !errors.Is(r.Err, wantErr) { t.Fatalf("expected wrapped upgrader error, got %v", r.Err) } } func TestVersionName(t *testing.T) { cases := map[uint16]string{ utls.VersionTLS10: "TLS 1.0", utls.VersionTLS11: "TLS 1.1", utls.VersionTLS12: "TLS 1.2", utls.VersionTLS13: "TLS 1.3", 0x9999: "0x9999", } for v, want := range cases { if got := VersionName(v); got != want { t.Errorf("VersionName(%#x) = %q, want %q", v, got, want) } } }