From 30f01771accf8d01e9a484c8b2e4aeb0d5580b9e Mon Sep 17 00:00:00 2001 From: rcourtman Date: Wed, 17 Dec 2025 17:02:01 +0000 Subject: [PATCH] Add meaningful tests for host agent and exec websocket --- cmd/pulse-sensor-proxy/cleanup_test.go | 121 ++++++ internal/agentexec/server_websocket_test.go | 364 ++++++++++++++++++ internal/hostagent/agent.go | 27 +- internal/hostagent/agent_flushbuffer_test.go | 90 +++++ internal/hostagent/agent_new_test.go | 167 ++++++++ internal/hostagent/agent_sensors_test.go | 96 +++++ internal/hostagent/commands.go | 30 +- internal/hostagent/commands_connect_test.go | 132 +++++++ internal/hostagent/commands_execute_test.go | 151 ++++++++ .../hostagent/commands_registration_test.go | 180 +++++++++ 10 files changed, 1335 insertions(+), 23 deletions(-) create mode 100644 cmd/pulse-sensor-proxy/cleanup_test.go create mode 100644 internal/agentexec/server_websocket_test.go create mode 100644 internal/hostagent/agent_flushbuffer_test.go create mode 100644 internal/hostagent/agent_new_test.go create mode 100644 internal/hostagent/agent_sensors_test.go create mode 100644 internal/hostagent/commands_connect_test.go create mode 100644 internal/hostagent/commands_execute_test.go create mode 100644 internal/hostagent/commands_registration_test.go diff --git a/cmd/pulse-sensor-proxy/cleanup_test.go b/cmd/pulse-sensor-proxy/cleanup_test.go new file mode 100644 index 000000000..78b29fade --- /dev/null +++ b/cmd/pulse-sensor-proxy/cleanup_test.go @@ -0,0 +1,121 @@ +package main + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/rs/zerolog" +) + +func TestProxy_cleanupRequestPath_UsesConfiguredWorkDir(t *testing.T) { + p := &Proxy{workDir: "/tmp/pulse-sensor-proxy-test"} + got, err := p.cleanupRequestPath() + if err != nil { + t.Fatalf("cleanupRequestPath: %v", err) + } + want := filepath.Join(p.workDir, cleanupRequestFilename) + if got != want { + t.Fatalf("path = %q, want %q", got, want) + } +} + +func TestProxy_handleRequestCleanup_WritesValidPayloadAndReplacesExisting(t *testing.T) { + workDir := t.TempDir() + p := &Proxy{workDir: workDir} + logger := zerolog.Nop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resp, err := p.handleRequestCleanup(ctx, &RPCRequest{ + Method: RPCRequestCleanup, + Params: map[string]interface{}{ + "host": "pve-1", + "reason": "testing", + }, + }, logger) + if err != nil { + t.Fatalf("handleRequestCleanup: %v", err) + } + respMap, ok := resp.(map[string]any) + if !ok { + t.Fatalf("expected map response, got %#v", resp) + } + queued, ok := respMap["queued"].(bool) + if !ok || !queued { + t.Fatalf("expected queued=true response, got %#v", resp) + } + + path, err := p.cleanupRequestPath() + if err != nil { + t.Fatalf("cleanupRequestPath: %v", err) + } + + readPayload := func() map[string]any { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile(%s): %v", path, err) + } + var payload map[string]any + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + return payload + } + + payload := readPayload() + if payload["host"] != "pve-1" { + t.Fatalf("payload host = %#v, want %q", payload["host"], "pve-1") + } + if payload["reason"] != "testing" { + t.Fatalf("payload reason = %#v, want %q", payload["reason"], "testing") + } + if _, ok := payload["requestedAt"].(string); !ok { + t.Fatalf("payload requestedAt missing or not string: %#v", payload["requestedAt"]) + } + if ts, ok := payload["requestedAt"].(string); ok { + if _, err := time.Parse(time.RFC3339, ts); err != nil { + t.Fatalf("requestedAt not RFC3339: %q: %v", ts, err) + } + } + + if runtime.GOOS != "windows" { + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat(%s): %v", path, err) + } + if fi.Mode().Perm() != 0o600 { + t.Fatalf("payload file mode = %v, want %v", fi.Mode().Perm(), os.FileMode(0o600)) + } + } + + resp, err = p.handleRequestCleanup(ctx, &RPCRequest{ + Method: RPCRequestCleanup, + Params: map[string]interface{}{ + "host": "pve-1", + "reason": "testing-2", + }, + }, logger) + if err != nil { + t.Fatalf("handleRequestCleanup (2): %v", err) + } + respMap, ok = resp.(map[string]any) + if !ok { + t.Fatalf("expected map response (2), got %#v", resp) + } + queued, ok = respMap["queued"].(bool) + if !ok || !queued { + t.Fatalf("expected queued=true response (2), got %#v", resp) + } + + payload2 := readPayload() + if payload2["reason"] != "testing-2" { + t.Fatalf("payload reason after replace = %#v, want %q", payload2["reason"], "testing-2") + } +} diff --git a/internal/agentexec/server_websocket_test.go b/internal/agentexec/server_websocket_test.go new file mode 100644 index 000000000..fb0f5a671 --- /dev/null +++ b/internal/agentexec/server_websocket_test.go @@ -0,0 +1,364 @@ +package agentexec + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type wsRawMessage struct { + Type MessageType `json:"type"` + ID string `json:"id,omitempty"` + Timestamp time.Time `json:"timestamp"` + Payload *json.RawMessage `json:"payload,omitempty"` +} + +func newWSServer(t *testing.T, s *Server) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.HandleWebSocket(w, r) + })) +} + +func wsURLForHTTP(serverURL string) string { + return "ws" + strings.TrimPrefix(serverURL, "http") +} + +func wsWriteMessage(t *testing.T, conn *websocket.Conn, msg Message) { + t.Helper() + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + if err := conn.WriteJSON(msg); err != nil { + t.Fatalf("WriteJSON: %v", err) + } +} + +func wsReadRawMessage(t *testing.T, conn *websocket.Conn) wsRawMessage { + t.Helper() + msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second) + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + return msg +} + +func wsReadRegisteredPayload(t *testing.T, conn *websocket.Conn) RegisteredPayload { + t.Helper() + msg := wsReadRawMessage(t, conn) + if msg.Type != MsgTypeRegistered { + t.Fatalf("message type = %q, want %q", msg.Type, MsgTypeRegistered) + } + if msg.Payload == nil { + t.Fatalf("registered payload missing") + } + var payload RegisteredPayload + if err := json.Unmarshal(*msg.Payload, &payload); err != nil { + t.Fatalf("unmarshal registered payload: %v", err) + } + return payload +} + +func wsReadRawMessageWithTimeout(conn *websocket.Conn, timeout time.Duration) (wsRawMessage, error) { + conn.SetReadDeadline(time.Now().Add(timeout)) + _, data, err := conn.ReadMessage() + if err != nil { + return wsRawMessage{}, err + } + var msg wsRawMessage + if err := json.Unmarshal(data, &msg); err != nil { + return wsRawMessage{}, err + } + return msg, nil +} + +func waitFor(t *testing.T, timeout time.Duration, cond func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("condition not met within %v", timeout) +} + +func TestHandleWebSocket_RegistrationSuccessAndDisconnectRemovesAgent(t *testing.T) { + s := NewServer(func(token string) bool { return token == "ok" }) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Version: "1.2.3", + Platform: "linux", + Tags: []string{"tag1"}, + Token: "ok", + }, + }) + + reg := wsReadRegisteredPayload(t, conn) + if !reg.Success { + t.Fatalf("registration failed: %q", reg.Message) + } + + if !s.IsAgentConnected("a1") { + t.Fatalf("expected agent to be connected") + } + + conn.Close() + + waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") }) +} + +func TestHandleWebSocket_InvalidTokenRejected(t *testing.T) { + s := NewServer(func(string) bool { return false }) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Version: "1.2.3", + Platform: "linux", + Token: "bad", + }, + }) + + reg := wsReadRegisteredPayload(t, conn) + if reg.Success { + t.Fatalf("expected registration to be rejected") + } + + waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") }) + + conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + _, _, err = conn.ReadMessage() + if err == nil { + t.Fatalf("expected connection to be closed by server") + } +} + +func TestHandleWebSocket_FirstMessageMustBeRegister(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentPing, + Timestamp: time.Now(), + }) + + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, _, err = conn.ReadMessage() + if err == nil { + t.Fatalf("expected server to close connection") + } +} + +func TestHandleWebSocket_AgentPingRespondsWithPong(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Version: "1.2.3", + Platform: "linux", + Token: "any", + }, + }) + _ = wsReadRegisteredPayload(t, conn) + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentPing, + Timestamp: time.Now(), + }) + + msg := wsReadRawMessage(t, conn) + if msg.Type != MsgTypePong { + t.Fatalf("message type = %q, want %q", msg.Type, MsgTypePong) + } +} + +func TestExecuteCommand_RoundTripViaWebSocket(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + wsWriteMessage(t, conn, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Version: "1.2.3", + Platform: "linux", + Token: "any", + }, + }) + _ = wsReadRegisteredPayload(t, conn) + + agentDone := make(chan struct{}) + agentErr := make(chan error, 1) + go func() { + defer close(agentDone) + for { + msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second) + if err != nil { + agentErr <- err + return + } + if msg.Type != MsgTypeExecuteCmd { + continue + } + if msg.Payload == nil { + agentErr <- nil + return + } + var payload ExecuteCommandPayload + if err := json.Unmarshal(*msg.Payload, &payload); err != nil { + agentErr <- err + return + } + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + if err := conn.WriteJSON(Message{ + Type: MsgTypeCommandResult, + Timestamp: time.Now(), + Payload: CommandResultPayload{ + RequestID: payload.RequestID, + Success: true, + Stdout: "ok", + ExitCode: 0, + Duration: 1, + }, + }); err != nil { + agentErr <- err + return + } + agentErr <- nil + return + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, err := s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{ + RequestID: "req1", + Command: "echo ok", + Timeout: 1, + }) + if err != nil { + t.Fatalf("ExecuteCommand: %v", err) + } + if result == nil || !result.Success || result.Stdout != "ok" || result.ExitCode != 0 { + t.Fatalf("unexpected result: %#v", result) + } + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatalf("agent goroutine did not finish") + } + + if err := <-agentErr; err != nil { + t.Fatalf("agent error: %v", err) + } +} + +func TestHandleWebSocket_ReconnectSameAgentIDClosesOldConnection(t *testing.T) { + s := NewServer(nil) + ts := newWSServer(t, s) + defer ts.Close() + + dial := func() *websocket.Conn { + t.Helper() + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + return conn + } + + c1 := dial() + defer c1.Close() + wsWriteMessage(t, c1, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Version: "1.2.3", + Platform: "linux", + Token: "any", + }, + }) + _ = wsReadRegisteredPayload(t, c1) + + c2 := dial() + defer c2.Close() + wsWriteMessage(t, c2, Message{ + Type: MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: AgentRegisterPayload{ + AgentID: "a1", + Hostname: "host1", + Version: "1.2.3", + Platform: "linux", + Token: "any", + }, + }) + _ = wsReadRegisteredPayload(t, c2) + + c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, _, err := c1.ReadMessage() + if err == nil { + t.Fatalf("expected old connection to be closed") + } +} diff --git a/internal/hostagent/agent.go b/internal/hostagent/agent.go index 1880fadd4..ea6053d7d 100644 --- a/internal/hostagent/agent.go +++ b/internal/hostagent/agent.go @@ -72,6 +72,17 @@ const defaultInterval = 30 * time.Second var readFile = os.ReadFile +var ( + hostInfoWithContext = gohost.InfoWithContext + hostUptimeWithContext = gohost.UptimeWithContext + hostmetricsCollect = hostmetrics.Collect + sensorsCollectLocal = sensors.CollectLocal + sensorsParse = sensors.Parse + mdadmCollectArrays = mdadm.CollectArrays + cephCollect = ceph.Collect + nowUTC = func() time.Time { return time.Now().UTC() } +) + // New constructs a fully initialised host Agent. func New(cfg Config) (*Agent, error) { if cfg.Interval <= 0 { @@ -107,7 +118,7 @@ func New(cfg Config) (*Agent, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - info, err := gohost.InfoWithContext(ctx) + info, err := hostInfoWithContext(ctx) if err != nil { return nil, fmt.Errorf("fetch host info: %w", err) } @@ -311,8 +322,8 @@ func (a *Agent) buildReport(ctx context.Context) (agentshost.Report, error) { collectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - uptime, _ := gohost.UptimeWithContext(collectCtx) - snapshot, err := hostmetrics.Collect(collectCtx) + uptime, _ := hostUptimeWithContext(collectCtx) + snapshot, err := hostmetricsCollect(collectCtx) if err != nil { return agentshost.Report{}, fmt.Errorf("collect metrics: %w", err) } @@ -361,7 +372,7 @@ func (a *Agent) buildReport(ctx context.Context) (agentshost.Report, error) { RAID: raidData, Ceph: cephData, Tags: append([]string(nil), a.cfg.Tags...), - Timestamp: time.Now().UTC(), + Timestamp: nowUTC(), } return report, nil @@ -416,14 +427,14 @@ func (a *Agent) collectTemperatures(ctx context.Context) agentshost.Sensors { } // Collect sensor JSON output - jsonOutput, err := sensors.CollectLocal(ctx) + jsonOutput, err := sensorsCollectLocal(ctx) if err != nil { a.logger.Debug().Err(err).Msg("Failed to collect sensor data (lm-sensors may not be installed)") return agentshost.Sensors{} } // Parse the sensor output - tempData, err := sensors.Parse(jsonOutput) + tempData, err := sensorsParse(jsonOutput) if err != nil { a.logger.Debug().Err(err).Msg("Failed to parse sensor data") return agentshost.Sensors{} @@ -476,7 +487,7 @@ func (a *Agent) collectRAIDArrays(ctx context.Context) []agentshost.RAIDArray { return nil } - arrays, err := mdadm.CollectArrays(ctx) + arrays, err := mdadmCollectArrays(ctx) if err != nil { a.logger.Debug().Err(err).Msg("Failed to collect RAID array data (mdadm may not be installed)") return nil @@ -499,7 +510,7 @@ func (a *Agent) collectCephStatus(ctx context.Context) *agentshost.CephCluster { return nil } - status, err := ceph.Collect(ctx) + status, err := cephCollect(ctx) if err != nil { a.logger.Debug().Err(err).Msg("Failed to collect Ceph status") return nil diff --git a/internal/hostagent/agent_flushbuffer_test.go b/internal/hostagent/agent_flushbuffer_test.go new file mode 100644 index 000000000..d574d0536 --- /dev/null +++ b/internal/hostagent/agent_flushbuffer_test.go @@ -0,0 +1,90 @@ +package hostagent + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/buffer" + agentshost "github.com/rcourtman/pulse-go-rewrite/pkg/agents/host" + "github.com/rs/zerolog" +) + +func TestAgent_flushBuffer_StopsOnFailureAndDoesNotDropReport(t *testing.T) { + var requestCount int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&requestCount, 1) + if n == 1 { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + a := &Agent{ + cfg: Config{APIToken: "token"}, + logger: zerolog.Nop(), + httpClient: server.Client(), + trimmedPulseURL: server.URL, + reportBuffer: buffer.New[agentshost.Report](10), + } + + report1 := agentshost.Report{Agent: agentshost.AgentInfo{ID: "r1"}} + report2 := agentshost.Report{Agent: agentshost.AgentInfo{ID: "r2"}} + a.reportBuffer.Push(report1) + a.reportBuffer.Push(report2) + + a.flushBuffer(context.Background()) + + if got := a.reportBuffer.Len(); got != 1 { + t.Fatalf("buffer len = %d, want %d", got, 1) + } + + peek, ok := a.reportBuffer.Peek() + if !ok { + t.Fatalf("expected buffered report") + } + if peek.Agent.ID != "r2" { + t.Fatalf("remaining buffered report = %q, want %q", peek.Agent.ID, "r2") + } +} + +func TestAgent_flushBuffer_RetryAfterTransientFailure(t *testing.T) { + var fail atomic.Bool + fail.Store(true) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if fail.Load() { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + a := &Agent{ + cfg: Config{APIToken: "token"}, + logger: zerolog.Nop(), + httpClient: server.Client(), + trimmedPulseURL: server.URL, + reportBuffer: buffer.New[agentshost.Report](10), + } + a.reportBuffer.Push(agentshost.Report{Agent: agentshost.AgentInfo{ID: "r1"}}) + a.reportBuffer.Push(agentshost.Report{Agent: agentshost.AgentInfo{ID: "r2"}}) + + a.flushBuffer(context.Background()) + if got := a.reportBuffer.Len(); got != 2 { + t.Fatalf("buffer len after failure = %d, want %d", got, 2) + } + + fail.Store(false) + a.flushBuffer(context.Background()) + if !a.reportBuffer.IsEmpty() { + t.Fatalf("expected buffer to be empty, has %d items", a.reportBuffer.Len()) + } +} + diff --git a/internal/hostagent/agent_new_test.go b/internal/hostagent/agent_new_test.go new file mode 100644 index 000000000..09a08bdcc --- /dev/null +++ b/internal/hostagent/agent_new_test.go @@ -0,0 +1,167 @@ +package hostagent + +import ( + "context" + "crypto/tls" + "net/http" + "os" + "runtime" + "testing" + + "github.com/rs/zerolog" + gohost "github.com/shirou/gopsutil/v4/host" +) + +func TestNew_RequiresAPIToken(t *testing.T) { + originalHostInfo := hostInfoWithContext + t.Cleanup(func() { hostInfoWithContext = originalHostInfo }) + hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) { + return &gohost.InfoStat{Hostname: "host", HostID: "hid", KernelArch: runtime.GOARCH}, nil + } + + _, err := New(Config{APIToken: " ", LogLevel: zerolog.InfoLevel}) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestNew_NormalizesConfigAndTags(t *testing.T) { + originalHostInfo := hostInfoWithContext + originalReadFile := readFile + t.Cleanup(func() { + hostInfoWithContext = originalHostInfo + readFile = originalReadFile + }) + + hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) { + return &gohost.InfoStat{ + Hostname: " host-from-info ", + HostID: " gopsutil-id ", + Platform: "Darwin", + PlatformFamily: "", + PlatformVersion: "14.0", + KernelVersion: "6.6.0", + KernelArch: runtime.GOARCH, + }, nil + } + + readFile = func(name string) ([]byte, error) { + if name == "/etc/machine-id" { + return []byte("0123456789abcdef0123456789abcdef\n"), nil + } + return nil, os.ErrNotExist + } + + originalTags := []string{" tag-a ", "tag-a", "", " tag-b", "tag-b "} + cfg := Config{ + PulseURL: "http://example.com///", + APIToken: "token", + Interval: 0, + Tags: originalTags, + InsecureSkipVerify: true, + LogLevel: zerolog.InfoLevel, + } + + agent, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + + if agent.interval != defaultInterval { + t.Fatalf("interval = %v, want %v", agent.interval, defaultInterval) + } + if agent.trimmedPulseURL != "http://example.com" { + t.Fatalf("trimmedPulseURL = %q, want %q", agent.trimmedPulseURL, "http://example.com") + } + if agent.hostname != "host-from-info" { + t.Fatalf("hostname = %q, want %q", agent.hostname, "host-from-info") + } + if agent.displayName != "host-from-info" { + t.Fatalf("displayName = %q, want %q", agent.displayName, "host-from-info") + } + if agent.platform != "macos" { + t.Fatalf("platform = %q, want %q", agent.platform, "macos") + } + if got, want := agent.cfg.Tags, []string{"tag-a", "tag-b"}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] { + t.Fatalf("tags = %#v, want %#v", got, want) + } + + // Ensure we don't retain aliasing to the caller-provided tags slice. + originalTags[0] = "mutated" + if agent.cfg.Tags[0] != "tag-a" { + t.Fatalf("agent tags aliased caller slice: %#v", agent.cfg.Tags) + } + + httpTransport, ok := agent.httpClient.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", agent.httpClient.Transport) + } + if httpTransport.TLSClientConfig == nil || httpTransport.TLSClientConfig.MinVersion != tls.VersionTLS12 { + t.Fatalf("TLSClientConfig MinVersion = %#v, want TLS1.2", httpTransport.TLSClientConfig) + } + if !httpTransport.TLSClientConfig.InsecureSkipVerify { + t.Fatalf("expected InsecureSkipVerify=true") + } + + if runtime.GOOS == "linux" { + const want = "01234567-89ab-cdef-0123-456789abcdef" + if agent.machineID != want { + t.Fatalf("machineID = %q, want %q", agent.machineID, want) + } + if agent.agentID != want { + t.Fatalf("agentID = %q, want %q", agent.agentID, want) + } + } else { + if agent.machineID != "gopsutil-id" { + t.Fatalf("machineID = %q, want %q", agent.machineID, "gopsutil-id") + } + if agent.agentID != "gopsutil-id" { + t.Fatalf("agentID = %q, want %q", agent.agentID, "gopsutil-id") + } + } +} + +func TestNew_DefaultPulseURL(t *testing.T) { + originalHostInfo := hostInfoWithContext + t.Cleanup(func() { hostInfoWithContext = originalHostInfo }) + hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) { + return &gohost.InfoStat{Hostname: "host", HostID: "hid", KernelArch: runtime.GOARCH}, nil + } + + agent, err := New(Config{PulseURL: "", APIToken: "token", LogLevel: zerolog.InfoLevel}) + if err != nil { + t.Fatalf("New: %v", err) + } + if agent.trimmedPulseURL != "http://localhost:7655" { + t.Fatalf("trimmedPulseURL = %q, want %q", agent.trimmedPulseURL, "http://localhost:7655") + } +} + +func TestNew_FallsBackToHostnameWhenMachineIDEmpty(t *testing.T) { + originalHostInfo := hostInfoWithContext + originalReadFile := readFile + t.Cleanup(func() { + hostInfoWithContext = originalHostInfo + readFile = originalReadFile + }) + + hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) { + return &gohost.InfoStat{ + Hostname: "host-from-info", + HostID: "", + KernelArch: runtime.GOARCH, + }, nil + } + readFile = func(string) ([]byte, error) { return nil, os.ErrNotExist } + + agent, err := New(Config{APIToken: "token", LogLevel: zerolog.InfoLevel}) + if err != nil { + t.Fatalf("New: %v", err) + } + if agent.machineID != "" { + t.Fatalf("expected empty machineID, got %q", agent.machineID) + } + if agent.agentID != "host-from-info" { + t.Fatalf("agentID = %q, want %q", agent.agentID, "host-from-info") + } +} diff --git a/internal/hostagent/agent_sensors_test.go b/internal/hostagent/agent_sensors_test.go new file mode 100644 index 000000000..dd3f0ff8b --- /dev/null +++ b/internal/hostagent/agent_sensors_test.go @@ -0,0 +1,96 @@ +package hostagent + +import ( + "context" + "errors" + "runtime" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/sensors" + "github.com/rs/zerolog" +) + +func TestAgent_collectTemperatures_MapsKeys(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("temperature collection currently only runs on linux") + } + + originalCollect := sensorsCollectLocal + originalParse := sensorsParse + t.Cleanup(func() { + sensorsCollectLocal = originalCollect + sensorsParse = originalParse + }) + + sensorsCollectLocal = func(context.Context) (string, error) { return "{}", nil } + sensorsParse = func(string) (*sensors.TemperatureData, error) { + return &sensors.TemperatureData{ + Available: true, + CPUPackage: 55.5, + Cores: map[string]float64{ + "Core 0": 44, + "Core 1": 45, + }, + NVMe: map[string]float64{ + "nvme0": 40, + }, + GPU: map[string]float64{ + "amdgpu-pci-0100": 60, + }, + }, nil + } + + a := &Agent{logger: zerolog.Nop()} + + got := a.collectTemperatures(context.Background()) + want := map[string]float64{ + "cpu_package": 55.5, + "cpu_core_0": 44, + "cpu_core_1": 45, + "nvme0": 40, + "amdgpu-pci-0100": 60, + } + + if got.TemperatureCelsius == nil { + t.Fatalf("expected TemperatureCelsius map to be initialised") + } + if len(got.TemperatureCelsius) != len(want) { + t.Fatalf("temperature keys = %d, want %d", len(got.TemperatureCelsius), len(want)) + } + for k, v := range want { + if gotVal, ok := got.TemperatureCelsius[k]; !ok || gotVal != v { + t.Fatalf("TemperatureCelsius[%q] = (%v, %v), want (%v, %v)", k, gotVal, ok, v, true) + } + } +} + +func TestAgent_collectTemperatures_BestEffortFailuresReturnEmpty(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("temperature collection currently only runs on linux") + } + + originalCollect := sensorsCollectLocal + originalParse := sensorsParse + t.Cleanup(func() { + sensorsCollectLocal = originalCollect + sensorsParse = originalParse + }) + + a := &Agent{logger: zerolog.Nop()} + + sensorsCollectLocal = func(context.Context) (string, error) { return "", errors.New("no sensors") } + if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 { + t.Fatalf("expected empty sensors on collect error, got %#v", got.TemperatureCelsius) + } + + sensorsCollectLocal = func(context.Context) (string, error) { return "{}", nil } + sensorsParse = func(string) (*sensors.TemperatureData, error) { return nil, errors.New("bad json") } + if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 { + t.Fatalf("expected empty sensors on parse error, got %#v", got.TemperatureCelsius) + } + + sensorsParse = func(string) (*sensors.TemperatureData, error) { return &sensors.TemperatureData{Available: false}, nil } + if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 { + t.Fatalf("expected empty sensors when unavailable, got %#v", got.TemperatureCelsius) + } +} diff --git a/internal/hostagent/commands.go b/internal/hostagent/commands.go index 53e82d6f3..39fb0bf15 100644 --- a/internal/hostagent/commands.go +++ b/internal/hostagent/commands.go @@ -361,6 +361,16 @@ func (c *CommandClient) handleExecuteCommand(ctx context.Context, conn *websocke } } +func wrapCommand(payload executeCommandPayload) string { + if payload.TargetType == "container" && payload.TargetID != "" { + return fmt.Sprintf("pct exec %s -- %s", payload.TargetID, payload.Command) + } + if payload.TargetType == "vm" && payload.TargetID != "" { + return fmt.Sprintf("qm guest exec %s -- %s", payload.TargetID, payload.Command) + } + return payload.Command +} + func (c *CommandClient) executeCommand(ctx context.Context, payload executeCommandPayload) commandResultPayload { result := commandResultPayload{ RequestID: payload.RequestID, @@ -375,17 +385,7 @@ func (c *CommandClient) executeCommand(ctx context.Context, payload executeComma cmdCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - // Build the command based on target type - command := payload.Command - - // If targeting a container or VM, wrap the command - if payload.TargetType == "container" && payload.TargetID != "" { - // Use pct exec for LXC containers - command = fmt.Sprintf("pct exec %s -- %s", payload.TargetID, payload.Command) - } else if payload.TargetType == "vm" && payload.TargetID != "" { - // Use qm guest exec for VMs (requires QEMU guest agent) - command = fmt.Sprintf("qm guest exec %s -- %s", payload.TargetID, payload.Command) - } + command := wrapCommand(payload) // Execute the command var cmd *exec.Cmd @@ -405,13 +405,13 @@ func (c *CommandClient) executeCommand(ctx context.Context, payload executeComma result.Stderr = stderr.String() if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - result.ExitCode = exitErr.ExitCode() - result.Success = false - } else if cmdCtx.Err() == context.DeadlineExceeded { + if cmdCtx.Err() == context.DeadlineExceeded { result.Error = "command timed out" result.ExitCode = -1 result.Success = false + } else if exitErr, ok := err.(*exec.ExitError); ok { + result.ExitCode = exitErr.ExitCode() + result.Success = false } else { result.Error = err.Error() result.ExitCode = -1 diff --git a/internal/hostagent/commands_connect_test.go b/internal/hostagent/commands_connect_test.go new file mode 100644 index 000000000..88cb2f3f2 --- /dev/null +++ b/internal/hostagent/commands_connect_test.go @@ -0,0 +1,132 @@ +package hostagent + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" +) + +func TestCommandClient_connectAndHandle_ExecutesCommandAndReturnsResult(t *testing.T) { + if testing.Short() { + t.Skip("skipping websocket integration test in short mode") + } + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + serverDone := make(chan struct{}) + gotResult := make(chan commandResultPayload, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + defer conn.Close() + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + + var regMsg wsMessage + if err := conn.ReadJSON(®Msg); err != nil { + t.Errorf("read registration: %v", err) + return + } + if regMsg.Type != msgTypeAgentRegister { + t.Errorf("registration type = %q, want %q", regMsg.Type, msgTypeAgentRegister) + return + } + + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + registeredBytes, _ := json.Marshal(registeredPayload{Success: true, Message: "Registered"}) + if err := conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: registeredBytes}); err != nil { + t.Errorf("write registered: %v", err) + return + } + + execPayloadBytes, _ := json.Marshal(executeCommandPayload{ + RequestID: "req-1", + Command: "echo hello", + TargetType: "host", + Timeout: 5, + }) + if err := conn.WriteJSON(wsMessage{Type: msgTypeExecuteCmd, ID: "req-1", Timestamp: time.Now(), Payload: execPayloadBytes}); err != nil { + t.Errorf("write execute_command: %v", err) + return + } + + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + var resultMsg wsMessage + if err := conn.ReadJSON(&resultMsg); err != nil { + t.Errorf("read command_result: %v", err) + return + } + if resultMsg.Type != msgTypeCommandResult { + t.Errorf("result type = %q, want %q", resultMsg.Type, msgTypeCommandResult) + return + } + + var result commandResultPayload + if err := json.Unmarshal(resultMsg.Payload, &result); err != nil { + t.Errorf("unmarshal command_result payload: %v", err) + return + } + gotResult <- result + + <-serverDone + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := &CommandClient{ + pulseURL: strings.TrimRight(server.URL, "/"), + apiToken: "token", + agentID: "agent-1", + hostname: "host-1", + platform: "linux", + version: "1.2.3", + logger: zerolog.Nop(), + done: make(chan struct{}), + } + + errCh := make(chan error, 1) + go func() { + errCh <- client.connectAndHandle(ctx) + }() + + select { + case result := <-gotResult: + if result.RequestID != "req-1" { + t.Fatalf("result.RequestID = %q, want %q", result.RequestID, "req-1") + } + if !result.Success || result.ExitCode != 0 { + t.Fatalf("unexpected result: %#v", result) + } + if !strings.Contains(result.Stdout, "hello") { + t.Fatalf("stdout = %q, expected to contain %q", result.Stdout, "hello") + } + + cancel() + close(serverDone) + case <-time.After(10 * time.Second): + t.Fatalf("timed out waiting for command result") + } + + select { + case err := <-errCh: + if err == nil { + t.Fatalf("expected error due to context cancellation") + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for connectAndHandle to return") + } +} + diff --git a/internal/hostagent/commands_execute_test.go b/internal/hostagent/commands_execute_test.go new file mode 100644 index 000000000..48b5c7c33 --- /dev/null +++ b/internal/hostagent/commands_execute_test.go @@ -0,0 +1,151 @@ +package hostagent + +import ( + "context" + "runtime" + "strings" + "testing" + "time" +) + +func TestWrapCommand_TargetWrapping(t *testing.T) { + tests := []struct { + name string + payload executeCommandPayload + want string + }{ + { + name: "host command unchanged", + payload: executeCommandPayload{ + Command: "echo ok", + TargetType: "host", + }, + want: "echo ok", + }, + { + name: "container wraps with pct", + payload: executeCommandPayload{ + Command: "echo ok", + TargetType: "container", + TargetID: "101", + }, + want: "pct exec 101 -- echo ok", + }, + { + name: "vm wraps with qm guest exec", + payload: executeCommandPayload{ + Command: "echo ok", + TargetType: "vm", + TargetID: "900", + }, + want: "qm guest exec 900 -- echo ok", + }, + { + name: "missing target id does not wrap", + payload: executeCommandPayload{ + Command: "echo ok", + TargetType: "container", + TargetID: "", + }, + want: "echo ok", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapCommand(tt.payload); got != tt.want { + t.Fatalf("wrapCommand() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestCommandClient_executeCommand_Success(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("executeCommand uses different shell on windows") + } + + c := &CommandClient{} + result := c.executeCommand(context.Background(), executeCommandPayload{ + RequestID: "r1", + Command: "echo hello", + Timeout: 5, + }) + if !result.Success || result.ExitCode != 0 { + t.Fatalf("expected success, got %#v", result) + } + if !strings.Contains(result.Stdout, "hello") { + t.Fatalf("stdout = %q, expected to contain %q", result.Stdout, "hello") + } +} + +func TestCommandClient_executeCommand_NonZeroExitCode(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("executeCommand uses different shell on windows") + } + + c := &CommandClient{} + result := c.executeCommand(context.Background(), executeCommandPayload{ + RequestID: "r1", + Command: "echo err 1>&2; exit 3", + Timeout: 5, + }) + if result.Success { + t.Fatalf("expected failure, got %#v", result) + } + if result.ExitCode != 3 { + t.Fatalf("exit code = %d, want %d", result.ExitCode, 3) + } + if !strings.Contains(result.Stderr, "err") { + t.Fatalf("stderr = %q, expected to contain %q", result.Stderr, "err") + } +} + +func TestCommandClient_executeCommand_TimeoutSetsError(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("executeCommand uses different shell on windows") + } + + c := &CommandClient{} + start := time.Now() + result := c.executeCommand(context.Background(), executeCommandPayload{ + RequestID: "r1", + Command: "sleep 2", + Timeout: 1, + }) + if time.Since(start) > 3*time.Second { + t.Fatalf("timeout path took too long: %v", time.Since(start)) + } + if result.Success { + t.Fatalf("expected failure, got %#v", result) + } + if result.ExitCode != -1 { + t.Fatalf("exit code = %d, want %d", result.ExitCode, -1) + } + if result.Error != "command timed out" { + t.Fatalf("error = %q, want %q", result.Error, "command timed out") + } +} + +func TestCommandClient_executeCommand_TruncatesLargeOutput(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("executeCommand uses different shell on windows") + } + + c := &CommandClient{} + result := c.executeCommand(context.Background(), executeCommandPayload{ + RequestID: "r1", + Command: "head -c 1048580 /dev/zero | tr '\\0' 'a'", + Timeout: 5, + }) + if !result.Success { + t.Fatalf("expected success, got %#v", result) + } + if !strings.Contains(result.Stdout, "(output truncated)") { + t.Fatalf("expected truncation marker, got stdout len=%d", len(result.Stdout)) + } + if len(result.Stdout) > 1024*1024+64 { + t.Fatalf("stdout len=%d, expected <= %d", len(result.Stdout), 1024*1024+64) + } +} + diff --git a/internal/hostagent/commands_registration_test.go b/internal/hostagent/commands_registration_test.go new file mode 100644 index 000000000..d075342d6 --- /dev/null +++ b/internal/hostagent/commands_registration_test.go @@ -0,0 +1,180 @@ +package hostagent + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" +) + +func wsURLForHTTP(serverURL string) string { + return "ws" + strings.TrimPrefix(serverURL, "http") +} + +func TestCommandClient_sendRegistration_WritesExpectedPayload(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + gotMsgCh := make(chan wsMessage, 1) + gotPayloadCh := make(chan registerPayload, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + defer conn.Close() + + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + var msg wsMessage + if err := conn.ReadJSON(&msg); err != nil { + t.Errorf("ReadJSON: %v", err) + return + } + + var payload registerPayload + if err := json.Unmarshal(msg.Payload, &payload); err != nil { + t.Errorf("unmarshal payload: %v", err) + return + } + + gotMsgCh <- msg + gotPayloadCh <- payload + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + client := &CommandClient{ + apiToken: "token-1", + agentID: "agent-1", + hostname: "host-1", + platform: "linux", + version: "1.2.3", + logger: zerolog.Nop(), + } + + if err := client.sendRegistration(conn); err != nil { + t.Fatalf("sendRegistration: %v", err) + } + + msg := <-gotMsgCh + if msg.Type != msgTypeAgentRegister { + t.Fatalf("msg.Type = %q, want %q", msg.Type, msgTypeAgentRegister) + } + if msg.Timestamp.IsZero() { + t.Fatalf("expected non-zero Timestamp") + } + + payload := <-gotPayloadCh + if payload.AgentID != "agent-1" { + t.Fatalf("payload.AgentID = %q, want %q", payload.AgentID, "agent-1") + } + if payload.Hostname != "host-1" { + t.Fatalf("payload.Hostname = %q, want %q", payload.Hostname, "host-1") + } + if payload.Platform != "linux" { + t.Fatalf("payload.Platform = %q, want %q", payload.Platform, "linux") + } + if payload.Version != "1.2.3" { + t.Fatalf("payload.Version = %q, want %q", payload.Version, "1.2.3") + } + if payload.Token != "token-1" { + t.Fatalf("payload.Token = %q, want %q", payload.Token, "token-1") + } +} + +func TestCommandClient_waitForRegistration_AcceptsSuccess(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + defer conn.Close() + + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + payload, _ := json.Marshal(registeredPayload{Success: true, Message: "Registered"}) + _ = conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: payload}) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + client := &CommandClient{logger: zerolog.Nop()} + if err := client.waitForRegistration(conn); err != nil { + t.Fatalf("waitForRegistration: %v", err) + } +} + +func TestCommandClient_waitForRegistration_RejectsFailure(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + defer conn.Close() + + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + payload, _ := json.Marshal(registeredPayload{Success: false, Message: "Invalid token"}) + _ = conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: payload}) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + client := &CommandClient{logger: zerolog.Nop()} + if err := client.waitForRegistration(conn); err == nil { + t.Fatalf("expected error") + } +} + +func TestCommandClient_waitForRegistration_UnexpectedMessageType(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade: %v", err) + return + } + defer conn.Close() + + conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + _ = conn.WriteJSON(wsMessage{Type: msgTypePong, Timestamp: time.Now()}) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + client := &CommandClient{logger: zerolog.Nop()} + if err := client.waitForRegistration(conn); err == nil { + t.Fatalf("expected error") + } +}