From 4baa665693f798e1c4a4a6a82cfb0f61326871fd Mon Sep 17 00:00:00 2001 From: Pierre-Olivier Mercier Date: Sun, 4 Sep 2022 11:12:37 +0200 Subject: [PATCH] sessions: Can store values --- api.go | 9 ++++++++- db.go | 1 + session.go | 54 +++++++++++++++++++++++++++++++++++++++++------------- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/api.go b/api.go index 8cda6fc..a5c700e 100644 --- a/api.go +++ b/api.go @@ -69,12 +69,13 @@ func adminRestricted(u *User, c *gin.Context) bool { func authMiddleware(access ...func(*User, *gin.Context) bool) gin.HandlerFunc { return func(c *gin.Context) { var user *User = nil + var session *Session = nil if cookie, err := c.Request.Cookie("auth"); err == nil { if sessionid, err := base64.StdEncoding.DecodeString(cookie.Value); err != nil { eraseCookie(c) c.AbortWithStatusJSON(http.StatusNotAcceptable, gin.H{"errmsg": err.Error()}) return - } else if session, err := getSession(sessionid); err != nil { + } else if session, err = getSession(sessionid); err != nil { eraseCookie(c) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"errmsg": err.Error()}) return @@ -97,9 +98,15 @@ func authMiddleware(access ...func(*User, *gin.Context) bool) gin.HandlerFunc { } // Retrieve corresponding user + c.Set("Session", session) c.Set("LoggedUser", user) // We are now ready to continue c.Next() + + // On return, check if the session has changed + if session != nil && session.HasChanged() { + session.Update() + } } } diff --git a/db.go b/db.go index ca31938..1352077 100644 --- a/db.go +++ b/db.go @@ -72,6 +72,7 @@ CREATE TABLE IF NOT EXISTS user_sessions( id_session BLOB(255) NOT NULL, id_user INTEGER, time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + val TEXT NOT NULL DEFAULT '{}', FOREIGN KEY(id_user) REFERENCES users(id_user) ) DEFAULT CHARACTER SET = utf8 COLLATE = utf8_bin; `); err != nil { diff --git a/session.go b/session.go index 80bdd6d..a828e35 100644 --- a/session.go +++ b/session.go @@ -2,18 +2,24 @@ package main import ( "crypto/rand" + "encoding/json" "time" ) type Session struct { - Id []byte `json:"id"` - IdUser *int64 `json:"login"` - Time time.Time `json:"time"` + Id []byte `json:"id"` + IdUser *int64 `json:"login"` + Time time.Time `json:"time"` + changed bool + Values map[string]interface{} `json:"values"` } func getSession(id []byte) (s *Session, err error) { s = new(Session) - err = DBQueryRow("SELECT id_session, id_user, time FROM user_sessions WHERE id_session=?", id).Scan(&s.Id, &s.IdUser, &s.Time) + var val string + err = DBQueryRow("SELECT id_session, id_user, time, val FROM user_sessions WHERE id_session=?", id).Scan(&s.Id, &s.IdUser, &s.Time, &val) + + err = json.Unmarshal([]byte(val), &s.Values) return } @@ -21,32 +27,54 @@ func NewSession() (*Session, error) { session_id := make([]byte, 255) if _, err := rand.Read(session_id); err != nil { return nil, err - } else if _, err := DBExec("INSERT INTO user_sessions (id_session, time) VALUES (?, ?)", session_id, time.Now()); err != nil { + } else if _, err := DBExec("INSERT INTO user_sessions (id_session, time, val) VALUES (?, ?, '{}')", session_id, time.Now()); err != nil { return nil, err } else { - return &Session{session_id, nil, time.Now()}, nil + return &Session{session_id, nil, time.Now(), false, map[string]interface{}{}}, nil } } -func (user User) NewSession() (*Session, error) { +func (user *User) NewSession() (*Session, error) { session_id := make([]byte, 255) if _, err := rand.Read(session_id); err != nil { return nil, err - } else if _, err := DBExec("INSERT INTO user_sessions (id_session, id_user, time) VALUES (?, ?, ?)", session_id, user.Id, time.Now()); err != nil { + } else if _, err := DBExec("INSERT INTO user_sessions (id_session, id_user, time, val) VALUES (?, ?, ?, '{}')", session_id, user.Id, time.Now()); err != nil { return nil, err } else { - return &Session{session_id, &user.Id, time.Now()}, nil + return &Session{session_id, &user.Id, time.Now(), false, map[string]interface{}{}}, nil } } -func (s Session) SetUser(user *User) (*Session, error) { +func (s *Session) SetUser(user *User) (*Session, error) { s.IdUser = &user.Id _, err := s.Update() - return &s, err + s.changed = false + return s, err } -func (s Session) Update() (int64, error) { - if res, err := DBExec("UPDATE user_sessions SET id_user = ?, time = ? WHERE id_session = ?", s.IdUser, s.Time, s.Id); err != nil { +func (s *Session) GetKey(key string) (v interface{}, ok bool) { + v, ok = s.Values[key] + return +} + +func (s *Session) DeleteKey(key string) { + delete(s.Values, key) + s.changed = true +} + +func (s *Session) SetKey(key string, val interface{}) { + s.Values[key] = val + s.changed = true +} + +func (s *Session) HasChanged() bool { + return s.changed +} + +func (s *Session) Update() (int64, error) { + if val, err := json.Marshal(s.Values); err != nil { + return 0, err + } else if res, err := DBExec("UPDATE user_sessions SET id_user = ?, time = ?, val = ? WHERE id_session = ?", s.IdUser, s.Time, string(val), s.Id); err != nil { return 0, err } else if nb, err := res.RowsAffected(); err != nil { return 0, err