mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-11 21:28:15 +00:00
- Fix CPU core display to show for all guests with CPU data - Previously only showed cores when CPU > 0 (truthy) - Now shows "(0.0/X cores)" consistently for all running/stopped guests - Improve code organization with new helper utilities - Clean up import statements and remove debug logs
479 lines
No EOL
12 KiB
Go
479 lines
No EOL
12 KiB
Go
package websocket
|
|
|
|
import (
|
|
"encoding/json"
|
|
"math"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// SetAllowedOrigins sets the allowed origins for CORS
|
|
func (h *Hub) SetAllowedOrigins(origins []string) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.allowedOrigins = origins
|
|
}
|
|
|
|
// checkOrigin validates the origin against allowed origins
|
|
func (h *Hub) checkOrigin(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
// No origin header, allow for non-browser clients
|
|
return true
|
|
}
|
|
|
|
h.mu.RLock()
|
|
allowedOrigins := h.allowedOrigins
|
|
h.mu.RUnlock()
|
|
|
|
// Check if wildcard is allowed
|
|
for _, allowed := range allowedOrigins {
|
|
if allowed == "*" {
|
|
return true
|
|
}
|
|
if allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
|
|
log.Warn().
|
|
Str("origin", origin).
|
|
Strs("allowedOrigins", allowedOrigins).
|
|
Msg("WebSocket connection rejected due to CORS")
|
|
|
|
return false
|
|
}
|
|
|
|
// Client represents a WebSocket client
|
|
type Client struct {
|
|
hub *Hub
|
|
conn *websocket.Conn
|
|
send chan []byte
|
|
id string
|
|
lastPing time.Time
|
|
}
|
|
|
|
// Hub maintains active WebSocket clients and broadcasts messages
|
|
type Hub struct {
|
|
clients map[*Client]bool
|
|
broadcast chan []byte
|
|
register chan *Client
|
|
unregister chan *Client
|
|
mu sync.RWMutex
|
|
getState func() interface{} // Function to get current state
|
|
allowedOrigins []string // Allowed origins for CORS
|
|
}
|
|
|
|
// Message represents a WebSocket message
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
Data interface{} `json:"data"`
|
|
Timestamp string `json:"timestamp,omitempty"`
|
|
}
|
|
|
|
// SetStateGetter sets the state getter function
|
|
func (h *Hub) SetStateGetter(getState func() interface{}) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.getState = getState
|
|
}
|
|
|
|
// NewHub creates a new WebSocket hub
|
|
func NewHub(getState func() interface{}) *Hub {
|
|
return &Hub{
|
|
clients: make(map[*Client]bool),
|
|
broadcast: make(chan []byte, 256),
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
getState: getState,
|
|
allowedOrigins: []string{"*"}, // Default to allow all
|
|
}
|
|
}
|
|
|
|
// Run starts the hub's main loop
|
|
func (h *Hub) Run() {
|
|
pingTicker := time.NewTicker(30 * time.Second)
|
|
defer pingTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.mu.Lock()
|
|
h.clients[client] = true
|
|
h.mu.Unlock()
|
|
log.Info().Str("client", client.id).Msg("WebSocket client connected")
|
|
|
|
// Send initial state to the new client immediately
|
|
if h.getState != nil {
|
|
// Add a small delay to ensure client is ready
|
|
go func() {
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// First send a small welcome message
|
|
welcomeMsg := Message{
|
|
Type: "welcome",
|
|
Data: map[string]string{"message": "Connected to Pulse WebSocket"},
|
|
}
|
|
if data, err := json.Marshal(welcomeMsg); err == nil {
|
|
// Check if client is still registered before sending
|
|
if _, ok := h.clients[client]; ok {
|
|
log.Info().Str("client", client.id).Msg("Sending welcome message")
|
|
select {
|
|
case client.send <- data:
|
|
log.Info().Str("client", client.id).Msg("Welcome message sent")
|
|
default:
|
|
log.Warn().Str("client", client.id).Msg("Failed to send welcome message - buffer full")
|
|
}
|
|
} else {
|
|
log.Debug().Str("client", client.id).Msg("Client disconnected before welcome message")
|
|
}
|
|
}
|
|
|
|
// Then send the initial state after another delay
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
initialMsg := Message{
|
|
Type: "initialState",
|
|
Data: sanitizeData(h.getState()),
|
|
}
|
|
if data, err := json.Marshal(initialMsg); err == nil {
|
|
// Check if client is still registered before sending
|
|
if _, ok := h.clients[client]; ok {
|
|
log.Info().Str("client", client.id).Int("dataLen", len(data)).Int("dataKB", len(data)/1024).Msg("Sending initial state to client")
|
|
|
|
select {
|
|
case client.send <- data:
|
|
log.Info().Str("client", client.id).Msg("Initial state sent successfully")
|
|
default:
|
|
log.Warn().Str("client", client.id).Msg("Client send buffer full, skipping initial state")
|
|
}
|
|
} else {
|
|
log.Debug().Str("client", client.id).Msg("Client disconnected before initial state")
|
|
}
|
|
} else {
|
|
log.Error().Err(err).Str("client", client.id).Msg("Failed to marshal initial state")
|
|
}
|
|
}()
|
|
} else {
|
|
log.Warn().Msg("No getState function defined")
|
|
}
|
|
|
|
case client := <-h.unregister:
|
|
h.mu.Lock()
|
|
if _, ok := h.clients[client]; ok {
|
|
delete(h.clients, client)
|
|
close(client.send)
|
|
h.mu.Unlock()
|
|
log.Info().Str("client", client.id).Msg("WebSocket client disconnected")
|
|
} else {
|
|
h.mu.Unlock()
|
|
}
|
|
|
|
case message := <-h.broadcast:
|
|
h.mu.RLock()
|
|
clients := make([]*Client, 0, len(h.clients))
|
|
for client := range h.clients {
|
|
clients = append(clients, client)
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
for _, client := range clients {
|
|
select {
|
|
case client.send <- message:
|
|
default:
|
|
// Client's send channel is full, close it
|
|
h.mu.Lock()
|
|
delete(h.clients, client)
|
|
close(client.send)
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
case <-pingTicker.C:
|
|
h.sendPing()
|
|
}
|
|
}
|
|
}
|
|
|
|
// HandleWebSocket handles WebSocket upgrade requests
|
|
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
log.Info().
|
|
Str("origin", r.Header.Get("Origin")).
|
|
Str("host", r.Host).
|
|
Str("userAgent", r.Header.Get("User-Agent")).
|
|
Msg("WebSocket upgrade request")
|
|
|
|
// Create upgrader with our origin check
|
|
upgrader := websocket.Upgrader{
|
|
ReadBufferSize: 1024 * 64, // 64KB to handle large state messages
|
|
WriteBufferSize: 1024 * 64, // 64KB to handle large state messages
|
|
CheckOrigin: h.checkOrigin,
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to upgrade WebSocket connection")
|
|
return
|
|
}
|
|
|
|
clientID := utils.GenerateID("client")
|
|
client := &Client{
|
|
hub: h,
|
|
conn: conn,
|
|
send: make(chan []byte, 1024), // Increased buffer for high-frequency updates
|
|
id: clientID,
|
|
lastPing: time.Now(),
|
|
}
|
|
|
|
log.Info().Str("client", clientID).Msg("WebSocket client created")
|
|
|
|
client.hub.register <- client
|
|
|
|
// Start goroutines for reading and writing
|
|
go client.writePump()
|
|
go client.readPump()
|
|
}
|
|
|
|
// BroadcastState broadcasts state update to all clients
|
|
func (h *Hub) BroadcastState(state interface{}) {
|
|
msg := Message{
|
|
Type: "rawData",
|
|
Data: state,
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// BroadcastAlert broadcasts alert to all clients
|
|
func (h *Hub) BroadcastAlert(alert interface{}) {
|
|
log.Info().Interface("alert", alert).Msg("Broadcasting alert to WebSocket clients")
|
|
msg := Message{
|
|
Type: "alert",
|
|
Data: alert,
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// BroadcastAlertResolved broadcasts alert resolution to all clients
|
|
func (h *Hub) BroadcastAlertResolved(alertID string) {
|
|
log.Info().Str("alertID", alertID).Msg("Broadcasting alert resolved to WebSocket clients")
|
|
msg := Message{
|
|
Type: "alertResolved",
|
|
Data: map[string]string{"alertId": alertID},
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// GetClientCount returns the number of connected clients
|
|
func (h *Hub) GetClientCount() int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
return len(h.clients)
|
|
}
|
|
|
|
// Broadcast sends a custom message to all connected clients
|
|
func (h *Hub) Broadcast(data interface{}) {
|
|
h.BroadcastMessage(Message{
|
|
Type: "custom",
|
|
Data: data,
|
|
Timestamp: time.Now().Format(time.RFC3339),
|
|
})
|
|
}
|
|
|
|
// BroadcastMessage sends a message to all clients
|
|
func (h *Hub) BroadcastMessage(msg Message) {
|
|
// Sanitize the message data to handle NaN values
|
|
msg.Data = sanitizeData(msg.Data)
|
|
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("type", msg.Type).Msg("Failed to marshal WebSocket message")
|
|
// Try to marshal without data to see what's failing
|
|
debugMsg := Message{Type: msg.Type, Data: "[error marshaling data]"}
|
|
if debugData, debugErr := json.Marshal(debugMsg); debugErr == nil {
|
|
log.Debug().Str("debugMsg", string(debugData)).Msg("Debug message")
|
|
}
|
|
return
|
|
}
|
|
|
|
select {
|
|
case h.broadcast <- data:
|
|
default:
|
|
log.Warn().Msg("WebSocket broadcast channel full")
|
|
}
|
|
}
|
|
|
|
// sendPing sends a ping message to all clients
|
|
func (h *Hub) sendPing() {
|
|
msg := Message{
|
|
Type: "ping",
|
|
Data: map[string]int64{"timestamp": time.Now().Unix()},
|
|
}
|
|
h.BroadcastMessage(msg)
|
|
}
|
|
|
|
// readPump handles incoming messages from the client
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
log.Info().Str("client", c.id).Msg("ReadPump exiting")
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.lastPing = time.Now()
|
|
return nil
|
|
})
|
|
|
|
log.Info().Str("client", c.id).Msg("ReadPump started")
|
|
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Error().Err(err).Str("client", c.id).Msg("WebSocket read error")
|
|
} else {
|
|
log.Info().Err(err).Str("client", c.id).Msg("WebSocket closed")
|
|
}
|
|
break
|
|
}
|
|
|
|
// Handle incoming messages
|
|
var msg Message
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
log.Error().Err(err).Str("client", c.id).Msg("Failed to unmarshal WebSocket message")
|
|
continue
|
|
}
|
|
|
|
// Handle different message types
|
|
switch msg.Type {
|
|
case "ping":
|
|
// Respond with pong
|
|
pong := Message{
|
|
Type: "pong",
|
|
Data: map[string]int64{"timestamp": time.Now().Unix()},
|
|
}
|
|
if data, err := json.Marshal(pong); err == nil {
|
|
c.send <- data
|
|
}
|
|
case "requestData":
|
|
// Send current state
|
|
if c.hub.getState != nil {
|
|
stateMsg := Message{
|
|
Type: "rawData",
|
|
Data: sanitizeData(c.hub.getState()),
|
|
}
|
|
if data, err := json.Marshal(stateMsg); err == nil {
|
|
c.send <- data
|
|
} else {
|
|
log.Error().Err(err).Msg("Failed to marshal state for requestData")
|
|
}
|
|
}
|
|
default:
|
|
log.Debug().Str("client", c.id).Str("type", msg.Type).Msg("Received WebSocket message")
|
|
}
|
|
}
|
|
}
|
|
|
|
// writePump handles outgoing messages to the client
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(54 * time.Second)
|
|
defer func() {
|
|
log.Info().Str("client", c.id).Msg("WritePump exiting")
|
|
ticker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
log.Info().Str("client", c.id).Msg("WritePump started")
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if !ok {
|
|
log.Debug().Str("client", c.id).Msg("Send channel closed")
|
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
// Send the primary message
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
|
log.Error().Err(err).Str("client", c.id).Msg("Failed to write message")
|
|
return
|
|
}
|
|
|
|
// Send any queued messages
|
|
n := len(c.send)
|
|
for i := 0; i < n; i++ {
|
|
select {
|
|
case msg := <-c.send:
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
|
return
|
|
}
|
|
default:
|
|
// No more messages
|
|
}
|
|
}
|
|
|
|
case <-ticker.C:
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// sanitizeData recursively sanitizes data to replace NaN/Inf values with nil
|
|
func sanitizeData(data interface{}) interface{} {
|
|
// First, marshal to JSON to convert structs to maps
|
|
jsonBytes, err := json.Marshal(data)
|
|
if err != nil {
|
|
return data
|
|
}
|
|
|
|
var jsonData interface{}
|
|
if err := json.Unmarshal(jsonBytes, &jsonData); err != nil {
|
|
return data
|
|
}
|
|
|
|
return sanitizeValue(jsonData)
|
|
}
|
|
|
|
// sanitizeValue recursively sanitizes JSON-compatible values
|
|
func sanitizeValue(data interface{}) interface{} {
|
|
switch v := data.(type) {
|
|
case float64:
|
|
if math.IsNaN(v) || math.IsInf(v, 0) {
|
|
return nil
|
|
}
|
|
return v
|
|
case float32:
|
|
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
|
return nil
|
|
}
|
|
return v
|
|
case map[string]interface{}:
|
|
sanitized := make(map[string]interface{})
|
|
for k, val := range v {
|
|
sanitized[k] = sanitizeValue(val)
|
|
}
|
|
return sanitized
|
|
case []interface{}:
|
|
sanitized := make([]interface{}, len(v))
|
|
for i, val := range v {
|
|
sanitized[i] = sanitizeValue(val)
|
|
}
|
|
return sanitized
|
|
default:
|
|
// For all other types (string, bool, nil, etc.), return as-is
|
|
return v
|
|
}
|
|
} |