mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-19 07:54:10 +00:00
feat: add WebSocket multi-tenant isolation
- Enhance WebSocket hub with tenant awareness - Add tenant isolation for real-time updates - Add hub tenant isolation tests
This commit is contained in:
parent
1e77763870
commit
62e8d609de
2 changed files with 430 additions and 29 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
|
@ -15,6 +16,7 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
||||
"github.com/rcourtman/pulse-go-rewrite/pkg/auth"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
|
@ -175,6 +177,7 @@ type Client struct {
|
|||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
id string
|
||||
orgID string // Organization ID for tenant isolation
|
||||
lastPing time.Time
|
||||
closed atomic.Bool // Set when the client is unregistered; prevents sends to closed channel
|
||||
writeFailures int32 // Consecutive write failures; disconnects after maxWriteFailures
|
||||
|
|
@ -278,22 +281,57 @@ func cloneMetadataValue(value interface{}) interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
// TenantBroadcast represents a broadcast message targeted at a specific tenant.
|
||||
type TenantBroadcast struct {
|
||||
OrgID string
|
||||
Message Message
|
||||
}
|
||||
|
||||
// OrgAuthChecker checks if a user/token can access an organization.
|
||||
type OrgAuthChecker interface {
|
||||
// CanAccessOrg checks if the given user (and optional token) can access the org.
|
||||
CanAccessOrg(userID string, token interface{}, orgID string) bool
|
||||
}
|
||||
|
||||
// MultiTenantCheckResult contains the result of a multi-tenant check.
|
||||
type MultiTenantCheckResult struct {
|
||||
Allowed bool
|
||||
FeatureEnabled bool // false = feature flag disabled
|
||||
Licensed bool // false = not licensed (only meaningful if FeatureEnabled is true)
|
||||
Reason string // Human-readable reason for denial
|
||||
}
|
||||
|
||||
// MultiTenantChecker checks if multi-tenant functionality is enabled and licensed.
|
||||
type MultiTenantChecker interface {
|
||||
// CheckMultiTenant checks if multi-tenant is enabled (feature flag) and licensed for the org.
|
||||
// Returns a result with details about why access was denied if not allowed.
|
||||
CheckMultiTenant(ctx context.Context, orgID string) MultiTenantCheckResult
|
||||
}
|
||||
|
||||
// Hub maintains active WebSocket clients and broadcasts messages
|
||||
type Hub struct {
|
||||
clients map[*Client]bool
|
||||
broadcast chan []byte
|
||||
broadcastSeq chan Message // Sequenced broadcast channel for ordering
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
stopChan chan struct{} // Signals shutdown
|
||||
mu sync.RWMutex
|
||||
getState func() interface{} // Function to get current state
|
||||
allowedOrigins []string // Allowed origins for CORS
|
||||
clients map[*Client]bool // All clients (legacy support)
|
||||
clientsByTenant map[string]map[*Client]bool // Per-tenant client tracking
|
||||
broadcast chan []byte
|
||||
broadcastSeq chan Message // Sequenced broadcast channel for ordering
|
||||
tenantBroadcast chan TenantBroadcast // Per-tenant broadcast channel
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
stopChan chan struct{} // Signals shutdown
|
||||
mu sync.RWMutex
|
||||
getState func() interface{} // Function to get current state (legacy)
|
||||
getStateByTenant func(orgID string) interface{} // Function to get state for specific tenant
|
||||
allowedOrigins []string // Allowed origins for CORS
|
||||
orgAuthChecker OrgAuthChecker // Org authorization checker
|
||||
multiTenantChecker MultiTenantChecker // Multi-tenant feature flag and license checker
|
||||
// Broadcast coalescing fields
|
||||
coalesceWindow time.Duration
|
||||
coalescePending *Message
|
||||
coalesceTimer *time.Timer
|
||||
coalesceMutex sync.Mutex
|
||||
// Per-tenant coalescing
|
||||
tenantCoalescePending map[string]*Message
|
||||
tenantCoalesceTimers map[string]*time.Timer
|
||||
}
|
||||
|
||||
// Message represents a WebSocket message
|
||||
|
|
@ -303,25 +341,70 @@ type Message struct {
|
|||
Timestamp string `json:"timestamp,omitempty"`
|
||||
}
|
||||
|
||||
// SetStateGetter sets the state getter function
|
||||
// SetStateGetter sets the state getter function (legacy, for default tenant)
|
||||
func (h *Hub) SetStateGetter(getState func() interface{}) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.getState = getState
|
||||
}
|
||||
|
||||
// SetStateGetterForTenant sets the tenant-aware state getter function
|
||||
func (h *Hub) SetStateGetterForTenant(getState func(orgID string) interface{}) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.getStateByTenant = getState
|
||||
}
|
||||
|
||||
// SetOrgAuthChecker sets the organization authorization checker.
|
||||
func (h *Hub) SetOrgAuthChecker(checker OrgAuthChecker) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.orgAuthChecker = checker
|
||||
}
|
||||
|
||||
// SetMultiTenantChecker sets the multi-tenant feature flag and license checker.
|
||||
func (h *Hub) SetMultiTenantChecker(checker MultiTenantChecker) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.multiTenantChecker = checker
|
||||
}
|
||||
|
||||
// getStateForClient returns the state for a specific client based on their tenant
|
||||
func (h *Hub) getStateForClient(client *Client) interface{} {
|
||||
h.mu.RLock()
|
||||
getStateByTenant := h.getStateByTenant
|
||||
getState := h.getState
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Try tenant-specific getter first
|
||||
if getStateByTenant != nil && client.orgID != "" {
|
||||
return getStateByTenant(client.orgID)
|
||||
}
|
||||
|
||||
// Fall back to default getter
|
||||
if getState != nil {
|
||||
return getState()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewHub creates a new WebSocket hub
|
||||
func NewHub(getState func() interface{}) *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*Client]bool),
|
||||
broadcast: make(chan []byte, 256),
|
||||
broadcastSeq: make(chan Message, 256), // Buffered sequenced channel
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
stopChan: make(chan struct{}),
|
||||
getState: getState,
|
||||
allowedOrigins: []string{}, // Default to empty (will be set based on actual host)
|
||||
coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms
|
||||
clients: make(map[*Client]bool),
|
||||
clientsByTenant: make(map[string]map[*Client]bool),
|
||||
broadcast: make(chan []byte, 256),
|
||||
broadcastSeq: make(chan Message, 256), // Buffered sequenced channel
|
||||
tenantBroadcast: make(chan TenantBroadcast, 256), // Per-tenant broadcasts
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
stopChan: make(chan struct{}),
|
||||
getState: getState,
|
||||
allowedOrigins: []string{}, // Default to empty (will be set based on actual host)
|
||||
coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms
|
||||
tenantCoalescePending: make(map[string]*Message),
|
||||
tenantCoalesceTimers: make(map[string]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -338,12 +421,21 @@ func (h *Hub) Run() {
|
|||
case client := <-h.register:
|
||||
h.mu.Lock()
|
||||
h.clients[client] = true
|
||||
// Also register by tenant if org ID is set
|
||||
if client.orgID != "" {
|
||||
if h.clientsByTenant[client.orgID] == nil {
|
||||
h.clientsByTenant[client.orgID] = make(map[*Client]bool)
|
||||
}
|
||||
h.clientsByTenant[client.orgID][client] = true
|
||||
}
|
||||
h.mu.Unlock()
|
||||
log.Info().Str("client", client.id).Msg("WebSocket client connected")
|
||||
log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client connected")
|
||||
|
||||
// Send initial state to the new client immediately
|
||||
log.Debug().Bool("hasGetState", h.getState != nil).Msg("Checking getState function for new client")
|
||||
if h.getState != nil {
|
||||
// Use tenant-aware state getter if available
|
||||
hasGetState := h.getState != nil || h.getStateByTenant != nil
|
||||
log.Debug().Bool("hasGetState", hasGetState).Msg("Checking getState function for new client")
|
||||
if hasGetState {
|
||||
// Add a small delay to ensure client is ready
|
||||
go func() {
|
||||
log.Debug().Str("client", client.id).Msg("Starting initial state goroutine")
|
||||
|
|
@ -352,7 +444,7 @@ func (h *Hub) Run() {
|
|||
// First send a small welcome message
|
||||
welcomeMsg := Message{
|
||||
Type: "welcome",
|
||||
Data: map[string]string{"message": "Connected to Pulse WebSocket"},
|
||||
Data: map[string]string{"message": "Connected to Pulse WebSocket", "orgId": client.orgID},
|
||||
}
|
||||
if data, err := json.Marshal(welcomeMsg); err == nil {
|
||||
// Check if client is still registered before sending (must hold lock)
|
||||
|
|
@ -376,8 +468,8 @@ func (h *Hub) Run() {
|
|||
time.Sleep(100 * time.Millisecond)
|
||||
log.Debug().Str("client", client.id).Msg("About to get state")
|
||||
|
||||
// Get the state
|
||||
stateData := h.getState()
|
||||
// Get the state using tenant-aware getter
|
||||
stateData := h.getStateForClient(client)
|
||||
log.Debug().Str("client", client.id).Interface("stateType", fmt.Sprintf("%T", stateData)).Msg("Got state for initial message")
|
||||
|
||||
initialMsg := Message{
|
||||
|
|
@ -412,10 +504,18 @@ func (h *Hub) Run() {
|
|||
h.mu.Lock()
|
||||
if _, ok := h.clients[client]; ok {
|
||||
delete(h.clients, client)
|
||||
// Also remove from tenant map
|
||||
if client.orgID != "" && h.clientsByTenant[client.orgID] != nil {
|
||||
delete(h.clientsByTenant[client.orgID], client)
|
||||
// Clean up empty tenant maps
|
||||
if len(h.clientsByTenant[client.orgID]) == 0 {
|
||||
delete(h.clientsByTenant, client.orgID)
|
||||
}
|
||||
}
|
||||
client.closed.Store(true) // Mark closed before closing channel to prevent sends
|
||||
close(client.send)
|
||||
h.mu.Unlock()
|
||||
log.Info().Str("client", client.id).Msg("WebSocket client disconnected")
|
||||
log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client disconnected")
|
||||
} else {
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
|
@ -472,6 +572,67 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||
Str("userAgent", r.Header.Get("User-Agent")).
|
||||
Msg("WebSocket upgrade request")
|
||||
|
||||
// Extract org ID from request for tenant isolation
|
||||
// Priority: Header > Cookie > Query param > Default
|
||||
orgID := r.Header.Get("X-Pulse-Org-ID")
|
||||
if orgID == "" {
|
||||
if cookie, err := r.Cookie("pulse_org_id"); err == nil {
|
||||
orgID = cookie.Value
|
||||
}
|
||||
}
|
||||
if orgID == "" {
|
||||
orgID = r.URL.Query().Get("org_id")
|
||||
}
|
||||
if orgID == "" {
|
||||
orgID = "default"
|
||||
}
|
||||
|
||||
// Multi-tenant feature flag and license check for non-default orgs
|
||||
h.mu.RLock()
|
||||
mtChecker := h.multiTenantChecker
|
||||
authChecker := h.orgAuthChecker
|
||||
h.mu.RUnlock()
|
||||
|
||||
if orgID != "default" {
|
||||
// Check if multi-tenant is enabled and licensed
|
||||
if mtChecker != nil {
|
||||
result := mtChecker.CheckMultiTenant(r.Context(), orgID)
|
||||
if !result.Allowed {
|
||||
log.Warn().
|
||||
Str("org_id", orgID).
|
||||
Bool("feature_enabled", result.FeatureEnabled).
|
||||
Bool("licensed", result.Licensed).
|
||||
Str("reason", result.Reason).
|
||||
Msg("WebSocket connection denied - multi-tenant check failed")
|
||||
|
||||
if !result.FeatureEnabled {
|
||||
// Feature flag disabled - 501 Not Implemented
|
||||
http.Error(w, "Multi-tenant functionality is not enabled", http.StatusNotImplemented)
|
||||
} else {
|
||||
// Feature enabled but not licensed - 402 Payment Required
|
||||
http.Error(w, "Multi-tenant access requires an Enterprise license", http.StatusPaymentRequired)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Authorization check - auth context should already be set by AuthContextMiddleware
|
||||
if authChecker != nil {
|
||||
// Get user and token from context (set by AuthContextMiddleware)
|
||||
userID := getUserFromContext(r.Context())
|
||||
token := getAPITokenFromContext(r.Context())
|
||||
|
||||
if !authChecker.CanAccessOrg(userID, token, orgID) {
|
||||
log.Warn().
|
||||
Str("org_id", orgID).
|
||||
Str("user_id", userID).
|
||||
Msg("WebSocket connection denied - unauthorized org access")
|
||||
http.Error(w, "Unauthorized to access organization", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create upgrader with our origin check
|
||||
upgrader := websocket.Upgrader{
|
||||
ReadBufferSize: 1024 * 1024 * 4, // 4MB to handle large state messages
|
||||
|
|
@ -487,15 +648,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
clientID := utils.GenerateID("client")
|
||||
client := &Client{
|
||||
hub: h,
|
||||
conn: conn,
|
||||
hub: h,
|
||||
conn: conn,
|
||||
orgID: orgID,
|
||||
// Keep buffer bounded to avoid holding large state snapshots indefinitely if a client stalls.
|
||||
send: make(chan []byte, 128),
|
||||
id: clientID,
|
||||
lastPing: time.Now(),
|
||||
}
|
||||
|
||||
log.Info().Str("client", clientID).Msg("WebSocket client created")
|
||||
log.Info().Str("client", clientID).Str("org_id", orgID).Msg("WebSocket client created")
|
||||
|
||||
client.hub.register <- client
|
||||
|
||||
|
|
@ -504,6 +666,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||
go client.readPump()
|
||||
}
|
||||
|
||||
// getUserFromContext extracts the user ID from request context.
|
||||
func getUserFromContext(ctx context.Context) string {
|
||||
return auth.GetUser(ctx)
|
||||
}
|
||||
|
||||
// getAPITokenFromContext extracts the API token from request context.
|
||||
func getAPITokenFromContext(ctx context.Context) interface{} {
|
||||
return auth.GetAPIToken(ctx)
|
||||
}
|
||||
|
||||
// dispatchToClients fan-outs a marshaled payload to all clients, dropping any that
|
||||
// cannot keep up to prevent unbounded buffering.
|
||||
func (h *Hub) dispatchToClients(data []byte, dropLog string) {
|
||||
|
|
@ -579,6 +751,44 @@ func (h *Hub) runBroadcastSequencer() {
|
|||
}
|
||||
}
|
||||
|
||||
case tb := <-h.tenantBroadcast:
|
||||
// Handle tenant-specific broadcasts with coalescing
|
||||
if tb.Message.Type == "rawData" {
|
||||
h.coalesceMutex.Lock()
|
||||
|
||||
// Cancel pending timer for this tenant if exists
|
||||
if timer := h.tenantCoalesceTimers[tb.OrgID]; timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
// Update pending message for this tenant
|
||||
msgCopy := tb.Message
|
||||
h.tenantCoalescePending[tb.OrgID] = &msgCopy
|
||||
|
||||
// Set timer to send after coalesce window
|
||||
orgID := tb.OrgID // Capture for closure
|
||||
h.tenantCoalesceTimers[orgID] = time.AfterFunc(h.coalesceWindow, func() {
|
||||
h.coalesceMutex.Lock()
|
||||
pending := h.tenantCoalescePending[orgID]
|
||||
delete(h.tenantCoalescePending, orgID)
|
||||
delete(h.tenantCoalesceTimers, orgID)
|
||||
h.coalesceMutex.Unlock()
|
||||
|
||||
if pending != nil {
|
||||
if data, err := json.Marshal(*pending); err == nil {
|
||||
h.dispatchToTenantClients(orgID, data, "Client send channel full, dropping tenant coalesced message and closing connection")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
h.coalesceMutex.Unlock()
|
||||
} else {
|
||||
// Non-state messages - send immediately to tenant
|
||||
if data, err := json.Marshal(tb.Message); err == nil {
|
||||
h.dispatchToTenantClients(tb.OrgID, data, "Client send channel full, dropping tenant message and closing connection")
|
||||
}
|
||||
}
|
||||
|
||||
case <-h.stopChan:
|
||||
log.Debug().Msg("Broadcast sequencer shutting down")
|
||||
// Cancel pending timer if exists
|
||||
|
|
@ -586,6 +796,12 @@ func (h *Hub) runBroadcastSequencer() {
|
|||
if h.coalesceTimer != nil {
|
||||
h.coalesceTimer.Stop()
|
||||
}
|
||||
// Cancel all tenant timers
|
||||
for _, timer := range h.tenantCoalesceTimers {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
}
|
||||
h.coalesceMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
|
@ -619,6 +835,65 @@ func (h *Hub) BroadcastState(state interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
// BroadcastStateToTenant broadcasts state update only to clients of a specific tenant.
|
||||
func (h *Hub) BroadcastStateToTenant(orgID string, state interface{}) {
|
||||
log.Debug().Str("org_id", orgID).Msg("Broadcasting state to tenant")
|
||||
|
||||
msg := Message{
|
||||
Type: "rawData",
|
||||
Data: state,
|
||||
}
|
||||
|
||||
// Send through tenant broadcast channel
|
||||
select {
|
||||
case h.tenantBroadcast <- TenantBroadcast{OrgID: orgID, Message: msg}:
|
||||
default:
|
||||
log.Warn().Str("org_id", orgID).Msg("Tenant broadcast channel full, dropping state update")
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchToTenantClients sends a message only to clients of a specific tenant.
|
||||
func (h *Hub) dispatchToTenantClients(orgID string, data []byte, dropLog string) {
|
||||
h.mu.RLock()
|
||||
tenantClients := h.clientsByTenant[orgID]
|
||||
if tenantClients == nil {
|
||||
h.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
clients := make([]*Client, 0, len(tenantClients))
|
||||
for client := range tenantClients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
for _, client := range clients {
|
||||
if !client.safeSend(data) {
|
||||
// Client closed or buffer full - remove if still registered
|
||||
h.mu.Lock()
|
||||
if _, stillPresent := h.clients[client]; stillPresent {
|
||||
delete(h.clients, client)
|
||||
if h.clientsByTenant[client.orgID] != nil {
|
||||
delete(h.clientsByTenant[client.orgID], client)
|
||||
}
|
||||
client.closed.Store(true)
|
||||
close(client.send)
|
||||
log.Warn().Str("client", client.id).Str("org_id", client.orgID).Msg(dropLog)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTenantClientCount returns the number of connected clients for a specific tenant.
|
||||
func (h *Hub) GetTenantClientCount(orgID string) int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
if tenantClients := h.clientsByTenant[orgID]; tenantClients != nil {
|
||||
return len(tenantClients)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// BroadcastAlert broadcasts alert to all clients
|
||||
func (h *Hub) BroadcastAlert(alert interface{}) {
|
||||
log.Info().Interface("alert", alert).Msg("Broadcasting alert to WebSocket clients")
|
||||
|
|
|
|||
126
internal/websocket/hub_tenant_test.go
Normal file
126
internal/websocket/hub_tenant_test.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// MockStateGetter implements StateGetter interface
|
||||
type MockStateGetter struct {
|
||||
state interface{}
|
||||
}
|
||||
|
||||
func (m *MockStateGetter) GetState() interface{} {
|
||||
return m.state
|
||||
}
|
||||
|
||||
// MockTenantStateGetter implements TenantStateGetter interface
|
||||
type MockTenantStateGetter struct {
|
||||
state map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *MockTenantStateGetter) GetState(orgID string) interface{} {
|
||||
return m.state[orgID]
|
||||
}
|
||||
|
||||
// MockOrgAuthChecker implements OrgAuthChecker interface
|
||||
type MockOrgAuthChecker struct {
|
||||
called bool
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (m *MockOrgAuthChecker) CanAccessOrg(userID string, token interface{}, orgID string) bool {
|
||||
m.called = true
|
||||
return m.allow
|
||||
}
|
||||
|
||||
func TestHub_Tenant_Broadcasting(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
go hub.Run()
|
||||
defer func() {
|
||||
// Stop hub logic if exposed, or just let it die with test
|
||||
}()
|
||||
|
||||
// Setup Tenant State Getter
|
||||
mockState := &MockTenantStateGetter{
|
||||
state: map[string]interface{}{
|
||||
"org1": map[string]string{"foo": "bar"},
|
||||
"org2": map[string]string{"baz": "qux"},
|
||||
},
|
||||
}
|
||||
hub.SetStateGetterForTenant(mockState.GetState)
|
||||
|
||||
// Test GetTenantClientCount
|
||||
assert.Equal(t, 0, hub.GetTenantClientCount("org1"))
|
||||
|
||||
// Test BroadcastStateToTenant (should not panic even with 0 clients)
|
||||
hub.BroadcastStateToTenant("org1", map[string]string{"status": "ok"})
|
||||
hub.BroadcastStateToTenant("org2", nil)
|
||||
hub.BroadcastStateToTenant("missing", nil)
|
||||
|
||||
// Allow async broadcast to process
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestHub_Setters_Coverage(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
|
||||
// Test SetOrgAuthChecker
|
||||
checker := &MockOrgAuthChecker{allow: true}
|
||||
hub.SetOrgAuthChecker(checker)
|
||||
|
||||
// Verify it was set
|
||||
assert.NotNil(t, hub.orgAuthChecker)
|
||||
|
||||
// Trigger the checker
|
||||
success := hub.orgAuthChecker.CanAccessOrg("user", "tok", "org")
|
||||
assert.True(t, success)
|
||||
assert.True(t, checker.called)
|
||||
}
|
||||
|
||||
func TestHub_DispatchToTenantClients(t *testing.T) {
|
||||
// This tests the internal logic of iterating clients
|
||||
hub := NewHub(nil)
|
||||
|
||||
// Create a mock client
|
||||
client := &Client{
|
||||
hub: hub,
|
||||
send: make(chan []byte, 256),
|
||||
orgID: "org1",
|
||||
// isActive removed
|
||||
}
|
||||
|
||||
// Manually register (simulating register channel)
|
||||
hub.clients[client] = true
|
||||
hub.register <- client
|
||||
|
||||
// Allow registration to process
|
||||
go hub.Run()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Now broadcast to org1 (internal method)
|
||||
msg := []byte("test message")
|
||||
// dispatchToTenantClients(orgID string, data []byte, dropLog string)
|
||||
// But wait, dispatchToTenantClients is private (lowercase). Can we call it?
|
||||
// Tests are in package `websocket`, so yes we can access private methods of `hub.go`.
|
||||
hub.dispatchToTenantClients("org1", msg, "Dropping test message")
|
||||
|
||||
// Check if client received it
|
||||
select {
|
||||
case received := <-client.send:
|
||||
assert.Equal(t, msg, received)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Client did not receive tenant broadcast")
|
||||
}
|
||||
|
||||
// Broadcast to org2 (should not receive)
|
||||
hub.dispatchToTenantClients("org2", msg, "Dropping test message")
|
||||
select {
|
||||
case <-client.send:
|
||||
t.Fatal("Client received message for wrong tenant")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Success
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue