diff --git a/backend/controllers/websocket.go b/backend/controllers/websocket.go index 5d245f5d..67d31b87 100644 --- a/backend/controllers/websocket.go +++ b/backend/controllers/websocket.go @@ -8,9 +8,19 @@ import ( "ccsync_backend/utils" + "github.com/gorilla/sessions" "github.com/gorilla/websocket" ) +// getEnv returns the environment mode, defaulting to "development" +func getEnv() string { + env := os.Getenv("ENV") + if env == "" { + return "development" + } + return env +} + type JobStatus struct { Job string `json:"job"` Status string `json:"status"` @@ -21,7 +31,7 @@ func checkWebSocketOrigin(r *http.Request) bool { origin := r.Header.Get("Origin") // In development mode, be more permissive - if os.Getenv("ENV") != "production" { + if getEnv() != "production" { if origin == "" || strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1") { @@ -73,6 +83,45 @@ var upgrader = websocket.Upgrader{ var clients = make(map[*websocket.Conn]bool) var broadcast = make(chan JobStatus) +// AuthenticatedWebSocketHandler creates a WebSocket handler that requires authentication +func AuthenticatedWebSocketHandler(store *sessions.CookieStore) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Validate session before upgrading to WebSocket + session, err := store.Get(r, "session-name") + if err != nil { + utils.Logger.Warnf("WebSocket auth failed: could not get session: %v", err) + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + + userInfo, ok := session.Values["user"].(map[string]interface{}) + if !ok || userInfo == nil { + utils.Logger.Warnf("WebSocket auth failed: no user in session") + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + + // User is authenticated, proceed with WebSocket upgrade + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + utils.Logger.Error("WebSocket Upgrade Error:", err) + return + } + defer ws.Close() + + clients[ws] = true + for { + _, _, err := ws.ReadMessage() + if err != nil { + delete(clients, ws) + break + } + } + } +} + +// WebSocketHandler is kept for backward compatibility but should not be used +// Use AuthenticatedWebSocketHandler instead func WebSocketHandler(w http.ResponseWriter, r *http.Request) { ws, err := upgrader.Upgrade(w, r, nil) if err != nil { diff --git a/backend/main.go b/backend/main.go index 20acf19e..e76e0c1f 100644 --- a/backend/main.go +++ b/backend/main.go @@ -127,7 +127,7 @@ func main() { mux.HandleFunc("/health", controllers.HealthCheckHandler) - mux.HandleFunc("/ws", controllers.WebSocketHandler) + mux.HandleFunc("/ws", controllers.AuthenticatedWebSocketHandler(store)) // API documentation endpoint mux.HandleFunc("/api/docs/", httpSwagger.WrapHandler)