Add meaningful tests for host agent and exec websocket

This commit is contained in:
rcourtman 2025-12-17 17:02:01 +00:00
parent ab480ca489
commit 30f01771ac
10 changed files with 1335 additions and 23 deletions

View file

@ -0,0 +1,121 @@
package main
import (
"context"
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/rs/zerolog"
)
func TestProxy_cleanupRequestPath_UsesConfiguredWorkDir(t *testing.T) {
p := &Proxy{workDir: "/tmp/pulse-sensor-proxy-test"}
got, err := p.cleanupRequestPath()
if err != nil {
t.Fatalf("cleanupRequestPath: %v", err)
}
want := filepath.Join(p.workDir, cleanupRequestFilename)
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
}
func TestProxy_handleRequestCleanup_WritesValidPayloadAndReplacesExisting(t *testing.T) {
workDir := t.TempDir()
p := &Proxy{workDir: workDir}
logger := zerolog.Nop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err := p.handleRequestCleanup(ctx, &RPCRequest{
Method: RPCRequestCleanup,
Params: map[string]interface{}{
"host": "pve-1",
"reason": "testing",
},
}, logger)
if err != nil {
t.Fatalf("handleRequestCleanup: %v", err)
}
respMap, ok := resp.(map[string]any)
if !ok {
t.Fatalf("expected map response, got %#v", resp)
}
queued, ok := respMap["queued"].(bool)
if !ok || !queued {
t.Fatalf("expected queued=true response, got %#v", resp)
}
path, err := p.cleanupRequestPath()
if err != nil {
t.Fatalf("cleanupRequestPath: %v", err)
}
readPayload := func() map[string]any {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile(%s): %v", path, err)
}
var payload map[string]any
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
return payload
}
payload := readPayload()
if payload["host"] != "pve-1" {
t.Fatalf("payload host = %#v, want %q", payload["host"], "pve-1")
}
if payload["reason"] != "testing" {
t.Fatalf("payload reason = %#v, want %q", payload["reason"], "testing")
}
if _, ok := payload["requestedAt"].(string); !ok {
t.Fatalf("payload requestedAt missing or not string: %#v", payload["requestedAt"])
}
if ts, ok := payload["requestedAt"].(string); ok {
if _, err := time.Parse(time.RFC3339, ts); err != nil {
t.Fatalf("requestedAt not RFC3339: %q: %v", ts, err)
}
}
if runtime.GOOS != "windows" {
fi, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat(%s): %v", path, err)
}
if fi.Mode().Perm() != 0o600 {
t.Fatalf("payload file mode = %v, want %v", fi.Mode().Perm(), os.FileMode(0o600))
}
}
resp, err = p.handleRequestCleanup(ctx, &RPCRequest{
Method: RPCRequestCleanup,
Params: map[string]interface{}{
"host": "pve-1",
"reason": "testing-2",
},
}, logger)
if err != nil {
t.Fatalf("handleRequestCleanup (2): %v", err)
}
respMap, ok = resp.(map[string]any)
if !ok {
t.Fatalf("expected map response (2), got %#v", resp)
}
queued, ok = respMap["queued"].(bool)
if !ok || !queued {
t.Fatalf("expected queued=true response (2), got %#v", resp)
}
payload2 := readPayload()
if payload2["reason"] != "testing-2" {
t.Fatalf("payload reason after replace = %#v, want %q", payload2["reason"], "testing-2")
}
}

View file

