224 lines
5.9 KiB
Go
224 lines
5.9 KiB
Go
// SPDX-License-Identifier: GPL-2.0-only
|
|
|
|
package collect
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
checker "git.happydns.org/checker-dnsviz/checker"
|
|
sdk "git.happydns.org/checker-sdk-go/checker"
|
|
)
|
|
|
|
func TestIsValidDomainName(t *testing.T) {
|
|
good := []string{
|
|
"example.com",
|
|
"a",
|
|
"sub.domain.example.com",
|
|
"_dmarc.example.com",
|
|
"xn--bcher-kva.de",
|
|
"123.example.com",
|
|
}
|
|
bad := []string{
|
|
"",
|
|
"-bad.example",
|
|
".bad.example",
|
|
"foo bar.example",
|
|
"foo;rm -rf.example",
|
|
"foo$bar",
|
|
"héllo.example",
|
|
strings.Repeat("a", 254),
|
|
}
|
|
for _, s := range good {
|
|
if !isValidDomainName(s) {
|
|
t.Errorf("expected %q valid", s)
|
|
}
|
|
}
|
|
for _, s := range bad {
|
|
if isValidDomainName(s) {
|
|
t.Errorf("expected %q invalid", s)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTruncate(t *testing.T) {
|
|
if truncate("short", 100) != "short" {
|
|
t.Error("short string should pass through")
|
|
}
|
|
got := truncate("abcdef", 3)
|
|
if got != "abc…" {
|
|
t.Errorf("truncate=%q", got)
|
|
}
|
|
}
|
|
|
|
func TestCapLimit_DropsOverflow(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
c := &capLimit{B: &buf, max: 4}
|
|
n, err := c.Write([]byte("abcdef"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if n != 6 {
|
|
t.Errorf("Write should report full length to keep os/exec happy, got %d", n)
|
|
}
|
|
if buf.String() != "abcd" {
|
|
t.Errorf("buffer=%q want abcd", buf.String())
|
|
}
|
|
// A subsequent write while full is silently discarded.
|
|
n, _ = c.Write([]byte("g"))
|
|
if n != 1 || buf.String() != "abcd" {
|
|
t.Errorf("post-cap write: n=%d buf=%q", n, buf.String())
|
|
}
|
|
}
|
|
|
|
func TestCollect_MissingDomain(t *testing.T) {
|
|
c := &Collector{}
|
|
if _, err := c.Collect(context.Background(), sdk.CheckerOptions{}); err == nil {
|
|
t.Fatal("expected error for missing domain_name")
|
|
}
|
|
}
|
|
|
|
func TestCollect_RejectsInjection(t *testing.T) {
|
|
c := &Collector{}
|
|
_, err := c.Collect(context.Background(), sdk.CheckerOptions{"domain_name": "-A"})
|
|
if err == nil || !strings.Contains(err.Error(), "invalid 'domain_name'") {
|
|
t.Errorf("expected invalid domain error, got %v", err)
|
|
}
|
|
_, err = c.Collect(context.Background(), sdk.CheckerOptions{"domain_name": "foo;rm -rf /"})
|
|
if err == nil || !strings.Contains(err.Error(), "invalid 'domain_name'") {
|
|
t.Errorf("expected invalid domain error, got %v", err)
|
|
}
|
|
}
|
|
|
|
// fakeDNSVizScript writes a small POSIX shell that emulates `dnsviz probe`
|
|
// (always emits a fixed JSON) and `dnsviz grok` (emits a canned grok JSON,
|
|
// regardless of stdin), so Collect can run end-to-end without the real
|
|
// Python tool.
|
|
func fakeDNSVizScript(t *testing.T) string {
|
|
t.Helper()
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("POSIX shell needed for the fake dnsviz")
|
|
}
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "dnsviz")
|
|
body := `#!/bin/sh
|
|
case "$1" in
|
|
probe)
|
|
cat <<EOF
|
|
{"_meta":{"phase":"probe"}}
|
|
EOF
|
|
;;
|
|
grok)
|
|
cat <<EOF
|
|
{
|
|
"example.com.": {
|
|
"status": "NOERROR",
|
|
"delegation": {"status": "SECURE"}
|
|
},
|
|
"com.": {"delegation": {"status": "SECURE"}}
|
|
}
|
|
EOF
|
|
;;
|
|
failprobe)
|
|
echo "boom" 1>&2
|
|
exit 7
|
|
;;
|
|
esac
|
|
`
|
|
if err := os.WriteFile(path, []byte(body), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return path
|
|
}
|
|
|
|
func TestCollect_EndToEnd(t *testing.T) {
|
|
bin := fakeDNSVizScript(t)
|
|
c := &Collector{Bin: bin, ExtraArgs: ""}
|
|
out, err := c.Collect(context.Background(), sdk.CheckerOptions{
|
|
"domain_name": "example.com.",
|
|
"probeTimeoutSeconds": float64(5),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Collect: %v", err)
|
|
}
|
|
d, ok := out.(*checker.DNSVizData)
|
|
if !ok {
|
|
t.Fatalf("type=%T", out)
|
|
}
|
|
if d.Domain != "example.com" {
|
|
t.Errorf("domain not normalized: %q", d.Domain)
|
|
}
|
|
if len(d.Zones) != 2 {
|
|
t.Errorf("expected 2 zones from grok, got %d", len(d.Zones))
|
|
}
|
|
if d.Zones["example.com."].Status != "SECURE" {
|
|
t.Errorf("status=%q", d.Zones["example.com."].Status)
|
|
}
|
|
if len(d.Raw) == 0 {
|
|
t.Error("raw should be populated")
|
|
}
|
|
}
|
|
|
|
func TestCollect_ProbeFailure(t *testing.T) {
|
|
// A non-existent binary makes probe fail. The error path should bubble up
|
|
// and not be conflated with successful execution.
|
|
c := &Collector{Bin: "/nonexistent/dnsviz/binary"}
|
|
_, err := c.Collect(context.Background(), sdk.CheckerOptions{"domain_name": "example.com"})
|
|
if err == nil || !strings.Contains(err.Error(), "dnsviz probe failed") {
|
|
t.Errorf("expected probe-failed error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCollect_ContextCanceled(t *testing.T) {
|
|
bin := fakeDNSVizScript(t)
|
|
c := &Collector{Bin: bin}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
_, err := c.Collect(ctx, sdk.CheckerOptions{"domain_name": "example.com"})
|
|
if err == nil {
|
|
t.Fatal("expected error from cancelled context")
|
|
}
|
|
// Either probe or grok should report cancellation. We don't assert on
|
|
// the exact wording: just that it surfaced.
|
|
if !errors.Is(err, context.Canceled) && !strings.Contains(err.Error(), "dnsviz probe failed") &&
|
|
!strings.Contains(err.Error(), "dnsviz grok failed") {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCollect_TimeoutHonoured(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("POSIX shell needed")
|
|
}
|
|
dir := t.TempDir()
|
|
bin := filepath.Join(dir, "dnsviz")
|
|
// Sleep longer than the configured timeout; both probe and grok will
|
|
// stall, so the call should return a timeout-flavoured error.
|
|
// `exec sleep` so the shell process replaces itself with sleep, leaving
|
|
// a single PID for exec.CommandContext to SIGKILL on timeout (otherwise
|
|
// the orphaned sleep keeps the stdout pipe open and Wait blocks).
|
|
body := "#!/bin/sh\nexec sleep 5\n"
|
|
if err := os.WriteFile(bin, []byte(body), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
c := &Collector{Bin: bin}
|
|
start := time.Now()
|
|
_, err := c.Collect(context.Background(), sdk.CheckerOptions{
|
|
"domain_name": "example.com",
|
|
"probeTimeoutSeconds": float64(1),
|
|
})
|
|
elapsed := time.Since(start)
|
|
if err == nil {
|
|
t.Fatal("expected timeout error")
|
|
}
|
|
if elapsed > 4*time.Second {
|
|
t.Errorf("timeout not enforced: elapsed %v", elapsed)
|
|
}
|
|
}
|