// This file is part of checker-dnsviz. // // checker-dnsviz is free software: you can redistribute it and/or modify it // under the terms of the GNU General Public License as published by the Free // Software Foundation, version 2. // // checker-dnsviz is distributed in the hope that it will be useful, but // WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY // or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for // more details. // // You should have received a copy of the GNU General Public License along with // checker-dnsviz. If not, see . // // SPDX-License-Identifier: GPL-2.0-only // Package collect contains the DNSViz subprocess invocation. It is kept // separate from the checker package so that the checker package (pure analysis // logic) can be imported under MIT terms without pulling in GPL-covered code. package collect import ( "bytes" "context" "fmt" "os/exec" "strings" "time" checker "git.happydns.org/checker-dnsviz/checker" sdk "git.happydns.org/checker-sdk-go/checker" ) const ( defaultProbeTimeout = 120 * time.Second maxDNSVizOutputBytes = 16 << 20 // 16 MiB ) // Collector holds the runtime configuration for DNSViz invocations. type Collector struct { // Bin is the path to the dnsviz CLI. Defaults to "dnsviz". Bin string // ExtraArgs is a whitespace-separated list of extra arguments appended to // `dnsviz probe`. Defaults to "-A". ExtraArgs string } // Collect runs `dnsviz probe | dnsviz grok` against the domain named in opts // and returns the structured analysis as a *checker.DNSVizData. func (c *Collector) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) { domain, _ := sdk.GetOption[string](opts, "domain_name") domain = strings.TrimSpace(strings.TrimSuffix(domain, ".")) if domain == "" { return nil, fmt.Errorf("missing 'domain_name' option") } if !isValidDomainName(domain) { return nil, fmt.Errorf("invalid 'domain_name' option") } timeout := defaultProbeTimeout if n := sdk.GetIntOption(opts, "probeTimeoutSeconds", 0); n > 0 { timeout = time.Duration(n) * time.Second } bin := strings.TrimSpace(c.Bin) if bin == "" { bin = "dnsviz" } extraArgs := c.ExtraArgs if extraArgs == "" { extraArgs = "-A" } probeCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() probeOut, probeErr, err := runProbe(probeCtx, bin, domain, extraArgs) if err != nil { return nil, fmt.Errorf("dnsviz probe failed: %w (stderr: %s)", err, truncate(probeErr, 4096)) } grokCtx, cancelGrok := context.WithTimeout(ctx, timeout) defer cancelGrok() grokOut, grokErr, err := runGrok(grokCtx, bin, probeOut) if err != nil { return nil, fmt.Errorf("dnsviz grok failed: %w (stderr: %s)", err, truncate(grokErr, 4096)) } zones, order, err := checker.ParseGrokOutput(grokOut) if err != nil { return nil, fmt.Errorf("decoding dnsviz grok output: %w", err) } return &checker.DNSVizData{ Domain: domain, Zones: zones, Order: order, Raw: grokOut, ProbeStderr: probeErr, GrokStderr: grokErr, }, nil } func runProbe(ctx context.Context, bin, domain, extraArgs string) ([]byte, string, error) { args := []string{"probe"} args = append(args, strings.Fields(extraArgs)...) args = append(args, "--", domain) cmd := exec.CommandContext(ctx, bin, args...) var stdout, stderr bytes.Buffer cmd.Stdout = &capLimit{B: &stdout, max: maxDNSVizOutputBytes} cmd.Stderr = &capLimit{B: &stderr, max: maxDNSVizOutputBytes} if err := cmd.Run(); err != nil { return stdout.Bytes(), stderr.String(), err } return stdout.Bytes(), stderr.String(), nil } func runGrok(ctx context.Context, bin string, probeJSON []byte) ([]byte, string, error) { cmd := exec.CommandContext(ctx, bin, "grok") cmd.Stdin = bytes.NewReader(probeJSON) var stdout, stderr bytes.Buffer cmd.Stdout = &capLimit{B: &stdout, max: maxDNSVizOutputBytes} cmd.Stderr = &capLimit{B: &stderr, max: maxDNSVizOutputBytes} if err := cmd.Run(); err != nil { return stdout.Bytes(), stderr.String(), err } return stdout.Bytes(), stderr.String(), nil } func isValidDomainName(s string) bool { if s == "" || len(s) > 253 || s[0] == '-' || s[0] == '.' { return false } for i := 0; i < len(s); i++ { c := s[i] switch { case c >= 'a' && c <= 'z': case c >= 'A' && c <= 'Z': case c >= '0' && c <= '9': case c == '-' || c == '.' || c == '_': default: return false } } return true } func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] + "…" } type capLimit struct { B *bytes.Buffer max int } func (c *capLimit) Write(p []byte) (int, error) { remaining := c.max - c.B.Len() if remaining <= 0 { return len(p), nil } if len(p) > remaining { c.B.Write(p[:remaining]) return len(p), nil } return c.B.Write(p) }