@ -0,0 +1,364 @@
package agentexec
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type wsRawMessage struct {
Type MessageType `json:"type"`
ID string `json:"id,omitempty"`
Timestamp time.Time `json:"timestamp"`
Payload *json.RawMessage `json:"payload,omitempty"`
}
func newWSServer(t *testing.T, s *Server) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.HandleWebSocket(w, r)
}))
}
func wsURLForHTTP(serverURL string) string {
return "ws" + strings.TrimPrefix(serverURL, "http")
}
func wsWriteMessage(t *testing.T, conn *websocket.Conn, msg Message) {
t.Helper()
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := conn.WriteJSON(msg); err != nil {
t.Fatalf("WriteJSON: %v", err)
}
}
func wsReadRawMessage(t *testing.T, conn *websocket.Conn) wsRawMessage {
t.Helper()
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
t.Fatalf("ReadMessage: %v", err)
}
return msg
}
func wsReadRegisteredPayload(t *testing.T, conn *websocket.Conn) RegisteredPayload {
t.Helper()
msg := wsReadRawMessage(t, conn)
if msg.Type != MsgTypeRegistered {
t.Fatalf("message type = %q, want %q", msg.Type, MsgTypeRegistered)
}
if msg.Payload == nil {
t.Fatalf("registered payload missing")
}
var payload RegisteredPayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
t.Fatalf("unmarshal registered payload: %v", err)
}
return payload
}
func wsReadRawMessageWithTimeout(conn *websocket.Conn, timeout time.Duration) (wsRawMessage, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
_, data, err := conn.ReadMessage()
if err != nil {
return wsRawMessage{}, err
}
var msg wsRawMessage
if err := json.Unmarshal(data, &msg); err != nil {
return wsRawMessage{}, err
}
return msg, nil
}
func waitFor(t *testing.T, timeout time.Duration, cond func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("condition not met within %v", timeout)
}
func TestHandleWebSocket_RegistrationSuccessAndDisconnectRemovesAgent(t *testing.T) {
s := NewServer(func(token string) bool { return token == "ok" })
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Tags: []string{"tag1"},
Token: "ok",
},
})
reg := wsReadRegisteredPayload(t, conn)
if !reg.Success {
t.Fatalf("registration failed: %q", reg.Message)
}
if !s.IsAgentConnected("a1") {
t.Fatalf("expected agent to be connected")
}
conn.Close()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestHandleWebSocket_InvalidTokenRejected(t *testing.T) {
s := NewServer(func(string) bool { return false })
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "bad",
},
})
reg := wsReadRegisteredPayload(t, conn)
if reg.Success {
t.Fatalf("expected registration to be rejected")
}
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
_, _, err = conn.ReadMessage()
if err == nil {
t.Fatalf("expected connection to be closed by server")
}
}
func TestHandleWebSocket_FirstMessageMustBeRegister(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentPing,
Timestamp: time.Now(),
})
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
_, _, err = conn.ReadMessage()
if err == nil {
t.Fatalf("expected server to close connection")
}
}
func TestHandleWebSocket_AgentPingRespondsWithPong(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentPing,
Timestamp: time.Now(),
})
msg := wsReadRawMessage(t, conn)
if msg.Type != MsgTypePong {
t.Fatalf("message type = %q, want %q", msg.Type, MsgTypePong)
}
}
func TestExecuteCommand_RoundTripViaWebSocket(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, conn)
agentDone := make(chan struct{})
agentErr := make(chan error, 1)
go func() {
defer close(agentDone)
for {
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
agentErr <- err
return
}
if msg.Type != MsgTypeExecuteCmd {
continue
}
if msg.Payload == nil {
agentErr <- nil
return
}
var payload ExecuteCommandPayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
agentErr <- err
return
}
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := conn.WriteJSON(Message{
Type: MsgTypeCommandResult,
Timestamp: time.Now(),
Payload: CommandResultPayload{
RequestID: payload.RequestID,
Success: true,
Stdout: "ok",
ExitCode: 0,
Duration: 1,
},
}); err != nil {
agentErr <- err
return
}
agentErr <- nil
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, err := s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{
RequestID: "req1",
Command: "echo ok",
Timeout: 1,
})
if err != nil {
t.Fatalf("ExecuteCommand: %v", err)
}
if result == nil || !result.Success || result.Stdout != "ok" || result.ExitCode != 0 {
t.Fatalf("unexpected result: %#v", result)
}
select {
case <-agentDone:
case <-time.After(2 * time.Second):
t.Fatalf("agent goroutine did not finish")
}
if err := <-agentErr; err != nil {
t.Fatalf("agent error: %v", err)
}
}
func TestHandleWebSocket_ReconnectSameAgentIDClosesOldConnection(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
dial := func() *websocket.Conn {
t.Helper()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
return conn
}
c1 := dial()
defer c1.Close()
wsWriteMessage(t, c1, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, c1)
c2 := dial()
defer c2.Close()
wsWriteMessage(t, c2, Message{
Type: MsgTypeAgentRegister,
Timestamp: time.Now(),
Payload: AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Version: "1.2.3",
Platform: "linux",
Token: "any",
},
})
_ = wsReadRegisteredPayload(t, c2)
c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
_, _, err := c1.ReadMessage()
if err == nil {
t.Fatalf("expected old connection to be closed")
}
}

View file

@ -72,6 +72,17 @@ const defaultInterval = 30 * time.Second
var readFile = os.ReadFile
var (
hostInfoWithContext = gohost.InfoWithContext
hostUptimeWithContext = gohost.UptimeWithContext
hostmetricsCollect = hostmetrics.Collect
sensorsCollectLocal = sensors.CollectLocal
sensorsParse = sensors.Parse
mdadmCollectArrays = mdadm.CollectArrays
cephCollect = ceph.Collect
nowUTC = func() time.Time { return time.Now().UTC() }
)
// New constructs a fully initialised host Agent.
func New(cfg Config) (*Agent, error) {
if cfg.Interval <= 0 {
@ -107,7 +118,7 @@ func New(cfg Config) (*Agent, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
info, err := gohost.InfoWithContext(ctx)
info, err := hostInfoWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("fetch host info: %w", err)
}
@ -311,8 +322,8 @@ func (a *Agent) buildReport(ctx context.Context) (agentshost.Report, error) {
collectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
uptime, _ := gohost.UptimeWithContext(collectCtx)
snapshot, err := hostmetrics.Collect(collectCtx)
uptime, _ := hostUptimeWithContext(collectCtx)
snapshot, err := hostmetricsCollect(collectCtx)
if err != nil {
return agentshost.Report{}, fmt.Errorf("collect metrics: %w", err)
}
@ -361,7 +372,7 @@ func (a *Agent) buildReport(ctx context.Context) (agentshost.Report, error) {
RAID: raidData,
Ceph: cephData,
Tags: append([]string(nil), a.cfg.Tags...),
Timestamp: time.Now().UTC(),
Timestamp: nowUTC(),
}
return report, nil
@ -416,14 +427,14 @@ func (a *Agent) collectTemperatures(ctx context.Context) agentshost.Sensors {
}
// Collect sensor JSON output
jsonOutput, err := sensors.CollectLocal(ctx)
jsonOutput, err := sensorsCollectLocal(ctx)
if err != nil {
a.logger.Debug().Err(err).Msg("Failed to collect sensor data (lm-sensors may not be installed)")
return agentshost.Sensors{}
}
// Parse the sensor output
tempData, err := sensors.Parse(jsonOutput)
tempData, err := sensorsParse(jsonOutput)
if err != nil {
a.logger.Debug().Err(err).Msg("Failed to parse sensor data")
return agentshost.Sensors{}
@ -476,7 +487,7 @@ func (a *Agent) collectRAIDArrays(ctx context.Context) []agentshost.RAIDArray {
return nil
}
arrays, err := mdadm.CollectArrays(ctx)
arrays, err := mdadmCollectArrays(ctx)
if err != nil {
a.logger.Debug().Err(err).Msg("Failed to collect RAID array data (mdadm may not be installed)")
return nil
@ -499,7 +510,7 @@ func (a *Agent) collectCephStatus(ctx context.Context) *agentshost.CephCluster {
return nil
}
status, err := ceph.Collect(ctx)
status, err := cephCollect(ctx)
if err != nil {
a.logger.Debug().Err(err).Msg("Failed to collect Ceph status")
return nil

View file

@ -0,0 +1,90 @@
package hostagent
import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/buffer"
agentshost "github.com/rcourtman/pulse-go-rewrite/pkg/agents/host"
"github.com/rs/zerolog"
)
func TestAgent_flushBuffer_StopsOnFailureAndDoesNotDropReport(t *testing.T) {
var requestCount int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&requestCount, 1)
if n == 1 {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
a := &Agent{
cfg: Config{APIToken: "token"},
logger: zerolog.Nop(),
httpClient: server.Client(),
trimmedPulseURL: server.URL,
reportBuffer: buffer.New[agentshost.Report](10),
}
report1 := agentshost.Report{Agent: agentshost.AgentInfo{ID: "r1"}}
report2 := agentshost.Report{Agent: agentshost.AgentInfo{ID: "r2"}}
a.reportBuffer.Push(report1)
a.reportBuffer.Push(report2)
a.flushBuffer(context.Background())
if got := a.reportBuffer.Len(); got != 1 {
t.Fatalf("buffer len = %d, want %d", got, 1)
}
peek, ok := a.reportBuffer.Peek()
if !ok {
t.Fatalf("expected buffered report")
}
if peek.Agent.ID != "r2" {
t.Fatalf("remaining buffered report = %q, want %q", peek.Agent.ID, "r2")
}
}
func TestAgent_flushBuffer_RetryAfterTransientFailure(t *testing.T) {
var fail atomic.Bool
fail.Store(true)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if fail.Load() {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
a := &Agent{
cfg: Config{APIToken: "token"},
logger: zerolog.Nop(),
httpClient: server.Client(),
trimmedPulseURL: server.URL,
reportBuffer: buffer.New[agentshost.Report](10),
}
a.reportBuffer.Push(agentshost.Report{Agent: agentshost.AgentInfo{ID: "r1"}})
a.reportBuffer.Push(agentshost.Report{Agent: agentshost.AgentInfo{ID: "r2"}})
a.flushBuffer(context.Background())
if got := a.reportBuffer.Len(); got != 2 {
t.Fatalf("buffer len after failure = %d, want %d", got, 2)
}
fail.Store(false)
a.flushBuffer(context.Background())
if !a.reportBuffer.IsEmpty() {
t.Fatalf("expected buffer to be empty, has %d items", a.reportBuffer.Len())
}
}

View file

@ -0,0 +1,167 @@
package hostagent
import (
"context"
"crypto/tls"
"net/http"
"os"
"runtime"
"testing"
"github.com/rs/zerolog"
gohost "github.com/shirou/gopsutil/v4/host"
)
func TestNew_RequiresAPIToken(t *testing.T) {
originalHostInfo := hostInfoWithContext
t.Cleanup(func() { hostInfoWithContext = originalHostInfo })
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
return &gohost.InfoStat{Hostname: "host", HostID: "hid", KernelArch: runtime.GOARCH}, nil
}
_, err := New(Config{APIToken: " ", LogLevel: zerolog.InfoLevel})
if err == nil {
t.Fatalf("expected error")
}
}
func TestNew_NormalizesConfigAndTags(t *testing.T) {
originalHostInfo := hostInfoWithContext
originalReadFile := readFile
t.Cleanup(func() {
hostInfoWithContext = originalHostInfo
readFile = originalReadFile
})
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
return &gohost.InfoStat{
Hostname: " host-from-info ",
HostID: " gopsutil-id ",
Platform: "Darwin",
PlatformFamily: "",
PlatformVersion: "14.0",
KernelVersion: "6.6.0",
KernelArch: runtime.GOARCH,
}, nil
}
readFile = func(name string) ([]byte, error) {
if name == "/etc/machine-id" {
return []byte("0123456789abcdef0123456789abcdef\n"), nil
}
return nil, os.ErrNotExist
}
originalTags := []string{" tag-a ", "tag-a", "", " tag-b", "tag-b "}
cfg := Config{
PulseURL: "http://example.com///",
APIToken: "token",
Interval: 0,
Tags: originalTags,
InsecureSkipVerify: true,
LogLevel: zerolog.InfoLevel,
}
agent, err := New(cfg)
if err != nil {
t.Fatalf("New: %v", err)
}
if agent.interval != defaultInterval {
t.Fatalf("interval = %v, want %v", agent.interval, defaultInterval)
}
if agent.trimmedPulseURL != "http://example.com" {
t.Fatalf("trimmedPulseURL = %q, want %q", agent.trimmedPulseURL, "http://example.com")
}
if agent.hostname != "host-from-info" {
t.Fatalf("hostname = %q, want %q", agent.hostname, "host-from-info")
}
if agent.displayName != "host-from-info" {
t.Fatalf("displayName = %q, want %q", agent.displayName, "host-from-info")
}
if agent.platform != "macos" {
t.Fatalf("platform = %q, want %q", agent.platform, "macos")
}
if got, want := agent.cfg.Tags, []string{"tag-a", "tag-b"}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] {
t.Fatalf("tags = %#v, want %#v", got, want)
}
// Ensure we don't retain aliasing to the caller-provided tags slice.
originalTags[0] = "mutated"
if agent.cfg.Tags[0] != "tag-a" {
t.Fatalf("agent tags aliased caller slice: %#v", agent.cfg.Tags)
}
httpTransport, ok := agent.httpClient.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected *http.Transport, got %T", agent.httpClient.Transport)
}
if httpTransport.TLSClientConfig == nil || httpTransport.TLSClientConfig.MinVersion != tls.VersionTLS12 {
t.Fatalf("TLSClientConfig MinVersion = %#v, want TLS1.2", httpTransport.TLSClientConfig)
}
if !httpTransport.TLSClientConfig.InsecureSkipVerify {
t.Fatalf("expected InsecureSkipVerify=true")
}
if runtime.GOOS == "linux" {
const want = "01234567-89ab-cdef-0123-456789abcdef"
if agent.machineID != want {
t.Fatalf("machineID = %q, want %q", agent.machineID, want)
}
if agent.agentID != want {
t.Fatalf("agentID = %q, want %q", agent.agentID, want)
}
} else {
if agent.machineID != "gopsutil-id" {
t.Fatalf("machineID = %q, want %q", agent.machineID, "gopsutil-id")
}
if agent.agentID != "gopsutil-id" {
t.Fatalf("agentID = %q, want %q", agent.agentID, "gopsutil-id")
}
}
}
func TestNew_DefaultPulseURL(t *testing.T) {
originalHostInfo := hostInfoWithContext
t.Cleanup(func() { hostInfoWithContext = originalHostInfo })
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
return &gohost.InfoStat{Hostname: "host", HostID: "hid", KernelArch: runtime.GOARCH}, nil
}
agent, err := New(Config{PulseURL: "", APIToken: "token", LogLevel: zerolog.InfoLevel})
if err != nil {
t.Fatalf("New: %v", err)
}
if agent.trimmedPulseURL != "http://localhost:7655" {
t.Fatalf("trimmedPulseURL = %q, want %q", agent.trimmedPulseURL, "http://localhost:7655")
}
}
func TestNew_FallsBackToHostnameWhenMachineIDEmpty(t *testing.T) {
originalHostInfo := hostInfoWithContext
originalReadFile := readFile
t.Cleanup(func() {
hostInfoWithContext = originalHostInfo
readFile = originalReadFile
})
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
return &gohost.InfoStat{
Hostname: "host-from-info",
HostID: "",
KernelArch: runtime.GOARCH,
}, nil
}
readFile = func(string) ([]byte, error) { return nil, os.ErrNotExist }
agent, err := New(Config{APIToken: "token", LogLevel: zerolog.InfoLevel})
if err != nil {
t.Fatalf("New: %v", err)
}
if agent.machineID != "" {
t.Fatalf("expected empty machineID, got %q", agent.machineID)
}
if agent.agentID != "host-from-info" {
t.Fatalf("agentID = %q, want %q", agent.agentID, "host-from-info")
}
}

