From 16d2285ff31c0ee6809441ff58e91c19b11c20a8 Mon Sep 17 00:00:00 2001 From: nemunaire Date: Fri, 22 Jun 2018 00:29:12 +0200 Subject: [PATCH] Initial commit --- .gitignore | 1 + api/handlers.go | 82 ++++++++++++++++++++++++++++++++++ api/router.go | 11 +++++ api/version.go | 13 ++++++ main.go | 115 ++++++++++++++++++++++++++++++++++++++++++++++++ struct/db.go | 96 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 318 insertions(+) create mode 100644 .gitignore create mode 100644 api/handlers.go create mode 100644 api/router.go create mode 100644 api/version.go create mode 100644 main.go create mode 100644 struct/db.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8980d88 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +checkhome diff --git a/api/handlers.go b/api/handlers.go new file mode 100644 index 0000000..6fa5550 --- /dev/null +++ b/api/handlers.go @@ -0,0 +1,82 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + + "github.com/julienschmidt/httprouter" +) + +type DispatchFunction func(httprouter.Params, []byte) (interface{}, error) + +func apiHandler(f DispatchFunction) func(http.ResponseWriter, *http.Request, httprouter.Params) { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + if addr := r.Header.Get("X-Forwarded-For"); addr != "" { + r.RemoteAddr = addr + } + log.Printf("%s \"%s %s\" [%s]\n", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent()) + + // Read the body + if r.ContentLength < 0 || r.ContentLength > 6553600 { + http.Error(w, fmt.Sprintf("{errmsg:\"Request too large or request size unknown\"}"), http.StatusRequestEntityTooLarge) + return + } + var body []byte + if r.ContentLength > 0 { + tmp := make([]byte, 1024) + for { + n, err := r.Body.Read(tmp) + for j := 0; j < n; j++ { + body = append(body, tmp[j]) + } + if err != nil || n <= 0 { + break + } + } + } + + var ret interface{} + var err error = nil + + ret, err = f(ps, body) + + // Format response + resStatus := http.StatusOK + if err != nil { + ret = map[string]string{"errmsg": err.Error()} + resStatus = http.StatusBadRequest + log.Println(r.RemoteAddr, resStatus, err.Error()) + } + + if ret == nil { + ret = map[string]string{"errmsg": "Page not found"} + resStatus = http.StatusNotFound + } + + if str, found := ret.(string); found { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resStatus) + io.WriteString(w, str) + } else if bts, found := ret.([]byte); found { + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", "attachment") + w.Header().Set("Content-Transfer-Encoding", "binary") + w.WriteHeader(resStatus) + w.Write(bts) + } else if j, err := json.Marshal(ret); err != nil { + w.Header().Set("Content-Type", "application/json") + http.Error(w, fmt.Sprintf("{\"errmsg\":%q}", err), http.StatusInternalServerError) + } else { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resStatus) + w.Write(j) + } + } +} + +func notFound(ps httprouter.Params, _ []byte) (interface{}, error) { + return nil, nil +} diff --git a/api/router.go b/api/router.go new file mode 100644 index 0000000..a6bd873 --- /dev/null +++ b/api/router.go @@ -0,0 +1,11 @@ +package api + +import ( + "github.com/julienschmidt/httprouter" +) + +var router = httprouter.New() + +func Router() *httprouter.Router { + return router +} diff --git a/api/version.go b/api/version.go new file mode 100644 index 0000000..1386a96 --- /dev/null +++ b/api/version.go @@ -0,0 +1,13 @@ +package api + +import ( + "github.com/julienschmidt/httprouter" +) + +func init() { + router.GET("/api/version", apiHandler(showVersion)) +} + +func showVersion(_ httprouter.Params, body []byte) (interface{}, error) { + return map[string]interface{}{"version": 0.1}, nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..4766062 --- /dev/null +++ b/main.go @@ -0,0 +1,115 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "net/url" + "os" + "os/signal" + "path" + "path/filepath" + "strings" + "syscall" + + "git.nemunai.re/checkhome/api" + "git.nemunai.re/checkhome/struct" +) + +var StaticDir string + +type ResponseWriterPrefix struct { + real http.ResponseWriter + prefix string +} + +func (r ResponseWriterPrefix) Header() http.Header { + return r.real.Header() +} + +func (r ResponseWriterPrefix) WriteHeader(s int) { + if v, exists := r.real.Header()["Location"]; exists { + r.real.Header().Set("Location", r.prefix + v[0]) + } + r.real.WriteHeader(s) +} + +func (r ResponseWriterPrefix) Write(z []byte) (int, error) { + return r.real.Write(z) +} + +func StripPrefix(prefix string, h http.Handler) http.Handler { + if prefix == "" { + return h + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if prefix != "/" && r.URL.Path == "/" { + http.Redirect(w, r, prefix + "/", http.StatusFound) + } else if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + r2 := new(http.Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = p + h.ServeHTTP(ResponseWriterPrefix{w, prefix}, r2) + } else { + h.ServeHTTP(w, r) + } + }) +} + +func main() { + var bind = flag.String("bind", "127.0.0.1:8080", "Bind port/socket") + var dsn = flag.String("dsn", ckh.DSNGenerator(), "DSN to connect to the MySQL server") + var baseURL = flag.String("baseurl", "/", "URL prepended to each URL") + flag.StringVar(&StaticDir, "static", StaticDir, "Directory containing static files") + flag.Parse() + + // Sanitize options + var err error + log.Println("Checking paths...") + if StaticDir, err = filepath.Abs(StaticDir); err != nil { + log.Fatal(err) + } + if *baseURL != "/" { + tmp := path.Clean(*baseURL) + baseURL = &tmp + } else { + tmp := "" + baseURL = &tmp + } + + log.Println("Opening database...") + if err := ckh.DBInit(*dsn); err != nil { + log.Fatal("Cannot open the database: ", err) + } + defer ckh.DBClose() + + log.Println("Creating database...") + if err := ckh.DBCreate(); err != nil { + log.Fatal("Cannot create database: ", err) + } + // Prepare graceful shutdown + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) + + srv := &http.Server{ + Addr: *bind, + Handler: StripPrefix(*baseURL, api.Router()), + } + + // Serve content + go func() { + log.Fatal(srv.ListenAndServe()) + }() + log.Println(fmt.Sprintf("Ready, listening on %s", *bind)) + + // Wait shutdown signal + <-interrupt + + log.Print("The service is shutting down...") + srv.Shutdown(context.Background()) + log.Println("done") +} diff --git a/struct/db.go b/struct/db.go new file mode 100644 index 0000000..acdc217 --- /dev/null +++ b/struct/db.go @@ -0,0 +1,96 @@ +package ckh + +import ( + "database/sql" + "log" + "os" + "strings" + "time" + _ "github.com/go-sql-driver/mysql" +) + +// db stores the connection to the database +var db *sql.DB + +// DSNGenerator returns DSN filed with values from environment +func DSNGenerator() string { + db_user := "checkhome" + db_password := "checkhome" + db_host := "" + db_db := "checkhome" + + if v, exists := os.LookupEnv("MYSQL_HOST"); exists { + db_host = v + } + if v, exists := os.LookupEnv("MYSQL_PASSWORD"); exists { + db_password = v + } else if v, exists := os.LookupEnv("MYSQL_ROOT_PASSWORD"); exists { + db_user = "root" + db_password = v + } + if v, exists := os.LookupEnv("MYSQL_USER"); exists { + db_user = v + } + if v, exists := os.LookupEnv("MYSQL_DATABASE"); exists { + db_db = v + } + + return db_user + ":" + db_password + "@" + db_host + "/" + db_db +} + +// DBInit establishes the connection to the database +func DBInit(dsn string) (err error) { + if db, err = sql.Open("mysql", dsn + "?parseTime=true&foreign_key_checks=1"); err != nil { + return + } + + _, err = db.Exec(`SET SESSION sql_mode = 'STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION';`) + for i := 0; err != nil && i < 15; i += 1 { + if _, err = db.Exec(`SET SESSION sql_mode = 'STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION';`); err != nil && i <= 5 { + log.Println("An error occurs when trying to connect to DB, will retry in 2 seconds: ", err) + time.Sleep(2 * time.Second) + } + } + + return +} + +// DBCreate creates all necessary tables used by the package +func DBCreate() (err error) { + ct := ` +CREATE TABLE IF NOT EXISTS rooms (id_room INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, label TEXT NOT NULL); +CREATE TABLE IF NOT EXISTS tags (id_tag INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, label TEXT NOT NULL); +CREATE TABLE IF NOT EXISTS items (id_item INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, label TEXT NOT NULL, description TEXT NOT NULL, id_room INTEGER); +CREATE TABLE IF NOT EXISTS item_tag (id_item INTEGER NOT NULL, id_tag INTEGER NOT NULL); +CREATE TABLE IF NOT EXISTS users (id_user INTEGER NOT NULL, username VARCHAR(255), password BINARY(64)); +` + for _, ln := range strings.Split(ct, "\n") { + if len(ln) == 0 { + continue + } + if _, err = db.Exec(ln); err != nil { + return + } + } + return +} + +func DBClose() error { + return db.Close() +} + +func DBPrepare(query string) (*sql.Stmt, error) { + return db.Prepare(query) +} + +func DBQuery(query string, args ...interface{}) (*sql.Rows, error) { + return db.Query(query, args...) +} + +func DBExec(query string, args ...interface{}) (sql.Result, error) { + return db.Exec(query, args...) +} + +func DBQueryRow(query string, args ...interface{}) *sql.Row { + return db.QueryRow(query, args...) +}