Pulse/internal/agentexec/server_coverage_test.go
2026-03-18 16:06:30 +00:00

771 lines
18 KiB
Go

package agentexec
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type noHijackResponseWriter struct {
header http.Header
}
func (w *noHijackResponseWriter) Header() http.Header {
return w.header
}
func (w *noHijackResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *noHijackResponseWriter) WriteHeader(int) {}
func newConnPair(t *testing.T) (*websocket.Conn, *websocket.Conn, func()) {
t.Helper()
serverConnCh := make(chan *websocket.Conn, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade: %v", err)
return
}
serverConnCh <- conn
}))
clientConn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
ts.Close()
t.Fatalf("Dial: %v", err)
}
var serverConn *websocket.Conn
select {
case serverConn = <-serverConnCh:
case <-time.After(2 * time.Second):
clientConn.Close()
ts.Close()
t.Fatal("timed out waiting for server connection")
}
cleanup := func() {
clientConn.Close()
serverConn.Close()
ts.Close()
}
return serverConn, clientConn, cleanup
}
func TestHandleWebSocket_UpgradeFailureAndDeadlineErrors(t *testing.T) {
s := NewServer(nil)
req := httptest.NewRequest(http.MethodGet, "http://example/ws", nil)
s.HandleWebSocket(&noHijackResponseWriter{header: make(http.Header)}, req)
}
func TestHandleWebSocket_RegistrationReadError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
conn.Close()
}
func TestHandleWebSocket_RegistrationMessageJSONError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte("{")); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on invalid JSON")
}
}
func TestHandleWebSocket_RegistrationPayloadMissing(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, mustNewMessage(t, MsgTypeAgentRegister, "", nil))
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on missing payload")
}
}
func TestHandleWebSocket_RegistrationPayloadUnmarshalError(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"agent_register","payload":"oops"}`)); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close on invalid payload")
}
}
func TestHandleWebSocket_InvalidTokenRejectionSendFailure(t *testing.T) {
origWriteTextMessage := writeTextMessage
t.Cleanup(func() { writeTextMessage = origWriteTextMessage })
writeTextMessage = func(*websocket.Conn, []byte) error {
return errors.New("write failure")
}
s := NewServer(func(string, string) bool { return false })
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, mustNewMessage(t, MsgTypeAgentRegister, "", AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "bad",
}))
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected server to close connection after rejection send failure")
}
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestHandleWebSocket_RegistrationAckSendFailure(t *testing.T) {
origWriteTextMessage := writeTextMessage
t.Cleanup(func() { writeTextMessage = origWriteTextMessage })
writeTextMessage = func(*websocket.Conn, []byte) error {
return errors.New("write failure")
}
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, mustNewMessage(t, MsgTypeAgentRegister, "", AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
}))
waitFor(t, 2*time.Second, func() bool { return s.IsAgentConnected("a1") })
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected no registration ack when send fails")
}
conn.Close()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestHandleWebSocket_PongHandler(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, mustNewMessage(t, MsgTypeAgentRegister, "", AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
}))
_ = wsReadRegisteredPayload(t, conn)
if err := conn.WriteControl(websocket.PongMessage, []byte("pong"), time.Now().Add(time.Second)); err != nil {
t.Fatalf("WriteControl pong: %v", err)
}
conn.Close()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
}
func TestReadLoopDone(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
close(ac.done)
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
s.readLoop(ac)
if s.IsAgentConnected("a1") {
t.Fatalf("expected agent to be removed")
}
clientConn.Close()
}
func TestReadLoopUnexpectedCloseError(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
_ = clientConn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseProtocolError, "bye"),
time.Now().Add(time.Second),
)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
}
func TestReadLoopCommandResultBranches(t *testing.T) {
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.pendingReqs[pendingRequestKey("a1", "req-full")] = make(chan CommandResultPayload)
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
_ = clientConn.WriteMessage(websocket.TextMessage, []byte("{"))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":123}}`))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-full","success":true}}`))
_ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-missing","success":true}}`))
clientConn.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
s.mu.Lock()
delete(s.pendingReqs, pendingRequestKey("a1", "req-full"))
s.mu.Unlock()
}
func TestReadLoopAgentPingPongSendFailure(t *testing.T) {
origWriteTextMessage := writeTextMessage
t.Cleanup(func() { writeTextMessage = origWriteTextMessage })
writeTextMessage = func(*websocket.Conn, []byte) error {
return errors.New("write failure")
}
s := NewServer(nil)
serverConn, clientConn, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
foreignCh := make(chan CommandResultPayload, 1)
s.mu.Lock()
s.agents["a1"] = ac
s.pendingReqs[pendingRequestKey("a2", "req-shared")] = foreignCh
s.mu.Unlock()
done := make(chan struct{})
go func() {
s.readLoop(ac)
close(done)
}()
wsWriteMessage(t, clientConn, Message{
Type: MsgTypeAgentPing,
Timestamp: time.Now(),
})
clientConn.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("readLoop did not exit")
}
select {
case <-foreignCh:
t.Fatalf("expected result not to be delivered across agents")
default:
}
}
func TestPingLoopSuccessAndStop(t *testing.T) {
origInterval := pingInterval
t.Cleanup(func() { pingInterval = origInterval })
pingInterval = 5 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
stop := make(chan struct{})
exited := make(chan struct{})
go func() {
s.pingLoop(ac, stop)
close(exited)
}()
time.Sleep(2 * pingInterval)
close(stop)
select {
case <-exited:
case <-time.After(2 * time.Second):
t.Fatalf("pingLoop did not exit")
}
}
func TestPingLoopFailuresClose(t *testing.T) {
origInterval := pingInterval
t.Cleanup(func() { pingInterval = origInterval })
pingInterval = 5 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
serverConn.Close()
stop := make(chan struct{})
exited := make(chan struct{})
go func() {
s.pingLoop(ac, stop)
close(exited)
}()
select {
case <-exited:
case <-time.After(2 * time.Second):
t.Fatalf("pingLoop did not exit after failures")
}
}
func TestSendMessageMarshalError(t *testing.T) {
s := NewServer(nil)
if err := s.sendMessage(nil, Message{Payload: json.RawMessage("{")}); err == nil {
t.Fatalf("expected marshal error")
}
}
func TestExecuteCommandSendError(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
serverConn.Close()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{
RequestID: "r1",
Command: "echo ok",
Timeout: 1,
})
if err == nil {
t.Fatalf("expected send error")
}
}
func TestExecuteCommandTimeoutAndCancel(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{
RequestID: "r-timeout",
Command: "echo ok",
Timeout: 1,
})
if err == nil || !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error, got %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{
RequestID: "r-cancel",
Command: "echo ok",
Timeout: 1,
})
if err == nil {
t.Fatalf("expected cancel error")
}
}
func TestExecuteCommandDefaultTimeout(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
go func() {
for {
s.mu.RLock()
ch := s.pendingReqs[pendingRequestKey("a1", "r-default")]
s.mu.RUnlock()
if ch != nil {
ch <- CommandResultPayload{RequestID: "r-default", Success: true}
return
}
time.Sleep(2 * time.Millisecond)
}
}()
result, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{
RequestID: "r-default",
Command: "echo ok",
})
if err != nil || result == nil || !result.Success {
t.Fatalf("expected success, got result=%v err=%v", result, err)
}
}
func TestReadFileRoundTrip(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, mustNewMessage(t, MsgTypeAgentRegister, "", AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
}))
_ = wsReadRegisteredPayload(t, conn)
agentDone := make(chan error, 1)
go func() {
for {
msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second)
if err != nil {
agentDone <- err
return
}
if msg.Type != MsgTypeReadFile || msg.Payload == nil {
continue
}
var payload ReadFilePayload
if err := json.Unmarshal(*msg.Payload, &payload); err != nil {
agentDone <- err
return
}
agentDone <- conn.WriteJSON(mustNewMessage(t, MsgTypeCommandResult, "", CommandResultPayload{
RequestID: payload.RequestID,
Success: true,
Stdout: "data",
ExitCode: 0,
}))
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-1", Path: "/etc/hosts"})
if err != nil || result == nil || result.Stdout != "data" {
t.Fatalf("unexpected read file result=%v err=%v", result, err)
}
if err := <-agentDone; err != nil {
t.Fatalf("agent error: %v", err)
}
}
func TestReadFileTimeoutCancelAndSendError(t *testing.T) {
origTimeout := readFileTimeout
t.Cleanup(func() { readFileTimeout = origTimeout })
readFileTimeout = 10 * time.Millisecond
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{
RequestID: "read-timeout",
Path: "/etc/hosts",
}); err == nil {
t.Fatalf("expected timeout error")
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := s.ReadFile(ctx, "a1", ReadFilePayload{
RequestID: "read-cancel",
Path: "/etc/hosts",
}); err == nil {
t.Fatalf("expected cancel error")
}
serverConn.Close()
if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{
RequestID: "read-send",
Path: "/etc/hosts",
}); err == nil {
t.Fatalf("expected send error")
}
}
func TestShutdownRejectsNewWebSocketConnections(t *testing.T) {
s := NewServer(nil)
s.Shutdown()
ts := newWSServer(t, s)
defer ts.Close()
conn, resp, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if conn != nil {
conn.Close()
}
if err == nil {
t.Fatalf("expected websocket dial to fail during shutdown")
}
if resp == nil {
t.Fatalf("expected HTTP response when dial fails")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusServiceUnavailable)
}
}
func TestShutdownClosesActiveConnectionsAndIsIdempotent(t *testing.T) {
s := NewServer(nil)
ts := newWSServer(t, s)
defer ts.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
wsWriteMessage(t, conn, mustNewMessage(t, MsgTypeAgentRegister, "", AgentRegisterPayload{
AgentID: "a1",
Hostname: "host1",
Token: "any",
}))
_ = wsReadRegisteredPayload(t, conn)
if !s.IsAgentConnected("a1") {
t.Fatalf("expected agent to be connected")
}
s.Shutdown()
s.Shutdown()
waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") })
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
if _, _, err := conn.ReadMessage(); err == nil {
t.Fatalf("expected connection to be closed after shutdown")
}
}
func TestExecuteCommandAndReadFileReturnShutdownError(t *testing.T) {
t.Run("execute_command", func(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
errCh := make(chan error, 1)
go func() {
_, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-shutdown", Command: "echo test", Timeout: 60})
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
s.Shutdown()
select {
case err := <-errCh:
if !errors.Is(err, errServerShuttingDown) {
t.Fatalf("expected shutdown error, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("execute command did not unblock on shutdown")
}
})
t.Run("read_file", func(t *testing.T) {
s := NewServer(nil)
serverConn, _, cleanup := newConnPair(t)
defer cleanup()
ac := &agentConn{
conn: serverConn,
agent: ConnectedAgent{AgentID: "a1"},
done: make(chan struct{}),
}
s.mu.Lock()
s.agents["a1"] = ac
s.mu.Unlock()
errCh := make(chan error, 1)
go func() {
_, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-shutdown", Path: "/tmp/test"})
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
s.Shutdown()
select {
case err := <-errCh:
if !errors.Is(err, errServerShuttingDown) {
t.Fatalf("expected shutdown error, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("read file did not unblock on shutdown")
}
})
}