mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 03:20:11 +00:00
Add meaningful tests for host agent and exec websocket
This commit is contained in:
parent
ab480ca489
commit
30f01771ac
10 changed files with 1335 additions and 23 deletions
121
cmd/pulse-sensor-proxy/cleanup_test.go
Normal file
121
cmd/pulse-sensor-proxy/cleanup_test.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestProxy_cleanupRequestPath_UsesConfiguredWorkDir(t *testing.T) {
|
||||
p := &Proxy{workDir: "/tmp/pulse-sensor-proxy-test"}
|
||||
got, err := p.cleanupRequestPath()
|
||||
if err != nil {
|
||||
t.Fatalf("cleanupRequestPath: %v", err)
|
||||
}
|
||||
want := filepath.Join(p.workDir, cleanupRequestFilename)
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_handleRequestCleanup_WritesValidPayloadAndReplacesExisting(t *testing.T) {
|
||||
workDir := t.TempDir()
|
||||
p := &Proxy{workDir: workDir}
|
||||
logger := zerolog.Nop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.handleRequestCleanup(ctx, &RPCRequest{
|
||||
Method: RPCRequestCleanup,
|
||||
Params: map[string]interface{}{
|
||||
"host": "pve-1",
|
||||
"reason": "testing",
|
||||
},
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("handleRequestCleanup: %v", err)
|
||||
}
|
||||
respMap, ok := resp.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected map response, got %#v", resp)
|
||||
}
|
||||
queued, ok := respMap["queued"].(bool)
|
||||
if !ok || !queued {
|
||||
t.Fatalf("expected queued=true response, got %#v", resp)
|
||||
}
|
||||
|
||||
path, err := p.cleanupRequestPath()
|
||||
if err != nil {
|
||||
t.Fatalf("cleanupRequestPath: %v", err)
|
||||
}
|
||||
|
||||
readPayload := func() map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(%s): %v", path, err)
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
t.Fatalf("unmarshal payload: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
payload := readPayload()
|
||||
if payload["host"] != "pve-1" {
|
||||
t.Fatalf("payload host = %#v, want %q", payload["host"], "pve-1")
|
||||
}
|
||||
if payload["reason"] != "testing" {
|
||||
t.Fatalf("payload reason = %#v, want %q", payload["reason"], "testing")
|
||||
}
|
||||
if _, ok := payload["requestedAt"].(string); !ok {
|
||||
t.Fatalf("payload requestedAt missing or not string: %#v", payload["requestedAt"])
|
||||
}
|
||||
if ts, ok := payload["requestedAt"].(string); ok {
|
||||
if _, err := time.Parse(time.RFC3339, ts); err != nil {
|
||||
t.Fatalf("requestedAt not RFC3339: %q: %v", ts, err)
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat(%s): %v", path, err)
|
||||
}
|
||||
if fi.Mode().Perm() != 0o600 {
|
||||
t.Fatalf("payload file mode = %v, want %v", fi.Mode().Perm(), os.FileMode(0o600))
|
||||
}
|
||||
}
|
||||
|
||||
resp, err = p.handleRequestCleanup(ctx, &RPCRequest{
|
||||
Method: RPCRequestCleanup,
|
||||
Params: map[string]interface{}{
|
||||
"host": "pve-1",
|
||||
"reason": "testing-2",
|
||||
},
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("handleRequestCleanup (2): %v", err)
|
||||
}
|
||||
respMap, ok = resp.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected map response (2), got %#v", resp)
|
||||
}
|
||||
queued, ok = respMap["queued"].(bool)
|
||||
if !ok || !queued {
|
||||
t.Fatalf("expected queued=true response (2), got %#v", resp)
|
||||
}
|
||||
|
||||
payload2 := readPayload()
|
||||
if payload2["reason"] != "testing-2" {
|
||||
t.Fatalf("payload reason after replace = %#v, want %q", payload2["reason"], "testing-2")
|
||||
}
|
||||
}
|
||||
364
internal/agentexec/server_websocket_test.go
Normal file
364
internal/agentexec/server_websocket_test.go
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
package agentexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type wsRawMessage struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Payload *json.RawMessage `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
func newWSServer(t *testing.T, s *Server) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.HandleWebSocket(w, r)
|
||||
}))
|
||||
}
|
||||
|
||||
func wsURLForHTTP(serverURL string) string {
|
||||
return "ws" + strings.TrimPrefix(serverURL, "http")
|
||||
}
|
||||
|
||||
func wsWriteMessage(t *testing.T, conn *websocket.Conn, msg Message) {
|
||||
t.Helper()
|
||||
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
t.Fatalf("WriteJSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func wsReadRawMessage(t *testing.T, conn *websocket.Conn) wsRawMessage {
|
||||
t.Helper()
|
||||
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage: %v", err)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func wsReadRegisteredPayload(t *testing.T, conn *websocket.Conn) RegisteredPayload {
|
||||
t.Helper()
|
||||
msg := wsReadRawMessage(t, conn)
|
||||
if msg.Type != MsgTypeRegistered {
|
||||
t.Fatalf("message type = %q, want %q", msg.Type, MsgTypeRegistered)
|
||||
}
|
||||
if msg.Payload == nil {
|
||||
t.Fatalf("registered payload missing")
|
||||
}
|
||||
var payload RegisteredPayload
|
||||
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
|
||||
t.Fatalf("unmarshal registered payload: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func wsReadRawMessageWithTimeout(conn *websocket.Conn, timeout time.Duration) (wsRawMessage, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
_, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return wsRawMessage{}, err
|
||||
}
|
||||
var msg wsRawMessage
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return wsRawMessage{}, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, timeout time.Duration, cond func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if cond() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %v", timeout)
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_RegistrationSuccessAndDisconnectRemovesAgent(t *testing.T) {
|
||||
s := NewServer(func(token string) bool { return token == "ok" })
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Version: "1.2.3",
|
||||
Platform: "linux",
|
||||
Tags: []string{"tag1"},
|
||||
Token: "ok",
|
||||
},
|
||||
})
|
||||
|
||||
reg := wsReadRegisteredPayload(t, conn)
|
||||
if !reg.Success {
|
||||
t.Fatalf("registration failed: %q", reg.Message)
|
||||
}
|
||||
|
||||
if !s.IsAgentConnected("a1") {
|
||||
t.Fatalf("expected agent to be connected")
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_InvalidTokenRejected(t *testing.T) {
|
||||
s := NewServer(func(string) bool { return false })
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Version: "1.2.3",
|
||||
Platform: "linux",
|
||||
Token: "bad",
|
||||
},
|
||||
})
|
||||
|
||||
reg := wsReadRegisteredPayload(t, conn)
|
||||
if reg.Success {
|
||||
t.Fatalf("expected registration to be rejected")
|
||||
}
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatalf("expected connection to be closed by server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_FirstMessageMustBeRegister(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentPing,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatalf("expected server to close connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_AgentPingRespondsWithPong(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Version: "1.2.3",
|
||||
Platform: "linux",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
_ = wsReadRegisteredPayload(t, conn)
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentPing,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
msg := wsReadRawMessage(t, conn)
|
||||
if msg.Type != MsgTypePong {
|
||||
t.Fatalf("message type = %q, want %q", msg.Type, MsgTypePong)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommand_RoundTripViaWebSocket(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
wsWriteMessage(t, conn, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Version: "1.2.3",
|
||||
Platform: "linux",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
_ = wsReadRegisteredPayload(t, conn)
|
||||
|
||||
agentDone := make(chan struct{})
|
||||
agentErr := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(agentDone)
|
||||
for {
|
||||
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
|
||||
if err != nil {
|
||||
agentErr <- err
|
||||
return
|
||||
}
|
||||
if msg.Type != MsgTypeExecuteCmd {
|
||||
continue
|
||||
}
|
||||
if msg.Payload == nil {
|
||||
agentErr <- nil
|
||||
return
|
||||
}
|
||||
var payload ExecuteCommandPayload
|
||||
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
|
||||
agentErr <- err
|
||||
return
|
||||
}
|
||||
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
if err := conn.WriteJSON(Message{
|
||||
Type: MsgTypeCommandResult,
|
||||
Timestamp: time.Now(),
|
||||
Payload: CommandResultPayload{
|
||||
RequestID: payload.RequestID,
|
||||
Success: true,
|
||||
Stdout: "ok",
|
||||
ExitCode: 0,
|
||||
Duration: 1,
|
||||
},
|
||||
}); err != nil {
|
||||
agentErr <- err
|
||||
return
|
||||
}
|
||||
agentErr <- nil
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{
|
||||
RequestID: "req1",
|
||||
Command: "echo ok",
|
||||
Timeout: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteCommand: %v", err)
|
||||
}
|
||||
if result == nil || !result.Success || result.Stdout != "ok" || result.ExitCode != 0 {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-agentDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("agent goroutine did not finish")
|
||||
}
|
||||
|
||||
if err := <-agentErr; err != nil {
|
||||
t.Fatalf("agent error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocket_ReconnectSameAgentIDClosesOldConnection(t *testing.T) {
|
||||
s := NewServer(nil)
|
||||
ts := newWSServer(t, s)
|
||||
defer ts.Close()
|
||||
|
||||
dial := func() *websocket.Conn {
|
||||
t.Helper()
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
c1 := dial()
|
||||
defer c1.Close()
|
||||
wsWriteMessage(t, c1, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Version: "1.2.3",
|
||||
Platform: "linux",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
_ = wsReadRegisteredPayload(t, c1)
|
||||
|
||||
c2 := dial()
|
||||
defer c2.Close()
|
||||
wsWriteMessage(t, c2, Message{
|
||||
Type: MsgTypeAgentRegister,
|
||||
Timestamp: time.Now(),
|
||||
Payload: AgentRegisterPayload{
|
||||
AgentID: "a1",
|
||||
Hostname: "host1",
|
||||
Version: "1.2.3",
|
||||
Platform: "linux",
|
||||
Token: "any",
|
||||
},
|
||||
})
|
||||
_ = wsReadRegisteredPayload(t, c2)
|
||||
|
||||
c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
_, _, err := c1.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatalf("expected old connection to be closed")
|
||||
}
|
||||
}
|
||||
|
|
@ -72,6 +72,17 @@ const defaultInterval = 30 * time.Second
|
|||
|
||||
var readFile = os.ReadFile
|
||||
|
||||
var (
|
||||
hostInfoWithContext = gohost.InfoWithContext
|
||||
hostUptimeWithContext = gohost.UptimeWithContext
|
||||
hostmetricsCollect = hostmetrics.Collect
|
||||
sensorsCollectLocal = sensors.CollectLocal
|
||||
sensorsParse = sensors.Parse
|
||||
mdadmCollectArrays = mdadm.CollectArrays
|
||||
cephCollect = ceph.Collect
|
||||
nowUTC = func() time.Time { return time.Now().UTC() }
|
||||
)
|
||||
|
||||
// New constructs a fully initialised host Agent.
|
||||
func New(cfg Config) (*Agent, error) {
|
||||
if cfg.Interval <= 0 {
|
||||
|
|
@ -107,7 +118,7 @@ func New(cfg Config) (*Agent, error) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
info, err := gohost.InfoWithContext(ctx)
|
||||
info, err := hostInfoWithContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch host info: %w", err)
|
||||
}
|
||||
|
|
@ -311,8 +322,8 @@ func (a *Agent) buildReport(ctx context.Context) (agentshost.Report, error) {
|
|||
collectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
uptime, _ := gohost.UptimeWithContext(collectCtx)
|
||||
snapshot, err := hostmetrics.Collect(collectCtx)
|
||||
uptime, _ := hostUptimeWithContext(collectCtx)
|
||||
snapshot, err := hostmetricsCollect(collectCtx)
|
||||
if err != nil {
|
||||
return agentshost.Report{}, fmt.Errorf("collect metrics: %w", err)
|
||||
}
|
||||
|
|
@ -361,7 +372,7 @@ func (a *Agent) buildReport(ctx context.Context) (agentshost.Report, error) {
|
|||
RAID: raidData,
|
||||
Ceph: cephData,
|
||||
Tags: append([]string(nil), a.cfg.Tags...),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Timestamp: nowUTC(),
|
||||
}
|
||||
|
||||
return report, nil
|
||||
|
|
@ -416,14 +427,14 @@ func (a *Agent) collectTemperatures(ctx context.Context) agentshost.Sensors {
|
|||
}
|
||||
|
||||
// Collect sensor JSON output
|
||||
jsonOutput, err := sensors.CollectLocal(ctx)
|
||||
jsonOutput, err := sensorsCollectLocal(ctx)
|
||||
if err != nil {
|
||||
a.logger.Debug().Err(err).Msg("Failed to collect sensor data (lm-sensors may not be installed)")
|
||||
return agentshost.Sensors{}
|
||||
}
|
||||
|
||||
// Parse the sensor output
|
||||
tempData, err := sensors.Parse(jsonOutput)
|
||||
tempData, err := sensorsParse(jsonOutput)
|
||||
if err != nil {
|
||||
a.logger.Debug().Err(err).Msg("Failed to parse sensor data")
|
||||
return agentshost.Sensors{}
|
||||
|
|
@ -476,7 +487,7 @@ func (a *Agent) collectRAIDArrays(ctx context.Context) []agentshost.RAIDArray {
|
|||
return nil
|
||||
}
|
||||
|
||||
arrays, err := mdadm.CollectArrays(ctx)
|
||||
arrays, err := mdadmCollectArrays(ctx)
|
||||
if err != nil {
|
||||
a.logger.Debug().Err(err).Msg("Failed to collect RAID array data (mdadm may not be installed)")
|
||||
return nil
|
||||
|
|
@ -499,7 +510,7 @@ func (a *Agent) collectCephStatus(ctx context.Context) *agentshost.CephCluster {
|
|||
return nil
|
||||
}
|
||||
|
||||
status, err := ceph.Collect(ctx)
|
||||
status, err := cephCollect(ctx)
|
||||
if err != nil {
|
||||
a.logger.Debug().Err(err).Msg("Failed to collect Ceph status")
|
||||
return nil
|
||||
|
|
|
|||
90
internal/hostagent/agent_flushbuffer_test.go
Normal file
90
internal/hostagent/agent_flushbuffer_test.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/buffer"
|
||||
agentshost "github.com/rcourtman/pulse-go-rewrite/pkg/agents/host"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestAgent_flushBuffer_StopsOnFailureAndDoesNotDropReport(t *testing.T) {
|
||||
var requestCount int32
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
n := atomic.AddInt32(&requestCount, 1)
|
||||
if n == 1 {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
a := &Agent{
|
||||
cfg: Config{APIToken: "token"},
|
||||
logger: zerolog.Nop(),
|
||||
httpClient: server.Client(),
|
||||
trimmedPulseURL: server.URL,
|
||||
reportBuffer: buffer.New[agentshost.Report](10),
|
||||
}
|
||||
|
||||
report1 := agentshost.Report{Agent: agentshost.AgentInfo{ID: "r1"}}
|
||||
report2 := agentshost.Report{Agent: agentshost.AgentInfo{ID: "r2"}}
|
||||
a.reportBuffer.Push(report1)
|
||||
a.reportBuffer.Push(report2)
|
||||
|
||||
a.flushBuffer(context.Background())
|
||||
|
||||
if got := a.reportBuffer.Len(); got != 1 {
|
||||
t.Fatalf("buffer len = %d, want %d", got, 1)
|
||||
}
|
||||
|
||||
peek, ok := a.reportBuffer.Peek()
|
||||
if !ok {
|
||||
t.Fatalf("expected buffered report")
|
||||
}
|
||||
if peek.Agent.ID != "r2" {
|
||||
t.Fatalf("remaining buffered report = %q, want %q", peek.Agent.ID, "r2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_flushBuffer_RetryAfterTransientFailure(t *testing.T) {
|
||||
var fail atomic.Bool
|
||||
fail.Store(true)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if fail.Load() {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
a := &Agent{
|
||||
cfg: Config{APIToken: "token"},
|
||||
logger: zerolog.Nop(),
|
||||
httpClient: server.Client(),
|
||||
trimmedPulseURL: server.URL,
|
||||
reportBuffer: buffer.New[agentshost.Report](10),
|
||||
}
|
||||
a.reportBuffer.Push(agentshost.Report{Agent: agentshost.AgentInfo{ID: "r1"}})
|
||||
a.reportBuffer.Push(agentshost.Report{Agent: agentshost.AgentInfo{ID: "r2"}})
|
||||
|
||||
a.flushBuffer(context.Background())
|
||||
if got := a.reportBuffer.Len(); got != 2 {
|
||||
t.Fatalf("buffer len after failure = %d, want %d", got, 2)
|
||||
}
|
||||
|
||||
fail.Store(false)
|
||||
a.flushBuffer(context.Background())
|
||||
if !a.reportBuffer.IsEmpty() {
|
||||
t.Fatalf("expected buffer to be empty, has %d items", a.reportBuffer.Len())
|
||||
}
|
||||
}
|
||||
|
||||
167
internal/hostagent/agent_new_test.go
Normal file
167
internal/hostagent/agent_new_test.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
gohost "github.com/shirou/gopsutil/v4/host"
|
||||
)
|
||||
|
||||
func TestNew_RequiresAPIToken(t *testing.T) {
|
||||
originalHostInfo := hostInfoWithContext
|
||||
t.Cleanup(func() { hostInfoWithContext = originalHostInfo })
|
||||
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
|
||||
return &gohost.InfoStat{Hostname: "host", HostID: "hid", KernelArch: runtime.GOARCH}, nil
|
||||
}
|
||||
|
||||
_, err := New(Config{APIToken: " ", LogLevel: zerolog.InfoLevel})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_NormalizesConfigAndTags(t *testing.T) {
|
||||
originalHostInfo := hostInfoWithContext
|
||||
originalReadFile := readFile
|
||||
t.Cleanup(func() {
|
||||
hostInfoWithContext = originalHostInfo
|
||||
readFile = originalReadFile
|
||||
})
|
||||
|
||||
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
|
||||
return &gohost.InfoStat{
|
||||
Hostname: " host-from-info ",
|
||||
HostID: " gopsutil-id ",
|
||||
Platform: "Darwin",
|
||||
PlatformFamily: "",
|
||||
PlatformVersion: "14.0",
|
||||
KernelVersion: "6.6.0",
|
||||
KernelArch: runtime.GOARCH,
|
||||
}, nil
|
||||
}
|
||||
|
||||
readFile = func(name string) ([]byte, error) {
|
||||
if name == "/etc/machine-id" {
|
||||
return []byte("0123456789abcdef0123456789abcdef\n"), nil
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
originalTags := []string{" tag-a ", "tag-a", "", " tag-b", "tag-b "}
|
||||
cfg := Config{
|
||||
PulseURL: "http://example.com///",
|
||||
APIToken: "token",
|
||||
Interval: 0,
|
||||
Tags: originalTags,
|
||||
InsecureSkipVerify: true,
|
||||
LogLevel: zerolog.InfoLevel,
|
||||
}
|
||||
|
||||
agent, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
if agent.interval != defaultInterval {
|
||||
t.Fatalf("interval = %v, want %v", agent.interval, defaultInterval)
|
||||
}
|
||||
if agent.trimmedPulseURL != "http://example.com" {
|
||||
t.Fatalf("trimmedPulseURL = %q, want %q", agent.trimmedPulseURL, "http://example.com")
|
||||
}
|
||||
if agent.hostname != "host-from-info" {
|
||||
t.Fatalf("hostname = %q, want %q", agent.hostname, "host-from-info")
|
||||
}
|
||||
if agent.displayName != "host-from-info" {
|
||||
t.Fatalf("displayName = %q, want %q", agent.displayName, "host-from-info")
|
||||
}
|
||||
if agent.platform != "macos" {
|
||||
t.Fatalf("platform = %q, want %q", agent.platform, "macos")
|
||||
}
|
||||
if got, want := agent.cfg.Tags, []string{"tag-a", "tag-b"}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] {
|
||||
t.Fatalf("tags = %#v, want %#v", got, want)
|
||||
}
|
||||
|
||||
// Ensure we don't retain aliasing to the caller-provided tags slice.
|
||||
originalTags[0] = "mutated"
|
||||
if agent.cfg.Tags[0] != "tag-a" {
|
||||
t.Fatalf("agent tags aliased caller slice: %#v", agent.cfg.Tags)
|
||||
}
|
||||
|
||||
httpTransport, ok := agent.httpClient.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", agent.httpClient.Transport)
|
||||
}
|
||||
if httpTransport.TLSClientConfig == nil || httpTransport.TLSClientConfig.MinVersion != tls.VersionTLS12 {
|
||||
t.Fatalf("TLSClientConfig MinVersion = %#v, want TLS1.2", httpTransport.TLSClientConfig)
|
||||
}
|
||||
if !httpTransport.TLSClientConfig.InsecureSkipVerify {
|
||||
t.Fatalf("expected InsecureSkipVerify=true")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
const want = "01234567-89ab-cdef-0123-456789abcdef"
|
||||
if agent.machineID != want {
|
||||
t.Fatalf("machineID = %q, want %q", agent.machineID, want)
|
||||
}
|
||||
if agent.agentID != want {
|
||||
t.Fatalf("agentID = %q, want %q", agent.agentID, want)
|
||||
}
|
||||
} else {
|
||||
if agent.machineID != "gopsutil-id" {
|
||||
t.Fatalf("machineID = %q, want %q", agent.machineID, "gopsutil-id")
|
||||
}
|
||||
if agent.agentID != "gopsutil-id" {
|
||||
t.Fatalf("agentID = %q, want %q", agent.agentID, "gopsutil-id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_DefaultPulseURL(t *testing.T) {
|
||||
originalHostInfo := hostInfoWithContext
|
||||
t.Cleanup(func() { hostInfoWithContext = originalHostInfo })
|
||||
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
|
||||
return &gohost.InfoStat{Hostname: "host", HostID: "hid", KernelArch: runtime.GOARCH}, nil
|
||||
}
|
||||
|
||||
agent, err := New(Config{PulseURL: "", APIToken: "token", LogLevel: zerolog.InfoLevel})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if agent.trimmedPulseURL != "http://localhost:7655" {
|
||||
t.Fatalf("trimmedPulseURL = %q, want %q", agent.trimmedPulseURL, "http://localhost:7655")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_FallsBackToHostnameWhenMachineIDEmpty(t *testing.T) {
|
||||
originalHostInfo := hostInfoWithContext
|
||||
originalReadFile := readFile
|
||||
t.Cleanup(func() {
|
||||
hostInfoWithContext = originalHostInfo
|
||||
readFile = originalReadFile
|
||||
})
|
||||
|
||||
hostInfoWithContext = func(context.Context) (*gohost.InfoStat, error) {
|
||||
return &gohost.InfoStat{
|
||||
Hostname: "host-from-info",
|
||||
HostID: "",
|
||||
KernelArch: runtime.GOARCH,
|
||||
}, nil
|
||||
}
|
||||
readFile = func(string) ([]byte, error) { return nil, os.ErrNotExist }
|
||||
|
||||
agent, err := New(Config{APIToken: "token", LogLevel: zerolog.InfoLevel})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if agent.machineID != "" {
|
||||
t.Fatalf("expected empty machineID, got %q", agent.machineID)
|
||||
}
|
||||
if agent.agentID != "host-from-info" {
|
||||
t.Fatalf("agentID = %q, want %q", agent.agentID, "host-from-info")
|
||||
}
|
||||
}
|
||||
96
internal/hostagent/agent_sensors_test.go
Normal file
96
internal/hostagent/agent_sensors_test.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/sensors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestAgent_collectTemperatures_MapsKeys(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("temperature collection currently only runs on linux")
|
||||
}
|
||||
|
||||
originalCollect := sensorsCollectLocal
|
||||
originalParse := sensorsParse
|
||||
t.Cleanup(func() {
|
||||
sensorsCollectLocal = originalCollect
|
||||
sensorsParse = originalParse
|
||||
})
|
||||
|
||||
sensorsCollectLocal = func(context.Context) (string, error) { return "{}", nil }
|
||||
sensorsParse = func(string) (*sensors.TemperatureData, error) {
|
||||
return &sensors.TemperatureData{
|
||||
Available: true,
|
||||
CPUPackage: 55.5,
|
||||
Cores: map[string]float64{
|
||||
"Core 0": 44,
|
||||
"Core 1": 45,
|
||||
},
|
||||
NVMe: map[string]float64{
|
||||
"nvme0": 40,
|
||||
},
|
||||
GPU: map[string]float64{
|
||||
"amdgpu-pci-0100": 60,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
a := &Agent{logger: zerolog.Nop()}
|
||||
|
||||
got := a.collectTemperatures(context.Background())
|
||||
want := map[string]float64{
|
||||
"cpu_package": 55.5,
|
||||
"cpu_core_0": 44,
|
||||
"cpu_core_1": 45,
|
||||
"nvme0": 40,
|
||||
"amdgpu-pci-0100": 60,
|
||||
}
|
||||
|
||||
if got.TemperatureCelsius == nil {
|
||||
t.Fatalf("expected TemperatureCelsius map to be initialised")
|
||||
}
|
||||
if len(got.TemperatureCelsius) != len(want) {
|
||||
t.Fatalf("temperature keys = %d, want %d", len(got.TemperatureCelsius), len(want))
|
||||
}
|
||||
for k, v := range want {
|
||||
if gotVal, ok := got.TemperatureCelsius[k]; !ok || gotVal != v {
|
||||
t.Fatalf("TemperatureCelsius[%q] = (%v, %v), want (%v, %v)", k, gotVal, ok, v, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_collectTemperatures_BestEffortFailuresReturnEmpty(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("temperature collection currently only runs on linux")
|
||||
}
|
||||
|
||||
originalCollect := sensorsCollectLocal
|
||||
originalParse := sensorsParse
|
||||
t.Cleanup(func() {
|
||||
sensorsCollectLocal = originalCollect
|
||||
sensorsParse = originalParse
|
||||
})
|
||||
|
||||
a := &Agent{logger: zerolog.Nop()}
|
||||
|
||||
sensorsCollectLocal = func(context.Context) (string, error) { return "", errors.New("no sensors") }
|
||||
if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 {
|
||||
t.Fatalf("expected empty sensors on collect error, got %#v", got.TemperatureCelsius)
|
||||
}
|
||||
|
||||
sensorsCollectLocal = func(context.Context) (string, error) { return "{}", nil }
|
||||
sensorsParse = func(string) (*sensors.TemperatureData, error) { return nil, errors.New("bad json") }
|
||||
if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 {
|
||||
t.Fatalf("expected empty sensors on parse error, got %#v", got.TemperatureCelsius)
|
||||
}
|
||||
|
||||
sensorsParse = func(string) (*sensors.TemperatureData, error) { return &sensors.TemperatureData{Available: false}, nil }
|
||||
if got := a.collectTemperatures(context.Background()); len(got.TemperatureCelsius) != 0 {
|
||||
t.Fatalf("expected empty sensors when unavailable, got %#v", got.TemperatureCelsius)
|
||||
}
|
||||
}
|
||||
|
|
@ -361,6 +361,16 @@ func (c *CommandClient) handleExecuteCommand(ctx context.Context, conn *websocke
|
|||
}
|
||||
}
|
||||
|
||||
func wrapCommand(payload executeCommandPayload) string {
|
||||
if payload.TargetType == "container" && payload.TargetID != "" {
|
||||
return fmt.Sprintf("pct exec %s -- %s", payload.TargetID, payload.Command)
|
||||
}
|
||||
if payload.TargetType == "vm" && payload.TargetID != "" {
|
||||
return fmt.Sprintf("qm guest exec %s -- %s", payload.TargetID, payload.Command)
|
||||
}
|
||||
return payload.Command
|
||||
}
|
||||
|
||||
func (c *CommandClient) executeCommand(ctx context.Context, payload executeCommandPayload) commandResultPayload {
|
||||
result := commandResultPayload{
|
||||
RequestID: payload.RequestID,
|
||||
|
|
@ -375,17 +385,7 @@ func (c *CommandClient) executeCommand(ctx context.Context, payload executeComma
|
|||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build the command based on target type
|
||||
command := payload.Command
|
||||
|
||||
// If targeting a container or VM, wrap the command
|
||||
if payload.TargetType == "container" && payload.TargetID != "" {
|
||||
// Use pct exec for LXC containers
|
||||
command = fmt.Sprintf("pct exec %s -- %s", payload.TargetID, payload.Command)
|
||||
} else if payload.TargetType == "vm" && payload.TargetID != "" {
|
||||
// Use qm guest exec for VMs (requires QEMU guest agent)
|
||||
command = fmt.Sprintf("qm guest exec %s -- %s", payload.TargetID, payload.Command)
|
||||
}
|
||||
command := wrapCommand(payload)
|
||||
|
||||
// Execute the command
|
||||
var cmd *exec.Cmd
|
||||
|
|
@ -405,13 +405,13 @@ func (c *CommandClient) executeCommand(ctx context.Context, payload executeComma
|
|||
result.Stderr = stderr.String()
|
||||
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
result.Success = false
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
result.Error = "command timed out"
|
||||
result.ExitCode = -1
|
||||
result.Success = false
|
||||
} else if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
result.Success = false
|
||||
} else {
|
||||
result.Error = err.Error()
|
||||
result.ExitCode = -1
|
||||
|
|
|
|||
132
internal/hostagent/commands_connect_test.go
Normal file
132
internal/hostagent/commands_connect_test.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestCommandClient_connectAndHandle_ExecutesCommandAndReturnsResult(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping websocket integration test in short mode")
|
||||
}
|
||||
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
gotResult := make(chan commandResultPayload, 1)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
var regMsg wsMessage
|
||||
if err := conn.ReadJSON(®Msg); err != nil {
|
||||
t.Errorf("read registration: %v", err)
|
||||
return
|
||||
}
|
||||
if regMsg.Type != msgTypeAgentRegister {
|
||||
t.Errorf("registration type = %q, want %q", regMsg.Type, msgTypeAgentRegister)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
registeredBytes, _ := json.Marshal(registeredPayload{Success: true, Message: "Registered"})
|
||||
if err := conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: registeredBytes}); err != nil {
|
||||
t.Errorf("write registered: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
execPayloadBytes, _ := json.Marshal(executeCommandPayload{
|
||||
RequestID: "req-1",
|
||||
Command: "echo hello",
|
||||
TargetType: "host",
|
||||
Timeout: 5,
|
||||
})
|
||||
if err := conn.WriteJSON(wsMessage{Type: msgTypeExecuteCmd, ID: "req-1", Timestamp: time.Now(), Payload: execPayloadBytes}); err != nil {
|
||||
t.Errorf("write execute_command: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
var resultMsg wsMessage
|
||||
if err := conn.ReadJSON(&resultMsg); err != nil {
|
||||
t.Errorf("read command_result: %v", err)
|
||||
return
|
||||
}
|
||||
if resultMsg.Type != msgTypeCommandResult {
|
||||
t.Errorf("result type = %q, want %q", resultMsg.Type, msgTypeCommandResult)
|
||||
return
|
||||
}
|
||||
|
||||
var result commandResultPayload
|
||||
if err := json.Unmarshal(resultMsg.Payload, &result); err != nil {
|
||||
t.Errorf("unmarshal command_result payload: %v", err)
|
||||
return
|
||||
}
|
||||
gotResult <- result
|
||||
|
||||
<-serverDone
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
client := &CommandClient{
|
||||
pulseURL: strings.TrimRight(server.URL, "/"),
|
||||
apiToken: "token",
|
||||
agentID: "agent-1",
|
||||
hostname: "host-1",
|
||||
platform: "linux",
|
||||
version: "1.2.3",
|
||||
logger: zerolog.Nop(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- client.connectAndHandle(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-gotResult:
|
||||
if result.RequestID != "req-1" {
|
||||
t.Fatalf("result.RequestID = %q, want %q", result.RequestID, "req-1")
|
||||
}
|
||||
if !result.Success || result.ExitCode != 0 {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
if !strings.Contains(result.Stdout, "hello") {
|
||||
t.Fatalf("stdout = %q, expected to contain %q", result.Stdout, "hello")
|
||||
}
|
||||
|
||||
cancel()
|
||||
close(serverDone)
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("timed out waiting for command result")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err == nil {
|
||||
t.Fatalf("expected error due to context cancellation")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("timed out waiting for connectAndHandle to return")
|
||||
}
|
||||
}
|
||||
|
||||
151
internal/hostagent/commands_execute_test.go
Normal file
151
internal/hostagent/commands_execute_test.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWrapCommand_TargetWrapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload executeCommandPayload
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "host command unchanged",
|
||||
payload: executeCommandPayload{
|
||||
Command: "echo ok",
|
||||
TargetType: "host",
|
||||
},
|
||||
want: "echo ok",
|
||||
},
|
||||
{
|
||||
name: "container wraps with pct",
|
||||
payload: executeCommandPayload{
|
||||
Command: "echo ok",
|
||||
TargetType: "container",
|
||||
TargetID: "101",
|
||||
},
|
||||
want: "pct exec 101 -- echo ok",
|
||||
},
|
||||
{
|
||||
name: "vm wraps with qm guest exec",
|
||||
payload: executeCommandPayload{
|
||||
Command: "echo ok",
|
||||
TargetType: "vm",
|
||||
TargetID: "900",
|
||||
},
|
||||
want: "qm guest exec 900 -- echo ok",
|
||||
},
|
||||
{
|
||||
name: "missing target id does not wrap",
|
||||
payload: executeCommandPayload{
|
||||
Command: "echo ok",
|
||||
TargetType: "container",
|
||||
TargetID: "",
|
||||
},
|
||||
want: "echo ok",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := wrapCommand(tt.payload); got != tt.want {
|
||||
t.Fatalf("wrapCommand() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_executeCommand_Success(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("executeCommand uses different shell on windows")
|
||||
}
|
||||
|
||||
c := &CommandClient{}
|
||||
result := c.executeCommand(context.Background(), executeCommandPayload{
|
||||
RequestID: "r1",
|
||||
Command: "echo hello",
|
||||
Timeout: 5,
|
||||
})
|
||||
if !result.Success || result.ExitCode != 0 {
|
||||
t.Fatalf("expected success, got %#v", result)
|
||||
}
|
||||
if !strings.Contains(result.Stdout, "hello") {
|
||||
t.Fatalf("stdout = %q, expected to contain %q", result.Stdout, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_executeCommand_NonZeroExitCode(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("executeCommand uses different shell on windows")
|
||||
}
|
||||
|
||||
c := &CommandClient{}
|
||||
result := c.executeCommand(context.Background(), executeCommandPayload{
|
||||
RequestID: "r1",
|
||||
Command: "echo err 1>&2; exit 3",
|
||||
Timeout: 5,
|
||||
})
|
||||
if result.Success {
|
||||
t.Fatalf("expected failure, got %#v", result)
|
||||
}
|
||||
if result.ExitCode != 3 {
|
||||
t.Fatalf("exit code = %d, want %d", result.ExitCode, 3)
|
||||
}
|
||||
if !strings.Contains(result.Stderr, "err") {
|
||||
t.Fatalf("stderr = %q, expected to contain %q", result.Stderr, "err")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_executeCommand_TimeoutSetsError(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("executeCommand uses different shell on windows")
|
||||
}
|
||||
|
||||
c := &CommandClient{}
|
||||
start := time.Now()
|
||||
result := c.executeCommand(context.Background(), executeCommandPayload{
|
||||
RequestID: "r1",
|
||||
Command: "sleep 2",
|
||||
Timeout: 1,
|
||||
})
|
||||
if time.Since(start) > 3*time.Second {
|
||||
t.Fatalf("timeout path took too long: %v", time.Since(start))
|
||||
}
|
||||
if result.Success {
|
||||
t.Fatalf("expected failure, got %#v", result)
|
||||
}
|
||||
if result.ExitCode != -1 {
|
||||
t.Fatalf("exit code = %d, want %d", result.ExitCode, -1)
|
||||
}
|
||||
if result.Error != "command timed out" {
|
||||
t.Fatalf("error = %q, want %q", result.Error, "command timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_executeCommand_TruncatesLargeOutput(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("executeCommand uses different shell on windows")
|
||||
}
|
||||
|
||||
c := &CommandClient{}
|
||||
result := c.executeCommand(context.Background(), executeCommandPayload{
|
||||
RequestID: "r1",
|
||||
Command: "head -c 1048580 /dev/zero | tr '\\0' 'a'",
|
||||
Timeout: 5,
|
||||
})
|
||||
if !result.Success {
|
||||
t.Fatalf("expected success, got %#v", result)
|
||||
}
|
||||
if !strings.Contains(result.Stdout, "(output truncated)") {
|
||||
t.Fatalf("expected truncation marker, got stdout len=%d", len(result.Stdout))
|
||||
}
|
||||
if len(result.Stdout) > 1024*1024+64 {
|
||||
t.Fatalf("stdout len=%d, expected <= %d", len(result.Stdout), 1024*1024+64)
|
||||
}
|
||||
}
|
||||
|
||||
180
internal/hostagent/commands_registration_test.go
Normal file
180
internal/hostagent/commands_registration_test.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func wsURLForHTTP(serverURL string) string {
|
||||
return "ws" + strings.TrimPrefix(serverURL, "http")
|
||||
}
|
||||
|
||||
func TestCommandClient_sendRegistration_WritesExpectedPayload(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
|
||||
gotMsgCh := make(chan wsMessage, 1)
|
||||
gotPayloadCh := make(chan registerPayload, 1)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
var msg wsMessage
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
t.Errorf("ReadJSON: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var payload registerPayload
|
||||
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
|
||||
t.Errorf("unmarshal payload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
gotMsgCh <- msg
|
||||
gotPayloadCh <- payload
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := &CommandClient{
|
||||
apiToken: "token-1",
|
||||
agentID: "agent-1",
|
||||
hostname: "host-1",
|
||||
platform: "linux",
|
||||
version: "1.2.3",
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
|
||||
if err := client.sendRegistration(conn); err != nil {
|
||||
t.Fatalf("sendRegistration: %v", err)
|
||||
}
|
||||
|
||||
msg := <-gotMsgCh
|
||||
if msg.Type != msgTypeAgentRegister {
|
||||
t.Fatalf("msg.Type = %q, want %q", msg.Type, msgTypeAgentRegister)
|
||||
}
|
||||
if msg.Timestamp.IsZero() {
|
||||
t.Fatalf("expected non-zero Timestamp")
|
||||
}
|
||||
|
||||
payload := <-gotPayloadCh
|
||||
if payload.AgentID != "agent-1" {
|
||||
t.Fatalf("payload.AgentID = %q, want %q", payload.AgentID, "agent-1")
|
||||
}
|
||||
if payload.Hostname != "host-1" {
|
||||
t.Fatalf("payload.Hostname = %q, want %q", payload.Hostname, "host-1")
|
||||
}
|
||||
if payload.Platform != "linux" {
|
||||
t.Fatalf("payload.Platform = %q, want %q", payload.Platform, "linux")
|
||||
}
|
||||
if payload.Version != "1.2.3" {
|
||||
t.Fatalf("payload.Version = %q, want %q", payload.Version, "1.2.3")
|
||||
}
|
||||
if payload.Token != "token-1" {
|
||||
t.Fatalf("payload.Token = %q, want %q", payload.Token, "token-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_waitForRegistration_AcceptsSuccess(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
payload, _ := json.Marshal(registeredPayload{Success: true, Message: "Registered"})
|
||||
_ = conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: payload})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := &CommandClient{logger: zerolog.Nop()}
|
||||
if err := client.waitForRegistration(conn); err != nil {
|
||||
t.Fatalf("waitForRegistration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_waitForRegistration_RejectsFailure(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
payload, _ := json.Marshal(registeredPayload{Success: false, Message: "Invalid token"})
|
||||
_ = conn.WriteJSON(wsMessage{Type: msgTypeRegistered, Timestamp: time.Now(), Payload: payload})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := &CommandClient{logger: zerolog.Nop()}
|
||||
if err := client.waitForRegistration(conn); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandClient_waitForRegistration_UnexpectedMessageType(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
_ = conn.WriteJSON(wsMessage{Type: msgTypePong, Timestamp: time.Now()})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(server.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := &CommandClient{logger: zerolog.Nop()}
|
||||
if err := client.waitForRegistration(conn); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue