checker: harden HTTP client, cap response size, drop dead legacy rule
This commit is contained in:
parent
181c5961f1
commit
2710dfb459
9 changed files with 407 additions and 118 deletions
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -13,6 +14,37 @@ import (
|
|||
sdk "git.happydns.org/checker-sdk-go/checker"
|
||||
)
|
||||
|
||||
// maxResponseBytes caps the body size we'll read from the Zonemaster API.
|
||||
// Real result payloads are tens to a few hundred KB; 8 MiB is generous head-
|
||||
// room and still bounded so a misbehaving or hostile endpoint can't exhaust
|
||||
// memory.
|
||||
const maxResponseBytes = 8 << 20
|
||||
|
||||
// maxCollectDuration caps the total time spent collecting (start + poll +
|
||||
// fetch). The caller's context still wins if it has a tighter deadline.
|
||||
const maxCollectDuration = 15 * time.Minute
|
||||
|
||||
// pollInterval is how often we ask the Zonemaster API for test progress.
|
||||
const pollInterval = 2 * time.Second
|
||||
|
||||
// zmHTTPClient is the HTTP client used for all Zonemaster API calls. It has
|
||||
// per-phase timeouts so a stalling endpoint can never hang us indefinitely
|
||||
// even if the caller passes a context without a deadline.
|
||||
var zmHTTPClient = &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
|
||||
domainName, ok := opts["domainName"].(string)
|
||||
if !ok || domainName == "" {
|
||||
|
|
@ -36,6 +68,11 @@ func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOption
|
|||
profile = prof
|
||||
}
|
||||
|
||||
// Cap the total collection time even when the caller's context has no
|
||||
// deadline. The caller's deadline still wins if it's tighter.
|
||||
ctx, cancel := context.WithTimeout(ctx, maxCollectDuration)
|
||||
defer cancel()
|
||||
|
||||
// Step 1: start the test.
|
||||
startResult, err := zmCallJSONRPC(ctx, apiURL, "start_domain_test", zmStartTestParams{
|
||||
Domain: domainName,
|
||||
|
|
@ -56,9 +93,10 @@ func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOption
|
|||
}
|
||||
|
||||
// Step 2: poll for completion.
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
poll:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -75,12 +113,11 @@ func (p *zonemasterProvider) Collect(ctx context.Context, opts sdk.CheckerOption
|
|||
}
|
||||
|
||||
if progress >= 100 {
|
||||
goto testComplete
|
||||
break poll
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testComplete:
|
||||
// Step 3: fetch results.
|
||||
rawResults, err := zmCallJSONRPC(ctx, apiURL, "get_test_results", zmGetResultsParams{
|
||||
ID: testID,
|
||||
|
|
@ -117,19 +154,37 @@ func zmCallJSONRPC(ctx context.Context, apiURL, method string, params any) (json
|
|||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := zmHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Cap the body we'll ever read so a misbehaving endpoint can't exhaust
|
||||
// memory. +1 lets us detect that the cap was hit.
|
||||
limited := io.LimitReader(resp.Body, maxResponseBytes+1)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
b, readErr := io.ReadAll(limited)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("API returned status %d (failed to read body: %v)", resp.StatusCode, readErr)
|
||||
}
|
||||
if len(b) > maxResponseBytes {
|
||||
b = b[:maxResponseBytes]
|
||||
}
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
|
||||
body, readErr := io.ReadAll(limited)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if len(body) > maxResponseBytes {
|
||||
return nil, fmt.Errorf("API response exceeds %d bytes", maxResponseBytes)
|
||||
}
|
||||
|
||||
var rpcResp zmJSONRPCResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil {
|
||||
if err := json.Unmarshal(body, &rpcResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue