diff --git a/main.go b/main.go index e875efc..1b2804f 100644 --- a/main.go +++ b/main.go @@ -59,6 +59,7 @@ func StripPrefix(prefix string, h http.Handler) http.Handler { func main() { var bind = flag.String("bind", ":8081", "Bind port/socket") var dsn = flag.String("dsn", DSNGenerator(), "DSN to connect to the MySQL server") + flag.StringVar(&DevProxy, "dev", DevProxy, "Proxify traffic to this host for static assets") flag.StringVar(&baseURL, "baseurl", baseURL, "URL prepended to each URL") flag.UintVar(¤tPromo, "current-promo", currentPromo, "Year of the current promotion") flag.Var(&localAuthUsers, "local-auth-user", "Allow local authentication for this user (bypass OIDC).") @@ -76,6 +77,14 @@ func main() { baseURL = "" } + if DevProxy != "" { + Router().GET("/.svelte-kit/*_", serveOrReverse("")) + Router().GET("/node_modules/*_", serveOrReverse("")) + Router().GET("/@vite/*_", serveOrReverse("")) + Router().GET("/__vite_ping", serveOrReverse("")) + Router().GET("/src/*_", serveOrReverse("")) + } + initializeOIDC() // Initialize contents diff --git a/static.go b/static.go index 0f24406..2ea309e 100644 --- a/static.go +++ b/static.go @@ -1,17 +1,49 @@ package main import ( + "io" "net/http" + "net/url" + "path" "github.com/julienschmidt/httprouter" ) +var DevProxy string + func serveOrReverse(forced_url string) func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - if forced_url != "" { - r.URL.Path = forced_url + if DevProxy != "" { + if u, err := url.Parse(DevProxy); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else { + if forced_url != "" { + u.Path = path.Join(u.Path, forced_url) + } else { + u.Path = path.Join(u.Path, r.URL.Path) + } + + if r, err := http.NewRequest(r.Method, u.String(), r.Body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else if resp, err := http.DefaultClient.Do(r); err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + } else { + defer resp.Body.Close() + + for key := range resp.Header { + w.Header().Add(key, resp.Header.Get(key)) + } + w.WriteHeader(resp.StatusCode) + + io.Copy(w, resp.Body) + } + } + } else { + if forced_url != "" { + r.URL.Path = forced_url + } + http.FileServer(Assets).ServeHTTP(w, r) } - http.FileServer(Assets).ServeHTTP(w, r) } }