Pulse/internal/agentexec/server.go
2026-03-18 16:06:30 +00:00

1081 lines
31 KiB
Go

package agentexec
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return isAllowedWebSocketOrigin(r)
},
}
var (
jsonMarshal = json.Marshal
writeTextMessage = func(conn *websocket.Conn, data []byte) error {
return conn.WriteMessage(websocket.TextMessage, data)
}
pingInterval = 5 * time.Second
pingWriteWait = 5 * time.Second
readFileTimeout = 30 * time.Second
errServerShuttingDown = errors.New("agent execution server is shutting down")
)
const maxWebSocketMessageBytes int64 = 1 << 20 // 1 MiB
const (
maxAgentIDLength = 128
maxRequestIDLength = 128
maxExecuteCommandLength = 32 * 1024
maxTargetIDLength = 256
maxExecuteCommandTimeoutSeconds = 3600
defaultReadFileMaxBytes int64 = 1 << 20 // 1 MiB
maxReadFileMaxBytes int64 = 10 << 20 // 10 MiB
maxReadFilePathLength = 4096
)
var safeTargetIDPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
// Server manages WebSocket connections from agents
type Server struct {
mu sync.RWMutex
agents map[string]*agentConn // agentID -> connection
pendingReqs map[string]chan CommandResultPayload // scoped request key -> response channel
deploySubs map[string]chan DeployProgressPayload // deploySubKey(agentID, jobID) -> progress subscriber
validateToken func(token string, agentID string) bool
shutdown chan struct{}
shutdownOnce sync.Once
}
type agentConn struct {
conn *websocket.Conn
agent ConnectedAgent
writeMu sync.Mutex
done chan struct{}
doneOnce sync.Once
}
func (ac *agentConn) signalDone() {
ac.doneOnce.Do(func() {
defer func() {
// Some call sites/tests may have already closed done directly.
_ = recover()
}()
close(ac.done)
})
}
// NewServer creates a new agent execution server
func NewServer(validateToken func(token string, agentID string) bool) *Server {
return &Server{
agents: make(map[string]*agentConn),
pendingReqs: make(map[string]chan CommandResultPayload),
deploySubs: make(map[string]chan DeployProgressPayload),
validateToken: validateToken,
shutdown: make(chan struct{}),
}
}
func (s *Server) isShuttingDown() bool {
select {
case <-s.shutdown:
return true
default:
return false
}
}
func pendingRequestKey(agentID, requestID string) string {
return agentID + "\x00" + requestID
}
func deploySubKey(agentID, jobID string) string {
return agentID + "\x00" + jobID
}
func normalizeTarget(targetType, targetID string) (string, string, error) {
normalizedType := strings.ToLower(strings.TrimSpace(targetType))
if normalizedType == "" {
normalizedType = "agent"
}
normalizedTargetID := strings.TrimSpace(targetID)
switch normalizedType {
case "agent":
// Agent-level execution ignores target ID.
return "agent", "", nil
case "container", "vm":
if normalizedTargetID == "" {
return "", "", fmt.Errorf("target id is required for target type %q", normalizedType)
}
if len(normalizedTargetID) > maxTargetIDLength {
return "", "", fmt.Errorf("target id exceeds %d characters", maxTargetIDLength)
}
if !safeTargetIDPattern.MatchString(normalizedTargetID) {
return "", "", fmt.Errorf("target id contains invalid characters")
}
return normalizedType, normalizedTargetID, nil
default:
return "", "", fmt.Errorf("invalid target type %q", targetType)
}
}
func validateExecuteCommandPayload(cmd *ExecuteCommandPayload) error {
if cmd == nil {
return fmt.Errorf("command payload is required")
}
if strings.TrimSpace(cmd.Command) == "" {
return fmt.Errorf("command is required")
}
if len(cmd.Command) > maxExecuteCommandLength {
return fmt.Errorf("command exceeds %d characters", maxExecuteCommandLength)
}
targetType, targetID, err := normalizeTarget(cmd.TargetType, cmd.TargetID)
if err != nil {
return err
}
cmd.TargetType = targetType
cmd.TargetID = targetID
if cmd.Timeout < 0 {
return fmt.Errorf("timeout cannot be negative")
}
if cmd.Timeout > maxExecuteCommandTimeoutSeconds {
return fmt.Errorf("timeout cannot exceed %d seconds", maxExecuteCommandTimeoutSeconds)
}
return nil
}
func validateReadFilePayload(req *ReadFilePayload) error {
if req == nil {
return fmt.Errorf("read file payload is required")
}
req.Path = strings.TrimSpace(req.Path)
if req.Path == "" {
return fmt.Errorf("path is required")
}
if len(req.Path) > maxReadFilePathLength {
return fmt.Errorf("path exceeds %d characters", maxReadFilePathLength)
}
if strings.ContainsAny(req.Path, "\x00\r\n") {
return fmt.Errorf("path contains invalid control characters")
}
targetType, targetID, err := normalizeTarget(req.TargetType, req.TargetID)
if err != nil {
return err
}
req.TargetType = targetType
req.TargetID = targetID
if req.MaxBytes < 0 {
return fmt.Errorf("max bytes cannot be negative")
}
if req.MaxBytes == 0 {
req.MaxBytes = defaultReadFileMaxBytes
}
if req.MaxBytes > maxReadFileMaxBytes {
return fmt.Errorf("max bytes cannot exceed %d", maxReadFileMaxBytes)
}
return nil
}
func isAllowedWebSocketOrigin(r *http.Request) bool {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
// Non-browser clients (expected for agents) usually omit Origin.
return true
}
parsed, err := url.Parse(origin)
if err != nil || parsed.Host == "" {
return false
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return false
}
return normalizeOriginHost(parsed.Host) == normalizeOriginHost(r.Host)
}
func normalizeOriginHost(host string) string {
normalized := strings.TrimSpace(strings.ToLower(host))
if normalized == "" {
return normalized
}
parsedHost, parsedPort, err := net.SplitHostPort(normalized)
if err != nil {
return normalized
}
if parsedPort == "80" || parsedPort == "443" {
return parsedHost
}
return net.JoinHostPort(parsedHost, parsedPort)
}
// HandleWebSocket handles incoming WebSocket connections from agents
func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
remoteAddr := r.RemoteAddr
if s.isShuttingDown() {
http.Error(w, "agent execution server is shutting down", http.StatusServiceUnavailable)
return
}
// CRITICAL: Clear http.Server deadlines BEFORE WebSocket upgrade.
// The http.Server.ReadTimeout sets a deadline on the underlying connection when
// the request starts. We must clear it before the upgrade or the connection will
// be closed when that deadline fires (typically ~15 seconds after connection).
// Use http.ResponseController (Go 1.20+) to clear the deadline.
rc := http.NewResponseController(w)
if err := rc.SetReadDeadline(time.Time{}); err != nil {
log.Debug().
Err(err).
Str("remote_addr", remoteAddr).
Msg("Failed to clear read deadline via ResponseController")
}
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
log.Debug().
Err(err).
Str("remote_addr", remoteAddr).
Msg("Failed to clear write deadline via ResponseController")
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error().Err(err).Str("remote_addr", remoteAddr).Msg("Failed to upgrade WebSocket connection")
return
}
conn.SetReadLimit(maxWebSocketMessageBytes)
closeConn := func(context string) {
if closeErr := conn.Close(); closeErr != nil {
log.Debug().Err(closeErr).Msg(context)
}
}
if s.isShuttingDown() {
conn.Close()
return
}
// Also clear on the WebSocket's underlying connection as a safety net
if netConn := conn.NetConn(); netConn != nil {
if err := netConn.SetReadDeadline(time.Time{}); err != nil {
log.Debug().Err(err).Msg("Failed to clear net.Conn read deadline")
}
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
log.Debug().Err(err).Msg("Failed to clear net.Conn write deadline")
}
}
// Read first message (must be agent_register)
if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
log.Warn().Err(err).Msg("Failed to set initial registration read deadline")
}
_, msgBytes, err := conn.ReadMessage()
if err != nil {
log.Error().Err(err).Str("remote_addr", remoteAddr).Msg("Failed to read registration message")
closeConn("Failed to close connection after registration read error")
return
}
var msg Message
if err := json.Unmarshal(msgBytes, &msg); err != nil {
log.Error().Err(err).Str("remote_addr", remoteAddr).Msg("Failed to parse registration message")
closeConn("Failed to close connection after registration parse error")
return
}
if msg.Type != MsgTypeAgentRegister {
log.Error().Str("type", string(msg.Type)).Str("remote_addr", remoteAddr).Msg("First message must be agent_register")
closeConn("Failed to close connection after invalid first message type")
return
}
// Parse registration payload
var reg AgentRegisterPayload
if err := msg.DecodePayload(&reg); err != nil {
log.Error().Err(err).Str("remote_addr", remoteAddr).Msg("Failed to parse registration payload")
closeConn("Failed to close connection after registration payload parse error")
return
}
reg.AgentID = strings.TrimSpace(reg.AgentID)
if reg.AgentID == "" {
log.Warn().Msg("Agent registration rejected: missing agent_id")
rejMsg, rejErr := NewMessage(MsgTypeRegistered, "", RegisteredPayload{Success: false, Message: "Invalid agent_id"})
if rejErr != nil {
log.Warn().Err(rejErr).Msg("Failed to encode rejection message")
} else if sendErr := s.sendMessage(conn, rejMsg); sendErr != nil {
log.Warn().Err(sendErr).Msg("Failed to send rejection to agent with missing agent_id")
}
conn.Close()
return
}
if len(reg.AgentID) > maxAgentIDLength {
log.Warn().
Int("agent_id_length", len(reg.AgentID)).
Msg("Agent registration rejected: agent_id exceeds maximum length")
rejMsg, rejErr := NewMessage(MsgTypeRegistered, "", RegisteredPayload{Success: false, Message: "Invalid agent_id"})
if rejErr != nil {
log.Warn().Err(rejErr).Msg("Failed to encode rejection for oversized agent_id")
} else if sendErr := s.sendMessage(conn, rejMsg); sendErr != nil {
log.Warn().Err(sendErr).Msg("Failed to send rejection to agent with oversized agent_id")
}
conn.Close()
return
}
// Validate token
if s.validateToken != nil && !s.validateToken(reg.Token, reg.AgentID) {
log.Warn().Str("agent_id", reg.AgentID).Msg("Agent registration rejected: invalid token")
rejectedMsg, err := NewMessage(MsgTypeRegistered, "", RegisteredPayload{Success: false, Message: "Invalid token"})
if err != nil {
log.Warn().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to encode rejection message")
conn.Close()
return
}
if err := s.sendMessage(conn, rejectedMsg); err != nil {
log.Warn().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to send rejection to agent")
}
closeConn("Failed to close connection after registration rejection")
return
}
// Create agent connection
ac := &agentConn{
conn: conn,
agent: ConnectedAgent{
AgentID: reg.AgentID,
Hostname: reg.Hostname,
Version: reg.Version,
Platform: reg.Platform,
Tags: reg.Tags,
ConnectedAt: time.Now(),
},
done: make(chan struct{}),
}
// Clear deadline for normal operation - both on the WebSocket and underlying connection
// This MUST happen BEFORE registering the agent in the map to avoid race conditions
// where other goroutines could call ExecuteCommand while we're still configuring the connection.
if err := conn.SetReadDeadline(time.Time{}); err != nil {
log.Warn().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to clear read deadline after registration")
}
if err := conn.SetWriteDeadline(time.Time{}); err != nil {
log.Warn().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to clear write deadline after registration")
}
if netConn := conn.NetConn(); netConn != nil {
if err := netConn.SetReadDeadline(time.Time{}); err != nil {
log.Warn().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to clear net.Conn read deadline after registration")
}
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
log.Warn().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to clear net.Conn write deadline after registration")
}
}
// Set up ping/pong handlers to keep connection alive
conn.SetPongHandler(func(appData string) error {
// Reset read deadline on pong received
if err := conn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("set read deadline on pong: %w", err)
}
return nil
})
// Register agent - after this point, other goroutines can access the connection
s.mu.Lock()
// Close existing connection if any
if existing, ok := s.agents[reg.AgentID]; ok {
log.Info().
Str("agent_id", reg.AgentID).
Str("hostname", reg.Hostname).
Msg("Replacing existing agent connection")
close(existing.done)
if err := existing.conn.Close(); err != nil {
log.Debug().Err(err).Str("agent_id", reg.AgentID).Msg("Failed to close existing connection during reconnect")
}
}
s.agents[reg.AgentID] = ac
s.mu.Unlock()
log.Info().
Str("agent_id", reg.AgentID).
Str("hostname", reg.Hostname).
Str("version", reg.Version).
Str("platform", reg.Platform).
Msg("Agent connected")
// Send registration success
ackMsg, ackErr := NewMessage(MsgTypeRegistered, "", RegisteredPayload{Success: true, Message: "Registered"})
if ackErr != nil {
log.Warn().Err(ackErr).Str("agent_id", reg.AgentID).Msg("Failed to encode registration ack")
conn.Close()
return
}
ac.writeMu.Lock()
if sendErr := s.sendMessage(conn, ackMsg); sendErr != nil {
log.Warn().
Err(sendErr).
Str("agent_id", reg.AgentID).
Str("hostname", reg.Hostname).
Msg("Failed to send registration ack")
}
ac.writeMu.Unlock()
// Start server-side ping loop to keep connection alive
pingDone := make(chan struct{})
go s.pingLoop(ac, pingDone)
defer close(pingDone)
// Run read loop (blocking) - don't use goroutine, or HTTP handler will close connection
s.readLoop(ac)
}
func (s *Server) readLoop(ac *agentConn) {
defer func() {
agentID := ac.agent.AgentID
s.mu.Lock()
if existing, ok := s.agents[agentID]; ok && existing == ac {
delete(s.agents, agentID)
}
// Close all deploy progress subscriptions for this agent so
// processPreflightProgress goroutines unblock and detect disconnect.
var closeChs []chan DeployProgressPayload
prefix := agentID + "\x00"
for key, ch := range s.deploySubs {
if strings.HasPrefix(key, prefix) {
closeChs = append(closeChs, ch)
delete(s.deploySubs, key)
}
}
s.mu.Unlock()
for _, ch := range closeChs {
close(ch)
}
if err := ac.conn.Close(); err != nil {
log.Debug().Err(err).Str("agent_id", agentID).Msg("Failed to close connection during read-loop cleanup")
}
log.Info().Str("agent_id", agentID).Msg("Agent disconnected")
}()
log.Debug().Str("agent_id", ac.agent.AgentID).Msg("Starting read loop for agent")
for {
select {
case <-ac.done:
log.Debug().Str("agent_id", ac.agent.AgentID).Msg("Read loop exiting: done channel closed")
return
default:
}
_, msgBytes, err := ac.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Error().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Unexpected WebSocket close error")
} else {
log.Debug().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Read loop exiting: read error")
}
return
}
var msg Message
if err := json.Unmarshal(msgBytes, &msg); err != nil {
log.Error().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Failed to parse message")
continue
}
switch msg.Type {
case MsgTypeAgentPing:
pongMsg, err := NewMessage(MsgTypePong, "", nil)
if err != nil {
log.Debug().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Failed to encode pong message")
continue
}
ac.writeMu.Lock()
if err := s.sendMessage(ac.conn, pongMsg); err != nil {
log.Debug().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Failed to send pong")
}
ac.writeMu.Unlock()
case MsgTypeCommandResult:
var result CommandResultPayload
if err := msg.DecodePayload(&result); err != nil {
log.Error().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Failed to parse command result")
continue
}
result.RequestID = strings.TrimSpace(result.RequestID)
if result.RequestID == "" {
log.Warn().Str("agent_id", ac.agent.AgentID).Msg("Dropping command result with empty request_id")
continue
}
if len(result.RequestID) > maxRequestIDLength {
log.Warn().
Str("agent_id", ac.agent.AgentID).
Int("request_id_length", len(result.RequestID)).
Msg("Dropping command result with oversized request_id")
continue
}
s.mu.RLock()
ch, ok := s.pendingReqs[pendingRequestKey(ac.agent.AgentID, result.RequestID)]
s.mu.RUnlock()
if ok {
select {
case ch <- result:
log.Debug().
Str("agent_id", ac.agent.AgentID).
Str("request_id", result.RequestID).
Bool("success", result.Success).
Int("exit_code", result.ExitCode).
Int64("duration_ms", result.Duration).
Msg("Received command result from agent")
default:
log.Warn().
Str("agent_id", ac.agent.AgentID).
Str("request_id", result.RequestID).
Msg("Result channel full, dropping")
}
} else {
log.Warn().
Str("agent_id", ac.agent.AgentID).
Str("request_id", result.RequestID).
Msg("No pending request for result")
}
case MsgTypeDeployProgress:
var progress DeployProgressPayload
if err := msg.DecodePayload(&progress); err != nil {
log.Error().Err(err).Str("agent_id", ac.agent.AgentID).Msg("Failed to parse deploy progress")
continue
}
if progress.JobID == "" {
log.Warn().Str("agent_id", ac.agent.AgentID).Msg("Dropping deploy progress with empty job_id")
continue
}
subKey := deploySubKey(ac.agent.AgentID, progress.JobID)
// Hold the read lock across map lookup AND the non-blocking send to
// prevent UnsubscribeDeployProgress from closing the channel between
// lookup and send (it needs the write lock to delete + close).
sent := false
s.mu.RLock()
ch, ok := s.deploySubs[subKey]
if ok {
select {
case ch <- progress:
sent = true
default:
}
}
s.mu.RUnlock()
// Final messages must be delivered — retry with backoff if the
// initial non-blocking send failed (channel was full).
if ok && !sent && progress.Final {
deadline := time.After(5 * time.Second)
ticker := time.NewTicker(50 * time.Millisecond)
retryLoop:
for {
select {
case <-deadline:
log.Error().
Str("agent_id", ac.agent.AgentID).
Str("job_id", progress.JobID).
Msg("Deploy final progress send timed out — force-closing subscription")
// Force-close the subscription so the consumer goroutine
// unblocks on channel close and can finalize the job.
s.mu.Lock()
if closeCh, exists := s.deploySubs[subKey]; exists {
delete(s.deploySubs, subKey)
close(closeCh)
}
s.mu.Unlock()
break retryLoop
case <-ticker.C:
s.mu.RLock()
ch, ok = s.deploySubs[subKey]
if !ok {
s.mu.RUnlock()
break retryLoop // channel was closed/unsubscribed
}
select {
case ch <- progress:
sent = true
s.mu.RUnlock()
break retryLoop
default:
s.mu.RUnlock()
}
}
}
ticker.Stop()
} else if ok && !sent {
log.Warn().
Str("agent_id", ac.agent.AgentID).
Str("job_id", progress.JobID).
Msg("Deploy progress channel full, dropping")
}
if ok {
if sent {
log.Debug().
Str("agent_id", ac.agent.AgentID).
Str("job_id", progress.JobID).
Str("target_id", progress.TargetID).
Str("phase", string(progress.Phase)).
Str("status", string(progress.Status)).
Bool("final", progress.Final).
Msg("Received deploy progress from agent")
}
} else {
log.Debug().
Str("agent_id", ac.agent.AgentID).
Str("job_id", progress.JobID).
Msg("No subscriber for deploy progress")
}
}
}
}
func (s *Server) pingLoop(ac *agentConn, done chan struct{}) {
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
// Track consecutive ping failures to detect dead connections faster
consecutiveFailures := 0
const maxConsecutiveFailures = 3
for {
select {
case <-done:
return
case <-ac.done:
return
case <-ticker.C:
ac.writeMu.Lock()
err := ac.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(pingWriteWait))
ac.writeMu.Unlock()
if err != nil {
consecutiveFailures++
log.Warn().
Err(err).
Str("agent_id", ac.agent.AgentID).
Str("hostname", ac.agent.Hostname).
Int("consecutive_failures", consecutiveFailures).
Msg("Failed to send ping to agent")
if consecutiveFailures >= maxConsecutiveFailures {
log.Error().
Err(err).
Str("agent_id", ac.agent.AgentID).
Str("hostname", ac.agent.Hostname).
Int("failures", consecutiveFailures).
Msg("Agent connection appears dead after multiple ping failures, closing connection")
// Close the connection - this will cause readLoop to exit and clean up
if closeErr := ac.conn.Close(); closeErr != nil {
log.Debug().Err(closeErr).Str("agent_id", ac.agent.AgentID).Msg("Failed to close dead connection after ping failures")
}
return
}
} else {
// Reset failure counter on successful ping
consecutiveFailures = 0
}
}
}
}
func (s *Server) sendMessage(conn *websocket.Conn, msg Message) error {
msgBytes, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal websocket message: %w", err)
}
if err := writeTextMessage(conn, msgBytes); err != nil {
return fmt.Errorf("write websocket message: %w", err)
}
return nil
}
// Shutdown gracefully stops the server by closing all active agent connections.
// The method is idempotent.
func (s *Server) Shutdown() {
s.shutdownOnce.Do(func() {
close(s.shutdown)
s.mu.Lock()
agents := make([]*agentConn, 0, len(s.agents))
for _, ac := range s.agents {
agents = append(agents, ac)
}
s.agents = make(map[string]*agentConn)
s.mu.Unlock()
for _, ac := range agents {
ac.signalDone()
_ = ac.conn.Close()
}
})
}
// ExecuteCommand sends a command to an agent and waits for the result
func (s *Server) ExecuteCommand(ctx context.Context, agentID string, cmd ExecuteCommandPayload) (*CommandResultPayload, error) {
agentID = strings.TrimSpace(agentID)
if agentID == "" {
return nil, fmt.Errorf("agent id is required")
}
cmd.RequestID = strings.TrimSpace(cmd.RequestID)
if cmd.RequestID == "" {
cmd.RequestID = uuid.New().String()
}
if len(cmd.RequestID) > maxRequestIDLength {
return nil, fmt.Errorf("request id exceeds %d characters", maxRequestIDLength)
}
if err := validateExecuteCommandPayload(&cmd); err != nil {
return nil, err
}
startedAt := time.Now()
s.mu.RLock()
ac, ok := s.agents[agentID]
s.mu.RUnlock()
if !ok {
log.Warn().
Str("agent_id", agentID).
Str("request_id", cmd.RequestID).
Msg("Execute command requested for disconnected agent")
return nil, fmt.Errorf("agent %s not connected", agentID)
}
execLog := log.With().
Str("agent_id", agentID).
Str("request_id", cmd.RequestID).
Str("target_type", cmd.TargetType).
Str("target_id", cmd.TargetID).
Logger()
// Create response channel
respCh := make(chan CommandResultPayload, 1)
reqKey := pendingRequestKey(agentID, cmd.RequestID)
s.mu.Lock()
s.pendingReqs[reqKey] = respCh
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.pendingReqs, reqKey)
s.mu.Unlock()
}()
// Send command
execMsg, execErr := NewMessage(MsgTypeExecuteCmd, cmd.RequestID, cmd)
if execErr != nil {
return nil, fmt.Errorf("failed to encode command: %w", execErr)
}
ac.writeMu.Lock()
err := s.sendMessage(ac.conn, execMsg)
ac.writeMu.Unlock()
if err != nil {
execLog.Error().
Err(err).
Dur("duration", time.Since(startedAt)).
Msg("Failed to send command to agent")
return nil, fmt.Errorf("failed to send command: %w", err)
}
// Wait for result
timeout := time.Duration(cmd.Timeout) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
}
timer := time.NewTimer(timeout)
defer func() {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}()
select {
case result := <-respCh:
execLog.Info().
Bool("success", result.Success).
Int("exit_code", result.ExitCode).
Int64("agent_duration_ms", result.Duration).
Dur("duration", time.Since(startedAt)).
Msg("Agent command completed")
return &result, nil
case <-time.After(timeout):
execLog.Warn().
Dur("timeout", timeout).
Dur("duration", time.Since(startedAt)).
Msg("Agent command timed out")
return nil, fmt.Errorf("command timed out after %v", timeout)
case <-ctx.Done():
execLog.Warn().
Err(ctx.Err()).
Dur("duration", time.Since(startedAt)).
Msg("Agent command canceled")
return nil, ctx.Err()
case <-s.shutdown:
return nil, errServerShuttingDown
}
}
// ReadFile reads a file from an agent
func (s *Server) ReadFile(ctx context.Context, agentID string, req ReadFilePayload) (*CommandResultPayload, error) {
agentID = strings.TrimSpace(agentID)
if agentID == "" {
return nil, fmt.Errorf("agent id is required")
}
req.RequestID = strings.TrimSpace(req.RequestID)
if req.RequestID == "" {
req.RequestID = uuid.New().String()
}
if err := validateReadFilePayload(&req); err != nil {
return nil, err
}
s.mu.RLock()
ac, ok := s.agents[agentID]
s.mu.RUnlock()
if !ok {
log.Warn().
Str("agent_id", agentID).
Str("request_id", req.RequestID).
Msg("Read file requested for disconnected agent")
return nil, fmt.Errorf("agent %s not connected", agentID)
}
readLog := log.With().
Str("agent_id", agentID).
Str("request_id", req.RequestID).
Str("path", req.Path).
Str("target_type", req.TargetType).
Str("target_id", req.TargetID).
Int64("max_bytes", req.MaxBytes).
Logger()
startedAt := time.Now()
// Create response channel
respCh := make(chan CommandResultPayload, 1)
reqKey := pendingRequestKey(agentID, req.RequestID)
s.mu.Lock()
s.pendingReqs[reqKey] = respCh
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.pendingReqs, reqKey)
s.mu.Unlock()
}()
// Send request
readPayloadBytes, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to encode read_file request: %w", err)
}
msg := Message{
Type: MsgTypeReadFile,
ID: req.RequestID,
Timestamp: time.Now(),
Payload: readPayloadBytes,
}
ac.writeMu.Lock()
sendErr := s.sendMessage(ac.conn, msg)
ac.writeMu.Unlock()
if sendErr != nil {
readLog.Error().
Err(sendErr).
Dur("duration", time.Since(startedAt)).
Msg("Failed to send read_file request to agent")
return nil, fmt.Errorf("failed to send read_file request: %w", sendErr)
}
// Wait for result
timeout := readFileTimeout
timer := time.NewTimer(timeout)
defer func() {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}()
select {
case result := <-respCh:
readLog.Info().
Bool("success", result.Success).
Int("exit_code", result.ExitCode).
Int64("agent_duration_ms", result.Duration).
Dur("duration", time.Since(startedAt)).
Msg("Agent read_file completed")
return &result, nil
case <-timer.C:
return nil, fmt.Errorf("read_file timed out after %v", timeout)
case <-ctx.Done():
return nil, fmt.Errorf("read_file %q on agent %q canceled: %w", req.RequestID, agentID, ctx.Err())
case <-s.shutdown:
return nil, errServerShuttingDown
}
}
// GetConnectedAgents returns a list of currently connected agents
func (s *Server) GetConnectedAgents() []ConnectedAgent {
s.mu.RLock()
defer s.mu.RUnlock()
agents := make([]ConnectedAgent, 0, len(s.agents))
for _, ac := range s.agents {
agents = append(agents, ac.agent)
}
return agents
}
// IsAgentConnected checks if an agent is currently connected
func (s *Server) IsAgentConnected(agentID string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.agents[agentID]
return ok
}
// GetAgentForHost finds the agent for a given hostname
func (s *Server) GetAgentForHost(hostname string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, ac := range s.agents {
if ac.agent.Hostname == hostname {
return ac.agent.AgentID, true
}
}
return "", false
}
// --- Deploy protocol ---
// SubscribeDeployProgress registers a channel to receive deploy progress
// events for the given agent and job ID. Returns a buffered channel. The caller
// must call UnsubscribeDeployProgress when done.
func (s *Server) SubscribeDeployProgress(agentID, jobID string, bufSize int) chan DeployProgressPayload {
if bufSize <= 0 {
bufSize = 64
}
ch := make(chan DeployProgressPayload, bufSize)
s.mu.Lock()
s.deploySubs[deploySubKey(agentID, jobID)] = ch
s.mu.Unlock()
return ch
}
// UnsubscribeDeployProgress removes and closes the progress subscriber for an agent's job.
// Safe to call multiple times — a no-op if already unsubscribed (e.g. by readLoop cleanup).
func (s *Server) UnsubscribeDeployProgress(agentID, jobID string) {
key := deploySubKey(agentID, jobID)
s.mu.Lock()
ch, exists := s.deploySubs[key]
delete(s.deploySubs, key)
s.mu.Unlock()
if exists {
close(ch)
}
}
// SendDeployPreflight sends a preflight check command to the source agent.
// The caller should subscribe to deploy progress for the job ID before calling
// this method. Results stream back as DeployProgressPayload messages.
func (s *Server) SendDeployPreflight(ctx context.Context, agentID string, payload DeployPreflightPayload) error {
payload.RequestID = strings.TrimSpace(payload.RequestID)
return s.sendDeployCommand(ctx, agentID, MsgTypeDeployPreflight, payload.RequestID, payload)
}
// SendDeployInstall sends an install command to the source agent.
// The caller should subscribe to deploy progress for the job ID before calling
// this method. Results stream back as DeployProgressPayload messages.
func (s *Server) SendDeployInstall(ctx context.Context, agentID string, payload DeployInstallPayload) error {
payload.RequestID = strings.TrimSpace(payload.RequestID)
return s.sendDeployCommand(ctx, agentID, MsgTypeDeployInstall, payload.RequestID, payload)
}
// SendDeployCancel sends a cancel command to the source agent.
func (s *Server) SendDeployCancel(ctx context.Context, agentID string, payload DeployCancelPayload) error {
payload.RequestID = strings.TrimSpace(payload.RequestID)
return s.sendDeployCommand(ctx, agentID, MsgTypeDeployCancelJob, payload.RequestID, payload)
}
func (s *Server) sendDeployCommand(ctx context.Context, agentID string, msgType MessageType, requestID string, payload any) error {
agentID = strings.TrimSpace(agentID)
if agentID == "" {
return fmt.Errorf("agent id is required")
}
s.mu.RLock()
ac, ok := s.agents[agentID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("agent %s not connected", agentID)
}
requestID = strings.TrimSpace(requestID)
if requestID == "" {
return fmt.Errorf("request id is required for deploy commands")
}
if len(requestID) > maxRequestIDLength {
return fmt.Errorf("request id exceeds %d characters", maxRequestIDLength)
}
msg, err := NewMessage(msgType, requestID, payload)
if err != nil {
return fmt.Errorf("failed to encode deploy command: %w", err)
}
ac.writeMu.Lock()
err = s.sendMessage(ac.conn, msg)
ac.writeMu.Unlock()
if err != nil {
return fmt.Errorf("failed to send deploy command: %w", err)
}
return nil
}