package checker import ( "context" "fmt" "net" "net/url" "strconv" "strings" ) // parseURI parses a STUN/TURN URI per RFC 7064 / RFC 7065. // // Examples: // // stun:turn.example.com // stun:turn.example.com:3478 // stuns:turn.example.com:5349 // turn:turn.example.com:3478?transport=udp // turns:turn.example.com:5349?transport=tcp func parseURI(raw string) (Endpoint, error) { raw = strings.TrimSpace(raw) if raw == "" { return Endpoint{}, fmt.Errorf("empty URI") } colon := strings.IndexByte(raw, ':') if colon < 0 { return Endpoint{}, fmt.Errorf("missing scheme in %q", raw) } scheme := strings.ToLower(raw[:colon]) rest := raw[colon+1:] var ep Endpoint ep.URI = raw ep.Source = "uri" switch scheme { case "stun": ep.IsTURN = false ep.Secure = false case "stuns": ep.IsTURN = false ep.Secure = true case "turn": ep.IsTURN = true ep.Secure = false case "turns": ep.IsTURN = true ep.Secure = true default: return Endpoint{}, fmt.Errorf("unknown scheme %q", scheme) } hostport := rest query := "" if q := strings.IndexByte(rest, '?'); q >= 0 { hostport = rest[:q] query = rest[q+1:] } host, portStr, err := net.SplitHostPort(hostport) if err != nil { // no port; pick the default per scheme host = hostport portStr = "" } if host == "" { return Endpoint{}, fmt.Errorf("missing host in %q", raw) } ep.Host = host // Default transport: UDP for stun/turn, TCP for stuns/turns. Overridable via ?transport= if ep.Secure { ep.Transport = TransportTLS } else { ep.Transport = TransportUDP } if query != "" { values, err := url.ParseQuery(query) if err == nil { if t := strings.ToLower(values.Get("transport")); t != "" { switch t { case "udp": ep.Transport = TransportUDP case "tcp": if ep.Secure { ep.Transport = TransportTLS } else { ep.Transport = TransportTCP } case "tls": ep.Transport = TransportTLS ep.Secure = true case "dtls": ep.Transport = TransportDTLS ep.Secure = true } } } } if portStr == "" { if ep.Secure { ep.Port = 5349 } else { ep.Port = 3478 } } else { p, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return Endpoint{}, fmt.Errorf("invalid port %q: %w", portStr, err) } ep.Port = uint16(p) } return ep, nil } // discoverEndpoints returns the list of endpoints to probe. // // If serverURI is set, it is the only endpoint. Otherwise SRV records are // looked up for the zone. Returned endpoints are filtered to the requested // transports. func discoverEndpoints(ctx context.Context, zone, serverURI string, transports []Transport) ([]Endpoint, error) { if serverURI != "" { ep, err := parseURI(serverURI) if err != nil { return nil, err } return filterByTransport([]Endpoint{ep}, transports), nil } zone = strings.TrimSuffix(strings.TrimSpace(zone), ".") if zone == "" { return nil, fmt.Errorf("either serverURI or zone is required") } resolver := net.DefaultResolver type srvSpec struct { service string // _stun, _turn, _stuns, _turns proto string // _udp / _tcp isTURN bool secure bool transport Transport } specs := []srvSpec{ {"_stun", "_udp", false, false, TransportUDP}, {"_stun", "_tcp", false, false, TransportTCP}, {"_stuns", "_tcp", false, true, TransportTLS}, {"_turn", "_udp", true, false, TransportUDP}, {"_turn", "_tcp", true, false, TransportTCP}, {"_turns", "_tcp", true, true, TransportTLS}, } var endpoints []Endpoint for _, s := range specs { _, addrs, err := resolver.LookupSRV(ctx, strings.TrimPrefix(s.service, "_"), strings.TrimPrefix(s.proto, "_"), zone) if err != nil || len(addrs) == 0 { continue } for _, srv := range addrs { ep := Endpoint{ Host: strings.TrimSuffix(srv.Target, "."), Port: srv.Port, Transport: s.transport, Secure: s.secure, IsTURN: s.isTURN, Source: fmt.Sprintf("srv:%s.%s.%s", s.service, s.proto, zone), } scheme := "stun" if s.isTURN { scheme = "turn" } if s.secure { scheme += "s" } ep.URI = fmt.Sprintf("%s:%s:%d", scheme, ep.Host, ep.Port) endpoints = append(endpoints, ep) } } if len(endpoints) == 0 { return nil, fmt.Errorf("no STUN/TURN SRV records found under %s", zone) } return filterByTransport(endpoints, transports), nil } func filterByTransport(eps []Endpoint, allowed []Transport) []Endpoint { if len(allowed) == 0 { return eps } allow := make(map[Transport]bool, len(allowed)) for _, t := range allowed { allow[t] = true } out := make([]Endpoint, 0, len(eps)) for _, ep := range eps { if allow[ep.Transport] { out = append(out, ep) } } return out } func parseTransports(raw string) []Transport { if raw == "" { return []Transport{TransportUDP, TransportTCP, TransportTLS} } var out []Transport for _, p := range strings.Split(raw, ",") { p = strings.TrimSpace(strings.ToLower(p)) switch p { case "udp": out = append(out, TransportUDP) case "tcp": out = append(out, TransportTCP) case "tls": out = append(out, TransportTLS) case "dtls": out = append(out, TransportDTLS) } } return out }