View file

@ -0,0 +1,96 @@
package hostagent
import (
"context"
"errors"
"runtime"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/sensors"
"github.com/rs/zerolog"
)
func TestAgent_collectTemperatures_MapsKeys(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("temperature collection currently only runs on linux")
}
originalCollect := sensorsCollectLocal
originalParse := sensorsParse
t.Cleanup(func() {
sensorsCollectLocal = originalCollect
sensorsParse = originalParse
})
sensorsCollectLocal = func(context.Context) (string, error) { return "{}", nil }
sensorsParse = func(string) (*sensors.TemperatureData, error) {
return &sensors.TemperatureData{
Available: true,
CPUPackage: 55.5,
Cores: map[string]float64{
"Core 0": 44,
"Core 1": 45,
},
NVMe: map[string]float64{
"nvme0": 40,
},
GPU: map[string]float64{
"amdgpu-pci-0100": 60,
},
}, nil
}
a := &Agent{logger: zerolog.Nop()}
got := a.collectTemperatures(context.Background())
want := map[string]float64{
"cpu_package": 55.5,
"cpu_core_0": 44,
"cpu_core_1": 45,
"nvme0": 40,
"amdgpu-pci-0100": 60,
}
if got.TemperatureCelsius == nil {
t.Fatalf("expected TemperatureCelsius map to be initialised")
}
if len(got.TemperatureCelsius) != len(want) {
t.Fatalf("temperature keys = %d, want %d", len(got.TemperatureCelsius), len(want))
}
for k, v := range want {
if gotVal, ok := got.TemperatureCelsius[k]; !ok || gotVal != v {
t.Fatalf("TemperatureCelsius[%q] = (%v, %v), want (%v, %v)", k, gotVal, ok, v, true)
}
}
}
func TestAgent_collectTemperatures_BestEffortFailuresReturnEmpty(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("temperature collection currently only runs on linux")
}
originalCollect := sensorsCollectLocal
originalParse := sensorsParse
t.Cleanup(func() {
sensorsCollectLocal = originalCollect
sensorsParse = originalParse
})
a := &Agent{logger: zerolog.Nop()}
sensorsCollectLocal = func(context.Context) (string, error) { return "", errors.New("no sensors") }
if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 {
t.Fatalf("expected empty sensors on collect error, got %#v", got.TemperatureCelsius)
}
sensorsCollectLocal = func(context.Context) (string, error) { return "{}", nil }
sensorsParse = func(string) (*sensors.TemperatureData, error) { return nil, errors.New("bad json") }
if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 {
t.Fatalf("expected empty sensors on parse error, got %#v", got.TemperatureCelsius)
}
sensorsParse = func(string) (*sensors.TemperatureData, error) { return &sensors.TemperatureData{Available: false}, nil }
if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 {
t.Fatalf("expected empty sensors when unavailable, got %#v", got.TemperatureCelsius)
}
}

