add basic session management
This commit is contained in:
parent
8ac28310f9
commit
74dfd0a42a
|
@ -1,13 +1,18 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"git.nemunai.re/libredns/struct"
|
||||
)
|
||||
|
||||
type Response interface {
|
||||
|
@ -70,3 +75,42 @@ func apiHandler(f func(httprouter.Params, io.Reader) (Response)) func(http.Respo
|
|||
f(ps, r.Body).WriteResponse(w)
|
||||
}
|
||||
}
|
||||
|
||||
func apiAuthHandler(f func(libredns.User, httprouter.Params, io.Reader) (Response)) func(http.ResponseWriter, *http.Request, httprouter.Params) {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
if addr := r.Header.Get("X-Forwarded-For"); addr != "" {
|
||||
r.RemoteAddr = addr
|
||||
}
|
||||
log.Printf("%s \"%s %s\" [%s]\n", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent())
|
||||
|
||||
// Read the body
|
||||
if r.ContentLength < 0 || r.ContentLength > 6553600 {
|
||||
http.Error(w, fmt.Sprintf("{errmsg:\"Request too large or request size unknown\"}"), http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
if flds := strings.Fields(r.Header.Get("Authorization")); len(flds) != 2 || flds[0] != "Bearer" {
|
||||
APIErrorResponse{
|
||||
err: errors.New("Authorization required"),
|
||||
status: http.StatusUnauthorized,
|
||||
}.WriteResponse(w)
|
||||
} else if sessionid, err := base64.StdEncoding.DecodeString(flds[1]); err != nil {
|
||||
APIErrorResponse{
|
||||
err: err,
|
||||
status: http.StatusUnauthorized,
|
||||
}.WriteResponse(w)
|
||||
} else if session, err := libredns.GetSession(sessionid); err != nil {
|
||||
APIErrorResponse{
|
||||
err: err,
|
||||
status: http.StatusUnauthorized,
|
||||
}.WriteResponse(w)
|
||||
} else if std, err := libredns.GetUser(int(session.IdUser)); err != nil {
|
||||
APIErrorResponse{
|
||||
err: err,
|
||||
status: http.StatusUnauthorized,
|
||||
}.WriteResponse(w)
|
||||
} else {
|
||||
f(std, ps, r.Body).WriteResponse(w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"git.nemunai.re/libredns/struct"
|
||||
)
|
||||
|
||||
var AuthFunc = checkAuth
|
||||
|
||||
func init() {
|
||||
router.GET("/api/users/auth", apiAuthHandler(validateAuthToken))
|
||||
router.POST("/api/users/auth", apiHandler(func(ps httprouter.Params, b io.Reader) (Response) {
|
||||
return AuthFunc(ps, b)
|
||||
}))
|
||||
}
|
||||
|
||||
func validateAuthToken(u libredns.User, _ httprouter.Params, _ io.Reader) (Response) {
|
||||
return APIResponse{
|
||||
response: u,
|
||||
}
|
||||
}
|
||||
|
||||
type loginForm struct {
|
||||
Email string
|
||||
Password string
|
||||
}
|
||||
|
||||
func dummyAuth(_ httprouter.Params, body io.Reader) Response {
|
||||
var lf loginForm
|
||||
if err := json.NewDecoder(body).Decode(&lf); err != nil {
|
||||
return APIErrorResponse{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if user, err := libredns.GetUserByEmail(lf.Email); err != nil {
|
||||
return APIErrorResponse{
|
||||
err: err,
|
||||
}
|
||||
} else {
|
||||
session, err := user.NewSession()
|
||||
if err != nil {
|
||||
return APIErrorResponse{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
res := map[string]interface{}{}
|
||||
res["status"] = "OK"
|
||||
res["id_session"] = session.Id
|
||||
|
||||
return APIResponse{
|
||||
response: res,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkAuth(_ httprouter.Params, body io.Reader) Response {
|
||||
var lf loginForm
|
||||
if err := json.NewDecoder(body).Decode(&lf); err != nil {
|
||||
return APIErrorResponse{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if user, err := libredns.GetUserByEmail(lf.Email); err != nil {
|
||||
return APIErrorResponse{
|
||||
err: err,
|
||||
}
|
||||
} else if !user.CheckAuth(lf.Password) {
|
||||
return APIErrorResponse{
|
||||
err: errors.New(`{"status": "Invalid username or password"}`),
|
||||
status: http.StatusUnauthorized,
|
||||
}
|
||||
} else {
|
||||
session, err := user.NewSession()
|
||||
if err != nil {
|
||||
return APIErrorResponse{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
res := map[string]interface{}{}
|
||||
res["status"] = "OK"
|
||||
res["id_session"] = session.Id
|
||||
|
||||
return APIResponse{
|
||||
response: res,
|
||||
}
|
||||
}
|
||||
}
|
10
struct/db.go
10
struct/db.go
|
@ -64,6 +64,16 @@ CREATE TABLE IF NOT EXISTS users(
|
|||
salt BINARY(64) NOT NULL,
|
||||
registration_time TIMESTAMP NOT NULL
|
||||
) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_sessions(
|
||||
id_session BLOB(255) NOT NULL,
|
||||
id_user INTEGER NOT NULL,
|
||||
time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(id_user) REFERENCES users(id_user)
|
||||
);
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
package libredns
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Id []byte `json:"id"`
|
||||
IdUser int64 `json:"login"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
func GetSession(id []byte) (s Session, err error) {
|
||||
err = DBQueryRow("SELECT id_session, id_user, time FROM user_sessions WHERE id_session=?", id).Scan(&s.Id, &s.IdUser, &s.Time)
|
||||
return
|
||||
}
|
||||
|
||||
func (user User) NewSession() (Session, error) {
|
||||
session_id := make([]byte, 255)
|
||||
if _, err := rand.Read(session_id); err != nil {
|
||||
return Session{}, err
|
||||
} else if _, err := DBExec("INSERT INTO user_sessions (id_session, id_user, time) VALUES (?, ?, ?)", session_id, user.Id, time.Now()); err != nil {
|
||||
return Session{}, err
|
||||
} else {
|
||||
return Session{session_id, user.Id, time.Now()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s Session) Update() (int64, error) {
|
||||
if res, err := DBExec("UPDATE user_sessions SET id_user = ?, time = ? WHERE id_session = ?", s.IdUser, s.Time, s.Id); err != nil {
|
||||
return 0, err
|
||||
} else if nb, err := res.RowsAffected(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return nb, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s Session) Delete() (int64, error) {
|
||||
if res, err := DBExec("DELETE FROM user_sessions WHERE id_session = ?", s.Id); err != nil {
|
||||
return 0, err
|
||||
} else if nb, err := res.RowsAffected(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return nb, err
|
||||
}
|
||||
}
|
||||
|
||||
func ClearSession() (int64, error) {
|
||||
if res, err := DBExec("DELETE FROM user_sessions"); err != nil {
|
||||
return 0, err
|
||||
} else if nb, err := res.RowsAffected(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return nb, err
|
||||
}
|
||||
}
|
|
@ -37,12 +37,12 @@ func GetUsers() (users []User, err error) {
|
|||
}
|
||||
|
||||
func GetUser(id int) (u User, err error) {
|
||||
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time WHERE id_user=?", id).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
|
||||
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time FROM users WHERE id_user=?", id).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserByEmail(email string) (u User, err error) {
|
||||
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time WHERE email=?", email).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
|
||||
err = DBQueryRow("SELECT id_user, email, password, salt, registration_time FROM users WHERE email=?", email).Scan(&u.Id, &u.Email, &u.password, &u.salt, &u.RegistrationTime)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,21 @@ func NewUser(email string, password string) (User, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (u User) Update() (int64, error) {
|
||||
func (u *User) CheckAuth(password string) bool {
|
||||
pass := GenPassword(password, u.salt)
|
||||
if len(pass) != len(u.password) {
|
||||
return false
|
||||
} else {
|
||||
for k := range pass {
|
||||
if pass[k] != u.password[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) Update() (int64, error) {
|
||||
if res, err := DBExec("UPDATE users SET email = ?, password = ?, salt = ?, registration_time = ? WHERE id_user = ?", u.Email, u.password, u.salt, u.RegistrationTime, u.Id); err != nil {
|
||||
return 0, err
|
||||
} else if nb, err := res.RowsAffected(); err != nil {
|
||||
|
@ -82,7 +96,7 @@ func (u User) Update() (int64, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (u User) Delete() (int64, error) {
|
||||
func (u *User) Delete() (int64, error) {
|
||||
if res, err := DBExec("DELETE FROM users WHERE id_user = ?", u.Id); err != nil {
|
||||
return 0, err
|
||||
} else if nb, err := res.RowsAffected(); err != nil {
|
||||
|
|
Loading…
Reference in New Issue