From 62e8d609de71ab4ca883e7eb96abcbca70d8a161 Mon Sep 17 00:00:00 2001 From: rcourtman Date: Sat, 24 Jan 2026 22:43:50 +0000 Subject: [PATCH] feat: add WebSocket multi-tenant isolation - Enhance WebSocket hub with tenant awareness - Add tenant isolation for real-time updates - Add hub tenant isolation tests --- internal/websocket/hub.go | 333 +++++++++++++++++++++++--- internal/websocket/hub_tenant_test.go | 126 ++++++++++ 2 files changed, 430 insertions(+), 29 deletions(-) create mode 100644 internal/websocket/hub_tenant_test.go diff --git a/internal/websocket/hub.go b/internal/websocket/hub.go index 7f943871d..6ff402be6 100644 --- a/internal/websocket/hub.go +++ b/internal/websocket/hub.go @@ -1,6 +1,7 @@ package websocket import ( + "context" "encoding/json" "fmt" "math" @@ -15,6 +16,7 @@ import ( "github.com/gorilla/websocket" "github.com/rcourtman/pulse-go-rewrite/internal/alerts" "github.com/rcourtman/pulse-go-rewrite/internal/utils" + "github.com/rcourtman/pulse-go-rewrite/pkg/auth" "github.com/rs/zerolog/log" ) @@ -175,6 +177,7 @@ type Client struct { conn *websocket.Conn send chan []byte id string + orgID string // Organization ID for tenant isolation lastPing time.Time closed atomic.Bool // Set when the client is unregistered; prevents sends to closed channel writeFailures int32 // Consecutive write failures; disconnects after maxWriteFailures @@ -278,22 +281,57 @@ func cloneMetadataValue(value interface{}) interface{} { } } +// TenantBroadcast represents a broadcast message targeted at a specific tenant. +type TenantBroadcast struct { + OrgID string + Message Message +} + +// OrgAuthChecker checks if a user/token can access an organization. +type OrgAuthChecker interface { + // CanAccessOrg checks if the given user (and optional token) can access the org. + CanAccessOrg(userID string, token interface{}, orgID string) bool +} + +// MultiTenantCheckResult contains the result of a multi-tenant check. +type MultiTenantCheckResult struct { + Allowed bool + FeatureEnabled bool // false = feature flag disabled + Licensed bool // false = not licensed (only meaningful if FeatureEnabled is true) + Reason string // Human-readable reason for denial +} + +// MultiTenantChecker checks if multi-tenant functionality is enabled and licensed. +type MultiTenantChecker interface { + // CheckMultiTenant checks if multi-tenant is enabled (feature flag) and licensed for the org. + // Returns a result with details about why access was denied if not allowed. + CheckMultiTenant(ctx context.Context, orgID string) MultiTenantCheckResult +} + // Hub maintains active WebSocket clients and broadcasts messages type Hub struct { - clients map[*Client]bool - broadcast chan []byte - broadcastSeq chan Message // Sequenced broadcast channel for ordering - register chan *Client - unregister chan *Client - stopChan chan struct{} // Signals shutdown - mu sync.RWMutex - getState func() interface{} // Function to get current state - allowedOrigins []string // Allowed origins for CORS + clients map[*Client]bool // All clients (legacy support) + clientsByTenant map[string]map[*Client]bool // Per-tenant client tracking + broadcast chan []byte + broadcastSeq chan Message // Sequenced broadcast channel for ordering + tenantBroadcast chan TenantBroadcast // Per-tenant broadcast channel + register chan *Client + unregister chan *Client + stopChan chan struct{} // Signals shutdown + mu sync.RWMutex + getState func() interface{} // Function to get current state (legacy) + getStateByTenant func(orgID string) interface{} // Function to get state for specific tenant + allowedOrigins []string // Allowed origins for CORS + orgAuthChecker OrgAuthChecker // Org authorization checker + multiTenantChecker MultiTenantChecker // Multi-tenant feature flag and license checker // Broadcast coalescing fields coalesceWindow time.Duration coalescePending *Message coalesceTimer *time.Timer coalesceMutex sync.Mutex + // Per-tenant coalescing + tenantCoalescePending map[string]*Message + tenantCoalesceTimers map[string]*time.Timer } // Message represents a WebSocket message @@ -303,25 +341,70 @@ type Message struct { Timestamp string `json:"timestamp,omitempty"` } -// SetStateGetter sets the state getter function +// SetStateGetter sets the state getter function (legacy, for default tenant) func (h *Hub) SetStateGetter(getState func() interface{}) { h.mu.Lock() defer h.mu.Unlock() h.getState = getState } +// SetStateGetterForTenant sets the tenant-aware state getter function +func (h *Hub) SetStateGetterForTenant(getState func(orgID string) interface{}) { + h.mu.Lock() + defer h.mu.Unlock() + h.getStateByTenant = getState +} + +// SetOrgAuthChecker sets the organization authorization checker. +func (h *Hub) SetOrgAuthChecker(checker OrgAuthChecker) { + h.mu.Lock() + defer h.mu.Unlock() + h.orgAuthChecker = checker +} + +// SetMultiTenantChecker sets the multi-tenant feature flag and license checker. +func (h *Hub) SetMultiTenantChecker(checker MultiTenantChecker) { + h.mu.Lock() + defer h.mu.Unlock() + h.multiTenantChecker = checker +} + +// getStateForClient returns the state for a specific client based on their tenant +func (h *Hub) getStateForClient(client *Client) interface{} { + h.mu.RLock() + getStateByTenant := h.getStateByTenant + getState := h.getState + h.mu.RUnlock() + + // Try tenant-specific getter first + if getStateByTenant != nil && client.orgID != "" { + return getStateByTenant(client.orgID) + } + + // Fall back to default getter + if getState != nil { + return getState() + } + + return nil +} + // NewHub creates a new WebSocket hub func NewHub(getState func() interface{}) *Hub { return &Hub{ - clients: make(map[*Client]bool), - broadcast: make(chan []byte, 256), - broadcastSeq: make(chan Message, 256), // Buffered sequenced channel - register: make(chan *Client), - unregister: make(chan *Client), - stopChan: make(chan struct{}), - getState: getState, - allowedOrigins: []string{}, // Default to empty (will be set based on actual host) - coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms + clients: make(map[*Client]bool), + clientsByTenant: make(map[string]map[*Client]bool), + broadcast: make(chan []byte, 256), + broadcastSeq: make(chan Message, 256), // Buffered sequenced channel + tenantBroadcast: make(chan TenantBroadcast, 256), // Per-tenant broadcasts + register: make(chan *Client), + unregister: make(chan *Client), + stopChan: make(chan struct{}), + getState: getState, + allowedOrigins: []string{}, // Default to empty (will be set based on actual host) + coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms + tenantCoalescePending: make(map[string]*Message), + tenantCoalesceTimers: make(map[string]*time.Timer), } } @@ -338,12 +421,21 @@ func (h *Hub) Run() { case client := <-h.register: h.mu.Lock() h.clients[client] = true + // Also register by tenant if org ID is set + if client.orgID != "" { + if h.clientsByTenant[client.orgID] == nil { + h.clientsByTenant[client.orgID] = make(map[*Client]bool) + } + h.clientsByTenant[client.orgID][client] = true + } h.mu.Unlock() - log.Info().Str("client", client.id).Msg("WebSocket client connected") + log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client connected") // Send initial state to the new client immediately - log.Debug().Bool("hasGetState", h.getState != nil).Msg("Checking getState function for new client") - if h.getState != nil { + // Use tenant-aware state getter if available + hasGetState := h.getState != nil || h.getStateByTenant != nil + log.Debug().Bool("hasGetState", hasGetState).Msg("Checking getState function for new client") + if hasGetState { // Add a small delay to ensure client is ready go func() { log.Debug().Str("client", client.id).Msg("Starting initial state goroutine") @@ -352,7 +444,7 @@ func (h *Hub) Run() { // First send a small welcome message welcomeMsg := Message{ Type: "welcome", - Data: map[string]string{"message": "Connected to Pulse WebSocket"}, + Data: map[string]string{"message": "Connected to Pulse WebSocket", "orgId": client.orgID}, } if data, err := json.Marshal(welcomeMsg); err == nil { // Check if client is still registered before sending (must hold lock) @@ -376,8 +468,8 @@ func (h *Hub) Run() { time.Sleep(100 * time.Millisecond) log.Debug().Str("client", client.id).Msg("About to get state") - // Get the state - stateData := h.getState() + // Get the state using tenant-aware getter + stateData := h.getStateForClient(client) log.Debug().Str("client", client.id).Interface("stateType", fmt.Sprintf("%T", stateData)).Msg("Got state for initial message") initialMsg := Message{ @@ -412,10 +504,18 @@ func (h *Hub) Run() { h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) + // Also remove from tenant map + if client.orgID != "" && h.clientsByTenant[client.orgID] != nil { + delete(h.clientsByTenant[client.orgID], client) + // Clean up empty tenant maps + if len(h.clientsByTenant[client.orgID]) == 0 { + delete(h.clientsByTenant, client.orgID) + } + } client.closed.Store(true) // Mark closed before closing channel to prevent sends close(client.send) h.mu.Unlock() - log.Info().Str("client", client.id).Msg("WebSocket client disconnected") + log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client disconnected") } else { h.mu.Unlock() } @@ -472,6 +572,67 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { Str("userAgent", r.Header.Get("User-Agent")). Msg("WebSocket upgrade request") + // Extract org ID from request for tenant isolation + // Priority: Header > Cookie > Query param > Default + orgID := r.Header.Get("X-Pulse-Org-ID") + if orgID == "" { + if cookie, err := r.Cookie("pulse_org_id"); err == nil { + orgID = cookie.Value + } + } + if orgID == "" { + orgID = r.URL.Query().Get("org_id") + } + if orgID == "" { + orgID = "default" + } + + // Multi-tenant feature flag and license check for non-default orgs + h.mu.RLock() + mtChecker := h.multiTenantChecker + authChecker := h.orgAuthChecker + h.mu.RUnlock() + + if orgID != "default" { + // Check if multi-tenant is enabled and licensed + if mtChecker != nil { + result := mtChecker.CheckMultiTenant(r.Context(), orgID) + if !result.Allowed { + log.Warn(). + Str("org_id", orgID). + Bool("feature_enabled", result.FeatureEnabled). + Bool("licensed", result.Licensed). + Str("reason", result.Reason). + Msg("WebSocket connection denied - multi-tenant check failed") + + if !result.FeatureEnabled { + // Feature flag disabled - 501 Not Implemented + http.Error(w, "Multi-tenant functionality is not enabled", http.StatusNotImplemented) + } else { + // Feature enabled but not licensed - 402 Payment Required + http.Error(w, "Multi-tenant access requires an Enterprise license", http.StatusPaymentRequired) + } + return + } + } + + // Authorization check - auth context should already be set by AuthContextMiddleware + if authChecker != nil { + // Get user and token from context (set by AuthContextMiddleware) + userID := getUserFromContext(r.Context()) + token := getAPITokenFromContext(r.Context()) + + if !authChecker.CanAccessOrg(userID, token, orgID) { + log.Warn(). + Str("org_id", orgID). + Str("user_id", userID). + Msg("WebSocket connection denied - unauthorized org access") + http.Error(w, "Unauthorized to access organization", http.StatusForbidden) + return + } + } + } + // Create upgrader with our origin check upgrader := websocket.Upgrader{ ReadBufferSize: 1024 * 1024 * 4, // 4MB to handle large state messages @@ -487,15 +648,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { clientID := utils.GenerateID("client") client := &Client{ - hub: h, - conn: conn, + hub: h, + conn: conn, + orgID: orgID, // Keep buffer bounded to avoid holding large state snapshots indefinitely if a client stalls. send: make(chan []byte, 128), id: clientID, lastPing: time.Now(), } - log.Info().Str("client", clientID).Msg("WebSocket client created") + log.Info().Str("client", clientID).Str("org_id", orgID).Msg("WebSocket client created") client.hub.register <- client @@ -504,6 +666,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { go client.readPump() } +// getUserFromContext extracts the user ID from request context. +func getUserFromContext(ctx context.Context) string { + return auth.GetUser(ctx) +} + +// getAPITokenFromContext extracts the API token from request context. +func getAPITokenFromContext(ctx context.Context) interface{} { + return auth.GetAPIToken(ctx) +} + // dispatchToClients fan-outs a marshaled payload to all clients, dropping any that // cannot keep up to prevent unbounded buffering. func (h *Hub) dispatchToClients(data []byte, dropLog string) { @@ -579,6 +751,44 @@ func (h *Hub) runBroadcastSequencer() { } } + case tb := <-h.tenantBroadcast: + // Handle tenant-specific broadcasts with coalescing + if tb.Message.Type == "rawData" { + h.coalesceMutex.Lock() + + // Cancel pending timer for this tenant if exists + if timer := h.tenantCoalesceTimers[tb.OrgID]; timer != nil { + timer.Stop() + } + + // Update pending message for this tenant + msgCopy := tb.Message + h.tenantCoalescePending[tb.OrgID] = &msgCopy + + // Set timer to send after coalesce window + orgID := tb.OrgID // Capture for closure + h.tenantCoalesceTimers[orgID] = time.AfterFunc(h.coalesceWindow, func() { + h.coalesceMutex.Lock() + pending := h.tenantCoalescePending[orgID] + delete(h.tenantCoalescePending, orgID) + delete(h.tenantCoalesceTimers, orgID) + h.coalesceMutex.Unlock() + + if pending != nil { + if data, err := json.Marshal(*pending); err == nil { + h.dispatchToTenantClients(orgID, data, "Client send channel full, dropping tenant coalesced message and closing connection") + } + } + }) + + h.coalesceMutex.Unlock() + } else { + // Non-state messages - send immediately to tenant + if data, err := json.Marshal(tb.Message); err == nil { + h.dispatchToTenantClients(tb.OrgID, data, "Client send channel full, dropping tenant message and closing connection") + } + } + case <-h.stopChan: log.Debug().Msg("Broadcast sequencer shutting down") // Cancel pending timer if exists @@ -586,6 +796,12 @@ func (h *Hub) runBroadcastSequencer() { if h.coalesceTimer != nil { h.coalesceTimer.Stop() } + // Cancel all tenant timers + for _, timer := range h.tenantCoalesceTimers { + if timer != nil { + timer.Stop() + } + } h.coalesceMutex.Unlock() return } @@ -619,6 +835,65 @@ func (h *Hub) BroadcastState(state interface{}) { } } +// BroadcastStateToTenant broadcasts state update only to clients of a specific tenant. +func (h *Hub) BroadcastStateToTenant(orgID string, state interface{}) { + log.Debug().Str("org_id", orgID).Msg("Broadcasting state to tenant") + + msg := Message{ + Type: "rawData", + Data: state, + } + + // Send through tenant broadcast channel + select { + case h.tenantBroadcast <- TenantBroadcast{OrgID: orgID, Message: msg}: + default: + log.Warn().Str("org_id", orgID).Msg("Tenant broadcast channel full, dropping state update") + } +} + +// dispatchToTenantClients sends a message only to clients of a specific tenant. +func (h *Hub) dispatchToTenantClients(orgID string, data []byte, dropLog string) { + h.mu.RLock() + tenantClients := h.clientsByTenant[orgID] + if tenantClients == nil { + h.mu.RUnlock() + return + } + clients := make([]*Client, 0, len(tenantClients)) + for client := range tenantClients { + clients = append(clients, client) + } + h.mu.RUnlock() + + for _, client := range clients { + if !client.safeSend(data) { + // Client closed or buffer full - remove if still registered + h.mu.Lock() + if _, stillPresent := h.clients[client]; stillPresent { + delete(h.clients, client) + if h.clientsByTenant[client.orgID] != nil { + delete(h.clientsByTenant[client.orgID], client) + } + client.closed.Store(true) + close(client.send) + log.Warn().Str("client", client.id).Str("org_id", client.orgID).Msg(dropLog) + } + h.mu.Unlock() + } + } +} + +// GetTenantClientCount returns the number of connected clients for a specific tenant. +func (h *Hub) GetTenantClientCount(orgID string) int { + h.mu.RLock() + defer h.mu.RUnlock() + if tenantClients := h.clientsByTenant[orgID]; tenantClients != nil { + return len(tenantClients) + } + return 0 +} + // BroadcastAlert broadcasts alert to all clients func (h *Hub) BroadcastAlert(alert interface{}) { log.Info().Interface("alert", alert).Msg("Broadcasting alert to WebSocket clients") diff --git a/internal/websocket/hub_tenant_test.go b/internal/websocket/hub_tenant_test.go new file mode 100644 index 000000000..085e29a17 --- /dev/null +++ b/internal/websocket/hub_tenant_test.go @@ -0,0 +1,126 @@ +package websocket + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// MockStateGetter implements StateGetter interface +type MockStateGetter struct { + state interface{} +} + +func (m *MockStateGetter) GetState() interface{} { + return m.state +} + +// MockTenantStateGetter implements TenantStateGetter interface +type MockTenantStateGetter struct { + state map[string]interface{} +} + +func (m *MockTenantStateGetter) GetState(orgID string) interface{} { + return m.state[orgID] +} + +// MockOrgAuthChecker implements OrgAuthChecker interface +type MockOrgAuthChecker struct { + called bool + allow bool +} + +func (m *MockOrgAuthChecker) CanAccessOrg(userID string, token interface{}, orgID string) bool { + m.called = true + return m.allow +} + +func TestHub_Tenant_Broadcasting(t *testing.T) { + hub := NewHub(nil) + go hub.Run() + defer func() { + // Stop hub logic if exposed, or just let it die with test + }() + + // Setup Tenant State Getter + mockState := &MockTenantStateGetter{ + state: map[string]interface{}{ + "org1": map[string]string{"foo": "bar"}, + "org2": map[string]string{"baz": "qux"}, + }, + } + hub.SetStateGetterForTenant(mockState.GetState) + + // Test GetTenantClientCount + assert.Equal(t, 0, hub.GetTenantClientCount("org1")) + + // Test BroadcastStateToTenant (should not panic even with 0 clients) + hub.BroadcastStateToTenant("org1", map[string]string{"status": "ok"}) + hub.BroadcastStateToTenant("org2", nil) + hub.BroadcastStateToTenant("missing", nil) + + // Allow async broadcast to process + time.Sleep(100 * time.Millisecond) +} + +func TestHub_Setters_Coverage(t *testing.T) { + hub := NewHub(nil) + + // Test SetOrgAuthChecker + checker := &MockOrgAuthChecker{allow: true} + hub.SetOrgAuthChecker(checker) + + // Verify it was set + assert.NotNil(t, hub.orgAuthChecker) + + // Trigger the checker + success := hub.orgAuthChecker.CanAccessOrg("user", "tok", "org") + assert.True(t, success) + assert.True(t, checker.called) +} + +func TestHub_DispatchToTenantClients(t *testing.T) { + // This tests the internal logic of iterating clients + hub := NewHub(nil) + + // Create a mock client + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + orgID: "org1", + // isActive removed + } + + // Manually register (simulating register channel) + hub.clients[client] = true + hub.register <- client + + // Allow registration to process + go hub.Run() + time.Sleep(50 * time.Millisecond) + + // Now broadcast to org1 (internal method) + msg := []byte("test message") + // dispatchToTenantClients(orgID string, data []byte, dropLog string) + // But wait, dispatchToTenantClients is private (lowercase). Can we call it? + // Tests are in package `websocket`, so yes we can access private methods of `hub.go`. + hub.dispatchToTenantClients("org1", msg, "Dropping test message") + + // Check if client received it + select { + case received := <-client.send: + assert.Equal(t, msg, received) + case <-time.After(100 * time.Millisecond): + t.Fatal("Client did not receive tenant broadcast") + } + + // Broadcast to org2 (should not receive) + hub.dispatchToTenantClients("org2", msg, "Dropping test message") + select { + case <-client.send: + t.Fatal("Client received message for wrong tenant") + case <-time.After(50 * time.Millisecond): + // Success + } +}