Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 103 additions & 3 deletions backend/controllers/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,126 @@ package controllers

import (
"net/http"
"net/url"
"os"
"strings"

"ccsync_backend/utils"

"github.com/gorilla/sessions"
"github.com/gorilla/websocket"
)

// getEnv returns the environment mode, defaulting to "development"
func getEnv() string {
env := os.Getenv("ENV")
if env == "" {
return "development"
}
return env
}

type JobStatus struct {
Job string `json:"job"`
Status string `json:"status"`
}

var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// checkWebSocketOrigin validates the Origin header against allowed origins
func checkWebSocketOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
// No origin header - could be same-origin or non-browser client
return true
}

// Get allowed origin from environment
allowedOrigin := os.Getenv("ALLOWED_ORIGIN")
if allowedOrigin == "" {
allowedOrigin = os.Getenv("FRONTEND_ORIGIN_DEV")
}

// In development, allow localhost origins
if getEnv() != "production" {
if strings.HasPrefix(origin, "http://localhost") ||
strings.HasPrefix(origin, "http://127.0.0.1") {
return true
}
}

// Check against configured allowed origin
if allowedOrigin != "" && origin == allowedOrigin {
return true
},
}

// If no ALLOWED_ORIGIN configured, check if origin hostname matches the request host
// Parse the origin URL to safely extract and compare hostnames
originURL, err := url.Parse(origin)
if err == nil {
// Extract just the hostname (without port) from both origin and request
originHost := originURL.Hostname()
requestHost := r.Host
// Remove port from request host if present
if colonIdx := strings.LastIndex(requestHost, ":"); colonIdx != -1 {
// Check if it's not an IPv6 address (which also contains colons)
if !strings.Contains(requestHost, "]") || colonIdx > strings.Index(requestHost, "]") {
requestHost = requestHost[:colonIdx]
}
}
if originHost == requestHost {
return true
}
}

utils.Logger.Warnf("WebSocket connection rejected from origin: %s", origin)
return false
}

var upgrader = websocket.Upgrader{
CheckOrigin: checkWebSocketOrigin,
}

var clients = make(map[*websocket.Conn]bool)
var broadcast = make(chan JobStatus)

// AuthenticatedWebSocketHandler creates a WebSocket handler that requires authentication
func AuthenticatedWebSocketHandler(store *sessions.CookieStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Validate session before upgrading to WebSocket
session, err := store.Get(r, "session-name")
if err != nil {
utils.Logger.Warnf("WebSocket auth failed: could not get session: %v", err)
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

userInfo, ok := session.Values["user"].(map[string]interface{})
if !ok || userInfo == nil {
utils.Logger.Warnf("WebSocket auth failed: no user in session")
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

// User is authenticated, proceed with WebSocket upgrade
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
utils.Logger.Error("WebSocket Upgrade Error:", err)
return
}
defer ws.Close()

clients[ws] = true
for {
_, _, err := ws.ReadMessage()
if err != nil {
delete(clients, ws)
break
}
}
}
}

// WebSocketHandler is kept for backward compatibility but should not be used
// Use AuthenticatedWebSocketHandler instead
func WebSocketHandler(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func main() {

mux.HandleFunc("/health", controllers.HealthCheckHandler)

mux.HandleFunc("/ws", controllers.WebSocketHandler)
mux.HandleFunc("/ws", controllers.AuthenticatedWebSocketHandler(store))

// API documentation endpoint
mux.HandleFunc("/api/docs/", httpSwagger.WrapHandler)
Expand Down