View file

@ -361,6 +361,16 @@ func (c *CommandClient) handleExecuteCommand(ctx context.Context, conn *websocke
}
}
func wrapCommand(payload executeCommandPayload) string {
if payload.TargetType == "container" && payload.TargetID != "" {
return fmt.Sprintf("pct exec %s -- %s", payload.TargetID, payload.Command)
}
if payload.TargetType == "vm" && payload.TargetID != "" {
return fmt.Sprintf("qm guest exec %s -- %s", payload.TargetID, payload.Command)
}
return payload.Command
}
func (c *CommandClient) executeCommand(ctx context.Context, payload executeCommandPayload) commandResultPayload {
result := commandResultPayload{
RequestID: payload.RequestID,
@ -375,17 +385,7 @@ func (c *CommandClient) executeCommand(ctx context.Context, payload executeComma
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Build the command based on target type
command := payload.Command
// If targeting a container or VM, wrap the command
if payload.TargetType == "container" && payload.TargetID != "" {
// Use pct exec for LXC containers
command = fmt.Sprintf("pct exec %s -- %s", payload.TargetID, payload.Command)
} else if payload.TargetType == "vm" && payload.TargetID != "" {
// Use qm guest exec for VMs (requires QEMU guest agent)
command = fmt.Sprintf("qm guest exec %s -- %s", payload.TargetID, payload.Command)
}
command := wrapCommand(payload)
// Execute the command
var cmd *exec.Cmd
@ -405,13 +405,13 @@ func (c *CommandClient) executeCommand(ctx context.Context, payload executeComma
result.Stderr = stderr.String()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
result.Success = false
} else if cmdCtx.Err() == context.DeadlineExceeded {
if cmdCtx.Err() == context.DeadlineExceeded {
result.Error = "command timed out"
result.ExitCode = -1
result.Success = false
} else if exitErr, ok := err.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
result.Success = false
} else {
result.Error = err.Error()
result.ExitCode = -1

View file

@ -0,0 +1,132 @@
package hostagent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
func TestCommandClient_connectAndHandle_ExecutesCommandAndReturnsResult(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket integration test in short mode")
}
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
serverDone := make(chan struct{})
gotResult := make(chan commandResultPayload, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var regMsg wsMessage
if err := conn.ReadJSON(&regMsg); err != nil {
t.Errorf("read registration: %v", err)
return
}
if regMsg.Type != msgTypeAgentRegister {
t.Errorf("registration type = %q, want %q", regMsg.Type, msgTypeAgentRegister)
return
}
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
registeredBytes, _ := json.Marshal(registeredPayload{Success: true, Message: "Registered"})
if err := conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: registeredBytes}); err != nil {
t.Errorf("write registered: %v", err)
return
}
execPayloadBytes, _ := json.Marshal(executeCommandPayload{
RequestID: "req-1",
Command: "echo hello",
TargetType: "host",
Timeout: 5,
})
if err := conn.WriteJSON(wsMessage{Type: msgTypeExecuteCmd, ID: "req-1", Timestamp: time.Now(), Payload: execPayloadBytes}); err != nil {
t.Errorf("write execute_command: %v", err)
return
}
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
var resultMsg wsMessage
if err := conn.ReadJSON(&resultMsg); err != nil {
t.Errorf("read command_result: %v", err)
return
}
if resultMsg.Type != msgTypeCommandResult {
t.Errorf("result type = %q, want %q", resultMsg.Type, msgTypeCommandResult)
return
}
var result commandResultPayload
if err := json.Unmarshal(resultMsg.Payload, &result); err != nil {
t.Errorf("unmarshal command_result payload: %v", err)
return
}
gotResult <- result
<-serverDone
}))
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := &CommandClient{
pulseURL: strings.TrimRight(server.URL, "/"),
apiToken: "token",
agentID: "agent-1",
hostname: "host-1",
platform: "linux",
version: "1.2.3",
logger: zerolog.Nop(),
done: make(chan struct{}),
}
errCh := make(chan error, 1)
go func() {
errCh <- client.connectAndHandle(ctx)
}()
select {
case result := <-gotResult:
if result.RequestID != "req-1" {
t.Fatalf("result.RequestID = %q, want %q", result.RequestID, "req-1")
}
if !result.Success || result.ExitCode != 0 {
t.Fatalf("unexpected result: %#v", result)
}
if !strings.Contains(result.Stdout, "hello") {
t.Fatalf("stdout = %q, expected to contain %q", result.Stdout, "hello")
}
cancel()
close(serverDone)
case <-time.After(10 * time.Second):
t.Fatalf("timed out waiting for command result")
}
select {
case err := <-errCh:
if err == nil {
t.Fatalf("expected error due to context cancellation")
}
case <-time.After(5 * time.Second):
t.Fatalf("timed out waiting for connectAndHandle to return")
}
}

View file

@ -0,0 +1,151 @@
package hostagent
import (
"context"
"runtime"
"strings"
"testing"
"time"
)
func TestWrapCommand_TargetWrapping(t *testing.T) {
tests := []struct {
name string
payload executeCommandPayload
want string
}{
{
name: "host command unchanged",
payload: executeCommandPayload{
Command: "echo ok",
TargetType: "host",
},
want: "echo ok",
},
{
name: "container wraps with pct",
payload: executeCommandPayload{
Command: "echo ok",
TargetType: "container",
TargetID: "101",
},
want: "pct exec 101 -- echo ok",
},
{
name: "vm wraps with qm guest exec",
payload: executeCommandPayload{
Command: "echo ok",
TargetType: "vm",
TargetID: "900",
},
want: "qm guest exec 900 -- echo ok",
},
{
name: "missing target id does not wrap",
payload: executeCommandPayload{
Command: "echo ok",
TargetType: "container",
TargetID: "",
},
want: "echo ok",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := wrapCommand(tt.payload); got != tt.want {
t.Fatalf("wrapCommand() = %q, want %q", got, tt.want)
}
})
}
}
func TestCommandClient_executeCommand_Success(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("executeCommand uses different shell on windows")
}
c := &CommandClient{}
result := c.executeCommand(context.Background(), executeCommandPayload{
RequestID: "r1",
Command: "echo hello",
Timeout: 5,
})
if !result.Success || result.ExitCode != 0 {
t.Fatalf("expected success, got %#v", result)
}
if !strings.Contains(result.Stdout, "hello") {
t.Fatalf("stdout = %q, expected to contain %q", result.Stdout, "hello")
}
}
func TestCommandClient_executeCommand_NonZeroExitCode(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("executeCommand uses different shell on windows")
}
c := &CommandClient{}
result := c.executeCommand(context.Background(), executeCommandPayload{
RequestID: "r1",
Command: "echo err 1>&2; exit 3",
Timeout: 5,
})
if result.Success {
t.Fatalf("expected failure, got %#v", result)
}
if result.ExitCode != 3 {
t.Fatalf("exit code = %d, want %d", result.ExitCode, 3)
}
if !strings.Contains(result.Stderr, "err") {
t.Fatalf("stderr = %q, expected to contain %q", result.Stderr, "err")
}
}
func TestCommandClient_executeCommand_TimeoutSetsError(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("executeCommand uses different shell on windows")
}
c := &CommandClient{}
start := time.Now()
result := c.executeCommand(context.Background(), executeCommandPayload{
RequestID: "r1",
Command: "sleep 2",
Timeout: 1,
})
if time.Since(start) > 3*time.Second {
t.Fatalf("timeout path took too long: %v", time.Since(start))
}
if result.Success {
t.Fatalf("expected failure, got %#v", result)
}
if result.ExitCode != -1 {
t.Fatalf("exit code = %d, want %d", result.ExitCode, -1)
}
if result.Error != "command timed out" {
t.Fatalf("error = %q, want %q", result.Error, "command timed out")
}
}
func TestCommandClient_executeCommand_TruncatesLargeOutput(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("executeCommand uses different shell on windows")
}
c := &CommandClient{}
result := c.executeCommand(context.Background(), executeCommandPayload{
RequestID: "r1",
Command: "head -c 1048580 /dev/zero | tr '\\0' 'a'",
Timeout: 5,
})
if !result.Success {
t.Fatalf("expected success, got %#v", result)
}
if !strings.Contains(result.Stdout, "(output truncated)") {
t.Fatalf("expected truncation marker, got stdout len=%d", len(result.Stdout))
}
if len(result.Stdout) > 1024*1024+64 {
t.Fatalf("stdout len=%d, expected <= %d", len(result.Stdout), 1024*1024+64)
}
}

