diff --git a/token-validator/auth.go b/token-validator/auth.go index 9b10e99..460b67d 100644 --- a/token-validator/auth.go +++ b/token-validator/auth.go @@ -3,27 +3,18 @@ package main import ( "encoding/json" "errors" - "fmt" "net/http" "github.com/julienschmidt/httprouter" ) func init() { - router.GET("/auth", authHandler(apiHandler(validateAuthToken, printStudent))) - router.POST("/auth", apiHandler(checkAuth)) + router.GET("/api/auth", apiAuthHandler(validateAuthToken)) + router.POST("/api/auth", apiHandler(checkAuth)) } -func printStudent(std *Student, r *http.Request) error { - if std != nil { - return errors.New(fmt.Sprintf("%s", *std)) - } else { - return nil - } -} - -func validateAuthToken(_ httprouter.Params, _ []byte) (interface{}, error) { - return false, nil +func validateAuthToken(s Student, _ httprouter.Params, _ []byte) (interface{}, error) { + return s, nil } type loginForm struct { @@ -37,7 +28,7 @@ func checkAuth(_ httprouter.Params, body []byte) (interface{}, error) { return nil, err } - if r, err := http.NewRequest("GET", "https://owncloud.srs.epita.fr/remote.php/webdav/", nil); err != nil { + if r, err := http.NewRequest("GET", "https://fic.srs.epita.fr/2020/", nil); err != nil { return nil, err } else { r.SetBasicAuth(lf.Username, lf.Password) diff --git a/token-validator/db.go b/token-validator/db.go index 58376a8..7a2d883 100644 --- a/token-validator/db.go +++ b/token-validator/db.go @@ -112,6 +112,18 @@ CREATE TABLE IF NOT EXISTS student_sessions( time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY(id_student) REFERENCES students(id_student) ) DEFAULT CHARACTER SET = utf8 COLLATE = utf8_bin; +`); err != nil { + return err + } + if _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS student_tunnel_tokens( + token BLOB(255) NOT NULL, + token_text CHAR(10) NOT NULL, + id_student INTEGER NOT NULL, + time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + pubkey BLOB(50) DEFAULT NULL, + FOREIGN KEY(id_student) REFERENCES students(id_student) +) DEFAULT CHARACTER SET = utf8 COLLATE = utf8_bin; `); err != nil { return err } diff --git a/token-validator/handler.go b/token-validator/handler.go index b85285f..eb7e659 100644 --- a/token-validator/handler.go +++ b/token-validator/handler.go @@ -5,6 +5,7 @@ import ( "crypto/sha512" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "log" @@ -156,6 +157,22 @@ func apiHandler(f DispatchFunction, access ...func(*Student, *http.Request) erro return rawHandler(func (_ *http.Request, ps httprouter.Params, b []byte) (interface{}, error) { return f(ps, b) }, access...) } +func apiAuthHandler(f func(Student, httprouter.Params, []byte) (interface{}, error), access ...func(*Student, *http.Request) error) func(http.ResponseWriter, *http.Request, httprouter.Params) { + return rawHandler(func (r *http.Request, ps httprouter.Params, b []byte) (interface{}, error) { + if flds := strings.Fields(r.Header.Get("Authorization")); len(flds) != 2 || flds[0] != "Bearer" { + return nil, errors.New("Authorization required") + } else if sessionid, err := base64.StdEncoding.DecodeString(flds[1]); err != nil { + return nil, err + } else if session, err := getSession(sessionid); err != nil { + return nil, err + } else if std, err := getStudent(int(session.IdStudent)); err != nil { + return nil, err + } else { + return f(std, ps, b) + } + }, access...) +} + func studentHandler(f func(Student, []byte) (interface{}, error)) func(httprouter.Params, []byte) (interface{}, error) { return func(ps httprouter.Params, body []byte) (interface{}, error) { if sid, err := strconv.Atoi(string(ps.ByName("sid"))); err != nil { diff --git a/token-validator/wg.go b/token-validator/wg.go new file mode 100644 index 0000000..f9a296d --- /dev/null +++ b/token-validator/wg.go @@ -0,0 +1,207 @@ +package main + +import ( + "crypto/rand" + "crypto/sha512" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/julienschmidt/httprouter" +) + +func init() { + router.GET("/api/wg.conf", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + w.Header().Set("Content-Type", "text/plain") + + err := GenWGConfig(w) + + if err != nil { + w.Write([]byte(err.Error())) + } + }) + router.GET("/api/wg/", apiAuthHandler(showWgTunnel)) + router.GET("/api/wginfo", apiAuthHandler(func (student Student, ps httprouter.Params, body []byte) (interface{}, error) { + return getTunnelInfo(student), nil + })) + router.POST("/api/wg/", apiAuthHandler(genWgToken)) + router.GET("/api/wg/:token", apiAuthHandler(getWgTunnelInfo)) +} + +func showWgTunnel(student Student, ps httprouter.Params, body []byte) (interface{}, error) { + // Get tunnels assigned to the student + return student.GetTunnelTokens() +} + +func genWgToken(student Student, ps httprouter.Params, body []byte) (interface{}, error) { + // Generate a token to access related wg info + return student.NewTunnelToken() +} + +type TunnelInfo struct { + Status string `json:"status"` + SrvPubKey []byte `json:"srv_pubkey"` + SrvPort uint16 `json:"srv_port"` + CltIPv6 string `json:"clt_ipv6"` + CltRange uint8 `json:"clt_range"` + SrvGW6 string `json:"srv_gw6"` +} + +func getTunnelInfo(student Student) TunnelInfo { + return TunnelInfo{ + Status: "OK", + SrvPubKey: []byte{'T', 'B', 'D'}, + SrvPort: 42912, + CltIPv6: studentIP(student.Id), + CltRange: 80, + SrvGW6: "2a01:e0a:2b:2252::1", + } +} + +type PubTunnel struct { + PubKey []byte +} + +func getWgTunnelInfo(student Student, ps httprouter.Params, body []byte) (interface{}, error) { + // Access wg infos + tokenhex := []byte(ps.ByName("token")) + tokendec := make([]byte, hex.DecodedLen(len(tokenhex))) + n, err := hex.Decode(tokendec, tokenhex) + if err != nil { + return nil, err + } + + token, err := student.GetTunnelToken(tokendec[:n]) + if err != nil { + return nil, err + } + + var pt PubTunnel + if err := json.Unmarshal(body, &pt); err != nil { + return nil, err + } + + token.PubKey = pt.PubKey + _, err = token.Update() + if err != nil { + return nil, err + } + + return getTunnelInfo(student), nil +} + + +type TunnelToken struct { + token []byte + TokenText string + IdStudent int64 + PubKey []byte + Time time.Time +} + +func getTunnelToken(token []byte) (t TunnelToken, err error) { + err = DBQueryRow("SELECT token, token_text, id_student, pubkey, time FROM student_tunnel_tokens WHERE token=? ORDER BY time DESC", token).Scan(&t.token, &t.TokenText, &t.IdStudent, &t.PubKey, &t.Time) + return +} + +func tokenFromText(token string) []byte { + sha := sha512.Sum512_256([]byte(token)) + return sha[:] +} + +func (student Student) NewTunnelToken() (t TunnelToken, err error) { + tok := make([]byte, 7) + if _, err = rand.Read(tok); err != nil { + return + } + + t.TokenText = base64.RawStdEncoding.EncodeToString(tok) + t.token = tokenFromText(t.TokenText) + t.IdStudent = student.Id + + _, err = DBExec("INSERT INTO student_tunnel_tokens (token, token_text, id_student, time) VALUES (?, ?, ?, ?)", t.token, t.TokenText, student.Id, time.Now()) + return +} + +func (student Student) GetTunnelTokens() (ts []TunnelToken, err error) { + if rows, errr := DBQuery("SELECT token, token_text, id_student, pubkey, time FROM student_tunnel_tokens WHERE id_student = ? ORDER BY time DESC", student.Id); errr != nil { + return nil, errr + } else { + defer rows.Close() + + for rows.Next() { + var t TunnelToken + if err = rows.Scan(&t.token, &t.TokenText, &t.IdStudent, &t.PubKey, &t.Time); err != nil { + return + } + ts = append(ts, t) + } + if err = rows.Err(); err != nil { + return + } + + return + } +} + +func (student Student) GetTunnelToken(token []byte) (t TunnelToken, err error) { + err = DBQueryRow("SELECT token, token_text, id_student, pubkey, time FROM student_tunnel_tokens WHERE token = ? AND id_student = ? ORDER BY time DESC", token, student.Id).Scan(&t.token, &t.TokenText, &t.IdStudent, &t.PubKey, &t.Time) + return +} + +func (t *TunnelToken) Update() (int64, error) { + newtoken := tokenFromText(t.TokenText) + tm := time.Now() + + if res, err := DBExec("UPDATE student_tunnel_tokens SET token = ?, token_text = ?, id_student = ?, pubkey = ?, time = ? WHERE token = ?", newtoken, t.TokenText, t.IdStudent, t.PubKey, tm, t.token); err != nil { + return 0, err + } else if nb, err := res.RowsAffected(); err != nil { + return 0, err + } else { + t.token = newtoken + t.Time = tm + return nb, err + } +} + +func GetStudentsTunnels() (ts []TunnelToken, err error) { + if rows, errr := DBQuery("SELECT token, token_text, id_student, pubkey, time FROM student_tunnel_tokens T INNER JOIN (SELECT B.token, B.id_student, MAX(B.time) FROM student_tunnel_tokens B GROUP BY id_student) L ON T.token = L.token"); errr != nil { + return nil, errr + } else { + defer rows.Close() + + for rows.Next() { + var t TunnelToken + if err = rows.Scan(&t.token, &t.TokenText, &t.IdStudent, &t.PubKey, &t.Time); err != nil { + return + } + ts = append(ts, t) + } + + err = rows.Err() + return + } +} + +func studentIP(idstd int64) string { + return fmt.Sprintf("2a01:e0a:2b:2252:%x::", idstd) +} + +func GenWGConfig(w io.Writer) (error) { + ts, err := GetStudentsTunnels() + if err != nil { + return err + } + + for _, t := range ts { + w.Write([]byte(fmt.Sprintf(`[Peer] +PublicKey = %s +AllowedIPs = %s/%d`, base64.StdEncoding.EncodeToString(t.PubKey), studentIP(t.IdStudent), 80))) + } + + return nil +}