diff --git a/auth.go b/auth.go index 8c8d1b7..8e6adb5 100644 --- a/auth.go +++ b/auth.go @@ -77,31 +77,51 @@ func logout(c *gin.Context) { c.JSON(http.StatusOK, true) } -func completeAuth(c *gin.Context, username string, email string, firstname string, lastname string, groups string, session *Session) (usr *User, err error) { +func completeAuth(c *gin.Context, username string, email string, firstname string, lastname string, promo uint, groups string, face_url string, session *Session) (usr *User, err error) { if !userExists(username) { - if usr, err = NewUser(username, email, firstname, lastname, groups); err != nil { + if promo == 0 { + promo = currentPromo + } + if usr, err = NewUser(username, email, firstname, lastname, promo, groups); err != nil { return } } else if usr, err = getUserByLogin(username); err != nil { return } + upd_user := false + + // Update user's promo if it has changed + if promo != 0 && promo != usr.Promo { + usr.Promo = promo + upd_user = true + } + + // Update user's group if they have been modified if len(groups) > 0 { if len(groups) > 255 { groups = groups[:255] } if usr.Groups != groups { usr.Groups = groups - usr.Update() + upd_user = true } } + if upd_user { + usr.Update() + } + if session == nil { session, err = usr.NewSession() - } else { - _, err = session.SetUser(usr) + if err != nil { + return + } } - + if face_url != "" { + session.SetKey("picture", face_url) + } + _, err = session.SetUser(usr) if err != nil { return } @@ -137,7 +157,7 @@ func dummyAuth(c *gin.Context) { return } - if usr, err := completeAuth(c, lf["username"], lf["email"], lf["firstname"], lf["lastname"], "", nil); err != nil { + if usr, err := completeAuth(c, lf["username"], lf["email"], lf["firstname"], lf["lastname"], currentPromo, "", "", nil); err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"errmsg": err.Error()}) return } else { diff --git a/auth_krb5.go b/auth_krb5.go index f78e6d9..55bdc68 100644 --- a/auth_krb5.go +++ b/auth_krb5.go @@ -83,7 +83,7 @@ func checkAuthKrb5(c *gin.Context) { return } - if usr, err := completeAuth(c, lf.Login, lf.Login+"@epita.fr", "", "", "", nil); err != nil { + if usr, err := completeAuth(c, lf.Login, lf.Login+"@epita.fr", "", "", currentPromo, "", "", nil); err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"errmsg": err.Error()}) return } else { diff --git a/auth_oidc.go b/auth_oidc.go index 973b878..553baae 100644 --- a/auth_oidc.go +++ b/auth_oidc.go @@ -48,7 +48,7 @@ func initializeOIDC(router *gin.Engine) { Endpoint: provider.Endpoint(), // "openid" is a required scope for OpenID Connect flows. - Scopes: []string{oidc.ScopeOpenID, "profile", "email", "epita"}, + Scopes: []string{oidc.ScopeOpenID, "profile", "email", "epita", "picture"}, } oidcConfig := oidc.Config{ @@ -105,15 +105,20 @@ func OIDC_CRI_complete(c *gin.Context) { } var claims struct { - Firstname string `json:"given_name"` - Lastname string `json:"family_name"` - Nickname string `json:"nickname"` - Username string `json:"preferred_username"` - Email string `json:"email"` - Groups []map[string]interface{} `json:"groups"` + Firstname string `json:"given_name"` + Lastname string `json:"family_name"` + Username string `json:"preferred_username"` + Email string `json:"email"` + Groups []map[string]interface{} `json:"groups"` + Campuses []string `json:"campuses"` + GraduationYears []uint `json:"graduation_years"` + Picture string `json:"picture"` + PictureSquare string `json:"picture_square"` + PictureThumb string `json:"picture_thumb"` } if err := idToken.Claims(&claims); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"errmsg": err.Error()}) + log.Println("Unable to extract claims to Claims:", err.Error()) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"errmsg": "Something goes wrong when analyzing your claims. Contact administrator to fix the issue."}) return } @@ -124,7 +129,16 @@ func OIDC_CRI_complete(c *gin.Context) { } } - if _, err := completeAuth(c, claims.Username, claims.Email, claims.Firstname, claims.Lastname, groups, session); err != nil { + var promo uint + if len(claims.GraduationYears) > 0 { + for _, gy := range claims.GraduationYears { + if gy > promo { + promo = gy + } + } + } + + if _, err := completeAuth(c, claims.Username, claims.Email, claims.Firstname, claims.Lastname, promo, groups, claims.PictureSquare, session); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"errmsg": err.Error()}) return } diff --git a/users.go b/users.go index b999398..5f31017 100644 --- a/users.go +++ b/users.go @@ -207,14 +207,14 @@ func userExists(login string) bool { return err == nil && z == 1 } -func NewUser(login string, email string, firstname string, lastname string, groups string) (*User, error) { +func NewUser(login string, email string, firstname string, lastname string, promo uint, groups string) (*User, error) { t := time.Now() - if res, err := DBExec("INSERT INTO users (login, email, firstname, lastname, time, promo, groups) VALUES (?, ?, ?, ?, ?, ?, ?)", login, email, firstname, lastname, t, currentPromo, groups); err != nil { + if res, err := DBExec("INSERT INTO users (login, email, firstname, lastname, time, promo, groups) VALUES (?, ?, ?, ?, ?, ?, ?)", login, email, firstname, lastname, t, promo, groups); err != nil { return nil, err } else if sid, err := res.LastInsertId(); err != nil { return nil, err } else { - return &User{sid, login, email, firstname, lastname, t, currentPromo, groups, false}, nil + return &User{sid, login, email, firstname, lastname, t, promo, groups, false}, nil } }