package checker import ( "context" "crypto/tls" "errors" "fmt" "net" "strconv" "time" "github.com/pion/dtls/v3" "github.com/pion/turn/v4" ) // dialedConn wraps the network conn used to talk to a STUN/TURN server, // always exposing a PacketConn (turn/stun talk in datagrams). For // stream transports (TCP/TLS) we wrap with turn.NewSTUNConn which frames // STUN messages on top of the byte stream per RFC 5389 ยง7.2.2. type dialedConn struct { pc net.PacketConn underlying net.Conn // non-nil for TCP/TLS; nil for UDP and DTLS tlsState *tls.ConnectionState dtlsState *dtls.State remoteAddr net.Addr } func (d *dialedConn) Close() error { var err error if d.pc != nil { err = d.pc.Close() } if d.underlying != nil { if e := d.underlying.Close(); e != nil && err == nil { err = e } } return err } // dtlsPacketConn adapts *dtls.Conn (net.Conn) to net.PacketConn. // DTLS frames messages at the record level; no additional length-prefix // framing (as turn.NewSTUNConn adds for TCP) is needed or correct here. type dtlsPacketConn struct { conn *dtls.Conn raddr net.Addr } func (d *dtlsPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { n, err := d.conn.Read(b) return n, d.raddr, err } func (d *dtlsPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) { return d.conn.Write(b) } func (d *dtlsPacketConn) Close() error { return d.conn.Close() } func (d *dtlsPacketConn) LocalAddr() net.Addr { return d.conn.LocalAddr() } func (d *dtlsPacketConn) SetDeadline(t time.Time) error { return d.conn.SetDeadline(t) } func (d *dtlsPacketConn) SetReadDeadline(t time.Time) error { return d.conn.SetReadDeadline(t) } func (d *dtlsPacketConn) SetWriteDeadline(t time.Time) error { return d.conn.SetWriteDeadline(t) } // dial establishes the appropriate L4(/secure) connection to ep. // timeout is applied per dial step (TCP connect, TLS handshake, DTLS handshake). func dial(ctx context.Context, ep Endpoint, timeout time.Duration) (*dialedConn, error) { addr := net.JoinHostPort(ep.Host, strconv.Itoa(int(ep.Port))) dctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() switch ep.Transport { case TransportUDP: raddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, fmt.Errorf("resolve udp %s: %w", addr, err) } // Use the dual-stack wildcard ("") so the kernel can pick an IPv6 // source when the resolved server address is IPv6. conn, err := net.ListenPacket("udp", ":0") if err != nil { return nil, fmt.Errorf("listen udp: %w", err) } return &dialedConn{pc: conn, remoteAddr: raddr}, nil case TransportTCP: var d net.Dialer c, err := d.DialContext(dctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("dial tcp %s: %w", addr, err) } return &dialedConn{ pc: turn.NewSTUNConn(c), underlying: c, remoteAddr: c.RemoteAddr(), }, nil case TransportTLS: var d net.Dialer raw, err := d.DialContext(dctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("dial tcp %s: %w", addr, err) } tlsConn := tls.Client(raw, &tls.Config{ServerName: ep.Host, MinVersion: tls.VersionTLS12}) if err := tlsConn.HandshakeContext(dctx); err != nil { raw.Close() return nil, fmt.Errorf("tls handshake %s: %w", addr, err) } state := tlsConn.ConnectionState() return &dialedConn{ pc: turn.NewSTUNConn(tlsConn), underlying: tlsConn, tlsState: &state, remoteAddr: tlsConn.RemoteAddr(), }, nil case TransportDTLS: raddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, fmt.Errorf("resolve udp %s: %w", addr, err) } udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, fmt.Errorf("listen udp: %w", err) } dconn, err := dtls.Client(udpConn, raddr, &dtls.Config{ ServerName: ep.Host, }) if err != nil { udpConn.Close() return nil, fmt.Errorf("dtls setup %s: %w", addr, err) } if err := dconn.HandshakeContext(dctx); err != nil { dconn.Close() udpConn.Close() return nil, fmt.Errorf("dtls handshake %s: %w", addr, err) } state, _ := dconn.ConnectionState() return &dialedConn{ pc: &dtlsPacketConn{conn: dconn, raddr: raddr}, dtlsState: &state, remoteAddr: raddr, }, nil default: return nil, errors.New("unknown transport") } }