View file

@ -0,0 +1,180 @@
package hostagent
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
func wsURLForHTTP(serverURL string) string {
return "ws" + strings.TrimPrefix(serverURL, "http")
}
func TestCommandClient_sendRegistration_WritesExpectedPayload(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
gotMsgCh := make(chan wsMessage, 1)
gotPayloadCh := make(chan registerPayload, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
var msg wsMessage
if err := conn.ReadJSON(&msg); err != nil {
t.Errorf("ReadJSON: %v", err)
return
}
var payload registerPayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
t.Errorf("unmarshal payload: %v", err)
return
}
gotMsgCh <- msg
gotPayloadCh <- payload
}))
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
client := &CommandClient{
apiToken: "token-1",
agentID: "agent-1",
hostname: "host-1",
platform: "linux",
version: "1.2.3",
logger: zerolog.Nop(),
}
if err := client.sendRegistration(conn); err != nil {
t.Fatalf("sendRegistration: %v", err)
}
msg := <-gotMsgCh
if msg.Type != msgTypeAgentRegister {
t.Fatalf("msg.Type = %q, want %q", msg.Type, msgTypeAgentRegister)
}
if msg.Timestamp.IsZero() {
t.Fatalf("expected non-zero Timestamp")
}
payload := <-gotPayloadCh
if payload.AgentID != "agent-1" {
t.Fatalf("payload.AgentID = %q, want %q", payload.AgentID, "agent-1")
}
if payload.Hostname != "host-1" {
t.Fatalf("payload.Hostname = %q, want %q", payload.Hostname, "host-1")
}
if payload.Platform != "linux" {
t.Fatalf("payload.Platform = %q, want %q", payload.Platform, "linux")
}
if payload.Version != "1.2.3" {
t.Fatalf("payload.Version = %q, want %q", payload.Version, "1.2.3")
}
if payload.Token != "token-1" {
t.Fatalf("payload.Token = %q, want %q", payload.Token, "token-1")
}
}
func TestCommandClient_waitForRegistration_AcceptsSuccess(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
defer conn.Close()
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
payload, _ := json.Marshal(registeredPayload{Success: true, Message: "Registered"})
_ = conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: payload})
}))
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
client := &CommandClient{logger: zerolog.Nop()}
if err := client.waitForRegistration(conn); err != nil {
t.Fatalf("waitForRegistration: %v", err)
}
}
func TestCommandClient_waitForRegistration_RejectsFailure(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
defer conn.Close()
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
payload, _ := json.Marshal(registeredPayload{Success: false, Message: "Invalid token"})
_ = conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: payload})
}))
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
client := &CommandClient{logger: zerolog.Nop()}
if err := client.waitForRegistration(conn); err == nil {
t.Fatalf("expected error")
}
}
func TestCommandClient_waitForRegistration_UnexpectedMessageType(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
defer conn.Close()
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
_ = conn.WriteJSON(wsMessage{Type: msgTypePong, Timestamp: time.Now()})
}))
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
client := &CommandClient{logger: zerolog.Nop()}
if err := client.waitForRegistration(conn); err == nil {
t.Fatalf("expected error")
}
}