332 lines
9.7 KiB
Go
332 lines
9.7 KiB
Go
package checker
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
asn1 "github.com/jcmturner/gofork/encoding/asn1"
|
|
"github.com/jcmturner/gokrb5/v8/iana/errorcode"
|
|
"github.com/jcmturner/gokrb5/v8/iana/etypeID"
|
|
"github.com/jcmturner/gokrb5/v8/iana/nametype"
|
|
"github.com/jcmturner/gokrb5/v8/iana/patype"
|
|
"github.com/jcmturner/gokrb5/v8/messages"
|
|
"github.com/jcmturner/gokrb5/v8/types"
|
|
)
|
|
|
|
// buildKRBError constructs a marshaled KRB-ERROR with the given code and
|
|
// optional EData payload.
|
|
func buildKRBError(t *testing.T, realm string, code int32, edata []byte) []byte {
|
|
t.Helper()
|
|
sname := types.NewPrincipalName(nametype.KRB_NT_SRV_INST, "krbtgt/"+realm)
|
|
k := messages.NewKRBError(sname, realm, code, "")
|
|
k.STime = time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)
|
|
k.Susec = 0
|
|
k.EData = edata
|
|
raw, err := k.Marshal()
|
|
if err != nil {
|
|
t.Fatalf("marshal KRBError: %v", err)
|
|
}
|
|
return raw
|
|
}
|
|
|
|
// buildETypeInfo2EData marshals a PADataSequence containing one
|
|
// PA_ETYPE_INFO2 entry per supplied (etype, salt) pair.
|
|
func buildETypeInfo2EData(t *testing.T, entries []types.ETypeInfo2Entry) []byte {
|
|
t.Helper()
|
|
value, err := asn1.Marshal(types.ETypeInfo2(entries))
|
|
if err != nil {
|
|
t.Fatalf("marshal ETypeInfo2: %v", err)
|
|
}
|
|
pas := types.PADataSequence{
|
|
{PADataType: patype.PA_ETYPE_INFO2, PADataValue: value},
|
|
{PADataType: patype.PA_PK_AS_REQ, PADataValue: []byte{0x00}},
|
|
}
|
|
raw, err := asn1.Marshal(pas)
|
|
if err != nil {
|
|
t.Fatalf("marshal PADataSequence: %v", err)
|
|
}
|
|
return raw
|
|
}
|
|
|
|
func TestParseASResponse_KRBErrorPreauthRequiredWithEData(t *testing.T) {
|
|
edata := buildETypeInfo2EData(t, []types.ETypeInfo2Entry{
|
|
{EType: etypeID.AES256_CTS_HMAC_SHA1_96, Salt: "EXAMPLE.COMuser"},
|
|
{EType: etypeID.RC4_HMAC, Salt: ""},
|
|
})
|
|
raw := buildKRBError(t, "EXAMPLE.COM", errorcode.KDC_ERR_PREAUTH_REQUIRED, edata)
|
|
|
|
var out ASProbeResult
|
|
parseASResponse(raw, &out)
|
|
|
|
if out.Error != "" {
|
|
t.Fatalf("unexpected parse error: %s", out.Error)
|
|
}
|
|
if out.ErrorCode != errorcode.KDC_ERR_PREAUTH_REQUIRED {
|
|
t.Errorf("ErrorCode = %d, want KDC_ERR_PREAUTH_REQUIRED", out.ErrorCode)
|
|
}
|
|
if !out.PreauthReq {
|
|
t.Error("PreauthReq should be true for KDC_ERR_PREAUTH_REQUIRED")
|
|
}
|
|
if out.ServerRealm != "EXAMPLE.COM" {
|
|
t.Errorf("ServerRealm = %q, want EXAMPLE.COM", out.ServerRealm)
|
|
}
|
|
if out.ServerTime.IsZero() {
|
|
t.Error("ServerTime should be populated from STime")
|
|
}
|
|
if !out.PKINITOffered {
|
|
t.Error("PKINITOffered should be true (PA_PK_AS_REQ present)")
|
|
}
|
|
if len(out.Enctypes) != 2 {
|
|
t.Fatalf("Enctypes len = %d, want 2", len(out.Enctypes))
|
|
}
|
|
|
|
var sawAES, sawRC4 bool
|
|
for _, e := range out.Enctypes {
|
|
switch e.ID {
|
|
case etypeID.AES256_CTS_HMAC_SHA1_96:
|
|
sawAES = true
|
|
if e.Weak {
|
|
t.Error("AES256 should not be flagged weak")
|
|
}
|
|
if e.Source != "etype-info2" {
|
|
t.Errorf("AES256 Source = %q, want etype-info2", e.Source)
|
|
}
|
|
case etypeID.RC4_HMAC:
|
|
sawRC4 = true
|
|
if !e.Weak {
|
|
t.Error("RC4_HMAC should be flagged weak")
|
|
}
|
|
}
|
|
}
|
|
if !sawAES || !sawRC4 {
|
|
t.Errorf("missing expected enctypes (sawAES=%v sawRC4=%v)", sawAES, sawRC4)
|
|
}
|
|
}
|
|
|
|
func TestParseASResponse_KRBErrorPrincipalUnknownNoEData(t *testing.T) {
|
|
raw := buildKRBError(t, "EXAMPLE.COM", errorcode.KDC_ERR_C_PRINCIPAL_UNKNOWN, nil)
|
|
|
|
var out ASProbeResult
|
|
parseASResponse(raw, &out)
|
|
|
|
if out.Error != "" {
|
|
t.Fatalf("unexpected parse error: %s", out.Error)
|
|
}
|
|
if out.ErrorCode != errorcode.KDC_ERR_C_PRINCIPAL_UNKNOWN {
|
|
t.Errorf("ErrorCode = %d, want KDC_ERR_C_PRINCIPAL_UNKNOWN", out.ErrorCode)
|
|
}
|
|
if out.PreauthReq {
|
|
t.Error("PreauthReq should be false")
|
|
}
|
|
if len(out.Enctypes) != 0 {
|
|
t.Errorf("Enctypes should be empty, got %d", len(out.Enctypes))
|
|
}
|
|
}
|
|
|
|
func TestParseASResponse_GarbageBytes(t *testing.T) {
|
|
var out ASProbeResult
|
|
parseASResponse([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, &out)
|
|
if out.Error == "" {
|
|
t.Fatal("expected an Error string for unparsable bytes")
|
|
}
|
|
if !strings.Contains(out.Error, "deadbeefcafe") {
|
|
t.Errorf("Error should include hex prefix of payload, got %q", out.Error)
|
|
}
|
|
}
|
|
|
|
func TestExtractEData_ETypeInfoFallback(t *testing.T) {
|
|
// PA_ETYPE_INFO (legacy) only. Salt is octet-string here.
|
|
value, err := asn1.Marshal(types.ETypeInfo{
|
|
{EType: etypeID.AES128_CTS_HMAC_SHA1_96, Salt: []byte("salty")},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal ETypeInfo: %v", err)
|
|
}
|
|
edata, err := asn1.Marshal(types.PADataSequence{
|
|
{PADataType: patype.PA_ETYPE_INFO, PADataValue: value},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("marshal PADataSequence: %v", err)
|
|
}
|
|
|
|
enctypes, pkinit := extractEData(edata)
|
|
if pkinit {
|
|
t.Error("PKINIT should not be reported when no PA_PK_AS_REQ is present")
|
|
}
|
|
if len(enctypes) != 1 {
|
|
t.Fatalf("got %d enctypes, want 1", len(enctypes))
|
|
}
|
|
if enctypes[0].Source != "etype-info" {
|
|
t.Errorf("Source = %q, want etype-info", enctypes[0].Source)
|
|
}
|
|
if enctypes[0].Salt != "salty" {
|
|
t.Errorf("Salt = %q, want salty", enctypes[0].Salt)
|
|
}
|
|
}
|
|
|
|
func TestExtractEData_ETypeInfo2WinsOverInfo(t *testing.T) {
|
|
// Both PA_ETYPE_INFO2 and PA_ETYPE_INFO advertise the same enctype.
|
|
// The legacy info should be skipped (de-duplicated).
|
|
v2, _ := asn1.Marshal(types.ETypeInfo2{
|
|
{EType: etypeID.AES256_CTS_HMAC_SHA1_96, Salt: "fromInfo2"},
|
|
})
|
|
v1, _ := asn1.Marshal(types.ETypeInfo{
|
|
{EType: etypeID.AES256_CTS_HMAC_SHA1_96, Salt: []byte("fromInfo")},
|
|
})
|
|
edata, _ := asn1.Marshal(types.PADataSequence{
|
|
{PADataType: patype.PA_ETYPE_INFO2, PADataValue: v2},
|
|
{PADataType: patype.PA_ETYPE_INFO, PADataValue: v1},
|
|
})
|
|
got, _ := extractEData(edata)
|
|
if len(got) != 1 {
|
|
t.Fatalf("got %d entries, want 1 (de-duplicated)", len(got))
|
|
}
|
|
if got[0].Salt != "fromInfo2" {
|
|
t.Errorf("salt = %q, want fromInfo2 (etype-info2 must take precedence)", got[0].Salt)
|
|
}
|
|
}
|
|
|
|
func TestExtractEData_BadASN1(t *testing.T) {
|
|
enctypes, pkinit := extractEData([]byte{0xff, 0x00})
|
|
if enctypes != nil || pkinit {
|
|
t.Errorf("expected nil/false on garbage, got %v / %v", enctypes, pkinit)
|
|
}
|
|
}
|
|
|
|
func TestEtypeName(t *testing.T) {
|
|
if got := etypeName(etypeID.AES256_CTS_HMAC_SHA1_96); !strings.Contains(strings.ToLower(got), "aes256") {
|
|
t.Errorf("AES256 name = %q, want it to mention aes256", got)
|
|
}
|
|
if got := etypeName(99999); got != "etype-99999" {
|
|
t.Errorf("unknown etype = %q, want etype-99999", got)
|
|
}
|
|
}
|
|
|
|
func TestErrorcodeNameAndKRBErrorInfo(t *testing.T) {
|
|
name := errorcodeName(errorcode.KDC_ERR_PREAUTH_REQUIRED)
|
|
if !strings.Contains(name, "PREAUTH") {
|
|
t.Errorf("errorcodeName = %q, want it to contain PREAUTH", name)
|
|
}
|
|
|
|
// Typed KRBError: errors.As path.
|
|
sname := types.NewPrincipalName(nametype.KRB_NT_SRV_INST, "krbtgt/EXAMPLE.COM")
|
|
krb := messages.NewKRBError(sname, "EXAMPLE.COM", errorcode.KDC_ERR_PREAUTH_REQUIRED, "")
|
|
code, n, ok := krbErrorInfo(krb)
|
|
if !ok || code != errorcode.KDC_ERR_PREAUTH_REQUIRED || !strings.Contains(n, "PREAUTH") {
|
|
t.Errorf("krbErrorInfo(typed) = %d %q %v", code, n, ok)
|
|
}
|
|
|
|
// String fallback: gokrb5 sometimes wraps the code only inside the message.
|
|
wrapped := errors.New("login failed: KRB Error: (24) KDC_ERR_PREAUTH_FAILED - bla")
|
|
code, n, ok = krbErrorInfo(wrapped)
|
|
if !ok || code != 24 {
|
|
t.Errorf("krbErrorInfo(string) code=%d ok=%v", code, ok)
|
|
}
|
|
if !strings.Contains(n, "PREAUTH_FAILED") {
|
|
t.Errorf("krbErrorInfo(string) name = %q", n)
|
|
}
|
|
|
|
if _, _, ok := krbErrorInfo(nil); ok {
|
|
t.Error("krbErrorInfo(nil) should return ok=false")
|
|
}
|
|
if _, _, ok := krbErrorInfo(errors.New("plain old error")); ok {
|
|
t.Error("krbErrorInfo on a non-KRB error should return ok=false")
|
|
}
|
|
}
|
|
|
|
func TestRoleForPrefix(t *testing.T) {
|
|
cases := map[string]string{
|
|
"_kerberos._tcp.": "kdc",
|
|
"_kerberos._udp.": "kdc",
|
|
"_kerberos-master._tcp.": "master",
|
|
"_kerberos-adm._tcp.": "kadmin",
|
|
"_kpasswd._tcp.": "kpasswd",
|
|
"_kpasswd._udp.": "kpasswd",
|
|
}
|
|
for in, want := range cases {
|
|
if got := roleForPrefix(in); got != want {
|
|
t.Errorf("roleForPrefix(%q) = %q, want %q", in, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestOptFloat(t *testing.T) {
|
|
cases := []struct {
|
|
in any
|
|
want float64
|
|
}{
|
|
{float64(2.5), 2.5},
|
|
{float32(1.5), 1.5},
|
|
{int(7), 7},
|
|
{int64(8), 8},
|
|
{"3.14", 3.14},
|
|
{"nope", 42}, // falls back to default
|
|
{nil, 42}, // missing key path is exercised below
|
|
}
|
|
for _, c := range cases {
|
|
opts := map[string]any{"k": c.in}
|
|
got := optFloat(opts, "k", 42)
|
|
if got != c.want {
|
|
t.Errorf("optFloat(%v) = %v, want %v", c.in, got, c.want)
|
|
}
|
|
}
|
|
if got := optFloat(map[string]any{}, "missing", 99); got != 99 {
|
|
t.Errorf("optFloat(missing) = %v, want 99", got)
|
|
}
|
|
}
|
|
|
|
func TestOptBool(t *testing.T) {
|
|
cases := []struct {
|
|
in any
|
|
def bool
|
|
want bool
|
|
}{
|
|
{true, false, true},
|
|
{false, true, false},
|
|
{"true", false, true},
|
|
{"1", false, true},
|
|
{"false", true, false}, // unrecognized string -> default
|
|
{nil, true, true},
|
|
{42, false, false}, // unsupported type -> default
|
|
}
|
|
for _, c := range cases {
|
|
opts := map[string]any{}
|
|
if c.in != nil {
|
|
opts["k"] = c.in
|
|
}
|
|
got := optBool(opts, "k", c.def)
|
|
if got != c.want {
|
|
t.Errorf("optBool(%v, def=%v) = %v, want %v", c.in, c.def, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSmallHelpers(t *testing.T) {
|
|
if got := abs(-3 * time.Second); got != 3*time.Second {
|
|
t.Errorf("abs negative = %v", got)
|
|
}
|
|
if got := abs(2 * time.Second); got != 2*time.Second {
|
|
t.Errorf("abs positive = %v", got)
|
|
}
|
|
if got := firstNonEmpty("", "", "x", "y"); got != "x" {
|
|
t.Errorf("firstNonEmpty = %q", got)
|
|
}
|
|
if got := firstNonEmpty("", ""); got != "" {
|
|
t.Errorf("firstNonEmpty(all empty) = %q", got)
|
|
}
|
|
if got := first([]byte{1, 2, 3}, 16); len(got) != 3 {
|
|
t.Errorf("first(short) len = %d, want 3", len(got))
|
|
}
|
|
if got := first([]byte{1, 2, 3, 4, 5}, 2); len(got) != 2 || got[0] != 1 || got[1] != 2 {
|
|
t.Errorf("first(long) = %v", got)
|
|
}
|
|
list := []EnctypeEntry{{ID: 18}, {ID: 17}}
|
|
if !hasEnctype(list, 17) {
|
|
t.Error("hasEnctype should find 17")
|
|
}
|
|
if hasEnctype(list, 23) {
|
|
t.Error("hasEnctype should not find 23")
|
|
}
|
|
}
|