diff --git a/change.go b/change.go index 08dbd39..35d3877 100644 --- a/change.go +++ b/change.go @@ -16,34 +16,50 @@ func checkPasswdConstraint(password string) error { func changePassword(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { - displayTmpl(w, "change.html", map[string]interface{}{}) + csrfToken, err := setCSRFToken(w) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + displayTmpl(w, "change.html", map[string]interface{}{"csrf_token": csrfToken}) return } + if !validateCSRF(r) { + csrfToken, _ := setCSRFToken(w) + displayTmplError(w, http.StatusForbidden, "change.html", map[string]interface{}{"error": "Invalid or missing CSRF token. Please try again.", "csrf_token": csrfToken}) + return + } + + renderError := func(status int, msg string) { + csrfToken, _ := setCSRFToken(w) + displayTmplError(w, status, "change.html", map[string]interface{}{"error": msg, "csrf_token": csrfToken}) + } + // Check the two new passwords are identical if r.PostFormValue("newpassword") != r.PostFormValue("new2password") { - displayTmplError(w, http.StatusNotAcceptable, "change.html", map[string]interface{}{"error": "New passwords are not identical. Please retry."}) + renderError(http.StatusNotAcceptable, "New passwords are not identical. Please retry.") } else if len(r.PostFormValue("login")) == 0 { - displayTmplError(w, http.StatusNotAcceptable, "change.html", map[string]interface{}{"error": "Please provide a valid login"}) + renderError(http.StatusNotAcceptable, "Please provide a valid login") } else if err := checkPasswdConstraint(r.PostFormValue("newpassword")); err != nil { - displayTmplError(w, http.StatusNotAcceptable, "change.html", map[string]interface{}{"error": "The password you chose doesn't respect all constraints: " + err.Error()}) + renderError(http.StatusNotAcceptable, "The password you chose doesn't respect all constraints: "+err.Error()) } else { conn, err := myLDAP.Connect() if err != nil || conn == nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "change.html", map[string]interface{}{"error": err.Error()}) + renderError(http.StatusInternalServerError, "Unable to process your request. Please try again later.") } else if err := conn.ServiceBind(); err != nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "change.html", map[string]interface{}{"error": err.Error()}) + renderError(http.StatusInternalServerError, "Unable to process your request. Please try again later.") } else if dn, err := conn.SearchDN(r.PostFormValue("login"), true); err != nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "change.html", map[string]interface{}{"error": err.Error()}) + renderError(http.StatusUnauthorized, "Invalid login or password.") } else if err := conn.Bind(dn, r.PostFormValue("password")); err != nil { log.Println(err) - displayTmplError(w, http.StatusUnauthorized, "change.html", map[string]interface{}{"error": err.Error()}) + renderError(http.StatusUnauthorized, "Invalid login or password.") } else if err := conn.ChangePassword(dn, r.PostFormValue("newpassword")); err != nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "change.html", map[string]interface{}{"error": err.Error()}) + renderError(http.StatusInternalServerError, "Unable to process your request. Please try again later.") } else { displayMsg(w, "Password successfully changed!", http.StatusOK) } diff --git a/csrf.go b/csrf.go new file mode 100644 index 0000000..a94be83 --- /dev/null +++ b/csrf.go @@ -0,0 +1,39 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" + "net/http" +) + +func generateCSRFToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} + +func setCSRFToken(w http.ResponseWriter) (string, error) { + token, err := generateCSRFToken() + if err != nil { + return "", err + } + http.SetCookie(w, &http.Cookie{ + Name: "csrf_token", + Value: token, + Path: "/", + HttpOnly: false, // must be readable via form hidden field comparison + SameSite: http.SameSiteStrictMode, + }) + return token, nil +} + +func validateCSRF(r *http.Request) bool { + cookie, err := r.Cookie("csrf_token") + if err != nil || cookie.Value == "" { + return false + } + formToken := r.PostFormValue("csrf_token") + return formToken != "" && cookie.Value == formToken +} diff --git a/ldap.go b/ldap.go index e890b27..85271fe 100644 --- a/ldap.go +++ b/ldap.go @@ -23,6 +23,7 @@ type LDAP struct { MailPort int MailUser string MailPassword string + MailFrom string } func (l LDAP) Connect() (*LDAPConn, error) { diff --git a/lost.go b/lost.go index 0ddd36f..6481d62 100644 --- a/lost.go +++ b/lost.go @@ -1,54 +1,64 @@ package main import ( - "crypto/sha512" + "crypto/rand" "encoding/base64" - "encoding/binary" "io" "log" "net/http" "os" "os/exec" + "sync" "time" "gopkg.in/gomail.v2" ) -func (l LDAPConn) genToken(dn string, previous bool) string { - hour := time.Now() - // Generate the previous token? - if previous { - hour.Add(time.Hour * -1) +type resetTokenEntry struct { + dn string + expiresAt time.Time +} + +var resetTokenStore = struct { + mu sync.Mutex + tokens map[string]resetTokenEntry +}{tokens: make(map[string]resetTokenEntry)} + +func generateResetToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err } + return base64.URLEncoding.EncodeToString(b), nil +} - b := make([]byte, binary.MaxVarintLen64) - binary.PutVarint(b, hour.Round(time.Hour).Unix()) +func storeResetToken(token string, dn string) { + resetTokenStore.mu.Lock() + defer resetTokenStore.mu.Unlock() - // Search the email address and current password - entries, err := l.GetEntry(dn) - if err != nil { - log.Println("Unable to generate token:", err) - return "#err" - } - - email := "" - curpasswd := "" - for _, e := range entries { - if e.Name == "mail" { - email += e.Values[0] - } else if e.Name == "userPassword" { - curpasswd += e.Values[0] + // Clean expired tokens + now := time.Now() + for t, e := range resetTokenStore.tokens { + if now.After(e.expiresAt) { + delete(resetTokenStore.tokens, t) } } + resetTokenStore.tokens[token] = resetTokenEntry{ + dn: dn, + expiresAt: now.Add(time.Hour), + } +} - // Hash that - hash := sha512.New() - hash.Write(b) - hash.Write([]byte(dn)) - hash.Write([]byte(email)) - hash.Write([]byte(curpasswd)) - - return base64.StdEncoding.EncodeToString(hash.Sum(nil)[:]) +func consumeResetToken(token string) (string, bool) { + resetTokenStore.mu.Lock() + defer resetTokenStore.mu.Unlock() + entry, ok := resetTokenStore.tokens[token] + if !ok || time.Now().After(entry.expiresAt) { + delete(resetTokenStore.tokens, token) + return "", false + } + delete(resetTokenStore.tokens, token) + return entry.dn, true } func lostPasswordToken(conn *LDAPConn, login string) (string, string, error) { @@ -64,15 +74,31 @@ func lostPasswordToken(conn *LDAPConn, login string) (string, string, error) { return "", "", err } - // Generate the token - token := conn.genToken(dn, false) + // Generate a cryptographically random token + token, err := generateResetToken() + if err != nil { + return "", "", err + } + + // Store token server-side with expiration + storeResetToken(token, dn) return token, dn, nil } func lostPassword(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { - displayTmpl(w, "lost.html", map[string]interface{}{}) + csrfToken, err := setCSRFToken(w) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + displayTmpl(w, "lost.html", map[string]interface{}{"csrf_token": csrfToken}) + return + } + + if !validateCSRF(r) { + displayTmplError(w, http.StatusForbidden, "lost.html", map[string]interface{}{"error": "Invalid or missing CSRF token. Please try again."}) return } @@ -80,7 +106,7 @@ func lostPassword(w http.ResponseWriter, r *http.Request) { conn, err := myLDAP.Connect() if err != nil || conn == nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": err.Error()}) + displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": "Unable to process your request. Please try again later."}) return } @@ -88,7 +114,8 @@ func lostPassword(w http.ResponseWriter, r *http.Request) { token, dn, err := lostPasswordToken(conn, r.PostFormValue("login")) if err != nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": err.Error()}) + // Return generic message to avoid user enumeration + displayMsg(w, "If an account with that login exists, a password recovery email has been sent.", http.StatusOK) return } @@ -96,7 +123,7 @@ func lostPassword(w http.ResponseWriter, r *http.Request) { entries, err := conn.GetEntry(dn) if err != nil { log.Println(err) - displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": err.Error()}) + displayMsg(w, "If an account with that login exists, a password recovery email has been sent.", http.StatusOK) return } @@ -113,16 +140,16 @@ func lostPassword(w http.ResponseWriter, r *http.Request) { if email == "" { log.Println("Unable to find a valid adress for user " + dn) - displayTmplError(w, http.StatusBadRequest, "lost.html", map[string]interface{}{"error": "We were unable to find a valid email address associated with your account. Please contact an administrator."}) + displayMsg(w, "If an account with that login exists, a password recovery email has been sent.", http.StatusOK) return } // Send the email m := gomail.NewMessage() - m.SetHeader("From", "noreply@nemunai.re") + m.SetHeader("From", myLDAP.MailFrom) m.SetHeader("To", email) m.SetHeader("Subject", "SSO nemunai.re: password recovery") - m.SetBody("text/plain", "Hello "+cn+"!\n\nSomeone, and we hope it's you, requested to reset your account password. \nIn order to continue, go to:\n"+BASEURL+"/reset?l="+r.PostFormValue("login")+"&t="+token+"\n\nBest regards,\n-- \nnemunai.re SSO") + m.SetBody("text/plain", "Hello "+cn+"!\n\nSomeone, and we hope it's you, requested to reset your account password. \nIn order to continue, go to:\n"+myPublicURL+"/reset?l="+r.PostFormValue("login")+"&t="+token+"\n\nThis link expires in 1 hour and can only be used once.\n\nBest regards,\n-- \nnemunai.re SSO") var s gomail.Sender if myLDAP.MailHost != "" { @@ -130,7 +157,7 @@ func lostPassword(w http.ResponseWriter, r *http.Request) { s, err = d.Dial() if err != nil { log.Println("Unable to connect to email server: " + err.Error()) - displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": "Unable to connect to email server: " + err.Error()}) + displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": "Unable to send password recovery email. Please try again later."}) return } } else { @@ -165,7 +192,7 @@ func lostPassword(w http.ResponseWriter, r *http.Request) { if err := gomail.Send(s, m); err != nil { log.Println("Unable to send email: " + err.Error()) - displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": "Unable to send email: " + err.Error()}) + displayTmplError(w, http.StatusInternalServerError, "lost.html", map[string]interface{}{"error": "Unable to send password recovery email. Please try again later."}) return } diff --git a/main.go b/main.go index 9ab3e58..67e4234 100644 --- a/main.go +++ b/main.go @@ -17,13 +17,14 @@ import ( "syscall" ) -const BASEURL = "https://ldap.nemunai.re" +var myPublicURL = "https://ldap.nemunai.re" var myLDAP = LDAP{ Host: "localhost", Port: 389, BaseDN: "dc=example,dc=com", MailPort: 587, + MailFrom: "noreply@nemunai.re", } type ResponseWriterPrefix struct { @@ -70,8 +71,11 @@ func main() { var bind = flag.String("bind", "127.0.0.1:8080", "Bind port/socket") var baseURL = flag.String("baseurl", "/", "URL prepended to each URL") var configfile = flag.String("config", "", "path to the configuration file") + var publicURL = flag.String("public-url", myPublicURL, "Public base URL used in password reset emails") flag.Parse() + myPublicURL = *publicURL + // Sanitize options log.Println("Checking paths...") if *baseURL != "/" { @@ -141,9 +145,25 @@ func main() { if val, ok := os.LookupEnv("SMTP_USER"); ok { myLDAP.MailUser = val } - if val, ok := os.LookupEnv("SMTP_PASSWORD"); ok { + if val, ok := os.LookupEnv("SMTP_PASSWORD_FILE"); ok { + if fd, err := os.Open(val); err != nil { + log.Fatal(err) + } else if cnt, err := os.ReadFile(val); err != nil { + fd.Close() + log.Fatal(err) + } else { + fd.Close() + myLDAP.MailPassword = string(cnt) + } + } else if val, ok := os.LookupEnv("SMTP_PASSWORD"); ok { myLDAP.MailPassword = val } + if val, ok := os.LookupEnv("SMTP_FROM"); ok { + myLDAP.MailFrom = val + } + if val, ok := os.LookupEnv("PUBLIC_URL"); ok { + myPublicURL = val + } if flag.NArg() > 0 { switch flag.Arg(0) { @@ -164,7 +184,7 @@ func main() { log.Fatal(err.Error()) } - fmt.Printf("Reset link for %s: %s/reset?l=%s&t=%s", dn, BASEURL, login, token) + fmt.Printf("Reset link for %s: %s/reset?l=%s&t=%s", dn, myPublicURL, login, token) return case "serve": case "server": diff --git a/reset.go b/reset.go index f644507..225d22d 100644 --- a/reset.go +++ b/reset.go @@ -3,7 +3,6 @@ package main import ( "log" "net/http" - "strings" ) func resetPassword(w http.ResponseWriter, r *http.Request) { @@ -14,22 +13,46 @@ func resetPassword(w http.ResponseWriter, r *http.Request) { base := map[string]interface{}{ "login": r.URL.Query().Get("l"), - "token": strings.Replace(r.URL.Query().Get("t"), " ", "+", -1), + "token": r.URL.Query().Get("t"), } if r.Method != "POST" { + csrfToken, err := setCSRFToken(w) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + base["csrf_token"] = csrfToken displayTmpl(w, "reset.html", base) return } + renderError := func(status int, msg string) { + csrfToken, _ := setCSRFToken(w) + base["error"] = msg + base["csrf_token"] = csrfToken + displayTmplError(w, status, "reset.html", base) + } + + if !validateCSRF(r) { + renderError(http.StatusForbidden, "Invalid or missing CSRF token. Please try again.") + return + } + // Check the two new passwords are identical if r.PostFormValue("newpassword") != r.PostFormValue("new2password") { - base["error"] = "New passwords are not identical. Please retry." - displayTmplError(w, http.StatusNotAcceptable, "reset.html", base) + renderError(http.StatusNotAcceptable, "New passwords are not identical. Please retry.") return } else if err := checkPasswdConstraint(r.PostFormValue("newpassword")); err != nil { - base["error"] = "The password you chose doesn't respect all constraints: " + err.Error() - displayTmplError(w, http.StatusNotAcceptable, "reset.html", base) + renderError(http.StatusNotAcceptable, "The password you chose doesn't respect all constraints: "+err.Error()) + return + } + + // Validate and consume the token (single-use, server-side) + token := r.PostFormValue("token") + dn, ok := consumeResetToken(token) + if !ok { + renderError(http.StatusNotAcceptable, "Token invalid or expired, please retry the lost password procedure. Tokens expire after 1 hour.") return } @@ -37,41 +60,22 @@ func resetPassword(w http.ResponseWriter, r *http.Request) { conn, err := myLDAP.Connect() if err != nil || conn == nil { log.Println(err) - base["error"] = err.Error() - displayTmplError(w, http.StatusInternalServerError, "reset.html", base) + renderError(http.StatusInternalServerError, "Unable to process your request. Please try again later.") return } - // Bind as service to perform the search + // Bind as service to perform the password change err = conn.ServiceBind() if err != nil { log.Println(err) - base["error"] = err.Error() - displayTmplError(w, http.StatusInternalServerError, "reset.html", base) - return - } - - // Search the dn of the given user - dn, err := conn.SearchDN(r.PostFormValue("login"), true) - if err != nil { - log.Println(err) - base["error"] = err.Error() - displayTmplError(w, http.StatusInternalServerError, "reset.html", base) - return - } - - // Check token validity (allow current token + last one) - if conn.genToken(dn, false) != r.PostFormValue("token") && conn.genToken(dn, true) != r.PostFormValue("token") { - base["error"] = "Token invalid, please retry the lost password procedure. Please note that our token expires after 1 hour." - displayTmplError(w, http.StatusNotAcceptable, "reset.html", base) + renderError(http.StatusInternalServerError, "Unable to process your request. Please try again later.") return } // Replace the password by the new given if err := conn.ChangePassword(dn, r.PostFormValue("newpassword")); err != nil { log.Println(err) - base["error"] = err.Error() - displayTmplError(w, http.StatusInternalServerError, "reset.html", base) + renderError(http.StatusInternalServerError, "Unable to process your request. Please try again later.") return } diff --git a/static/change.html b/static/change.html index 019e7f5..60ffe7e 100644 --- a/static/change.html +++ b/static/change.html @@ -3,6 +3,7 @@