middleware: skip JWT parsing for session ID tokens to suppress spurious log
When a Bearer token is a valid session ID (base32, 103 chars), the JWT middleware now silently hands off to the session store instead of logging a misleading "bad JWT claims" error. Also exports IsValidSessionID from the session store package and derives the session ID length from a constant tied to the key size in the usecase package, removing the hardcoded 103.
This commit is contained in:
parent
10775fe36c
commit
07d4c244d1
6 changed files with 245 additions and 7 deletions
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"git.happydns.org/happyDomain/internal/session"
|
||||
"git.happydns.org/happyDomain/model"
|
||||
)
|
||||
|
||||
|
|
@ -61,6 +62,12 @@ func JwtAuthMiddleware(authService happydns.AuthenticationUsecase, signingMethod
|
|||
return
|
||||
}
|
||||
|
||||
// Session IDs are handled by the session store; skip JWT parsing.
|
||||
if session.IsValidSessionID(token) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the token and retrieve claims
|
||||
claims := &UserClaims{}
|
||||
_, err := jwt.ParseWithClaims(token, claims,
|
||||
|
|
|
|||
130
internal/api/middleware/jwt_auth_middleware_test.go
Normal file
130
internal/api/middleware/jwt_auth_middleware_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
// This file is part of the happyDomain (R) project.
|
||||
// Copyright (c) 2020-2026 happyDomain
|
||||
// Authors: Pierre-Olivier Mercier, et al.
|
||||
//
|
||||
// This program is offered under a commercial and under the AGPL license.
|
||||
// For commercial licensing, contact us at <contact@happydomain.org>.
|
||||
//
|
||||
// For AGPL licensing:
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"git.happydns.org/happyDomain/internal/api/middleware"
|
||||
sessionUC "git.happydns.org/happyDomain/internal/usecase/session"
|
||||
"git.happydns.org/happyDomain/model"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// stubAuthUsecase is a no-op implementation of happydns.AuthenticationUsecase.
|
||||
// The middleware should never reach a method of this stub when the token is a
|
||||
// session ID, and should never reach it on a malformed JWT either (it returns
|
||||
// after logging). We still assert it was not called by leaving the methods
|
||||
// panicking — if any test trips one, we know the branching logic regressed.
|
||||
type stubAuthUsecase struct{}
|
||||
|
||||
func (stubAuthUsecase) AuthenticateUserWithPassword(_ happydns.LoginRequest) (*happydns.User, error) {
|
||||
panic("AuthenticateUserWithPassword should not be called in these tests")
|
||||
}
|
||||
|
||||
func (stubAuthUsecase) CompleteAuthentication(_ happydns.UserInfo) (*happydns.User, error) {
|
||||
panic("CompleteAuthentication should not be called in these tests")
|
||||
}
|
||||
|
||||
// captureLog redirects the default logger to a buffer for the duration of fn.
|
||||
func captureLog(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
prevOut := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetOutput(&buf)
|
||||
log.SetFlags(0)
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevOut)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
|
||||
fn()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func newRouter() *gin.Engine {
|
||||
r := gin.New()
|
||||
r.Use(middleware.JwtAuthMiddleware(stubAuthUsecase{}, "HS256", []byte("test-secret")))
|
||||
r.GET("/", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
return r
|
||||
}
|
||||
|
||||
func Test_JwtAuthMiddleware_SessionIDTokenIsSilent(t *testing.T) {
|
||||
r := newRouter()
|
||||
|
||||
output := captureLog(t, func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+sessionUC.NewSessionID())
|
||||
rr := httptest.NewRecorder()
|
||||
r.ServeHTTP(rr, req)
|
||||
})
|
||||
|
||||
if strings.Contains(output, "bad JWT claims") {
|
||||
t.Errorf("expected no %q log for a session-ID token, got:\n%s", "bad JWT claims", output)
|
||||
}
|
||||
if strings.TrimSpace(output) != "" {
|
||||
t.Errorf("expected no log output at all for a session-ID token, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_JwtAuthMiddleware_MalformedTokenStillLogs(t *testing.T) {
|
||||
r := newRouter()
|
||||
|
||||
output := captureLog(t, func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// Contains a dot, so it can't match the session-ID shape and will be
|
||||
// routed to the JWT parser, which will fail.
|
||||
req.Header.Set("Authorization", "Bearer not.a.jwt")
|
||||
rr := httptest.NewRecorder()
|
||||
r.ServeHTTP(rr, req)
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "bad JWT claims") {
|
||||
t.Errorf("expected %q log for a malformed JWT, got:\n%s", "bad JWT claims", output)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_JwtAuthMiddleware_NoAuthHeaderIsSilent(t *testing.T) {
|
||||
r := newRouter()
|
||||
|
||||
output := captureLog(t, func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
r.ServeHTTP(rr, req)
|
||||
})
|
||||
|
||||
if strings.TrimSpace(output) != "" {
|
||||
t.Errorf("expected no log output without an Authorization header, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
|
@ -75,11 +75,11 @@ func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, err
|
|||
|
||||
if _, ok := r.Header["Authorization"]; ok && len(r.Header["Authorization"]) > 0 {
|
||||
if flds := strings.Fields(r.Header["Authorization"][0]); len(flds) == 2 && flds[0] == "Bearer" {
|
||||
if isValidSessionID(flds[1]) {
|
||||
if IsValidSessionID(flds[1]) {
|
||||
session.ID = flds[1]
|
||||
}
|
||||
} else if user, _, ok := r.BasicAuth(); ok {
|
||||
if isValidSessionID(user) {
|
||||
if IsValidSessionID(user) {
|
||||
session.ID = user
|
||||
}
|
||||
}
|
||||
|
|
@ -242,10 +242,10 @@ func (s *SessionStore) save(session *sessions.Session, ua string) error {
|
|||
return s.storage.UpdateSession(mysession)
|
||||
}
|
||||
|
||||
// isValidSessionID returns true if s looks like a session ID generated by
|
||||
// NewSessionID: base32 standard alphabet ([A-Z2-7]), exactly 103 characters.
|
||||
func isValidSessionID(s string) bool {
|
||||
if len(s) != 103 {
|
||||
// IsValidSessionID returns true if s looks like a session ID generated by
|
||||
// NewSessionID: base32 standard alphabet ([A-Z2-7]), exactly SessionIDLen characters.
|
||||
func IsValidSessionID(s string) bool {
|
||||
if len(s) != sessionUC.SessionIDLen {
|
||||
return false
|
||||
}
|
||||
for _, c := range s {
|
||||
|
|
|
|||
70
internal/session/sessions_test.go
Normal file
70
internal/session/sessions_test.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
// This file is part of the happyDomain (R) project.
|
||||
// Copyright (c) 2020-2026 happyDomain
|
||||
// Authors: Pierre-Olivier Mercier, et al.
|
||||
//
|
||||
// This program is offered under a commercial and under the AGPL license.
|
||||
// For commercial licensing, contact us at <contact@happydomain.org>.
|
||||
//
|
||||
// For AGPL licensing:
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package session_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.happydns.org/happyDomain/internal/session"
|
||||
sessionUC "git.happydns.org/happyDomain/internal/usecase/session"
|
||||
)
|
||||
|
||||
func Test_IsValidSessionID_RoundTrip(t *testing.T) {
|
||||
// A freshly-generated session ID must always be considered valid.
|
||||
for range 32 {
|
||||
id := sessionUC.NewSessionID()
|
||||
if !session.IsValidSessionID(id) {
|
||||
t.Fatalf("NewSessionID() produced %q which IsValidSessionID rejected", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsValidSessionID_Rejects(t *testing.T) {
|
||||
valid := sessionUC.NewSessionID()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"one char short", valid[:len(valid)-1]},
|
||||
{"one char long", valid + "A"},
|
||||
{"all lowercase", strings.ToLower(valid)},
|
||||
{"with base32 padding", strings.Repeat("A", sessionUC.SessionIDLen-1) + "="},
|
||||
{"digit 0 (not in base32 alphabet)", strings.Repeat("A", sessionUC.SessionIDLen-1) + "0"},
|
||||
{"digit 1 (not in base32 alphabet)", strings.Repeat("A", sessionUC.SessionIDLen-1) + "1"},
|
||||
{"digit 8 (not in base32 alphabet)", strings.Repeat("A", sessionUC.SessionIDLen-1) + "8"},
|
||||
{"digit 9 (not in base32 alphabet)", strings.Repeat("A", sessionUC.SessionIDLen-1) + "9"},
|
||||
{"embedded space", strings.Repeat("A", sessionUC.SessionIDLen-1) + " "},
|
||||
{"non-ASCII", strings.Repeat("A", sessionUC.SessionIDLen-1) + "é"},
|
||||
{"looks like a JWT", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0In0.sig"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if session.IsValidSessionID(tc.in) {
|
||||
t.Errorf("IsValidSessionID(%q) = true, want false", tc.in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -196,8 +196,14 @@ func (s *Service) ByID(userID happydns.Identifier) error {
|
|||
return s.CloseUserSessions(&happydns.User{Id: userID})
|
||||
}
|
||||
|
||||
// sessionIDKeyLen is the number of random bytes used to generate a session ID.
|
||||
const sessionIDKeyLen = 64
|
||||
|
||||
// SessionIDLen is the length of a session ID string (base32, no padding).
|
||||
const SessionIDLen = (sessionIDKeyLen*8 + 4) / 5
|
||||
|
||||
// NewSessionID generates a random session identifier encoded
|
||||
// as a base32 string without padding characters.
|
||||
func NewSessionID() string {
|
||||
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(securecookie.GenerateRandomKey(64))
|
||||
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(securecookie.GenerateRandomKey(sessionIDKeyLen))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -533,6 +533,31 @@ func Test_NewSessionID(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test_SessionIDLen_matches_NewSessionID guards the invariant that
|
||||
// session.SessionIDLen is the actual length produced by NewSessionID. If the
|
||||
// encoding or key size ever changes without updating the constant, the
|
||||
// middleware's IsValidSessionID check would silently reject freshly generated
|
||||
// IDs and break Bearer-token auth.
|
||||
func Test_SessionIDLen_matches_NewSessionID(t *testing.T) {
|
||||
if got := len(session.NewSessionID()); got != session.SessionIDLen {
|
||||
t.Fatalf("SessionIDLen = %d, but NewSessionID() produced %d chars", session.SessionIDLen, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Test_SessionIDLen_MatchesNewSessionID guards against the derived
|
||||
// SessionIDLen constant drifting from the actual output of NewSessionID.
|
||||
// It runs several iterations because base32's length for a given input is
|
||||
// fully deterministic, but a future change to sessionIDKeyLen could leave
|
||||
// the formula off-by-one if the ceiling adjustment is wrong.
|
||||
func Test_SessionIDLen_MatchesNewSessionID(t *testing.T) {
|
||||
for range 32 {
|
||||
id := session.NewSessionID()
|
||||
if len(id) != session.SessionIDLen {
|
||||
t.Fatalf("NewSessionID() produced %d-char string, SessionIDLen is %d", len(id), session.SessionIDLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SessionExpiration(t *testing.T) {
|
||||
db, _ := inmemory.Instantiate()
|
||||
sessionService := session.NewService(db)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue