From a48bc1f1bc88813847ec6cb3398253055459d9ac Mon Sep 17 00:00:00 2001 From: Pierre-Olivier Mercier Date: Fri, 11 Nov 2022 11:14:43 +0100 Subject: [PATCH] OIDC: Retrieve promotion from OIDC claims --- auth.go | 24 ++++++++++++++++++++---- auth_krb5.go | 2 +- auth_oidc.go | 27 +++++++++++++++++++-------- users.go | 6 +++--- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/auth.go b/auth.go index 8c8d1b7..a05ea0d 100644 --- a/auth.go +++ b/auth.go @@ -77,25 +77,41 @@ func logout(c *gin.Context) { c.JSON(http.StatusOK, true) } -func completeAuth(c *gin.Context, username string, email string, firstname string, lastname string, groups string, session *Session) (usr *User, err error) { +func completeAuth(c *gin.Context, username string, email string, firstname string, lastname string, promo uint, groups string, session *Session) (usr *User, err error) { if !userExists(username) { - if usr, err = NewUser(username, email, firstname, lastname, groups); err != nil { + if promo == 0 { + promo = currentPromo + } + if usr, err = NewUser(username, email, firstname, lastname, promo, groups); err != nil { return } } else if usr, err = getUserByLogin(username); err != nil { return } + upd_user := false + + // Update user's promo if it has changed + if promo != 0 && promo != usr.Promo { + usr.Promo = promo + upd_user = true + } + + // Update user's group if they have been modified if len(groups) > 0 { if len(groups) > 255 { groups = groups[:255] } if usr.Groups != groups { usr.Groups = groups - usr.Update() + upd_user = true } } + if upd_user { + usr.Update() + } + if session == nil { session, err = usr.NewSession() } else { @@ -137,7 +153,7 @@ func dummyAuth(c *gin.Context) { return } - if usr, err := completeAuth(c, lf["username"], lf["email"], lf["firstname"], lf["lastname"], "", nil); err != nil { + if usr, err := completeAuth(c, lf["username"], lf["email"], lf["firstname"], lf["lastname"], currentPromo, "", nil); err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"errmsg": err.Error()}) return } else { diff --git a/auth_krb5.go b/auth_krb5.go index f78e6d9..dda6374 100644 --- a/auth_krb5.go +++ b/auth_krb5.go @@ -83,7 +83,7 @@ func checkAuthKrb5(c *gin.Context) { return } - if usr, err := completeAuth(c, lf.Login, lf.Login+"@epita.fr", "", "", "", nil); err != nil { + if usr, err := completeAuth(c, lf.Login, lf.Login+"@epita.fr", "", "", currentPromo, "", nil); err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"errmsg": err.Error()}) return } else { diff --git a/auth_oidc.go b/auth_oidc.go index 973b878..6e56990 100644 --- a/auth_oidc.go +++ b/auth_oidc.go @@ -105,15 +105,17 @@ func OIDC_CRI_complete(c *gin.Context) { } var claims struct { - Firstname string `json:"given_name"` - Lastname string `json:"family_name"` - Nickname string `json:"nickname"` - Username string `json:"preferred_username"` - Email string `json:"email"` - Groups []map[string]interface{} `json:"groups"` + Firstname string `json:"given_name"` + Lastname string `json:"family_name"` + Username string `json:"preferred_username"` + Email string `json:"email"` + Groups []map[string]interface{} `json:"groups"` + Campuses []string `json:"campuses"` + GraduationYears []uint `json:"graduation_years"` } if err := idToken.Claims(&claims); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"errmsg": err.Error()}) + log.Println("Unable to extract claims to Claims:", err.Error()) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"errmsg": "Something goes wrong when analyzing your claims. Contact administrator to fix the issue."}) return } @@ -124,7 +126,16 @@ func OIDC_CRI_complete(c *gin.Context) { } } - if _, err := completeAuth(c, claims.Username, claims.Email, claims.Firstname, claims.Lastname, groups, session); err != nil { + var promo uint + if len(claims.GraduationYears) > 0 { + for _, gy := range claims.GraduationYears { + if gy > promo { + promo = gy + } + } + } + + if _, err := completeAuth(c, claims.Username, claims.Email, claims.Firstname, claims.Lastname, promo, groups, session); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"errmsg": err.Error()}) return } diff --git a/users.go b/users.go index b999398..5f31017 100644 --- a/users.go +++ b/users.go @@ -207,14 +207,14 @@ func userExists(login string) bool { return err == nil && z == 1 } -func NewUser(login string, email string, firstname string, lastname string, groups string) (*User, error) { +func NewUser(login string, email string, firstname string, lastname string, promo uint, groups string) (*User, error) { t := time.Now() - if res, err := DBExec("INSERT INTO users (login, email, firstname, lastname, time, promo, groups) VALUES (?, ?, ?, ?, ?, ?, ?)", login, email, firstname, lastname, t, currentPromo, groups); err != nil { + if res, err := DBExec("INSERT INTO users (login, email, firstname, lastname, time, promo, groups) VALUES (?, ?, ?, ?, ?, ?, ?)", login, email, firstname, lastname, t, promo, groups); err != nil { return nil, err } else if sid, err := res.LastInsertId(); err != nil { return nil, err } else { - return &User{sid, login, email, firstname, lastname, t, currentPromo, groups, false}, nil + return &User{sid, login, email, firstname, lastname, t, promo, groups, false}, nil } }