mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 11:30:15 +00:00
1224 lines
36 KiB
Go
1224 lines
36 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
|
"github.com/rcourtman/pulse-go-rewrite/pkg/auth"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// extractPeerIP extracts just the IP part from a RemoteAddr (host:port format)
|
|
func extractPeerIP(remoteAddr string) string {
|
|
host, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
// RemoteAddr might not have a port, try using it directly
|
|
return remoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
// isValidPrivateOrigin checks if the origin is from a valid private network
|
|
func isValidPrivateOrigin(host string) bool {
|
|
// Check localhost variations
|
|
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
|
return true
|
|
}
|
|
|
|
// Check if it's a valid IP address
|
|
ip := net.ParseIP(host)
|
|
if ip != nil {
|
|
// Check if it's a private IP
|
|
return ip.IsLoopback() || ip.IsPrivate()
|
|
}
|
|
|
|
// Allow common local domain patterns but be more restrictive
|
|
// Only allow if it's clearly a local domain
|
|
if strings.HasSuffix(host, ".local") || strings.HasSuffix(host, ".lan") {
|
|
// But not arbitrary subdomains that could be malicious
|
|
parts := strings.Split(host, ".")
|
|
if len(parts) <= 3 { // hostname.local or hostname.subdomain.local
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// normalizeForwardedProto coerces forwarded proto values into the HTTP scheme space so that
|
|
// websocket upgrades coming through proxies that emit ws/wss continue to compare equal to the
|
|
// browser-sent Origin header (which is always http/https).
|
|
func normalizeForwardedProto(proto string, fallback string) string {
|
|
if proto == "" {
|
|
return fallback
|
|
}
|
|
|
|
// Some proxies send comma-separated proto chains; take the first hop.
|
|
if comma := strings.IndexByte(proto, ','); comma != -1 {
|
|
proto = proto[:comma]
|
|
}
|
|
|
|
cleaned := strings.TrimSpace(strings.ToLower(proto))
|
|
switch cleaned {
|
|
case "wss":
|
|
return "https"
|
|
case "ws":
|
|
return "http"
|
|
case "https", "http":
|
|
return cleaned
|
|
default:
|
|
if cleaned != "" {
|
|
return cleaned
|
|
}
|
|
return fallback
|
|
}
|
|
}
|
|
|
|
// SetAllowedOrigins sets the allowed origins for CORS
|
|
func (h *Hub) SetAllowedOrigins(origins []string) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.allowedOrigins = origins
|
|
}
|
|
|
|
// checkOrigin validates the origin against allowed origins
|
|
func (h *Hub) checkOrigin(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
// No origin header, allow for non-browser clients
|
|
return true
|
|
}
|
|
|
|
h.mu.RLock()
|
|
allowedOrigins := h.allowedOrigins
|
|
h.mu.RUnlock()
|
|
|
|
// Determine the actual origin (accounting for proxy headers)
|
|
scheme := "http"
|
|
if r.TLS != nil {
|
|
scheme = "https"
|
|
}
|
|
|
|
// Check X-Forwarded-Proto or X-Forwarded-Scheme for proxied requests
|
|
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
|
|
scheme = normalizeForwardedProto(forwardedProto, scheme)
|
|
} else if forwardedScheme := r.Header.Get("X-Forwarded-Scheme"); forwardedScheme != "" {
|
|
scheme = normalizeForwardedProto(forwardedScheme, scheme)
|
|
}
|
|
|
|
// Use X-Forwarded-Host if present (for proxied requests)
|
|
host := r.Host
|
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
|
host = forwardedHost
|
|
}
|
|
|
|
requestOrigin := scheme + "://" + host
|
|
|
|
// Allow same-origin requests
|
|
if origin == requestOrigin {
|
|
return true
|
|
}
|
|
|
|
// Check if wildcard is allowed
|
|
for _, allowed := range allowedOrigins {
|
|
if allowed == "*" {
|
|
return true
|
|
}
|
|
if allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// If no origins configured, only allow from truly private networks
|
|
// SECURITY: This is a relaxed policy for homelab deployments. For production,
|
|
// configure explicit allowed origins via WEBSOCKET_ALLOWED_ORIGINS env var.
|
|
if len(allowedOrigins) == 0 {
|
|
// Parse the origin URL to validate it properly
|
|
originHost := origin
|
|
if strings.HasPrefix(origin, "http://") {
|
|
originHost = strings.TrimPrefix(origin, "http://")
|
|
} else if strings.HasPrefix(origin, "https://") {
|
|
originHost = strings.TrimPrefix(origin, "https://")
|
|
}
|
|
|
|
// Extract just the hostname/IP part (remove port)
|
|
if colonIdx := strings.IndexByte(originHost, ':'); colonIdx != -1 {
|
|
originHost = originHost[:colonIdx]
|
|
}
|
|
|
|
// SECURITY: Validate that the peer IP is also from a private network
|
|
// This mitigates CSWSH where a malicious page on the same LAN tries to
|
|
// hijack a WebSocket connection using the victim's session cookie
|
|
peerIP := extractPeerIP(r.RemoteAddr)
|
|
peerIsPrivate := isValidPrivateOrigin(peerIP)
|
|
|
|
// Check if it's a valid private IP or localhost
|
|
if isValidPrivateOrigin(originHost) {
|
|
if !peerIsPrivate {
|
|
// Origin claims to be private but peer is public - suspicious
|
|
log.Warn().
|
|
Str("origin", origin).
|
|
Str("origin_host", originHost).
|
|
Str("peer_ip", peerIP).
|
|
Msg("WebSocket rejected - origin is private but peer is public (potential CSWSH)")
|
|
return false
|
|
}
|
|
log.Debug().
|
|
Str("origin", origin).
|
|
Str("host", originHost).
|
|
Str("peer_ip", peerIP).
|
|
Msg("Allowing WebSocket connection from private network (no explicit origins configured)")
|
|
return true
|
|
}
|
|
|
|
// Note: same-origin match already handled above (line 116)
|
|
log.Warn().
|
|
Str("origin", origin).
|
|
Str("requestOrigin", requestOrigin).
|
|
Str("peer_ip", peerIP).
|
|
Msg("WebSocket connection rejected - not from allowed local/private network")
|
|
return false
|
|
}
|
|
|
|
log.Warn().
|
|
Str("origin", origin).
|
|
Str("requestOrigin", requestOrigin).
|
|
Strs("allowedOrigins", allowedOrigins).
|
|
Msg("WebSocket connection rejected due to CORS")
|
|
|
|
return false
|
|
}
|
|
|
|
// Client represents a WebSocket client
|
|
type Client struct {
|
|
hub *Hub
|
|
conn *websocket.Conn
|
|
send chan []byte
|
|
id string
|
|
orgID string // Organization ID for tenant isolation
|
|
lastPing time.Time
|
|
closed atomic.Bool // Set when the client is unregistered; prevents sends to closed channel
|
|
writeFailures int32 // Consecutive write failures; disconnects after maxWriteFailures
|
|
}
|
|
|
|
// safeSend attempts to send data to the client's send channel.
|
|
// Returns false if the client is closed or the channel buffer is full.
|
|
// Uses defer/recover to handle the race between close(c.send) and send.
|
|
func (c *Client) safeSend(data []byte) (sent bool) {
|
|
// Early check to avoid most attempts on closed clients.
|
|
if c.closed.Load() {
|
|
return false
|
|
}
|
|
|
|
// Recover from panic if the channel was closed between the check above
|
|
// and the send below. This is a defensive pattern to prevent server crashes.
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
// Channel was closed concurrently; mark as not sent.
|
|
sent = false
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case c.send <- data:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// cloneAlertData returns a broadcast-safe copy of alert data to avoid data races when
|
|
// downstream sanitization/encoding happens concurrently with alert manager mutations.
|
|
func cloneAlertData(alert interface{}) interface{} {
|
|
switch a := alert.(type) {
|
|
case *alerts.Alert:
|
|
cloned := cloneAlert(a)
|
|
return cloned
|
|
case alerts.Alert:
|
|
cloned := cloneAlert(&a)
|
|
return cloned
|
|
default:
|
|
return alert
|
|
}
|
|
}
|
|
|
|
// cloneAlert performs a deep copy of the mutable fields within alerts.Alert.
|
|
func cloneAlert(src *alerts.Alert) alerts.Alert {
|
|
if src == nil {
|
|
return alerts.Alert{}
|
|
}
|
|
clone := *src
|
|
|
|
if src.AckTime != nil {
|
|
t := *src.AckTime
|
|
clone.AckTime = &t
|
|
}
|
|
|
|
if len(src.EscalationTimes) > 0 {
|
|
clone.EscalationTimes = append([]time.Time(nil), src.EscalationTimes...)
|
|
}
|
|
|
|
if src.Metadata != nil {
|
|
clone.Metadata = cloneMetadata(src.Metadata)
|
|
}
|
|
|
|
return clone
|
|
}
|
|
|
|
// cloneMetadata creates a deep copy of alert metadata to detach from shared maps/slices.
|
|
func cloneMetadata(src map[string]interface{}) map[string]interface{} {
|
|
if src == nil {
|
|
return nil
|
|
}
|
|
|
|
dst := make(map[string]interface{}, len(src))
|
|
for k, v := range src {
|
|
dst[k] = cloneMetadataValue(v)
|
|
}
|
|
return dst
|
|
}
|
|
|
|
func cloneMetadataValue(value interface{}) interface{} {
|
|
switch v := value.(type) {
|
|
case map[string]interface{}:
|
|
return cloneMetadata(v)
|
|
case map[string]string:
|
|
m := make(map[string]interface{}, len(v))
|
|
for key, val := range v {
|
|
m[key] = val
|
|
}
|
|
return m
|
|
case []interface{}:
|
|
arr := make([]interface{}, len(v))
|
|
for i, elem := range v {
|
|
arr[i] = cloneMetadataValue(elem)
|
|
}
|
|
return arr
|
|
case []string:
|
|
arr := make([]string, len(v))
|
|
copy(arr, v)
|
|
return arr
|
|
case []int:
|
|
arr := make([]int, len(v))
|
|
copy(arr, v)
|
|
return arr
|
|
case []float64:
|
|
arr := make([]float64, len(v))
|
|
copy(arr, v)
|
|
return arr
|
|
default:
|
|
return v
|
|
}
|
|
}
|
|
|
|
// TenantBroadcast represents a broadcast message targeted at a specific tenant.
|
|
type TenantBroadcast struct {
|
|
OrgID string
|
|
Message Message
|
|
}
|
|
|
|
// OrgAuthChecker checks if a user/token can access an organization.
|
|
type OrgAuthChecker interface {
|
|
// CanAccessOrg checks if the given user (and optional token) can access the org.
|
|
CanAccessOrg(userID string, token interface{}, orgID string) bool
|
|
}
|
|
|
|
// MultiTenantCheckResult contains the result of a multi-tenant check.
|
|
type MultiTenantCheckResult struct {
|
|
Allowed bool
|
|
FeatureEnabled bool // false = feature flag disabled
|
|
Licensed bool // false = not licensed (only meaningful if FeatureEnabled is true)
|
|
Reason string // Human-readable reason for denial
|
|
}
|
|
|
|
// MultiTenantChecker checks if multi-tenant functionality is enabled and licensed.
|
|
type MultiTenantChecker interface {
|
|
// CheckMultiTenant checks if multi-tenant is enabled (feature flag) and licensed for the org.
|
|
// Returns a result with details about why access was denied if not allowed.
|
|
CheckMultiTenant(ctx context.Context, orgID string) MultiTenantCheckResult
|
|
}
|
|
|
|
// Hub maintains active WebSocket clients and broadcasts messages
|
|
type Hub struct {
|
|
clients map[*Client]bool // All clients (legacy support)
|
|
clientsByTenant map[string]map[*Client]bool // Per-tenant client tracking
|
|
broadcast chan []byte
|
|
broadcastSeq chan Message // Sequenced broadcast channel for ordering
|
|
tenantBroadcast chan TenantBroadcast // Per-tenant broadcast channel
|
|
register chan *Client
|
|
unregister chan *Client
|
|
stopChan chan struct{} // Signals shutdown
|
|
mu sync.RWMutex
|
|
getState func() interface{} // Function to get current state (legacy)
|
|
getStateByTenant func(orgID string) interface{} // Function to get state for specific tenant
|
|
allowedOrigins []string // Allowed origins for CORS
|
|
orgAuthChecker OrgAuthChecker // Org authorization checker
|
|
multiTenantChecker MultiTenantChecker // Multi-tenant feature flag and license checker
|
|
singleTenantMode bool // Ignore tenant selection and force default org
|
|
// Broadcast coalescing fields
|
|
coalesceWindow time.Duration
|
|
coalescePending *Message
|
|
coalesceTimer *time.Timer
|
|
coalesceMutex sync.Mutex
|
|
// Per-tenant coalescing
|
|
tenantCoalescePending map[string]*Message
|
|
tenantCoalesceTimers map[string]*time.Timer
|
|
}
|
|
|
|
// Message represents a WebSocket message
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
Data interface{} `json:"data"`
|
|
Timestamp string `json:"timestamp,omitempty"`
|
|
}
|
|
|
|
// SetStateGetter sets the state getter function (legacy, for default tenant)
|
|
func (h *Hub) SetStateGetter(getState func() interface{}) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.getState = getState
|
|
}
|
|
|
|
// SetStateGetterForTenant sets the tenant-aware state getter function
|
|
func (h *Hub) SetStateGetterForTenant(getState func(orgID string) interface{}) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.getStateByTenant = getState
|
|
}
|
|
|
|
// SetOrgAuthChecker sets the organization authorization checker.
|
|
func (h *Hub) SetOrgAuthChecker(checker OrgAuthChecker) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.orgAuthChecker = checker
|
|
}
|
|
|
|
// SetMultiTenantChecker sets the multi-tenant feature flag and license checker.
|
|
func (h *Hub) SetMultiTenantChecker(checker MultiTenantChecker) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.multiTenantChecker = checker
|
|
}
|
|
|
|
// SetSingleTenantMode forces all connections to use the default org.
|
|
func (h *Hub) SetSingleTenantMode(enabled bool) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.singleTenantMode = enabled
|
|
}
|
|
|
|
// getStateForClient returns the state for a specific client based on their tenant
|
|
func (h *Hub) getStateForClient(client *Client) interface{} {
|
|
h.mu.RLock()
|
|
getStateByTenant := h.getStateByTenant
|
|
getState := h.getState
|
|
h.mu.RUnlock()
|
|
|
|
// Try tenant-specific getter first
|
|
if getStateByTenant != nil && client.orgID != "" {
|
|
return getStateByTenant(client.orgID)
|
|
}
|
|
|
|
// Fall back to default getter
|
|
if getState != nil {
|
|
return getState()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NewHub creates a new WebSocket hub
|
|
func NewHub(getState func() interface{}) *Hub {
|
|
return &Hub{
|
|
clients: make(map[*Client]bool),
|
|
clientsByTenant: make(map[string]map[*Client]bool),
|
|
broadcast: make(chan []byte, 256),
|
|
broadcastSeq: make(chan Message, 256), // Buffered sequenced channel
|
|
tenantBroadcast: make(chan TenantBroadcast, 256), // Per-tenant broadcasts
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
stopChan: make(chan struct{}),
|
|
getState: getState,
|
|
allowedOrigins: []string{}, // Default to empty (will be set based on actual host)
|
|
coalesceWindow: 100 * time.Millisecond, // Coalesce rapid updates within 100ms
|
|
tenantCoalescePending: make(map[string]*Message),
|
|
tenantCoalesceTimers: make(map[string]*time.Timer),
|
|
}
|
|
}
|
|
|
|
// Run starts the hub's main loop
|
|
func (h *Hub) Run() {
|
|
// Start broadcast sequencer goroutine
|
|
go h.runBroadcastSequencer()
|
|
|
|
pingTicker := time.NewTicker(30 * time.Second)
|
|
defer pingTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.mu.Lock()
|
|
h.clients[client] = true
|
|
// Also register by tenant if org ID is set
|
|
if client.orgID != "" {
|
|
if h.clientsByTenant[client.orgID] == nil {
|
|
h.clientsByTenant[client.orgID] = make(map[*Client]bool)
|
|
}
|
|
h.clientsByTenant[client.orgID][client] = true
|
|
}
|
|
h.mu.Unlock()
|
|
log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client connected")
|
|
|
|
// Send initial state to the new client immediately
|
|
// Use tenant-aware state getter if available
|
|
hasGetState := h.getState != nil || h.getStateByTenant != nil
|
|
log.Debug().Bool("hasGetState", hasGetState).Msg("Checking getState function for new client")
|
|
if hasGetState {
|
|
// Add a small delay to ensure client is ready
|
|
go func() {
|
|
log.Debug().Str("client", client.id).Msg("Starting initial state goroutine")
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// First send a small welcome message
|
|
welcomeMsg := Message{
|
|
Type: "welcome",
|
|
Data: map[string]string{"message": "Connected to Pulse WebSocket", "orgId": client.orgID},
|
|
}
|
|
if data, err := json.Marshal(welcomeMsg); err == nil {
|
|
// Check if client is still registered before sending (must hold lock)
|
|
h.mu.RLock()
|
|
_, stillRegistered := h.clients[client]
|
|
h.mu.RUnlock()
|
|
|
|
if stillRegistered {
|
|
log.Info().Str("client", client.id).Msg("Sending welcome message")
|
|
if client.safeSend(data) {
|
|
log.Info().Str("client", client.id).Msg("Welcome message sent")
|
|
} else {
|
|
log.Warn().Str("client", client.id).Msg("Failed to send welcome message - client closed or buffer full")
|
|
}
|
|
} else {
|
|
log.Debug().Str("client", client.id).Msg("Client disconnected before welcome message")
|
|
}
|
|
}
|
|
|
|
// Then send the initial state after another delay
|
|
time.Sleep(100 * time.Millisecond)
|
|
log.Debug().Str("client", client.id).Msg("About to get state")
|
|
|
|
// Get the state using tenant-aware getter
|
|
stateData := h.getStateForClient(client)
|
|
log.Debug().Str("client", client.id).Interface("stateType", fmt.Sprintf("%T", stateData)).Msg("Got state for initial message")
|
|
|
|
initialMsg := Message{
|
|
Type: "initialState",
|
|
Data: sanitizeData(stateData),
|
|
}
|
|
if data, err := json.Marshal(initialMsg); err == nil {
|
|
// Check if client is still registered before sending (must hold lock)
|
|
h.mu.RLock()
|
|
_, stillRegistered := h.clients[client]
|
|
h.mu.RUnlock()
|
|
|
|
if stillRegistered {
|
|
log.Info().Str("client", client.id).Int("dataLen", len(data)).Int("dataKB", len(data)/1024).Msg("Sending initial state to client")
|
|
if client.safeSend(data) {
|
|
log.Info().Str("client", client.id).Msg("Initial state sent successfully")
|
|
} else {
|
|
log.Warn().Str("client", client.id).Msg("Client closed or buffer full, skipping initial state")
|
|
}
|
|
} else {
|
|
log.Debug().Str("client", client.id).Msg("Client disconnected before initial state")
|
|
}
|
|
} else {
|
|
log.Error().Err(err).Str("client", client.id).Msg("Failed to marshal initial state")
|
|
}
|
|
}()
|
|
} else {
|
|
log.Warn().Msg("No getState function defined")
|
|
}
|
|
|
|
case client := <-h.unregister:
|
|
h.mu.Lock()
|
|
if _, ok := h.clients[client]; ok {
|
|
delete(h.clients, client)
|
|
// Also remove from tenant map
|
|
if client.orgID != "" && h.clientsByTenant[client.orgID] != nil {
|
|
delete(h.clientsByTenant[client.orgID], client)
|
|
// Clean up empty tenant maps
|
|
if len(h.clientsByTenant[client.orgID]) == 0 {
|
|
delete(h.clientsByTenant, client.orgID)
|
|
}
|
|
}
|
|
client.closed.Store(true) // Mark closed before closing channel to prevent sends
|
|
close(client.send)
|
|
h.mu.Unlock()
|
|
log.Info().Str("client", client.id).Str("org_id", client.orgID).Msg("WebSocket client disconnected")
|
|
} else {
|
|
h.mu.Unlock()
|
|
}
|
|
|
|
case message := <-h.broadcast:
|
|
h.mu.RLock()
|
|
clients := make([]*Client, 0, len(h.clients))
|
|
for client := range h.clients {
|
|
clients = append(clients, client)
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
for _, client := range clients {
|
|
if !client.safeSend(message) {
|
|
// Client closed or buffer full - remove if still registered
|
|
h.mu.Lock()
|
|
if _, stillPresent := h.clients[client]; stillPresent {
|
|
delete(h.clients, client)
|
|
client.closed.Store(true)
|
|
close(client.send)
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
case <-pingTicker.C:
|
|
h.sendPing()
|
|
|
|
case <-h.stopChan:
|
|
log.Info().Msg("WebSocket hub shutting down")
|
|
// Close all client connections
|
|
h.mu.Lock()
|
|
for client := range h.clients {
|
|
client.closed.Store(true)
|
|
close(client.send)
|
|
}
|
|
h.clients = make(map[*Client]bool)
|
|
h.mu.Unlock()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop gracefully shuts down the hub
|
|
func (h *Hub) Stop() {
|
|
close(h.stopChan)
|
|
}
|
|
|
|
// HandleWebSocket handles WebSocket upgrade requests
|
|
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
log.Info().
|
|
Str("origin", r.Header.Get("Origin")).
|
|
Str("host", r.Host).
|
|
Str("userAgent", r.Header.Get("User-Agent")).
|
|
Msg("WebSocket upgrade request")
|
|
|
|
// Extract org ID from request for tenant isolation
|
|
// Priority: Header > Cookie > Query param > Default
|
|
orgID := r.Header.Get("X-Pulse-Org-ID")
|
|
if orgID == "" {
|
|
if cookie, err := r.Cookie("pulse_org_id"); err == nil {
|
|
orgID = cookie.Value
|
|
}
|
|
}
|
|
if orgID == "" {
|
|
orgID = r.URL.Query().Get("org_id")
|
|
}
|
|
if orgID == "" {
|
|
orgID = "default"
|
|
}
|
|
|
|
// Multi-tenant feature flag and license check for non-default orgs
|
|
h.mu.RLock()
|
|
mtChecker := h.multiTenantChecker
|
|
authChecker := h.orgAuthChecker
|
|
singleTenantMode := h.singleTenantMode
|
|
h.mu.RUnlock()
|
|
|
|
if singleTenantMode && orgID != "" && orgID != "default" {
|
|
log.Debug().
|
|
Str("requested_org", orgID).
|
|
Msg("Ignoring non-default org for single-tenant WebSocket runtime")
|
|
orgID = "default"
|
|
}
|
|
|
|
if orgID != "default" {
|
|
// Check if multi-tenant is enabled and licensed
|
|
if mtChecker != nil {
|
|
result := mtChecker.CheckMultiTenant(r.Context(), orgID)
|
|
if !result.Allowed {
|
|
log.Warn().
|
|
Str("org_id", orgID).
|
|
Bool("feature_enabled", result.FeatureEnabled).
|
|
Bool("licensed", result.Licensed).
|
|
Str("reason", result.Reason).
|
|
Msg("WebSocket connection denied - multi-tenant check failed")
|
|
|
|
if !result.FeatureEnabled {
|
|
// Feature flag disabled - 501 Not Implemented
|
|
http.Error(w, "Multi-tenant functionality is not enabled", http.StatusNotImplemented)
|
|
} else {
|
|
// Feature enabled but not licensed - 402 Payment Required
|
|
http.Error(w, "Multi-tenant access requires an Enterprise license", http.StatusPaymentRequired)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// Authorization check - auth context should already be set by AuthContextMiddleware
|
|
if authChecker != nil {
|
|
// Get user and token from context (set by AuthContextMiddleware)
|
|
userID := getUserFromContext(r.Context())
|
|
token := getAPITokenFromContext(r.Context())
|
|
|
|
if !authChecker.CanAccessOrg(userID, token, orgID) {
|
|
log.Warn().
|
|
Str("org_id", orgID).
|
|
Str("user_id", userID).
|
|
Msg("WebSocket connection denied - unauthorized org access")
|
|
http.Error(w, "Unauthorized to access organization", http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create upgrader with our origin check
|
|
upgrader := websocket.Upgrader{
|
|
ReadBufferSize: 1024 * 1024 * 4, // 4MB to handle large state messages
|
|
WriteBufferSize: 1024 * 1024 * 4, // 4MB to handle large state messages
|
|
CheckOrigin: h.checkOrigin,
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to upgrade WebSocket connection")
|
|
return
|
|
}
|
|
|
|
clientID := utils.GenerateID("client")
|
|
client := &Client{
|
|
hub: h,
|
|
conn: conn,
|
|
orgID: orgID,
|
|
// Keep buffer bounded to avoid holding large state snapshots indefinitely if a client stalls.
|
|
send: make(chan []byte, 128),
|
|
id: clientID,
|
|
lastPing: time.Now(),
|
|
}
|
|
|
|
log.Info().Str("client", clientID).Str("org_id", orgID).Msg("WebSocket client created")
|
|
|
|
client.hub.register <- client
|
|
|
|
// Start goroutines for reading and writing
|
|
go client.writePump()
|
|
go client.readPump()
|
|
}
|
|
|
|
// getUserFromContext extracts the user ID from request context.
|
|
func getUserFromContext(ctx context.Context) string {
|
|
return auth.GetUser(ctx)
|
|
}
|
|
|
|
// getAPITokenFromContext extracts the API token from request context.
|
|
func getAPITokenFromContext(ctx context.Context) interface{} {
|
|
return auth.GetAPIToken(ctx)
|
|
}
|
|
|
|
// dispatchToClients fan-outs a marshaled payload to all clients, dropping any that
|
|
// cannot keep up to prevent unbounded buffering.
|
|
func (h *Hub) dispatchToClients(data []byte, dropLog string) {
|
|
h.mu.RLock()
|
|
clients := make([]*Client, 0, len(h.clients))
|
|
for client := range h.clients {
|
|
clients = append(clients, client)
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
for _, client := range clients {
|
|
if !client.safeSend(data) {
|
|
// Client closed or buffer full - remove if still registered
|
|
h.mu.Lock()
|
|
if _, stillPresent := h.clients[client]; stillPresent {
|
|
delete(h.clients, client)
|
|
client.closed.Store(true)
|
|
close(client.send)
|
|
log.Warn().Str("client", client.id).Msg(dropLog)
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) popCoalescedMessage() *Message {
|
|
h.coalesceMutex.Lock()
|
|
defer h.coalesceMutex.Unlock()
|
|
|
|
if h.coalescePending == nil {
|
|
return nil
|
|
}
|
|
|
|
msg := *h.coalescePending
|
|
h.coalescePending = nil
|
|
h.coalesceTimer = nil
|
|
return &msg
|
|
}
|
|
|
|
// runBroadcastSequencer handles sequenced broadcasts with coalescing for rapid state updates
|
|
func (h *Hub) runBroadcastSequencer() {
|
|
for {
|
|
select {
|
|
case msg := <-h.broadcastSeq:
|
|
// Handle raw data (state) messages with coalescing
|
|
if msg.Type == "rawData" {
|
|
h.coalesceMutex.Lock()
|
|
|
|
// Cancel pending timer if exists
|
|
if h.coalesceTimer != nil {
|
|
h.coalesceTimer.Stop()
|
|
}
|
|
|
|
// Update pending message
|
|
current := msg
|
|
h.coalescePending = ¤t
|
|
|
|
// Set timer to send after coalesce window
|
|
h.coalesceTimer = time.AfterFunc(h.coalesceWindow, func() {
|
|
pending := h.popCoalescedMessage()
|
|
if pending != nil {
|
|
if data, err := json.Marshal(*pending); err == nil {
|
|
h.dispatchToClients(data, "Client send channel full, dropping coalesced message and closing connection")
|
|
}
|
|
}
|
|
})
|
|
|
|
h.coalesceMutex.Unlock()
|
|
} else {
|
|
// Non-state messages (alerts, etc.) - send immediately
|
|
if data, err := json.Marshal(msg); err == nil {
|
|
h.dispatchToClients(data, "Client send channel full, dropping message and closing connection")
|
|
}
|
|
}
|
|
|
|
case tb := <-h.tenantBroadcast:
|
|
// Handle tenant-specific broadcasts with coalescing
|
|
if tb.Message.Type == "rawData" {
|
|
h.coalesceMutex.Lock()
|
|
|
|
// Cancel pending timer for this tenant if exists
|
|
if timer := h.tenantCoalesceTimers[tb.OrgID]; timer != nil {
|
|
timer.Stop()
|
|
}
|
|
|
|
// Update pending message for this tenant
|
|
msgCopy := tb.Message
|
|
h.tenantCoalescePending[tb.OrgID] = &msgCopy
|
|
|
|
// Set timer to send after coalesce window
|
|
orgID := tb.OrgID // Capture for closure
|
|
h.tenantCoalesceTimers[orgID] = time.AfterFunc(h.coalesceWindow, func() {
|
|
h.coalesceMutex.Lock()
|
|
pending := h.tenantCoalescePending[orgID]
|
|
delete(h.tenantCoalescePending, orgID)
|
|
delete(h.tenantCoalesceTimers, orgID)
|
|
h.coalesceMutex.Unlock()
|
|
|
|
if pending != nil {
|
|
if data, err := json.Marshal(*pending); err == nil {
|
|
h.dispatchToTenantClients(orgID, data, "Client send channel full, dropping tenant coalesced message and closing connection")
|
|
}
|
|
}
|
|
})
|
|
|
|
h.coalesceMutex.Unlock()
|
|
} else {
|
|
// Non-state messages - send immediately to tenant
|
|
if data, err := json.Marshal(tb.Message); err == nil {
|
|
h.dispatchToTenantClients(tb.OrgID, data, "Client send channel full, dropping tenant message and closing connection")
|
|
}
|
|
}
|
|
|
|
case <-h.stopChan:
|
|
log.Debug().Msg("Broadcast sequencer shutting down")
|
|
// Cancel pending timer if exists
|
|
h.coalesceMutex.Lock()
|
|
if h.coalesceTimer != nil {
|
|
h.coalesceTimer.Stop()
|
|
}
|
|
// Cancel all tenant timers
|
|
for _, timer := range h.tenantCoalesceTimers {
|
|
if timer != nil {
|
|
timer.Stop()
|
|
}
|
|
}
|
|
h.coalesceMutex.Unlock()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastState broadcasts state update to all clients via sequencer
|
|
func (h *Hub) BroadcastState(state interface{}) {
|
|
// Debug log to track docker hosts
|
|
dockerHostsCount := -1
|
|
// Use reflection to get dockerHosts field from any struct type
|
|
v := reflect.ValueOf(state)
|
|
if v.Kind() == reflect.Struct {
|
|
field := v.FieldByName("DockerHosts")
|
|
if field.IsValid() && field.Kind() == reflect.Slice {
|
|
dockerHostsCount = field.Len()
|
|
}
|
|
}
|
|
log.Debug().Int("dockerHostsCount", dockerHostsCount).Msg("Broadcasting state")
|
|
|
|
msg := Message{
|
|
Type: "rawData",
|
|
Data: state,
|
|
}
|
|
|
|
// Send through sequencer for ordering and coalescing
|
|
select {
|
|
case h.broadcastSeq <- msg:
|
|
default:
|
|
log.Warn().Msg("Broadcast sequencer channel full, dropping state update")
|
|
}
|
|
}
|
|
|
|
// BroadcastStateToTenant broadcasts state update only to clients of a specific tenant.
|
|
func (h *Hub) BroadcastStateToTenant(orgID string, state interface{}) {
|
|
log.Debug().Str("org_id", orgID).Msg("Broadcasting state to tenant")
|
|
|
|
msg := Message{
|
|
Type: "rawData",
|
|
Data: state,
|
|
}
|
|
|
|
// Send through tenant broadcast channel
|
|
select {
|
|
case h.tenantBroadcast <- TenantBroadcast{OrgID: orgID, Message: msg}:
|
|
default:
|
|
log.Warn().Str("org_id", orgID).Msg("Tenant broadcast channel full, dropping state update")
|
|
}
|
|
}
|
|
|
|
// dispatchToTenantClients sends a message only to clients of a specific tenant.
|
|
func (h *Hub) dispatchToTenantClients(orgID string, data []byte, dropLog string) {
|
|
h.mu.RLock()
|
|
tenantClients := h.clientsByTenant[orgID]
|
|
if tenantClients == nil {
|
|
h.mu.RUnlock()
|
|
return
|
|
}
|
|
clients := make([]*Client, 0, len(tenantClients))
|
|
for client := range tenantClients {
|
|
clients = append(clients, client)
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
for _, client := range clients {
|
|
if !client.safeSend(data) {
|
|
// Client closed or buffer full - remove if still registered
|
|
h.mu.Lock()
|
|
if _, stillPresent := h.clients[client]; stillPresent {
|
|
delete(h.clients, client)
|
|
if h.clientsByTenant[client.orgID] != nil {
|
|
delete(h.clientsByTenant[client.orgID], client)
|
|
}
|
|
client.closed.Store(true)
|
|
close(client.send)
|
|
log.Warn().Str("client", client.id).Str("org_id", client.orgID).Msg(dropLog)
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetTenantClientCount returns the number of connected clients for a specific tenant.
|
|
func (h *Hub) GetTenantClientCount(orgID string) int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
if tenantClients := h.clientsByTenant[orgID]; tenantClients != nil {
|
|
return len(tenantClients)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// BroadcastAlert broadcasts alert to all clients
|
|
func (h *Hub) BroadcastAlert(alert interface{}) {
|
|
log.Info().Interface("alert", alert).Msg("Broadcasting alert to WebSocket clients")
|
|
msg := Message{
|
|
Type: "alert",
|
|
Data: cloneAlertData(alert),
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// BroadcastAlertResolved broadcasts alert resolution to all clients
|
|
func (h *Hub) BroadcastAlertResolved(alertID string) {
|
|
log.Info().Str("alertID", alertID).Msg("Broadcasting alert resolved to WebSocket clients")
|
|
msg := Message{
|
|
Type: "alertResolved",
|
|
Data: map[string]string{"alertId": alertID},
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// GetClientCount returns the number of connected clients
|
|
func (h *Hub) GetClientCount() int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
return len(h.clients)
|
|
}
|
|
|
|
// Broadcast sends a custom message to all connected clients
|
|
func (h *Hub) Broadcast(data interface{}) {
|
|
h.BroadcastMessage(Message{
|
|
Type: "custom",
|
|
Data: data,
|
|
Timestamp: time.Now().Format(time.RFC3339),
|
|
})
|
|
}
|
|
|
|
// BroadcastMessage sends a message to all clients
|
|
func (h *Hub) BroadcastMessage(msg Message) {
|
|
// Sanitize the message data to handle NaN values
|
|
msg.Data = sanitizeData(msg.Data)
|
|
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("type", msg.Type).Msg("Failed to marshal WebSocket message")
|
|
// Try to marshal without data to see what's failing
|
|
debugMsg := Message{Type: msg.Type, Data: "[error marshaling data]"}
|
|
if debugData, debugErr := json.Marshal(debugMsg); debugErr == nil {
|
|
log.Debug().Str("debugMsg", string(debugData)).Msg("Debug message")
|
|
}
|
|
return
|
|
}
|
|
|
|
select {
|
|
case h.broadcast <- data:
|
|
default:
|
|
log.Warn().Msg("WebSocket broadcast channel full")
|
|
}
|
|
}
|
|
|
|
// sendPing sends a ping message to all clients
|
|
func (h *Hub) sendPing() {
|
|
msg := Message{
|
|
Type: "ping",
|
|
Data: map[string]int64{"timestamp": time.Now().Unix()},
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// readPump handles incoming messages from the client
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
log.Info().Str("client", c.id).Msg("ReadPump exiting")
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
if err := c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)); err != nil {
|
|
log.Warn().Err(err).Str("client", c.id).Msg("Failed to set initial read deadline")
|
|
}
|
|
c.conn.SetPongHandler(func(string) error {
|
|
if err := c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)); err != nil {
|
|
log.Warn().Err(err).Str("client", c.id).Msg("Failed to refresh read deadline on pong")
|
|
}
|
|
c.lastPing = time.Now()
|
|
return nil
|
|
})
|
|
|
|
log.Info().Str("client", c.id).Msg("ReadPump started")
|
|
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Error().Err(err).Str("client", c.id).Msg("WebSocket read error")
|
|
} else {
|
|
log.Info().Err(err).Str("client", c.id).Msg("WebSocket closed")
|
|
}
|
|
break
|
|
}
|
|
|
|
// Handle incoming messages
|
|
var msg Message
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
log.Error().Err(err).Str("client", c.id).Msg("Failed to unmarshal WebSocket message")
|
|
continue
|
|
}
|
|
|
|
// Handle different message types
|
|
switch msg.Type {
|
|
case "ping":
|
|
// Respond with pong
|
|
pong := Message{
|
|
Type: "pong",
|
|
Data: map[string]int64{"timestamp": time.Now().Unix()},
|
|
}
|
|
if data, err := json.Marshal(pong); err == nil {
|
|
c.safeSend(data)
|
|
}
|
|
case "requestData":
|
|
// Send current state
|
|
if c.hub.getState != nil {
|
|
stateMsg := Message{
|
|
Type: "rawData",
|
|
Data: sanitizeData(c.hub.getState()),
|
|
}
|
|
if data, err := json.Marshal(stateMsg); err == nil {
|
|
c.safeSend(data)
|
|
} else {
|
|
log.Error().Err(err).Msg("Failed to marshal state for requestData")
|
|
}
|
|
}
|
|
default:
|
|
log.Debug().Str("client", c.id).Str("type", msg.Type).Msg("Received WebSocket message")
|
|
}
|
|
}
|
|
}
|
|
|
|
// writePump handles outgoing messages to the client
|
|
func (c *Client) writePump() {
|
|
// Maximum consecutive write failures before disconnecting.
|
|
// This provides graceful degradation for slow clients.
|
|
const maxWriteFailures = 3
|
|
// Write deadline for messages. Increased from 10s to 30s to handle
|
|
// large state payloads on slower connections (e.g., Raspberry Pi, slow networks).
|
|
const writeDeadline = 30 * time.Second
|
|
// Ping deadline can be shorter since pings are small
|
|
const pingDeadline = 10 * time.Second
|
|
|
|
ticker := time.NewTicker(54 * time.Second)
|
|
defer func() {
|
|
log.Info().Str("client", c.id).Msg("WritePump exiting")
|
|
ticker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
log.Info().Str("client", c.id).Msg("WritePump started")
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
if err := c.conn.SetWriteDeadline(time.Now().Add(writeDeadline)); err != nil {
|
|
log.Warn().Err(err).Str("client", c.id).Msg("Failed to set write deadline before message send")
|
|
}
|
|
if !ok {
|
|
log.Debug().Str("client", c.id).Msg("Send channel closed")
|
|
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
|
|
log.Warn().Err(err).Str("client", c.id).Msg("Failed to send close message")
|
|
}
|
|
return
|
|
}
|
|
|
|
// Send the primary message
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
|
c.writeFailures++
|
|
log.Warn().
|
|
Err(err).
|
|
Str("client", c.id).
|
|
Int("msgSize", len(message)).
|
|
Int32("consecutiveFailures", c.writeFailures).
|
|
Msg("Failed to write message")
|
|
|
|
// Graceful degradation: only disconnect after multiple consecutive failures
|
|
if c.writeFailures >= maxWriteFailures {
|
|
log.Error().
|
|
Str("client", c.id).
|
|
Int32("failures", c.writeFailures).
|
|
Msg("Too many consecutive write failures, disconnecting client")
|
|
return
|
|
}
|
|
// Skip this message and continue - don't disconnect immediately
|
|
continue
|
|
}
|
|
|
|
// Reset failure count on successful write
|
|
c.writeFailures = 0
|
|
|
|
// Send any queued messages
|
|
n := len(c.send)
|
|
flushLoop:
|
|
for i := 0; i < n; i++ {
|
|
select {
|
|
case msg := <-c.send:
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
|
log.Warn().Err(err).Str("client", c.id).Int("msgSize", len(msg)).Msg("Failed to flush queued message")
|
|
// Don't disconnect on queued message failure, just break the flush loop
|
|
break flushLoop
|
|
}
|
|
default:
|
|
// No more messages
|
|
}
|
|
}
|
|
|
|
case <-ticker.C:
|
|
if err := c.conn.SetWriteDeadline(time.Now().Add(pingDeadline)); err != nil {
|
|
log.Warn().Err(err).Str("client", c.id).Msg("Failed to set write deadline for ping")
|
|
}
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
log.Debug().Err(err).Str("client", c.id).Msg("Failed to send ping; closing connection")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// sanitizeData recursively sanitizes data to replace NaN/Inf values with nil
|
|
func sanitizeData(data interface{}) interface{} {
|
|
// First, marshal to JSON to convert structs to maps
|
|
jsonBytes, err := json.Marshal(data)
|
|
if err != nil {
|
|
return data
|
|
}
|
|
|
|
var jsonData interface{}
|
|
if err := json.Unmarshal(jsonBytes, &jsonData); err != nil {
|
|
return data
|
|
}
|
|
|
|
return sanitizeValue(jsonData)
|
|
}
|
|
|
|
// sanitizeValue recursively sanitizes JSON-compatible values
|
|
func sanitizeValue(data interface{}) interface{} {
|
|
switch v := data.(type) {
|
|
case float64:
|
|
if math.IsNaN(v) || math.IsInf(v, 0) {
|
|
return nil
|
|
}
|
|
return v
|
|
case float32:
|
|
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
|
return nil
|
|
}
|
|
return v
|
|
case map[string]interface{}:
|
|
sanitized := make(map[string]interface{})
|
|
for k, val := range v {
|
|
sanitized[k] = sanitizeValue(val)
|
|
}
|
|
return sanitized
|
|
case []interface{}:
|
|
sanitized := make([]interface{}, len(v))
|
|
for i, val := range v {
|
|
sanitized[i] = sanitizeValue(val)
|
|
}
|
|
return sanitized
|
|
default:
|
|
// For all other types (string, bool, nil, etc.), return as-is
|
|
return v
|
|
}
|
|
}
|