feat: add WebSocket multi-tenant isolation

- Enhance WebSocket hub with tenant awareness
- Add tenant isolation for real-time updates
- Add hub tenant isolation tests
This commit is contained in:
rcourtman 2026-01-24 22:43:50 +00:00
parent 1e77763870
commit 62e8d609de
2 changed files with 430 additions and 29 deletions

View file

@ -1,6 +1,7 @@
package websocket
import (
"context"
"encoding/json"
"fmt"
"math"
@ -15,6 +16,7 @@ import (
"github.com/gorilla/websocket"
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
"github.com/rcourtman/pulse-go-rewrite/pkg/auth"
"github.com/rs/zerolog/log"
)
@ -175,6 +177,7 @@ type Client struct {
conn *websocket.Conn
send chan []byte
id string
orgID string // Organization ID for tenant isolation
lastPing time.Time
closed atomic.Bool // Set when the client is unregistered; prevents sends to closed channel
writeFailures int32 // Consecutive write failures; disconnects after maxWriteFailures
@ -278,22 +281,57 @@ func cloneMetadataValue(value interface{}) interface{} {
}
}
// TenantBroadcast represents a broadcast message targeted at a specific tenant.
type TenantBroadcast struct {
OrgID string
Message Message
}
// OrgAuthChecker checks if a user/token can access an organization.
type OrgAuthChecker interface {
// CanAccessOrg checks if the given user (and optional token) can access the org.
CanAccessOrg(userID string, token interface{}, orgID string) bool
}
// MultiTenantCheckResult contains the result of a multi-tenant check.
type MultiTenantCheckResult struct {
Allowed bool
FeatureEnabled bool // false = feature flag disabled
Licensed bool // false = not licensed (only meaningful if FeatureEnabled is true)
Reason string // Human-readable reason for denial
}
// MultiTenantChecker checks if multi-tenant functionality is enabled and licensed.
type MultiTenantChecker interface {
// CheckMultiTenant checks if multi-tenant is enabled (feature flag) and licensed for the org.
// Returns a result with details about why access was denied if not allowed.
CheckMultiTenant(ctx context.Context, orgID string) MultiTenantCheckResult
}
// Hub maintains active WebSocket clients and broadcasts messages
type Hub struct {
clients map[*Client]bool
broadcast chan []byte
broadcastSeq chan Message // Sequenced broadcast channel for ordering
register chan *Client
unregister chan *Client
stopChan chan struct{} // Signals shutdown
mu sync.RWMutex
getState func() interface{} // Function to get current state
allowedOrigins []string // Allowed origins for CORS
clients map[*Client]bool // All clients (legacy support)
clientsByTenant map[string]map[*Client]bool // Per-tenant client tracking
broadcast chan []byte
broadcastSeq chan Message // Sequenced broadcast channel for ordering
tenantBroadcast chan TenantBroadcast // Per-tenant broadcast channel
register chan *Client
unregister chan *Client
stopChan chan struct{} // Signals shutdown
mu sync.RWMutex
getState func() interface{} // Function to get current state (legacy)
getStateByTenant func(orgID string) interface{} // Function to get state for specific tenant
allowedOrigins []string // Allowed origins for CORS
orgAuthChecker OrgAuthChecker // Org authorization checker
multiTenantChecker MultiTenantChecker // Multi-tenant feature flag and license checker
// Broadcast coalescing fields
coalesceWindow time.Duration
coalescePending *Message
coalesceTimer *time.Timer
coalesceMutex sync.Mutex
// Per-tenant coalescing
tenantCoalescePending map[string]*Message
tenantCoalesceTimers map[string]*time.Timer
}
// Message represents a WebSocket message
@ -303,25 +341,70 @@ type Message struct {
Timestamp string `json:"timestamp,omitempty"`
}
// SetStateGetter sets the state getter function
// SetStateGetter sets the state getter function (legacy, for default tenant)
func (h *Hub) SetStateGetter(getState func() interface{}) {
h.mu.Lock()
defer h.mu.Unlock()
h.getState = getState
}
// SetStateGetterForTenant sets the tenant-aware state getter function
func (h *Hub) SetStateGetterForTenant(getState func(orgID string) interface{}) {
h.mu.Lock()
defer h.mu.Unlock()
h.getStateByTenant = getState
}
// SetOrgAuthChecker sets the organization authorization checker.
func (h *Hub) SetOrgAuthChecker(checker OrgAuthChecker) {
h.mu.Lock()
defer h.mu.Unlock()
h.orgAuthChecker = checker
}
// SetMultiTenantChecker sets the multi-tenant feature flag and license checker.
func (h *Hub) SetMultiTenantChecker(checker MultiTenantChecker) {
h.mu.Lock()
defer h.mu.Unlock()
h.multiTenantChecker = checker
}
// getStateForClient returns the state for a specific client based on their tenant
func (h *Hub) getStateForClient(client *Client) interface{} {
h.mu.RLock()
getStateByTenant := h.getStateByTenant
getState := h.getState
h.mu.RUnlock()
// Try tenant-specific getter first
if getStateByTenant != nil && client.orgID != "" {
return getStateByTenant(client.orgID)
}
// Fall back to default getter
if getState != nil {
return getState()
}
return nil
}
// NewHub creates a new WebSocket hub
func NewHub(getState func() interface{}) *Hub {
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte, 256),
broadcastSeq: make(chan Message, 256), // Buffered sequenced channel
register: make(chan *Client),
unregister: make(chan *Client),
stopChan: make(chan struct{}),
getState: getState,
allowedOrigins: []string{}, // Default to empty (will be set based on actual host)
coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms
clients: make(map[*Client]bool),
clientsByTenant: make(map[string]map[*Client]bool),
broadcast: make(chan []byte, 256),
broadcastSeq: make(chan Message, 256), // Buffered sequenced channel
tenantBroadcast: make(chan TenantBroadcast, 256), // Per-tenant broadcasts
register: make(chan *Client),
unregister: make(chan *Client),
stopChan: make(chan struct{}),
getState: getState,
allowedOrigins: []string{}, // Default to empty (will be set based on actual host)
coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms
tenantCoalescePending: make(map[string]*Message),
tenantCoalesceTimers: make(map[string]*time.Timer),
}
}
@ -338,12 +421,21 @@ func (h *Hub) Run() {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
// Also register by tenant if org ID is set
if client.orgID != "" {
if h.clientsByTenant[client.orgID] == nil {
h.clientsByTenant[client.orgID] = make(map[*Client]bool)
}
h.clientsByTenant[client.orgID][client] = true
}
h.mu.Unlock()
log.Info().Str("client", client.id).Msg("WebSocket client connected")
log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client connected")
// Send initial state to the new client immediately
log.Debug().Bool("hasGetState", h.getState != nil).Msg("Checking getState function for new client")
if h.getState != nil {
// Use tenant-aware state getter if available
hasGetState := h.getState != nil || h.getStateByTenant != nil
log.Debug().Bool("hasGetState", hasGetState).Msg("Checking getState function for new client")
if hasGetState {
// Add a small delay to ensure client is ready
go func() {
log.Debug().Str("client", client.id).Msg("Starting initial state goroutine")
@ -352,7 +444,7 @@ func (h *Hub) Run() {
// First send a small welcome message
welcomeMsg := Message{
Type: "welcome",
Data: map[string]string{"message": "Connected to Pulse WebSocket"},
Data: map[string]string{"message": "Connected to Pulse WebSocket", "orgId": client.orgID},
}
if data, err := json.Marshal(welcomeMsg); err == nil {
// Check if client is still registered before sending (must hold lock)
@ -376,8 +468,8 @@ func (h *Hub) Run() {
time.Sleep(100 * time.Millisecond)
log.Debug().Str("client", client.id).Msg("About to get state")
// Get the state
stateData := h.getState()
// Get the state using tenant-aware getter
stateData := h.getStateForClient(client)
log.Debug().Str("client", client.id).Interface("stateType", fmt.Sprintf("%T", stateData)).Msg("Got state for initial message")
initialMsg := Message{
@ -412,10 +504,18 @@ func (h *Hub) Run() {
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
// Also remove from tenant map
if client.orgID != "" && h.clientsByTenant[client.orgID] != nil {
delete(h.clientsByTenant[client.orgID], client)
// Clean up empty tenant maps
if len(h.clientsByTenant[client.orgID]) == 0 {
delete(h.clientsByTenant, client.orgID)
}
}
client.closed.Store(true) // Mark closed before closing channel to prevent sends
close(client.send)
h.mu.Unlock()
log.Info().Str("client", client.id).Msg("WebSocket client disconnected")
log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client disconnected")
} else {
h.mu.Unlock()
}
@ -472,6 +572,67 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
Str("userAgent", r.Header.Get("User-Agent")).
Msg("WebSocket upgrade request")
// Extract org ID from request for tenant isolation
// Priority: Header > Cookie > Query param > Default
orgID := r.Header.Get("X-Pulse-Org-ID")
if orgID == "" {
if cookie, err := r.Cookie("pulse_org_id"); err == nil {
orgID = cookie.Value
}
}
if orgID == "" {
orgID = r.URL.Query().Get("org_id")
}
if orgID == "" {
orgID = "default"
}
// Multi-tenant feature flag and license check for non-default orgs
h.mu.RLock()
mtChecker := h.multiTenantChecker
authChecker := h.orgAuthChecker
h.mu.RUnlock()
if orgID != "default" {
// Check if multi-tenant is enabled and licensed
if mtChecker != nil {
result := mtChecker.CheckMultiTenant(r.Context(), orgID)
if !result.Allowed {
log.Warn().
Str("org_id", orgID).
Bool("feature_enabled", result.FeatureEnabled).
Bool("licensed", result.Licensed).
Str("reason", result.Reason).
Msg("WebSocket connection denied - multi-tenant check failed")
if !result.FeatureEnabled {
// Feature flag disabled - 501 Not Implemented
http.Error(w, "Multi-tenant functionality is not enabled", http.StatusNotImplemented)
} else {
// Feature enabled but not licensed - 402 Payment Required
http.Error(w, "Multi-tenant access requires an Enterprise license", http.StatusPaymentRequired)
}
return
}
}
// Authorization check - auth context should already be set by AuthContextMiddleware
if authChecker != nil {
// Get user and token from context (set by AuthContextMiddleware)
userID := getUserFromContext(r.Context())
token := getAPITokenFromContext(r.Context())
if !authChecker.CanAccessOrg(userID, token, orgID) {
log.Warn().
Str("org_id", orgID).
Str("user_id", userID).
Msg("WebSocket connection denied - unauthorized org access")
http.Error(w, "Unauthorized to access organization", http.StatusForbidden)
return
}
}
}
// Create upgrader with our origin check
upgrader := websocket.Upgrader{
ReadBufferSize: 1024 * 1024 * 4, // 4MB to handle large state messages
@ -487,15 +648,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
clientID := utils.GenerateID("client")
client := &Client{
hub: h,
conn: conn,
hub: h,
conn: conn,
orgID: orgID,
// Keep buffer bounded to avoid holding large state snapshots indefinitely if a client stalls.
send: make(chan []byte, 128),
id: clientID,
lastPing: time.Now(),
}
log.Info().Str("client", clientID).Msg("WebSocket client created")
log.Info().Str("client", clientID).Str("org_id", orgID).Msg("WebSocket client created")
client.hub.register <- client
@ -504,6 +666,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
go client.readPump()
}
// getUserFromContext extracts the user ID from request context.
func getUserFromContext(ctx context.Context) string {
return auth.GetUser(ctx)
}
// getAPITokenFromContext extracts the API token from request context.
func getAPITokenFromContext(ctx context.Context) interface{} {
return auth.GetAPIToken(ctx)
}
// dispatchToClients fan-outs a marshaled payload to all clients, dropping any that
// cannot keep up to prevent unbounded buffering.
func (h *Hub) dispatchToClients(data []byte, dropLog string) {
@ -579,6 +751,44 @@ func (h *Hub) runBroadcastSequencer() {
}
}
case tb := <-h.tenantBroadcast:
// Handle tenant-specific broadcasts with coalescing
if tb.Message.Type == "rawData" {
h.coalesceMutex.Lock()
// Cancel pending timer for this tenant if exists
if timer := h.tenantCoalesceTimers[tb.OrgID]; timer != nil {
timer.Stop()
}
// Update pending message for this tenant
msgCopy := tb.Message
h.tenantCoalescePending[tb.OrgID] = &msgCopy
// Set timer to send after coalesce window
orgID := tb.OrgID // Capture for closure
h.tenantCoalesceTimers[orgID] = time.AfterFunc(h.coalesceWindow, func() {
h.coalesceMutex.Lock()
pending := h.tenantCoalescePending[orgID]
delete(h.tenantCoalescePending, orgID)
delete(h.tenantCoalesceTimers, orgID)
h.coalesceMutex.Unlock()
if pending != nil {
if data, err := json.Marshal(*pending); err == nil {
h.dispatchToTenantClients(orgID, data, "Client send channel full, dropping tenant coalesced message and closing connection")
}
}
})
h.coalesceMutex.Unlock()
} else {
// Non-state messages - send immediately to tenant
if data, err := json.Marshal(tb.Message); err == nil {
h.dispatchToTenantClients(tb.OrgID, data, "Client send channel full, dropping tenant message and closing connection")
}
}
case <-h.stopChan:
log.Debug().Msg("Broadcast sequencer shutting down")
// Cancel pending timer if exists
@ -586,6 +796,12 @@ func (h *Hub) runBroadcastSequencer() {
if h.coalesceTimer != nil {
h.coalesceTimer.Stop()
}
// Cancel all tenant timers
for _, timer := range h.tenantCoalesceTimers {
if timer != nil {
timer.Stop()
}
}
h.coalesceMutex.Unlock()
return
}
@ -619,6 +835,65 @@ func (h *Hub) BroadcastState(state interface{}) {
}
}
// BroadcastStateToTenant broadcasts state update only to clients of a specific tenant.
func (h *Hub) BroadcastStateToTenant(orgID string, state interface{}) {
log.Debug().Str("org_id", orgID).Msg("Broadcasting state to tenant")
msg := Message{
Type: "rawData",
Data: state,
}
// Send through tenant broadcast channel
select {
case h.tenantBroadcast <- TenantBroadcast{OrgID: orgID, Message: msg}:
default:
log.Warn().Str("org_id", orgID).Msg("Tenant broadcast channel full, dropping state update")
}
}
// dispatchToTenantClients sends a message only to clients of a specific tenant.
func (h *Hub) dispatchToTenantClients(orgID string, data []byte, dropLog string) {
h.mu.RLock()
tenantClients := h.clientsByTenant[orgID]
if tenantClients == nil {
h.mu.RUnlock()
return
}
clients := make([]*Client, 0, len(tenantClients))
for client := range tenantClients {
clients = append(clients, client)
}
h.mu.RUnlock()
for _, client := range clients {
if !client.safeSend(data) {
// Client closed or buffer full - remove if still registered
h.mu.Lock()
if _, stillPresent := h.clients[client]; stillPresent {
delete(h.clients, client)
if h.clientsByTenant[client.orgID] != nil {
delete(h.clientsByTenant[client.orgID], client)
}
client.closed.Store(true)
close(client.send)
log.Warn().Str("client", client.id).Str("org_id", client.orgID).Msg(dropLog)
}
h.mu.Unlock()
}
}
}
// GetTenantClientCount returns the number of connected clients for a specific tenant.
func (h *Hub) GetTenantClientCount(orgID string) int {
h.mu.RLock()
defer h.mu.RUnlock()
if tenantClients := h.clientsByTenant[orgID]; tenantClients != nil {
return len(tenantClients)
}
return 0
}
// BroadcastAlert broadcasts alert to all clients
func (h *Hub) BroadcastAlert(alert interface{}) {
log.Info().Interface("alert", alert).Msg("Broadcasting alert to WebSocket clients")

View file

@ -0,0 +1,126 @@
package websocket
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// MockStateGetter implements StateGetter interface
type MockStateGetter struct {
state interface{}
}
func (m *MockStateGetter) GetState() interface{} {
return m.state
}
// MockTenantStateGetter implements TenantStateGetter interface
type MockTenantStateGetter struct {
state map[string]interface{}
}
func (m *MockTenantStateGetter) GetState(orgID string) interface{} {
return m.state[orgID]
}
// MockOrgAuthChecker implements OrgAuthChecker interface
type MockOrgAuthChecker struct {
called bool
allow bool
}
func (m *MockOrgAuthChecker) CanAccessOrg(userID string, token interface{}, orgID string) bool {
m.called = true
return m.allow
}
func TestHub_Tenant_Broadcasting(t *testing.T) {
hub := NewHub(nil)
go hub.Run()
defer func() {
// Stop hub logic if exposed, or just let it die with test
}()
// Setup Tenant State Getter
mockState := &MockTenantStateGetter{
state: map[string]interface{}{
"org1": map[string]string{"foo": "bar"},
"org2": map[string]string{"baz": "qux"},
},
}
hub.SetStateGetterForTenant(mockState.GetState)
// Test GetTenantClientCount
assert.Equal(t, 0, hub.GetTenantClientCount("org1"))
// Test BroadcastStateToTenant (should not panic even with 0 clients)
hub.BroadcastStateToTenant("org1", map[string]string{"status": "ok"})
hub.BroadcastStateToTenant("org2", nil)
hub.BroadcastStateToTenant("missing", nil)
// Allow async broadcast to process
time.Sleep(100 * time.Millisecond)
}
func TestHub_Setters_Coverage(t *testing.T) {
hub := NewHub(nil)
// Test SetOrgAuthChecker
checker := &MockOrgAuthChecker{allow: true}
hub.SetOrgAuthChecker(checker)
// Verify it was set
assert.NotNil(t, hub.orgAuthChecker)
// Trigger the checker
success := hub.orgAuthChecker.CanAccessOrg("user", "tok", "org")
assert.True(t, success)
assert.True(t, checker.called)
}
func TestHub_DispatchToTenantClients(t *testing.T) {
// This tests the internal logic of iterating clients
hub := NewHub(nil)
// Create a mock client
client := &Client{
hub: hub,
send: make(chan []byte, 256),
orgID: "org1",
// isActive removed
}
// Manually register (simulating register channel)
hub.clients[client] = true
hub.register <- client
// Allow registration to process
go hub.Run()
time.Sleep(50 * time.Millisecond)
// Now broadcast to org1 (internal method)
msg := []byte("test message")
// dispatchToTenantClients(orgID string, data []byte, dropLog string)
// But wait, dispatchToTenantClients is private (lowercase). Can we call it?
// Tests are in package `websocket`, so yes we can access private methods of `hub.go`.
hub.dispatchToTenantClients("org1", msg, "Dropping test message")
// Check if client received it
select {
case received := <-client.send:
assert.Equal(t, msg, received)
case <-time.After(100 * time.Millisecond):
t.Fatal("Client did not receive tenant broadcast")
}
// Broadcast to org2 (should not receive)
hub.dispatchToTenantClients("org2", msg, "Dropping test message")
select {
case <-client.send:
t.Fatal("Client received message for wrong tenant")
case <-time.After(50 * time.Millisecond):
// Success
}
}