package checker import ( "context" "crypto/tls" "errors" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/turn/v4" ) // dialedConn wraps the network conn used to talk to a STUN/TURN server, // always exposing a PacketConn (turn/stun talk in datagrams). For // stream transports (TCP/TLS) we wrap with turn.NewSTUNConn which frames // STUN messages on top of the byte stream per RFC 5389 ยง7.2.2. type dialedConn struct { pc net.PacketConn underlying net.Conn // nil for UDP, non-nil for TCP/TLS/DTLS tlsState *tls.ConnectionState dtlsState *dtls.State remoteAddr net.Addr closeUnderlay func() error } func (d *dialedConn) Close() error { var err error if d.pc != nil { err = d.pc.Close() } if d.closeUnderlay != nil { if e := d.closeUnderlay(); e != nil && err == nil { err = e } } return err } // dial establishes the appropriate L4(/secure) connection to ep. // timeout is applied per dial step (TCP connect, TLS handshake, DTLS handshake). func dial(ctx context.Context, ep Endpoint, timeout time.Duration) (*dialedConn, error) { addr := net.JoinHostPort(ep.Host, fmt.Sprintf("%d", ep.Port)) dctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() switch ep.Transport { case TransportUDP: raddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, fmt.Errorf("resolve udp %s: %w", addr, err) } conn, err := net.ListenPacket("udp", "0.0.0.0:0") if err != nil { return nil, fmt.Errorf("listen udp: %w", err) } return &dialedConn{pc: conn, remoteAddr: raddr}, nil case TransportTCP: var d net.Dialer c, err := d.DialContext(dctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("dial tcp %s: %w", addr, err) } return &dialedConn{ pc: turn.NewSTUNConn(c), underlying: c, remoteAddr: c.RemoteAddr(), closeUnderlay: c.Close, }, nil case TransportTLS: var d net.Dialer raw, err := d.DialContext(dctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("dial tcp %s: %w", addr, err) } tlsConn := tls.Client(raw, &tls.Config{ServerName: ep.Host, MinVersion: tls.VersionTLS12}) if err := tlsConn.HandshakeContext(dctx); err != nil { raw.Close() return nil, fmt.Errorf("tls handshake %s: %w", addr, err) } state := tlsConn.ConnectionState() return &dialedConn{ pc: turn.NewSTUNConn(tlsConn), underlying: tlsConn, tlsState: &state, remoteAddr: tlsConn.RemoteAddr(), closeUnderlay: tlsConn.Close, }, nil case TransportDTLS: raddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, fmt.Errorf("resolve udp %s: %w", addr, err) } udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, fmt.Errorf("listen udp: %w", err) } dconn, err := dtls.Client(udpConn, raddr, &dtls.Config{ ServerName: ep.Host, }) if err != nil { udpConn.Close() return nil, fmt.Errorf("dtls setup %s: %w", addr, err) } if err := dconn.HandshakeContext(dctx); err != nil { dconn.Close() udpConn.Close() return nil, fmt.Errorf("dtls handshake %s: %w", addr, err) } state, _ := dconn.ConnectionState() return &dialedConn{ pc: turn.NewSTUNConn(dconn), underlying: dconn, dtlsState: &state, remoteAddr: raddr, closeUnderlay: dconn.Close, }, nil default: return nil, errors.New("unknown transport") } }