430 lines
13 KiB
Go
430 lines
13 KiB
Go
// This file is part of the happyDomain (R) project.
|
||
// Copyright (c) 2020-2026 happyDomain
|
||
// Authors: Pierre-Olivier Mercier, et al.
|
||
|
||
package checker
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/tls"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"strings"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
sdk "git.happydns.org/checker-sdk-go/checker"
|
||
happydns "git.happydns.org/happyDomain/model"
|
||
"golang.org/x/net/html"
|
||
)
|
||
|
||
// verboseLogging is enabled via the CHECKER_HTTP_VERBOSE environment variable;
|
||
// when off, per-probe logging is silenced to keep production logs clean.
|
||
var verboseLogging = os.Getenv("CHECKER_HTTP_VERBOSE") != ""
|
||
|
||
// Collect resolves the Target from CheckerOptions, runs the root
|
||
// collector synchronously (its output is the canonical HTTPData), then
|
||
// runs every registered Collector in parallel and merges their JSON
|
||
// payloads into HTTPData.Extensions under their Key().
|
||
func (p *httpProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
|
||
target, err := buildTarget(ctx, opts)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
rootOut, err := rootCollector{}.Collect(ctx, target)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
data, ok := rootOut.(*HTTPData)
|
||
if !ok {
|
||
return nil, fmt.Errorf("rootCollector returned %T, expected *HTTPData", rootOut)
|
||
}
|
||
|
||
registry.mu.Lock()
|
||
collectors := append([]Collector(nil), registry.collectors...)
|
||
registry.mu.Unlock()
|
||
if len(collectors) == 0 {
|
||
return data, nil
|
||
}
|
||
|
||
type result struct {
|
||
key string
|
||
raw json.RawMessage
|
||
err error
|
||
}
|
||
// Each collector may issue several probes (one per scheme × IP), so we
|
||
// budget it as runProbe does (timeout × (maxRedirects+1)) multiplied by
|
||
// a small factor for the fan-out. The deadline is shared so a single
|
||
// hung collector cannot keep the caller waiting longer than the
|
||
// slowest legitimate collector.
|
||
collectorBudget := target.Timeout * time.Duration(target.MaxRedirects+1) * 4
|
||
cctx, cancel := context.WithTimeout(ctx, collectorBudget)
|
||
defer cancel()
|
||
|
||
results := make(chan result, len(collectors))
|
||
for _, c := range collectors {
|
||
go func(c Collector) {
|
||
out, err := c.Collect(cctx, target)
|
||
if err != nil {
|
||
results <- result{key: c.Key(), err: err}
|
||
return
|
||
}
|
||
raw, mErr := json.Marshal(out)
|
||
results <- result{key: c.Key(), raw: raw, err: mErr}
|
||
}(c)
|
||
}
|
||
|
||
exts := make(map[string]json.RawMessage, len(collectors))
|
||
pending := len(collectors)
|
||
for pending > 0 {
|
||
select {
|
||
case r := <-results:
|
||
pending--
|
||
if r.err != nil {
|
||
if verboseLogging {
|
||
log.Printf("checker-http: collector %q failed: %v", r.key, r.err)
|
||
}
|
||
continue
|
||
}
|
||
exts[r.key] = r.raw
|
||
case <-cctx.Done():
|
||
if verboseLogging {
|
||
log.Printf("checker-http: %d collector(s) did not return before deadline (%v); abandoning", pending, cctx.Err())
|
||
}
|
||
pending = 0
|
||
}
|
||
}
|
||
if len(exts) > 0 {
|
||
data.Extensions = exts
|
||
}
|
||
return data, nil
|
||
}
|
||
|
||
// LoadExtension decodes a sub-observation written by a Collector into the
|
||
// caller-supplied typed value. Returns false (without error) when the
|
||
// extension is absent — most rules treat that as "no_data" rather than
|
||
// an error.
|
||
func LoadExtension[T any](data *HTTPData, key string) (*T, bool, error) {
|
||
raw, ok := data.Extensions[key]
|
||
if !ok || len(raw) == 0 {
|
||
return nil, false, nil
|
||
}
|
||
var v T
|
||
if err := json.Unmarshal(raw, &v); err != nil {
|
||
return nil, true, fmt.Errorf("decode extension %q: %w", key, err)
|
||
}
|
||
return &v, true, nil
|
||
}
|
||
|
||
// buildTarget centralises option parsing and IP discovery so every
|
||
// Collector receives a fully resolved Target.
|
||
func buildTarget(ctx context.Context, opts sdk.CheckerOptions) (Target, error) {
|
||
server, err := resolveServer(opts)
|
||
if err != nil {
|
||
return Target{}, err
|
||
}
|
||
|
||
timeoutMs := sdk.GetIntOption(opts, OptionProbeTimeoutMs, DefaultProbeTimeoutMs)
|
||
if timeoutMs <= 0 {
|
||
timeoutMs = DefaultProbeTimeoutMs
|
||
}
|
||
maxRedirects := sdk.GetIntOption(opts, OptionMaxRedirects, DefaultMaxRedirects)
|
||
if maxRedirects < 0 {
|
||
maxRedirects = DefaultMaxRedirects
|
||
}
|
||
userAgent := DefaultUserAgent
|
||
if v, ok := sdk.GetOption[string](opts, OptionUserAgent); ok && v != "" {
|
||
userAgent = v
|
||
}
|
||
|
||
// Origin is the FQDN where the service is mounted: svc.Domain holds the
|
||
// subdomain (relative to apex; "@" for apex), and the domain_name
|
||
// autofill carries the zone apex.
|
||
apex := ""
|
||
if v, ok := sdk.GetOption[string](opts, OptionDomainName); ok {
|
||
apex = strings.TrimSuffix(v, ".")
|
||
}
|
||
subdomain := ""
|
||
if svc, ok := sdk.GetOption[happydns.ServiceMessage](opts, OptionService); ok {
|
||
subdomain = strings.TrimSuffix(svc.Domain, ".")
|
||
}
|
||
origin := sdk.JoinRelative(subdomain, apex)
|
||
host, ips := addressesFromServer(server, origin)
|
||
if host == "" {
|
||
host = origin
|
||
}
|
||
// abstract.Server only pins one A and one AAAA. Resolve the host to
|
||
// pick up any additional records the authoritative DNS exposes, so
|
||
// multi-IP deployments aren't silently under-probed. Failures are
|
||
// non-fatal; the pinned IPs remain.
|
||
seen := make(map[string]struct{}, len(ips)+4)
|
||
for _, ip := range ips {
|
||
seen[ip] = struct{}{}
|
||
}
|
||
ips = append(ips, discoverIPs(ctx, host, seen)...)
|
||
if len(ips) == 0 {
|
||
return Target{}, fmt.Errorf("abstract.Server has no A/AAAA records")
|
||
}
|
||
|
||
return Target{
|
||
Host: host,
|
||
IPs: ips,
|
||
Timeout: time.Duration(timeoutMs) * time.Millisecond,
|
||
MaxRedirects: maxRedirects,
|
||
UserAgent: userAgent,
|
||
}, nil
|
||
}
|
||
|
||
func runProbe(ctx context.Context, host, ip, scheme string, port uint16, timeout time.Duration, maxRedirects int, ua string, parseHTML bool) HTTPProbe {
|
||
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
|
||
probe := HTTPProbe{
|
||
Scheme: scheme,
|
||
Host: host,
|
||
IP: ip,
|
||
Port: port,
|
||
Address: addr,
|
||
IsIPv6: strings.Contains(ip, ":"),
|
||
}
|
||
|
||
dialer := &net.Dialer{Timeout: timeout}
|
||
// tcpConnected is set the moment a dial succeeds, so we can
|
||
// distinguish pure-TCP failures from later TLS/HTTP errors without
|
||
// resorting to error-string matching.
|
||
var tcpConnected atomic.Bool
|
||
// Force every dial to the chosen IP, regardless of what hostname is
|
||
// in the URL; that way we can attribute results to a specific A/AAAA
|
||
// record and bypass local resolver oddities.
|
||
transport := &http.Transport{
|
||
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||
conn, err := dialer.DialContext(ctx, network, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
tcpConnected.Store(true)
|
||
return conn, nil
|
||
},
|
||
TLSClientConfig: &tls.Config{
|
||
ServerName: host,
|
||
// Deep TLS posture is delegated to checker-tls. We still want
|
||
// HTTPS errors (expired cert, bad chain, ...) to surface as
|
||
// probe errors, so verification stays enabled.
|
||
},
|
||
TLSHandshakeTimeout: timeout,
|
||
ResponseHeaderTimeout: timeout,
|
||
DisableKeepAlives: true,
|
||
}
|
||
defer transport.CloseIdleConnections()
|
||
|
||
// Bound the whole probe (dial + TLS + headers + body across all
|
||
// redirect hops) by a single per-probe deadline derived from ctx, so
|
||
// a slow target can't pin a worker beyond the parent's lifetime and
|
||
// outer cancellation propagates to in-flight I/O.
|
||
probeBudget := timeout * time.Duration(maxRedirects+1)
|
||
probeCtx, cancel := context.WithTimeout(ctx, probeBudget)
|
||
defer cancel()
|
||
|
||
var redirectChain []RedirectStep
|
||
client := &http.Client{
|
||
Transport: transport,
|
||
// No client-level Timeout: probeCtx already bounds the request,
|
||
// and a separate http.Client.Timeout would race with it.
|
||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||
prev := via[len(via)-1]
|
||
// req.Response is the 3xx response that triggered this hop;
|
||
// it carries the redirecting status code (301/302/307/308…).
|
||
status := 0
|
||
if req.Response != nil {
|
||
status = req.Response.StatusCode
|
||
}
|
||
redirectChain = append(redirectChain, RedirectStep{
|
||
From: prev.URL.String(),
|
||
To: req.URL.String(),
|
||
Status: status,
|
||
})
|
||
// The transport's DialContext is pinned to the original
|
||
// (ip, port) and TLS ServerName is pinned to the original
|
||
// host. Following a redirect that changes host, scheme, or
|
||
// port would silently route the request to the wrong
|
||
// backend. Stop and return the 3xx so the caller can see
|
||
// the Location, but don't follow it on this probe.
|
||
if !strings.EqualFold(req.URL.Host, host) ||
|
||
!strings.EqualFold(req.URL.Scheme, scheme) {
|
||
return http.ErrUseLastResponse
|
||
}
|
||
if len(via) > maxRedirects {
|
||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||
}
|
||
return nil
|
||
},
|
||
}
|
||
|
||
target := &url.URL{Scheme: scheme, Host: host, Path: "/"}
|
||
req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, target.String(), nil)
|
||
if err != nil {
|
||
probe.Error = err.Error()
|
||
return probe
|
||
}
|
||
req.Header.Set("User-Agent", ua)
|
||
req.Header.Set("Accept", "text/html,application/xhtml+xml;q=0.9,*/*;q=0.5")
|
||
|
||
start := time.Now()
|
||
resp, err := client.Do(req)
|
||
probe.ElapsedMS = time.Since(start).Milliseconds()
|
||
if err != nil {
|
||
probe.Error = err.Error()
|
||
// The dialer wrapper sets tcpConnected the moment a TCP
|
||
// connection is established, so we can attribute the failure
|
||
// to a post-TCP layer (TLS, HTTP, redirect policy) without
|
||
// any error-string heuristics.
|
||
probe.TCPConnected = tcpConnected.Load()
|
||
probe.RedirectChain = redirectChain
|
||
return probe
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
probe.TCPConnected = true
|
||
probe.StatusCode = resp.StatusCode
|
||
if resp.Request != nil && resp.Request.URL != nil {
|
||
probe.FinalURL = resp.Request.URL.String()
|
||
}
|
||
|
||
// Per RFC 7230 §3.2.2, repeated headers (other than Set-Cookie) are
|
||
// semantically equivalent to a single header whose value is the
|
||
// comma-joined list; folding here preserves directives like a second
|
||
// CSP or HSTS header that would otherwise be dropped. Set-Cookie is
|
||
// excluded from the map since cookies are surfaced via resp.Cookies().
|
||
probe.Headers = make(map[string]string, len(resp.Header))
|
||
for k, v := range resp.Header {
|
||
if len(v) == 0 {
|
||
continue
|
||
}
|
||
lk := strings.ToLower(k)
|
||
if lk == "set-cookie" {
|
||
continue
|
||
}
|
||
probe.Headers[lk] = strings.Join(v, ", ")
|
||
}
|
||
|
||
// resp.Cookies() and resp.Header.Values("Set-Cookie") yield entries
|
||
// in the same order, so we can pair them positionally to recover the
|
||
// raw byte length of each Set-Cookie line for the size rule.
|
||
rawSetCookies := resp.Header.Values("Set-Cookie")
|
||
for i, c := range resp.Cookies() {
|
||
ci := CookieInfo{
|
||
Name: c.Name,
|
||
Domain: c.Domain,
|
||
Path: c.Path,
|
||
Secure: c.Secure,
|
||
HttpOnly: c.HttpOnly,
|
||
SameSite: sameSiteString(c.SameSite),
|
||
HasExpiry: !c.Expires.IsZero() || c.MaxAge > 0,
|
||
}
|
||
if i < len(rawSetCookies) {
|
||
ci.Size = len(rawSetCookies[i])
|
||
}
|
||
probe.Cookies = append(probe.Cookies, ci)
|
||
}
|
||
probe.RedirectChain = redirectChain
|
||
|
||
// Read one extra byte to detect whether we hit the cap. Anything
|
||
// beyond MaxBodyBytes is dropped, but the probe surfaces
|
||
// BodyTruncated so callers know SRI/HTML rules saw a partial view.
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, MaxBodyBytes+1))
|
||
if err == nil {
|
||
if len(body) > MaxBodyBytes {
|
||
body = body[:MaxBodyBytes]
|
||
probe.BodyTruncated = true
|
||
}
|
||
probe.HTMLBytes = len(body)
|
||
if parseHTML && isHTMLContent(probe.Headers["content-type"]) {
|
||
probe.Resources = extractResources(body, host)
|
||
}
|
||
}
|
||
|
||
return probe
|
||
}
|
||
|
||
func sameSiteString(s http.SameSite) string {
|
||
switch s {
|
||
case http.SameSiteLaxMode:
|
||
return "Lax"
|
||
case http.SameSiteStrictMode:
|
||
return "Strict"
|
||
case http.SameSiteNoneMode:
|
||
return "None"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
func isHTMLContent(ct string) bool {
|
||
ct = strings.ToLower(ct)
|
||
return strings.Contains(ct, "text/html") || strings.Contains(ct, "application/xhtml")
|
||
}
|
||
|
||
// extractResources walks the HTML body and collects <script src=...>,
|
||
// <link href=... rel="stylesheet"|"preload"...> and inline-eligible <img>
|
||
// references, with a flag for whether the resource is cross-origin
|
||
// (different host than the page); SRI is only meaningful in that case.
|
||
func extractResources(body []byte, pageHost string) []HTMLResource {
|
||
doc, err := html.Parse(bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
var out []HTMLResource
|
||
var walk func(*html.Node)
|
||
walk = func(n *html.Node) {
|
||
if n.Type == html.ElementNode {
|
||
switch n.Data {
|
||
case "script":
|
||
if src, ok := attr(n, "src"); ok && src != "" {
|
||
out = append(out, mkResource("script", src, n, pageHost))
|
||
}
|
||
case "link":
|
||
rel, _ := attr(n, "rel")
|
||
if href, ok := attr(n, "href"); ok && href != "" && relIsAsset(rel) {
|
||
r := mkResource("link", href, n, pageHost)
|
||
r.Rel = rel
|
||
out = append(out, r)
|
||
}
|
||
}
|
||
}
|
||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||
walk(c)
|
||
}
|
||
}
|
||
walk(doc)
|
||
return out
|
||
}
|
||
|
||
func relIsAsset(rel string) bool {
|
||
rel = strings.ToLower(rel)
|
||
return strings.Contains(rel, "stylesheet") || strings.Contains(rel, "preload") || strings.Contains(rel, "modulepreload")
|
||
}
|
||
|
||
func mkResource(tag, ref string, n *html.Node, pageHost string) HTMLResource {
|
||
r := HTMLResource{Tag: tag, URL: ref}
|
||
if integ, ok := attr(n, "integrity"); ok && integ != "" {
|
||
r.Integrity = integ
|
||
}
|
||
if u, err := url.Parse(ref); err == nil && u.Host != "" && !strings.EqualFold(u.Host, pageHost) {
|
||
r.CrossOrigin = true
|
||
}
|
||
return r
|
||
}
|
||
|
||
func attr(n *html.Node, key string) (string, bool) {
|
||
for _, a := range n.Attr {
|
||
if strings.EqualFold(a.Key, key) {
|
||
return a.Val, true
|
||
}
|
||
}
|
||
return "", false
|
||
}
|