checker-http/checker/collect.go

430 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// This file is part of the happyDomain (R) project.
// Copyright (c) 2020-2026 happyDomain
// Authors: Pierre-Olivier Mercier, et al.
package checker
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync/atomic"
"time"
sdk "git.happydns.org/checker-sdk-go/checker"
happydns "git.happydns.org/happyDomain/model"
"golang.org/x/net/html"
)
// verboseLogging is enabled via the CHECKER_HTTP_VERBOSE environment variable;
// when off, per-probe logging is silenced to keep production logs clean.
var verboseLogging = os.Getenv("CHECKER_HTTP_VERBOSE") != ""
// Collect resolves the Target from CheckerOptions, runs the root
// collector synchronously (its output is the canonical HTTPData), then
// runs every registered Collector in parallel and merges their JSON
// payloads into HTTPData.Extensions under their Key().
func (p *httpProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
target, err := buildTarget(ctx, opts)
if err != nil {
return nil, err
}
rootOut, err := rootCollector{}.Collect(ctx, target)
if err != nil {
return nil, err
}
data, ok := rootOut.(*HTTPData)
if !ok {
return nil, fmt.Errorf("rootCollector returned %T, expected *HTTPData", rootOut)
}
registry.mu.Lock()
collectors := append([]Collector(nil), registry.collectors...)
registry.mu.Unlock()
if len(collectors) == 0 {
return data, nil
}
type result struct {
key string
raw json.RawMessage
err error
}
// Each collector may issue several probes (one per scheme × IP), so we
// budget it as runProbe does (timeout × (maxRedirects+1)) multiplied by
// a small factor for the fan-out. The deadline is shared so a single
// hung collector cannot keep the caller waiting longer than the
// slowest legitimate collector.
collectorBudget := target.Timeout * time.Duration(target.MaxRedirects+1) * 4
cctx, cancel := context.WithTimeout(ctx, collectorBudget)
defer cancel()
results := make(chan result, len(collectors))
for _, c := range collectors {
go func(c Collector) {
out, err := c.Collect(cctx, target)
if err != nil {
results <- result{key: c.Key(), err: err}
return
}
raw, mErr := json.Marshal(out)
results <- result{key: c.Key(), raw: raw, err: mErr}
}(c)
}
exts := make(map[string]json.RawMessage, len(collectors))
pending := len(collectors)
for pending > 0 {
select {
case r := <-results:
pending--
if r.err != nil {
if verboseLogging {
log.Printf("checker-http: collector %q failed: %v", r.key, r.err)
}
continue
}
exts[r.key] = r.raw
case <-cctx.Done():
if verboseLogging {
log.Printf("checker-http: %d collector(s) did not return before deadline (%v); abandoning", pending, cctx.Err())
}
pending = 0
}
}
if len(exts) > 0 {
data.Extensions = exts
}
return data, nil
}
// LoadExtension decodes a sub-observation written by a Collector into the
// caller-supplied typed value. Returns false (without error) when the
// extension is absent — most rules treat that as "no_data" rather than
// an error.
func LoadExtension[T any](data *HTTPData, key string) (*T, bool, error) {
raw, ok := data.Extensions[key]
if !ok || len(raw) == 0 {
return nil, false, nil
}
var v T
if err := json.Unmarshal(raw, &v); err != nil {
return nil, true, fmt.Errorf("decode extension %q: %w", key, err)
}
return &v, true, nil
}
// buildTarget centralises option parsing and IP discovery so every
// Collector receives a fully resolved Target.
func buildTarget(ctx context.Context, opts sdk.CheckerOptions) (Target, error) {
server, err := resolveServer(opts)
if err != nil {
return Target{}, err
}
timeoutMs := sdk.GetIntOption(opts, OptionProbeTimeoutMs, DefaultProbeTimeoutMs)
if timeoutMs <= 0 {
timeoutMs = DefaultProbeTimeoutMs
}
maxRedirects := sdk.GetIntOption(opts, OptionMaxRedirects, DefaultMaxRedirects)
if maxRedirects < 0 {
maxRedirects = DefaultMaxRedirects
}
userAgent := DefaultUserAgent
if v, ok := sdk.GetOption[string](opts, OptionUserAgent); ok && v != "" {
userAgent = v
}
// Origin is the FQDN where the service is mounted: svc.Domain holds the
// subdomain (relative to apex; "@" for apex), and the domain_name
// autofill carries the zone apex.
apex := ""
if v, ok := sdk.GetOption[string](opts, OptionDomainName); ok {
apex = strings.TrimSuffix(v, ".")
}
subdomain := ""
if svc, ok := sdk.GetOption[happydns.ServiceMessage](opts, OptionService); ok {
subdomain = strings.TrimSuffix(svc.Domain, ".")
}
origin := sdk.JoinRelative(subdomain, apex)
host, ips := addressesFromServer(server, origin)
if host == "" {
host = origin
}
// abstract.Server only pins one A and one AAAA. Resolve the host to
// pick up any additional records the authoritative DNS exposes, so
// multi-IP deployments aren't silently under-probed. Failures are
// non-fatal; the pinned IPs remain.
seen := make(map[string]struct{}, len(ips)+4)
for _, ip := range ips {
seen[ip] = struct{}{}
}
ips = append(ips, discoverIPs(ctx, host, seen)...)
if len(ips) == 0 {
return Target{}, fmt.Errorf("abstract.Server has no A/AAAA records")
}
return Target{
Host: host,
IPs: ips,
Timeout: time.Duration(timeoutMs) * time.Millisecond,
MaxRedirects: maxRedirects,
UserAgent: userAgent,
}, nil
}
func runProbe(ctx context.Context, host, ip, scheme string, port uint16, timeout time.Duration, maxRedirects int, ua string, parseHTML bool) HTTPProbe {
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
probe := HTTPProbe{
Scheme: scheme,
Host: host,
IP: ip,
Port: port,
Address: addr,
IsIPv6: strings.Contains(ip, ":"),
}
dialer := &net.Dialer{Timeout: timeout}
// tcpConnected is set the moment a dial succeeds, so we can
// distinguish pure-TCP failures from later TLS/HTTP errors without
// resorting to error-string matching.
var tcpConnected atomic.Bool
// Force every dial to the chosen IP, regardless of what hostname is
// in the URL; that way we can attribute results to a specific A/AAAA
// record and bypass local resolver oddities.
transport := &http.Transport{
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tcpConnected.Store(true)
return conn, nil
},
TLSClientConfig: &tls.Config{
ServerName: host,
// Deep TLS posture is delegated to checker-tls. We still want
// HTTPS errors (expired cert, bad chain, ...) to surface as
// probe errors, so verification stays enabled.
},
TLSHandshakeTimeout: timeout,
ResponseHeaderTimeout: timeout,
DisableKeepAlives: true,
}
defer transport.CloseIdleConnections()
// Bound the whole probe (dial + TLS + headers + body across all
// redirect hops) by a single per-probe deadline derived from ctx, so
// a slow target can't pin a worker beyond the parent's lifetime and
// outer cancellation propagates to in-flight I/O.
probeBudget := timeout * time.Duration(maxRedirects+1)
probeCtx, cancel := context.WithTimeout(ctx, probeBudget)
defer cancel()
var redirectChain []RedirectStep
client := &http.Client{
Transport: transport,
// No client-level Timeout: probeCtx already bounds the request,
// and a separate http.Client.Timeout would race with it.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
prev := via[len(via)-1]
// req.Response is the 3xx response that triggered this hop;
// it carries the redirecting status code (301/302/307/308…).
status := 0
if req.Response != nil {
status = req.Response.StatusCode
}
redirectChain = append(redirectChain, RedirectStep{
From: prev.URL.String(),
To: req.URL.String(),
Status: status,
})
// The transport's DialContext is pinned to the original
// (ip, port) and TLS ServerName is pinned to the original
// host. Following a redirect that changes host, scheme, or
// port would silently route the request to the wrong
// backend. Stop and return the 3xx so the caller can see
// the Location, but don't follow it on this probe.
if !strings.EqualFold(req.URL.Host, host) ||
!strings.EqualFold(req.URL.Scheme, scheme) {
return http.ErrUseLastResponse
}
if len(via) > maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
return nil
},
}
target := &url.URL{Scheme: scheme, Host: host, Path: "/"}
req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, target.String(), nil)
if err != nil {
probe.Error = err.Error()
return probe
}
req.Header.Set("User-Agent", ua)
req.Header.Set("Accept", "text/html,application/xhtml+xml;q=0.9,*/*;q=0.5")
start := time.Now()
resp, err := client.Do(req)
probe.ElapsedMS = time.Since(start).Milliseconds()
if err != nil {
probe.Error = err.Error()
// The dialer wrapper sets tcpConnected the moment a TCP
// connection is established, so we can attribute the failure
// to a post-TCP layer (TLS, HTTP, redirect policy) without
// any error-string heuristics.
probe.TCPConnected = tcpConnected.Load()
probe.RedirectChain = redirectChain
return probe
}
defer resp.Body.Close()
probe.TCPConnected = true
probe.StatusCode = resp.StatusCode
if resp.Request != nil && resp.Request.URL != nil {
probe.FinalURL = resp.Request.URL.String()
}
// Per RFC 7230 §3.2.2, repeated headers (other than Set-Cookie) are
// semantically equivalent to a single header whose value is the
// comma-joined list; folding here preserves directives like a second
// CSP or HSTS header that would otherwise be dropped. Set-Cookie is
// excluded from the map since cookies are surfaced via resp.Cookies().
probe.Headers = make(map[string]string, len(resp.Header))
for k, v := range resp.Header {
if len(v) == 0 {
continue
}
lk := strings.ToLower(k)
if lk == "set-cookie" {
continue
}
probe.Headers[lk] = strings.Join(v, ", ")
}
// resp.Cookies() and resp.Header.Values("Set-Cookie") yield entries
// in the same order, so we can pair them positionally to recover the
// raw byte length of each Set-Cookie line for the size rule.
rawSetCookies := resp.Header.Values("Set-Cookie")
for i, c := range resp.Cookies() {
ci := CookieInfo{
Name: c.Name,
Domain: c.Domain,
Path: c.Path,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
SameSite: sameSiteString(c.SameSite),
HasExpiry: !c.Expires.IsZero() || c.MaxAge > 0,
}
if i < len(rawSetCookies) {
ci.Size = len(rawSetCookies[i])
}
probe.Cookies = append(probe.Cookies, ci)
}
probe.RedirectChain = redirectChain
// Read one extra byte to detect whether we hit the cap. Anything
// beyond MaxBodyBytes is dropped, but the probe surfaces
// BodyTruncated so callers know SRI/HTML rules saw a partial view.
body, err := io.ReadAll(io.LimitReader(resp.Body, MaxBodyBytes+1))
if err == nil {
if len(body) > MaxBodyBytes {
body = body[:MaxBodyBytes]
probe.BodyTruncated = true
}
probe.HTMLBytes = len(body)
if parseHTML && isHTMLContent(probe.Headers["content-type"]) {
probe.Resources = extractResources(body, host)
}
}
return probe
}
func sameSiteString(s http.SameSite) string {
switch s {
case http.SameSiteLaxMode:
return "Lax"
case http.SameSiteStrictMode:
return "Strict"
case http.SameSiteNoneMode:
return "None"
default:
return ""
}
}
func isHTMLContent(ct string) bool {
ct = strings.ToLower(ct)
return strings.Contains(ct, "text/html") || strings.Contains(ct, "application/xhtml")
}
// extractResources walks the HTML body and collects <script src=...>,
// <link href=... rel="stylesheet"|"preload"...> and inline-eligible <img>
// references, with a flag for whether the resource is cross-origin
// (different host than the page); SRI is only meaningful in that case.
func extractResources(body []byte, pageHost string) []HTMLResource {
doc, err := html.Parse(bytes.NewReader(body))
if err != nil {
return nil
}
var out []HTMLResource
var walk func(*html.Node)
walk = func(n *html.Node) {
if n.Type == html.ElementNode {
switch n.Data {
case "script":
if src, ok := attr(n, "src"); ok && src != "" {
out = append(out, mkResource("script", src, n, pageHost))
}
case "link":
rel, _ := attr(n, "rel")
if href, ok := attr(n, "href"); ok && href != "" && relIsAsset(rel) {
r := mkResource("link", href, n, pageHost)
r.Rel = rel
out = append(out, r)
}
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
walk(c)
}
}
walk(doc)
return out
}
func relIsAsset(rel string) bool {
rel = strings.ToLower(rel)
return strings.Contains(rel, "stylesheet") || strings.Contains(rel, "preload") || strings.Contains(rel, "modulepreload")
}
func mkResource(tag, ref string, n *html.Node, pageHost string) HTMLResource {
r := HTMLResource{Tag: tag, URL: ref}
if integ, ok := attr(n, "integrity"); ok && integ != "" {
r.Integrity = integ
}
if u, err := url.Parse(ref); err == nil && u.Host != "" && !strings.EqualFold(u.Host, pageHost) {
r.CrossOrigin = true
}
return r
}
func attr(n *html.Node, key string) (string, bool) {
for _, a := range n.Attr {
if strings.EqualFold(a.Key, key) {
return a.Val, true
}
}
return "", false
}