package main import ( "context" "log" "net/http" "strconv" "sync" "time" "github.com/julienschmidt/httprouter" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) var ( WSClients = map[int64][]WSClient{} WSClientsMutex = sync.RWMutex{} WSAdmin = []WSClient{} WSAdminMutex = sync.RWMutex{} ) func init() { router.GET("/api/surveys/:sid/ws", rawAuthHandler(SurveyWS, loggedUser)) router.GET("/api/surveys/:sid/ws-admin", rawAuthHandler(SurveyWSAdmin, adminRestricted)) router.GET("/api/surveys/:sid/ws/stats", apiHandler(surveyHandler(func(s Survey, body []byte) HTTPResponse { return APIResponse{ WSSurveyStats(s.Id), } }), adminRestricted)) } func WSSurveyStats(sid int64) map[string]interface{} { var users []string var nb int WSClientsMutex.RLock() defer WSClientsMutex.RUnlock() if w, ok := WSClients[sid]; ok { nb = len(w) for _, ws := range w { users = append(users, ws.u.Login) } } return map[string]interface{}{ "nb_clients": nb, "users": users, } } type WSClient struct { ws *websocket.Conn c chan WSMessage u *User sid int64 } func SurveyWS_run(ws *websocket.Conn, c chan WSMessage, sid int64, u *User) { for { msg, ok := <-c if !ok { break } ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() err := wsjson.Write(ctx, ws, msg) if err != nil { log.Println("Error on WebSocket:", err) ws.Close(websocket.StatusInternalError, "error on write") break } } ws.Close(websocket.StatusNormalClosure, "end") WSClientsMutex.Lock() defer WSClientsMutex.Unlock() for i, clt := range WSClients[sid] { if clt.ws == ws { WSClients[sid][i] = WSClients[sid][len(WSClients[sid])-1] WSClients[sid] = WSClients[sid][:len(WSClients[sid])-1] break } } log.Println(u.Login, "disconnected") } func msgCurrentState(survey *Survey) (msg WSMessage) { if *survey.Direct == 0 { msg = WSMessage{ Action: "pause", } } else { msg = WSMessage{ Action: "new_question", QuestionId: survey.Direct, } } return } func SurveyWS(w http.ResponseWriter, r *http.Request, ps httprouter.Params, u *User, body []byte) { if sid, err := strconv.Atoi(string(ps.ByName("sid"))); err != nil { http.Error(w, "{\"errmsg\": \"Invalid survey identifier\"}", http.StatusBadRequest) return } else if survey, err := getSurvey(sid); err != nil { http.Error(w, "{\"errmsg\": \"Survey not found\"}", http.StatusNotFound) return } else if survey.Direct == nil { http.Error(w, "{\"errmsg\": \"Not a direct survey\"}", http.StatusBadRequest) return } else { ws, err := websocket.Accept(w, r, nil) if err != nil { log.Fatal("error get connection", err) } log.Println(u.Login, "is now connected to WS", sid) c := make(chan WSMessage, 1) WSClientsMutex.Lock() defer WSClientsMutex.Unlock() WSClients[survey.Id] = append(WSClients[survey.Id], WSClient{ws, c, u, survey.Id}) // Send current state c <- msgCurrentState(&survey) go SurveyWS_run(ws, c, survey.Id, u) } } func WSWriteAll(message WSMessage) { WSClientsMutex.RLock() defer WSClientsMutex.RUnlock() for _, wss := range WSClients { for _, ws := range wss { ws.c <- message } } } type WSMessage struct { Action string `json:"action"` SurveyId *int64 `json:"survey,omitempty"` QuestionId *int64 `json:"question,omitempty"` Stats map[string]interface{} `json:"stats,omitempty"` UserId *int64 `json:"user,omitempty"` Response string `json:"value,omitempty"` } func (s *Survey) WSWriteAll(message WSMessage) { WSClientsMutex.RLock() defer WSClientsMutex.RUnlock() if wss, ok := WSClients[s.Id]; ok { for _, ws := range wss { ws.c <- message } } } func (s *Survey) WSCloseAll(message string) { WSClientsMutex.RLock() defer WSClientsMutex.RUnlock() if wss, ok := WSClients[s.Id]; ok { for _, ws := range wss { close(ws.c) } } } // Admin ////////////////////////////////////////////////////////////// func SurveyWSAdmin_run(ctx context.Context, ws *websocket.Conn, c chan WSMessage, sid int64, u *User) { ct := time.Tick(25000 * time.Millisecond) loopadmin: for { select { case <-ct: ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() err := wsjson.Write(ctx, ws, WSMessage{ Action: "stats", Stats: WSSurveyStats(sid), }) if err != nil { log.Println("Error on WebSocket:", err) ws.Close(websocket.StatusInternalError, "error on write") break } case msg, ok := <-c: if !ok { break loopadmin } ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() err := wsjson.Write(ctx, ws, msg) if err != nil { log.Println("Error on WebSocket:", err) ws.Close(websocket.StatusInternalError, "error on write") break } } } ws.Close(websocket.StatusNormalClosure, "end") WSAdminMutex.Lock() defer WSAdminMutex.Unlock() for i, clt := range WSAdmin { if clt.ws == ws { WSAdmin[i] = WSAdmin[len(WSAdmin)-1] WSAdmin = WSAdmin[:len(WSAdmin)-1] break } } log.Println(u.Login, "admin disconnected") } func SurveyWSAdmin(w http.ResponseWriter, r *http.Request, ps httprouter.Params, u *User, body []byte) { if sid, err := strconv.Atoi(string(ps.ByName("sid"))); err != nil { http.Error(w, "{\"errmsg\": \"Invalid survey identifier\"}", http.StatusBadRequest) return } else if survey, err := getSurvey(sid); err != nil { http.Error(w, "{\"errmsg\": \"Survey not found\"}", http.StatusNotFound) return } else if survey.Direct == nil { http.Error(w, "{\"errmsg\": \"Not a direct survey\"}", http.StatusBadRequest) return } else { ws, err := websocket.Accept(w, r, nil) if err != nil { log.Fatal("error get connection", err) } log.Println(u.Login, "is now connected to WS-admin", sid) c := make(chan WSMessage, 2) WSAdminMutex.Lock() defer WSAdminMutex.Unlock() WSAdmin = append(WSAdmin, WSClient{ws, c, u, survey.Id}) // Send current state c <- msgCurrentState(&survey) go SurveyWSAdmin_run(r.Context(), ws, c, survey.Id, u) go func(c chan WSMessage, sid int) { var v WSMessage var err error for { err = wsjson.Read(context.Background(), ws, &v) if err != nil { log.Println("Error when receiving message:", err) close(c) break } if v.Action == "new_question" && v.QuestionId != nil { if survey, err := getSurvey(sid); err != nil { log.Println("Unable to retrieve survey:", err) } else { survey.Direct = v.QuestionId _, err = survey.Update() if err != nil { log.Println("Unable to update survey:", err) } survey.WSWriteAll(v) v.SurveyId = &survey.Id WSAdminWriteAll(v) } } else if v.Action == "pause" { if survey, err := getSurvey(sid); err != nil { log.Println("Unable to retrieve survey:", err) } else { var u int64 = 0 survey.Direct = &u _, err = survey.Update() if err != nil { log.Println("Unable to update survey:", err) } survey.WSWriteAll(v) v.SurveyId = &survey.Id WSAdminWriteAll(v) } } else if v.Action == "end" { if survey, err := getSurvey(sid); err != nil { log.Println("Unable to retrieve survey:", err) } else { survey.Direct = nil _, err = survey.Update() if err != nil { log.Println("Unable to update survey:", err) } survey.WSCloseAll("Fin du live") v.SurveyId = &survey.Id WSAdminWriteAll(v) } } else if v.Action == "get_stats" { err = wsjson.Write(context.Background(), ws, WSMessage{Action: "stats", Stats: WSSurveyStats(int64(sid))}) } else if v.Action == "get_responses" { if survey, err := getSurvey(sid); err != nil { log.Println("Unable to retrieve survey:", err) } else if questions, err := survey.GetQuestions(); err != nil { log.Println("Unable to retrieve questions:", err) } else { for _, q := range questions { if responses, err := q.GetResponses(); err != nil { log.Println("Unable to retrieve questions:", err) } else { for _, r := range responses { wsjson.Write(context.Background(), ws, WSMessage{Action: "new_response", UserId: &r.IdUser, QuestionId: &q.Id, Response: r.Answer}) } } } } } else { log.Println("Unknown admin action:", v.Action) } } }(c, sid) } } func WSAdminWriteAll(message WSMessage) { WSAdminMutex.RLock() defer WSAdminMutex.RUnlock() for _, ws := range WSAdmin { ws.c <- message } } func (s *Survey) WSAdminWriteAll(message WSMessage) { WSAdminMutex.RLock() defer WSAdminMutex.RUnlock() for _, ws := range WSAdmin { log.Println("snd", message, ws.sid, s.Id) if ws.sid == s.Id { ws.c <- message } } }