add basic session management

This commit is contained in:
nemunaire 2019-09-10 19:11:13 +02:00
parent 8ac28310f9
commit 74dfd0a42a
5 changed files with 227 additions and 4 deletions

View File

@ -1,13 +1,18 @@
package api
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"github.com/julienschmidt/httprouter"
"git.nemunai.re/libredns/struct"
)
type Response interface {
@ -70,3 +75,42 @@ func apiHandler(f func(httprouter.Params, io.Reader) (Response)) func(http.Respo
f(ps, r.Body).WriteResponse(w)
}
}
func apiAuthHandler(f func(libredns.User, httprouter.Params, io.Reader) (Response)) func(http.ResponseWriter, *http.Request, httprouter.Params) {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if addr := r.Header.Get("X-Forwarded-For"); addr != "" {
r.RemoteAddr = addr
}
log.Printf("%s \"%s %s\" [%s]\n", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent())
// Read the body
if r.ContentLength < 0 || r.ContentLength > 6553600 {
http.Error(w, fmt.Sprintf("{errmsg:\"Request too large or request size unknown\"}"), http.StatusRequestEntityTooLarge)
return
}
if flds := strings.Fields(r.Header.Get("Authorization")); len(flds) != 2 || flds[0] != "Bearer" {
APIErrorResponse{
err: errors.New("Authorization required"),
status: http.StatusUnauthorized,
}.WriteResponse(w)
} else if sessionid, err := base64.StdEncoding.DecodeString(flds[1]); err != nil {
APIErrorResponse{
err: err,
status: http.StatusUnauthorized,
}.WriteResponse(w)
} else if session, err := libredns.GetSession(sessionid); err != nil {
APIErrorResponse{
err: err,
status: http.StatusUnauthorized,
}.WriteResponse(w)
} else if std, err := libredns.GetUser(int(session.IdUser)); err != nil {
APIErrorResponse{
err: err,
status: http.StatusUnauthorized,
}.WriteResponse(w)
} else {
f(std, ps, r.Body).WriteResponse(w)
}
}
}

97
api/user_auth.go Normal file
View File

@ -0,0 +1,97 @@
package api
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/julienschmidt/httprouter"
"git.nemunai.re/libredns/struct"
)
var AuthFunc = checkAuth
func init() {
router.GET("/api/users/auth", apiAuthHandler(validateAuthToken))
router.POST("/api/users/auth", apiHandler(func(ps httprouter.Params, b io.Reader) (Response) {
return AuthFunc(ps, b)
}))
}
func validateAuthToken(u libredns.User, _ httprouter.Params, _ io.Reader) (Response) {
return APIResponse{
response: u,
}
}
type loginForm struct {
Email string
Password string
}
func dummyAuth(_ httprouter.Params, body io.Reader) Response {
var lf loginForm
if err := json.NewDecoder(body).Decode(&lf); err != nil {
return APIErrorResponse{
err: err,
}
}
if user, err := libredns.GetUserByEmail(lf.Email); err != nil {
return APIErrorResponse{
err: err,
}
} else {
session, err := user.NewSession()
if err != nil {
return APIErrorResponse{
err: err,
}
}
res := map[string]interface{}{}
res["status"] = "OK"
res["id_session"] = session.Id
return APIResponse{
response: res,
}
}
}
func checkAuth(_ httprouter.Params, body io.Reader) Response {
var lf loginForm
if err := json.NewDecoder(body).Decode(&lf); err != nil {
return APIErrorResponse{
err: err,
}
}
if user, err := libredns.GetUserByEmail(lf.Email); err != nil {
return APIErrorResponse{
err: err,
}
} else if !user.CheckAuth(lf.Password) {
return APIErrorResponse{
err: errors.New(`{"status": "Invalid username or password"}`),
status: http.StatusUnauthorized,
}
} else {
session, err := user.NewSession()
if err != nil {
return APIErrorResponse{
err: err,
}
}
res := map[string]interface{}{}
res["status"] = "OK"
res["id_session"] = session.Id
return APIResponse{
response: res,
}
}
}

View File

@ -64,6 +64,16 @@ CREATE TABLE IF NOT EXISTS users(
salt BINARY(64) NOT NULL,
registration_time TIMESTAMP NOT NULL
) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
`); err != nil {
return err
}
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS user_sessions(
id_session BLOB(255) NOT NULL,
id_user INTEGER NOT NULL,
time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(id_user) REFERENCES users(id_user)
);
`); err != nil {
return err
}

58
struct/session.go Normal file
View File

@ -0,0 +1,58 @@
package libredns
import (
"crypto/rand"
"time"
)
type Session struct {
Id []byte `json:"id"`
IdUser int64 `json:"login"`
Time time.Time `json:"time"`
}
func GetSession(id []byte) (s Session, err error) {
err = DBQueryRow("SELECT id_session, id_user, time FROM user_sessions WHERE id_session=?", id).Scan(&s.Id, &s.IdUser, &s.Time)
return
}
func (user User) NewSession() (Session, error) {
session_id := make([]byte, 255)
if _, err := rand.Read(session_id); err != nil {
return Session{}, err
} else if _, err := DBExec("INSERT INTO user_sessions (id_session, id_user, time) VALUES (?, ?, ?)", session_id, user.Id, time.Now()); err != nil {
return Session{}, err
} else {
return Session{session_id, user.Id, time.Now()}, nil
}
}
func (s Session) Update() (int64, error) {
if res, err := DBExec("UPDATE user_sessions SET id_user = ?, time = ? WHERE id_session = ?", s.IdUser, s.Time, s.Id); err != nil {
return 0, err
} else if nb, err := res.RowsAffected(); err != nil {
return 0, err
} else {
return nb, err
}
}
func (s Session) Delete() (int64, error) {
if res, err := DBExec("DELETE FROM user_sessions WHERE id_session = ?", s.Id); err != nil {
return 0, err
} else if nb, err := res.RowsAffected(); err != nil {
return 0, err
} else {
return nb, err
}
}
func ClearSession() (int64, error) {
if res, err := DBExec("DELETE FROM user_sessions"); err != nil {
return 0, err
} else if nb, err := res.RowsAffected(); err != nil {
return 0, err
} else {
return nb, err
}
}

View File

@ -37,12 +37,12 @@ func GetUsers() (users []User, err error) {
}
func GetUser(id int) (u User, err error) {
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time WHERE id_user=?", id).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time FROM users WHERE id_user=?", id).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
return
}
func GetUserByEmail(email string) (u User, err error) {
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time WHERE email=?", email).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time FROM users WHERE email=?", email).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
return
}
@ -72,7 +72,21 @@ func NewUser(email string, password string) (User, error) {
}
}
func (u User) Update() (int64, error) {
func (u *User) CheckAuth(password string) bool {
pass := GenPassword(password, u.salt)
if len(pass) != len(u.password) {
return false
} else {
for k := range pass {
if pass[k] != u.password[k] {
return false
}
}
return true
}
}
func (u *User) Update() (int64, error) {
if res, err := DBExec("UPDATE users SET email = ?, password = ?, salt = ?, registration_time = ? WHERE id_user = ?", u.Email, u.password, u.salt, u.RegistrationTime, u.Id); err != nil {
return 0, err
} else if nb, err := res.RowsAffected(); err != nil {
@ -82,7 +96,7 @@ func (u User) Update() (int64, error) {
}
}
func (u User) Delete() (int64, error) {
func (u *User) Delete() (int64, error) {
if res, err := DBExec("DELETE FROM users WHERE id_user = ?", u.Id); err != nil {
return 0, err
} else if nb, err := res.RowsAffected(); err != nil {