package dav import ( "context" "errors" "fmt" "net" "net/http" "net/url" "strings" ) // Discover resolves the DAV context URL per RFC 6764. Every leg is recorded // in the result even on failure so the report can pinpoint the broken step. func Discover(ctx context.Context, client *http.Client, kind Kind, domain, explicitURL string) DiscoveryResult { res := DiscoveryResult{} if explicitURL != "" { res.ContextURL = explicitURL res.Source = "explicit" return res } // Always probe /.well-known even if SRV would suffice: it's the #1 // misconfig hotspot and we want to surface it. wellKnown := "https://" + domain + kind.WellKnownPath() res.WellKnownURL = wellKnown ctxURL, chain, code, err := followWellKnown(ctx, client, wellKnown) res.WellKnownCode = code res.WellKnownChain = chain if err != nil { res.WellKnownError = err.Error() } else if ctxURL != "" { res.ContextURL = ctxURL res.Source = "well-known" } discoverSRV(ctx, kind, domain, &res) if res.ContextURL == "" && len(res.SecureSRV) > 0 { target := res.SecureSRV[0] path := res.TXTPath if path == "" { path = "/" } res.ContextURL = srvURL(target, path, true) res.Source = "srv-txt" } if res.ContextURL == "" && res.Error == "" { res.Error = "could not resolve a context URL via /.well-known or SRV" } return res } // followWellKnown follows up to 5 redirects manually so we can record the // chain and the *first* status, since RFC 6764 ยง5 expects a 3xx and a 200 // at this position is the misconfig we want to flag. func followWellKnown(ctx context.Context, client *http.Client, u string) (finalURL string, chain []string, firstCode int, err error) { chain = make([]string, 0, 5) cur := u for i := 0; i < 5; i++ { req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, cur, nil) if reqErr != nil { return "", chain, firstCode, reqErr } // Snapshot disables the client's own redirect-following so we can // record each hop ourselves. c := *client c.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse } resp, doErr := c.Do(req) if doErr != nil { return "", chain, firstCode, doErr } resp.Body.Close() chain = append(chain, fmt.Sprintf("%d %s", resp.StatusCode, cur)) if i == 0 { firstCode = resp.StatusCode } if resp.StatusCode >= 300 && resp.StatusCode < 400 { loc := resp.Header.Get("Location") if loc == "" { return "", chain, firstCode, errors.New("redirect with empty Location header") } next, parseErr := resolveLocation(cur, loc) if parseErr != nil { return "", chain, firstCode, parseErr } cur = next continue } if resp.StatusCode == http.StatusOK { return cur, chain, firstCode, nil } return "", chain, firstCode, fmt.Errorf("unexpected status %d", resp.StatusCode) } return "", chain, firstCode, errors.New("too many redirects") } func resolveLocation(base, loc string) (string, error) { baseURL, err := url.Parse(base) if err != nil { return "", err } locURL, err := url.Parse(loc) if err != nil { return "", err } return baseURL.ResolveReference(locURL).String(), nil } func discoverSRV(ctx context.Context, kind Kind, domain string, res *DiscoveryResult) { resolver := net.DefaultResolver type srvResult struct { records []SRVRecord err error } secureCh := make(chan srvResult, 1) plainCh := make(chan srvResult, 1) go func() { r, err := lookupSRV(ctx, resolver, kind.ServiceName(true), "tcp", domain) secureCh <- srvResult{r, err} }() go func() { r, err := lookupSRV(ctx, resolver, kind.ServiceName(false), "tcp", domain) plainCh <- srvResult{r, err} }() secureRes := <-secureCh if secureRes.err != nil && !isNoSuchHost(secureRes.err) { res.SRVError = secureRes.err.Error() } res.SecureSRV = secureRes.records plainRes := <-plainCh if plainRes.err != nil && !isNoSuchHost(plainRes.err) && res.SRVError == "" { res.SRVError = plainRes.err.Error() } res.PlaintextSRV = plainRes.records var txtName string if len(res.SecureSRV) > 0 { txtName = kind.ServiceName(true) + "._tcp." + trimTrailingDot(res.SecureSRV[0].Target) } else if len(res.PlaintextSRV) > 0 { txtName = kind.ServiceName(false) + "._tcp." + trimTrailingDot(res.PlaintextSRV[0].Target) } if txtName != "" { txts, err := resolver.LookupTXT(ctx, txtName) if err != nil && !isNoSuchHost(err) { res.TXTError = err.Error() } for _, t := range txts { if strings.HasPrefix(t, "path=") { res.TXTPath = strings.TrimPrefix(t, "path=") break } } } } func lookupSRV(ctx context.Context, r *net.Resolver, service, proto, name string) ([]SRVRecord, error) { _, addrs, err := r.LookupSRV(ctx, strings.TrimPrefix(service, "_"), proto, name) if err != nil { return nil, err } out := make([]SRVRecord, 0, len(addrs)) for _, a := range addrs { out = append(out, SRVRecord{ Target: trimTrailingDot(a.Target), Port: a.Port, Priority: a.Priority, Weight: a.Weight, }) } return out, nil } func srvURL(r SRVRecord, path string, secure bool) string { scheme := "https" defaultPort := uint16(443) if !secure { scheme = "http" defaultPort = 80 } host := r.Target if r.Port != defaultPort { host = fmt.Sprintf("%s:%d", r.Target, r.Port) } if !strings.HasPrefix(path, "/") { path = "/" + path } return scheme + "://" + host + path } func trimTrailingDot(s string) string { return strings.TrimSuffix(s, ".") } func isNoSuchHost(err error) bool { var dnsErr *net.DNSError if errors.As(err, &dnsErr) { return dnsErr.IsNotFound } return false }