From 0267a576041407b25ea6ecbf733cfcaabbff03cd Mon Sep 17 00:00:00 2001 From: rcourtman Date: Fri, 3 Oct 2025 12:21:58 +0000 Subject: [PATCH] Guard alert clones for WebSocket broadcasts --- internal/alerts/alerts.go | 125 +++++++++++++++++++------ internal/websocket/concurrency_test.go | 81 ++++++++++++++++ internal/websocket/hub.go | 87 ++++++++++++++++- 3 files changed, 263 insertions(+), 30 deletions(-) create mode 100644 internal/websocket/concurrency_test.go diff --git a/internal/alerts/alerts.go b/internal/alerts/alerts.go index aee9c5252..36810547f 100644 --- a/internal/alerts/alerts.go +++ b/internal/alerts/alerts.go @@ -46,6 +46,75 @@ type Alert struct { EscalationTimes []time.Time `json:"escalationTimes,omitempty"` // Times when escalations were sent } +// Clone returns a deep copy of the alert so it can be safely shared across goroutines. +func (a *Alert) Clone() *Alert { + if a == nil { + return nil + } + + clone := *a + + if a.AckTime != nil { + t := *a.AckTime + clone.AckTime = &t + } + + if len(a.EscalationTimes) > 0 { + clone.EscalationTimes = append([]time.Time(nil), a.EscalationTimes...) + } + + if a.Metadata != nil { + clone.Metadata = cloneMetadata(a.Metadata) + } + + return &clone +} + +func cloneMetadata(src map[string]interface{}) map[string]interface{} { + if src == nil { + return nil + } + + dst := make(map[string]interface{}, len(src)) + for k, v := range src { + dst[k] = cloneMetadataValue(v) + } + return dst +} + +func cloneMetadataValue(val interface{}) interface{} { + switch v := val.(type) { + case map[string]interface{}: + return cloneMetadata(v) + case map[string]string: + m := make(map[string]interface{}, len(v)) + for key, value := range v { + m[key] = value + } + return m + case []interface{}: + arr := make([]interface{}, len(v)) + for i, elem := range v { + arr[i] = cloneMetadataValue(elem) + } + return arr + case []string: + arr := make([]string, len(v)) + copy(arr, v) + return arr + case []int: + arr := make([]int, len(v)) + copy(arr, v) + return arr + case []float64: + arr := make([]float64, len(v)) + copy(arr, v) + return arr + default: + return v + } +} + // ResolvedAlert represents a recently resolved alert type ResolvedAlert struct { *Alert @@ -341,6 +410,21 @@ func (m *Manager) SetEscalateCallback(cb func(alert *Alert, level int)) { m.onEscalate = cb } +// dispatchAlert delivers an alert to the configured callback, cloning it first to +// prevent concurrent mutations from racing with consumers. +func (m *Manager) dispatchAlert(alert *Alert, async bool) { + if m.onAlert == nil || alert == nil { + return + } + + alertCopy := alert.Clone() + if async { + go m.onAlert(alertCopy) + } else { + m.onAlert(alertCopy) + } +} + // UpdateConfig updates the alert configuration func (m *Manager) UpdateConfig(config AlertConfig) { m.mu.Lock() @@ -1163,9 +1247,7 @@ func (m *Manager) checkZFSPoolHealth(storage models.Storage) { m.recentAlerts[stateAlertID] = alert m.historyManager.AddAlert(*alert) - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) log.Warn(). Str("pool", pool.Name). @@ -1219,9 +1301,7 @@ func (m *Manager) checkZFSPoolHealth(storage models.Storage) { m.recentAlerts[errorsAlertID] = alert m.historyManager.AddAlert(*alert) - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) log.Error(). Str("pool", pool.Name). @@ -1285,9 +1365,7 @@ func (m *Manager) checkZFSPoolHealth(storage models.Storage) { m.recentAlerts[alertID] = alert m.historyManager.AddAlert(*alert) - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) log.Warn(). Str("pool", pool.Name). @@ -1555,7 +1633,7 @@ func (m *Manager) checkMetric(resourceID, resourceName, node, instance, resource // Notify callback if m.onAlert != nil { log.Info().Str("alertID", alertID).Msg("Calling onAlert callback") - go m.onAlert(alert) + m.dispatchAlert(alert, true) } else { log.Warn().Msg("No onAlert callback set!") } @@ -1886,9 +1964,7 @@ func (m *Manager) checkNodeOffline(node models.Node) { m.historyManager.AddAlert(*alert) // Send notification after confirmation - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) // Log the critical event log.Error(). @@ -2007,9 +2083,7 @@ func (m *Manager) checkPBSOffline(pbs models.PBSInstance) { Int("confirmations", m.offlineConfirmations[pbs.ID]). Msg("PBS instance is offline") - if m.onAlert != nil { - go m.onAlert(alert) - } + m.dispatchAlert(alert, true) } // clearPBSOfflineAlert removes offline alert when PBS comes back online @@ -2120,9 +2194,7 @@ func (m *Manager) checkStorageOffline(storage models.Storage) { Int("confirmations", m.offlineConfirmations[storage.ID]). Msg("Storage is offline/unavailable") - if m.onAlert != nil { - go m.onAlert(alert) - } + m.dispatchAlert(alert, true) } // clearStorageOfflineAlert removes offline alert when storage comes back online @@ -2250,9 +2322,7 @@ func (m *Manager) checkGuestPoweredOff(guestID, name, node, instanceName, guestT m.historyManager.AddAlert(*alert) // Send notification after confirmation - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) // Log the event log.Warn(). @@ -2880,6 +2950,7 @@ func (m *Manager) LoadActiveAlerts() error { if alert.Level == AlertLevelCritical && now.Sub(alert.StartTime) < 2*time.Hour { // Use a goroutine and add a small delay to avoid notification spam on startup if m.onAlert != nil { + alertCopy := alert.Clone() go func(a *Alert) { time.Sleep(10 * time.Second) // Wait for system to stabilize after restart log.Info(). @@ -2887,7 +2958,7 @@ func (m *Manager) LoadActiveAlerts() error { Str("resource", a.ResourceName). Msg("Sending notification for restored critical alert") m.onAlert(a) - }(alert) + }(alertCopy) } } } @@ -3034,9 +3105,7 @@ func (m *Manager) CheckDiskHealth(instance, node string, disk proxmox.Disk) { m.recentAlerts[alertID] = alert m.historyManager.AddAlert(*alert) - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) log.Error(). Str("node", node). @@ -3103,9 +3172,7 @@ func (m *Manager) CheckDiskHealth(instance, node string, disk proxmox.Disk) { m.recentAlerts[wearoutAlertID] = alert m.historyManager.AddAlert(*alert) - if m.onAlert != nil { - m.onAlert(alert) - } + m.dispatchAlert(alert, false) log.Warn(). Str("node", node). diff --git a/internal/websocket/concurrency_test.go b/internal/websocket/concurrency_test.go new file mode 100644 index 000000000..e5ec6293a --- /dev/null +++ b/internal/websocket/concurrency_test.go @@ -0,0 +1,81 @@ +package websocket + +import ( + "sync" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" + "github.com/rs/zerolog" +) + +func TestBroadcastAlertConcurrentMutation(t *testing.T) { + origLevel := zerolog.GlobalLevel() + zerolog.SetGlobalLevel(zerolog.Disabled) + t.Cleanup(func() { + zerolog.SetGlobalLevel(origLevel) + }) + + hub := NewHub(nil) + + done := make(chan struct{}) + var drain sync.WaitGroup + drain.Add(1) + go func() { + defer drain.Done() + for { + select { + case <-done: + return + case _, ok := <-hub.broadcast: + if !ok { + return + } + } + } + }() + + alert := &alerts.Alert{ + ID: "test-alert", + Type: "cpu", + Level: alerts.AlertLevelWarning, + ResourceID: "vm/100", + Message: "CPU high", + Metadata: map[string]interface{}{ + "initial": true, + }, + StartTime: time.Now(), + } + + var mu sync.Mutex + var wg sync.WaitGroup + iterations := 1000 + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + mu.Lock() + alert.Value = float64(i) + if alert.Metadata != nil { + alert.Metadata["iteration"] = i + } + mu.Unlock() + time.Sleep(time.Microsecond) + } + }() + + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + mu.Lock() + alertCopy := alert.Clone() + mu.Unlock() + hub.BroadcastAlert(alertCopy) + } + }() + + wg.Wait() + close(done) + drain.Wait() +} diff --git a/internal/websocket/hub.go b/internal/websocket/hub.go index bbf9ab98e..b06437282 100644 --- a/internal/websocket/hub.go +++ b/internal/websocket/hub.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" "github.com/rcourtman/pulse-go-rewrite/internal/utils" "github.com/rs/zerolog/log" ) @@ -148,6 +149,90 @@ type Client struct { lastPing time.Time } +// cloneAlertData returns a broadcast-safe copy of alert data to avoid data races when +// downstream sanitization/encoding happens concurrently with alert manager mutations. +func cloneAlertData(alert interface{}) interface{} { + switch a := alert.(type) { + case *alerts.Alert: + cloned := cloneAlert(a) + return cloned + case alerts.Alert: + cloned := cloneAlert(&a) + return cloned + default: + return alert + } +} + +// cloneAlert performs a deep copy of the mutable fields within alerts.Alert. +func cloneAlert(src *alerts.Alert) alerts.Alert { + if src == nil { + return alerts.Alert{} + } + clone := *src + + if src.AckTime != nil { + t := *src.AckTime + clone.AckTime = &t + } + + if len(src.EscalationTimes) > 0 { + clone.EscalationTimes = append([]time.Time(nil), src.EscalationTimes...) + } + + if src.Metadata != nil { + clone.Metadata = cloneMetadata(src.Metadata) + } + + return clone +} + +// cloneMetadata creates a deep copy of alert metadata to detach from shared maps/slices. +func cloneMetadata(src map[string]interface{}) map[string]interface{} { + if src == nil { + return nil + } + + dst := make(map[string]interface{}, len(src)) + for k, v := range src { + dst[k] = cloneMetadataValue(v) + } + return dst +} + +func cloneMetadataValue(value interface{}) interface{} { + switch v := value.(type) { + case map[string]interface{}: + return cloneMetadata(v) + case map[string]string: + m := make(map[string]interface{}, len(v)) + for key, val := range v { + m[key] = val + } + return m + case []interface{}: + arr := make([]interface{}, len(v)) + for i, elem := range v { + arr[i] = cloneMetadataValue(elem) + } + return arr + case []string: + arr := make([]string, len(v)) + copy(arr, v) + return arr + case []int: + arr := make([]int, len(v)) + copy(arr, v) + return arr + case []float64: + arr := make([]float64, len(v)) + copy(arr, v) + return arr + default: + return v + } +} + // Hub maintains active WebSocket clients and broadcasts messages type Hub struct { clients map[*Client]bool @@ -350,7 +435,7 @@ func (h *Hub) BroadcastAlert(alert interface{}) { log.Info().Interface("alert", alert).Msg("Broadcasting alert to WebSocket clients") msg := Message{ Type: "alert", - Data: alert, + Data: cloneAlertData(alert), } h.BroadcastMessage(msg) }