checker-http/checker/provider_test.go

244 lines
6.7 KiB
Go

// This file is part of the happyDomain (R) project.
// Copyright (c) 2020-2026 happyDomain
// Authors: Pierre-Olivier Mercier, et al.
package checker
import (
"context"
"encoding/json"
"net"
"sort"
"testing"
sdk "git.happydns.org/checker-sdk-go/checker"
happydns "git.happydns.org/happyDomain/model"
"git.happydns.org/happyDomain/services/abstract"
"github.com/miekg/dns"
)
func mkServer(t *testing.T, name string, ipv4, ipv6 string) *abstract.Server {
t.Helper()
s := &abstract.Server{}
if ipv4 != "" {
s.A = &dns.A{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300},
A: net.ParseIP(ipv4),
}
}
if ipv6 != "" {
s.AAAA = &dns.AAAA{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300},
AAAA: net.ParseIP(ipv6),
}
}
return s
}
func mkServiceMessage(t *testing.T, srv *abstract.Server) happydns.ServiceMessage {
t.Helper()
raw, err := json.Marshal(srv)
if err != nil {
t.Fatal(err)
}
return happydns.ServiceMessage{
ServiceMeta: happydns.ServiceMeta{Type: "abstract.Server"},
Service: raw,
}
}
func TestProvider_KeyAndDefinition(t *testing.T) {
p := Provider()
if p.Key() != ObservationKeyHTTP {
t.Errorf("Key() = %q, want %q", p.Key(), ObservationKeyHTTP)
}
dp, ok := p.(sdk.CheckerDefinitionProvider)
if !ok {
t.Fatal("provider does not implement CheckerDefinitionProvider")
}
def := dp.Definition()
if def == nil || def.ID != "http" {
t.Fatalf("unexpected definition: %+v", def)
}
if !def.Availability.ApplyToService {
t.Errorf("ApplyToService should be true")
}
if len(def.Availability.LimitToServices) != 1 || def.Availability.LimitToServices[0] != "abstract.Server" {
t.Errorf("LimitToServices: %+v", def.Availability.LimitToServices)
}
if len(def.Rules) == 0 {
t.Error("Rules slice empty")
}
if def.Interval == nil || def.Interval.Default <= 0 {
t.Error("Interval default not set")
}
// User options must include expected keys.
idx := map[string]bool{}
for _, o := range def.Options.UserOpts {
idx[o.Id] = true
}
for _, want := range []string{OptionProbeTimeoutMs, OptionMaxRedirects, OptionUserAgent, OptionRequireHTTPS, OptionRequireHSTS, OptionMinHSTSMaxAgeDays, OptionRequireCSP} {
if !idx[want] {
t.Errorf("UserOpts missing %q", want)
}
}
}
func TestResolveServer_Success(t *testing.T) {
srv := mkServer(t, "example.test.", "203.0.113.10", "")
opts := sdk.CheckerOptions{OptionService: mkServiceMessage(t, srv)}
got, err := resolveServer(opts)
if err != nil {
t.Fatalf("resolveServer: %v", err)
}
if got.A == nil || got.A.A.String() != "203.0.113.10" {
t.Errorf("unexpected server: %+v", got)
}
}
func TestResolveServer_MissingService(t *testing.T) {
if _, err := resolveServer(sdk.CheckerOptions{}); err == nil {
t.Fatal("expected error for missing service option")
}
}
func TestResolveServer_WrongType(t *testing.T) {
msg := happydns.ServiceMessage{
ServiceMeta: happydns.ServiceMeta{Type: "abstract.NotServer"},
Service: json.RawMessage(`{}`),
}
if _, err := resolveServer(sdk.CheckerOptions{OptionService: msg}); err == nil {
t.Fatal("expected error for wrong service type")
}
}
func TestResolveServer_BadJSON(t *testing.T) {
msg := happydns.ServiceMessage{
ServiceMeta: happydns.ServiceMeta{Type: "abstract.Server"},
Service: json.RawMessage(`{not json`),
}
if _, err := resolveServer(sdk.CheckerOptions{OptionService: msg}); err == nil {
t.Fatal("expected error for malformed service payload")
}
}
func TestDiscoverIPs_DedupesAndMerges(t *testing.T) {
// Stand up a loopback DNS server that returns multiple A and AAAA
// records, then point a custom resolver at it.
mux := dns.NewServeMux()
mux.HandleFunc("multi.test.", func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
switch r.Question[0].Qtype {
case dns.TypeA:
for _, ip := range []string{"203.0.113.10", "203.0.113.11"} {
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: "multi.test.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(ip),
})
}
case dns.TypeAAAA:
m.Answer = append(m.Answer, &dns.AAAA{
Hdr: dns.RR_Header{Name: "multi.test.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
AAAA: net.ParseIP("2001:db8::a"),
})
}
_ = w.WriteMsg(m)
})
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
srv := &dns.Server{PacketConn: pc, Handler: mux}
go func() { _ = srv.ActivateAndServe() }()
defer srv.Shutdown()
prev := resolver
resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", pc.LocalAddr().String())
},
}
defer func() { resolver = prev }()
seen := map[string]struct{}{"203.0.113.10": {}} // already pinned
got := discoverIPs(context.Background(), "multi.test", seen)
sort.Strings(got)
want := []string{"2001:db8::a", "203.0.113.11"}
if len(got) != len(want) {
t.Fatalf("got %v, want %v", got, want)
}
for i := range got {
if got[i] != want[i] {
t.Errorf("ip[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestDiscoverIPs_LookupFailureIsNonFatal(t *testing.T) {
prev := resolver
resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, net.ErrClosed
},
}
defer func() { resolver = prev }()
if got := discoverIPs(context.Background(), "nope.test", map[string]struct{}{}); got != nil {
t.Errorf("expected nil on resolver failure, got %v", got)
}
}
func TestAddressesFromServer(t *testing.T) {
cases := []struct {
name string
srv *abstract.Server
wantHost string
wantIPs []string
}{
{
name: "v4 only",
srv: mkServer(t, "example.test.", "203.0.113.1", ""),
wantHost: "example.test",
wantIPs: []string{"203.0.113.1"},
},
{
name: "v6 only",
srv: mkServer(t, "v6.example.test.", "", "2001:db8::1"),
wantHost: "v6.example.test",
wantIPs: []string{"2001:db8::1"},
},
{
name: "dual stack",
srv: mkServer(t, "dual.example.test.", "203.0.113.2", "2001:db8::2"),
wantHost: "dual.example.test",
wantIPs: []string{"203.0.113.2", "2001:db8::2"},
},
{
name: "empty",
srv: &abstract.Server{},
wantHost: "",
wantIPs: nil,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
host, ips := addressesFromServer(c.srv, "")
if host != c.wantHost {
t.Errorf("host = %q, want %q", host, c.wantHost)
}
if len(ips) != len(c.wantIPs) {
t.Fatalf("ips = %+v, want %+v", ips, c.wantIPs)
}
for i, ip := range ips {
if ip != c.wantIPs[i] {
t.Errorf("ip[%d] = %q, want %q", i, ip, c.wantIPs[i])
}
}
})
}
}