package agentexec import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gorilla/websocket" ) type noHijackResponseWriter struct { header http.Header } func (w *noHijackResponseWriter) Header() http.Header { return w.header } func (w *noHijackResponseWriter) Write([]byte) (int, error) { return 0, nil } func (w *noHijackResponseWriter) WriteHeader(int) {} func newConnPair(t *testing.T) (*websocket.Conn, *websocket.Conn, func()) { t.Helper() serverConnCh := make(chan *websocket.Conn, 1) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Errorf("upgrade: %v", err) return } serverConnCh <- conn })) clientConn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { ts.Close() t.Fatalf("Dial: %v", err) } var serverConn *websocket.Conn select { case serverConn = <-serverConnCh: case <-time.After(2 * time.Second): clientConn.Close() ts.Close() t.Fatal("timed out waiting for server connection") } cleanup := func() { clientConn.Close() serverConn.Close() ts.Close() } return serverConn, clientConn, cleanup } func TestHandleWebSocket_UpgradeFailureAndDeadlineErrors(t *testing.T) { s := NewServer(nil) req := httptest.NewRequest(http.MethodGet, "http://example/ws", nil) s.HandleWebSocket(&noHijackResponseWriter{header: make(http.Header)}, req) } func TestHandleWebSocket_RegistrationReadError(t *testing.T) { s := NewServer(nil) ts := newWSServer(t, s) defer ts.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { t.Fatalf("Dial: %v", err) } conn.Close() } func TestHandleWebSocket_RegistrationMessageJSONError(t *testing.T) { s := NewServer(nil) ts := newWSServer(t, s) defer ts.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { t.Fatalf("Dial: %v", err) } defer conn.Close() if err := conn.WriteMessage(websocket.TextMessage, []byte("{")); err != nil { t.Fatalf("WriteMessage: %v", err) } conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) if _, _, err := conn.ReadMessage(); err == nil { t.Fatalf("expected server to close on invalid JSON") } } func TestHandleWebSocket_RegistrationPayloadMarshalError(t *testing.T) { orig := jsonMarshal t.Cleanup(func() { jsonMarshal = orig }) jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("boom") } s := NewServer(nil) ts := newWSServer(t, s) defer ts.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { t.Fatalf("Dial: %v", err) } defer conn.Close() wsWriteMessage(t, conn, Message{ Type: MsgTypeAgentRegister, Timestamp: time.Now(), Payload: AgentRegisterPayload{ AgentID: "a1", Hostname: "host1", Token: "any", }, }) conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) if _, _, err := conn.ReadMessage(); err == nil { t.Fatalf("expected server to close on marshal error") } } func TestHandleWebSocket_RegistrationPayloadUnmarshalError(t *testing.T) { s := NewServer(nil) ts := newWSServer(t, s) defer ts.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { t.Fatalf("Dial: %v", err) } defer conn.Close() if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"agent_register","payload":"oops"}`)); err != nil { t.Fatalf("WriteMessage: %v", err) } conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) if _, _, err := conn.ReadMessage(); err == nil { t.Fatalf("expected server to close on invalid payload") } } func TestHandleWebSocket_PongHandler(t *testing.T) { s := NewServer(nil) ts := newWSServer(t, s) defer ts.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { t.Fatalf("Dial: %v", err) } defer conn.Close() wsWriteMessage(t, conn, Message{ Type: MsgTypeAgentRegister, Timestamp: time.Now(), Payload: AgentRegisterPayload{ AgentID: "a1", Hostname: "host1", Token: "any", }, }) _ = wsReadRegisteredPayload(t, conn) if err := conn.WriteControl(websocket.PongMessage, []byte("pong"), time.Now().Add(time.Second)); err != nil { t.Fatalf("WriteControl pong: %v", err) } conn.Close() waitFor(t, 2*time.Second, func() bool { return !s.IsAgentConnected("a1") }) } func TestReadLoopDone(t *testing.T) { s := NewServer(nil) serverConn, clientConn, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } close(ac.done) s.mu.Lock() s.agents["a1"] = ac s.mu.Unlock() s.readLoop(ac) if s.IsAgentConnected("a1") { t.Fatalf("expected agent to be removed") } clientConn.Close() } func TestReadLoopUnexpectedCloseError(t *testing.T) { s := NewServer(nil) serverConn, clientConn, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } s.mu.Lock() s.agents["a1"] = ac s.mu.Unlock() done := make(chan struct{}) go func() { s.readLoop(ac) close(done) }() _ = clientConn.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseProtocolError, "bye"), time.Now().Add(time.Second), ) select { case <-done: case <-time.After(2 * time.Second): t.Fatalf("readLoop did not exit") } } func TestReadLoopCommandResultBranches(t *testing.T) { s := NewServer(nil) serverConn, clientConn, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } s.mu.Lock() s.agents["a1"] = ac s.pendingReqs["req-full"] = make(chan CommandResultPayload) s.mu.Unlock() done := make(chan struct{}) go func() { s.readLoop(ac) close(done) }() _ = clientConn.WriteMessage(websocket.TextMessage, []byte("{")) _ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":123}}`)) _ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-full","success":true}}`)) _ = clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"command_result","payload":{"request_id":"req-missing","success":true}}`)) clientConn.Close() select { case <-done: case <-time.After(2 * time.Second): t.Fatalf("readLoop did not exit") } s.mu.Lock() delete(s.pendingReqs, "req-full") s.mu.Unlock() } func TestPingLoopSuccessAndStop(t *testing.T) { origInterval := pingInterval t.Cleanup(func() { pingInterval = origInterval }) pingInterval = 5 * time.Millisecond s := NewServer(nil) serverConn, _, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } stop := make(chan struct{}) exited := make(chan struct{}) go func() { s.pingLoop(ac, stop) close(exited) }() time.Sleep(2 * pingInterval) close(stop) select { case <-exited: case <-time.After(2 * time.Second): t.Fatalf("pingLoop did not exit") } } func TestPingLoopFailuresClose(t *testing.T) { origInterval := pingInterval t.Cleanup(func() { pingInterval = origInterval }) pingInterval = 5 * time.Millisecond s := NewServer(nil) serverConn, _, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } serverConn.Close() stop := make(chan struct{}) exited := make(chan struct{}) go func() { s.pingLoop(ac, stop) close(exited) }() select { case <-exited: case <-time.After(2 * time.Second): t.Fatalf("pingLoop did not exit after failures") } } func TestSendMessageMarshalError(t *testing.T) { s := NewServer(nil) if err := s.sendMessage(nil, Message{Payload: make(chan int)}); err == nil { t.Fatalf("expected marshal error") } } func TestExecuteCommandSendError(t *testing.T) { s := NewServer(nil) serverConn, _, cleanup := newConnPair(t) defer cleanup() serverConn.Close() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } s.mu.Lock() s.agents["a1"] = ac s.mu.Unlock() _, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r1", Timeout: 1}) if err == nil { t.Fatalf("expected send error") } } func TestExecuteCommandTimeoutAndCancel(t *testing.T) { s := NewServer(nil) serverConn, _, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } s.mu.Lock() s.agents["a1"] = ac s.mu.Unlock() _, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-timeout", Timeout: 1}) if err == nil || !strings.Contains(err.Error(), "timed out") { t.Fatalf("expected timeout error, got %v", err) } ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = s.ExecuteCommand(ctx, "a1", ExecuteCommandPayload{RequestID: "r-cancel", Timeout: 1}) if err == nil { t.Fatalf("expected cancel error") } } func TestExecuteCommandDefaultTimeout(t *testing.T) { s := NewServer(nil) serverConn, _, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } s.mu.Lock() s.agents["a1"] = ac s.mu.Unlock() go func() { for { s.mu.RLock() ch := s.pendingReqs["r-default"] s.mu.RUnlock() if ch != nil { ch <- CommandResultPayload{RequestID: "r-default", Success: true} return } time.Sleep(2 * time.Millisecond) } }() result, err := s.ExecuteCommand(context.Background(), "a1", ExecuteCommandPayload{RequestID: "r-default"}) if err != nil || result == nil || !result.Success { t.Fatalf("expected success, got result=%v err=%v", result, err) } } func TestReadFileRoundTrip(t *testing.T) { s := NewServer(nil) ts := newWSServer(t, s) defer ts.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(ts.URL), nil) if err != nil { t.Fatalf("Dial: %v", err) } defer conn.Close() wsWriteMessage(t, conn, Message{ Type: MsgTypeAgentRegister, Timestamp: time.Now(), Payload: AgentRegisterPayload{ AgentID: "a1", Hostname: "host1", Token: "any", }, }) _ = wsReadRegisteredPayload(t, conn) agentDone := make(chan error, 1) go func() { for { msg, err := wsReadRawMessageWithTimeout(conn, 2*time.Second) if err != nil { agentDone <- err return } if msg.Type != MsgTypeReadFile || msg.Payload == nil { continue } var payload ReadFilePayload if err := json.Unmarshal(*msg.Payload, &payload); err != nil { agentDone <- err return } agentDone <- conn.WriteJSON(Message{ Type: MsgTypeCommandResult, Timestamp: time.Now(), Payload: CommandResultPayload{ RequestID: payload.RequestID, Success: true, Stdout: "data", ExitCode: 0, }, }) return } }() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() result, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-1", Path: "/etc/hosts"}) if err != nil || result == nil || result.Stdout != "data" { t.Fatalf("unexpected read file result=%v err=%v", result, err) } if err := <-agentDone; err != nil { t.Fatalf("agent error: %v", err) } } func TestReadFileTimeoutCancelAndSendError(t *testing.T) { origTimeout := readFileTimeout t.Cleanup(func() { readFileTimeout = origTimeout }) readFileTimeout = 10 * time.Millisecond s := NewServer(nil) serverConn, _, cleanup := newConnPair(t) defer cleanup() ac := &agentConn{ conn: serverConn, agent: ConnectedAgent{AgentID: "a1"}, done: make(chan struct{}), } s.mu.Lock() s.agents["a1"] = ac s.mu.Unlock() if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-timeout"}); err == nil { t.Fatalf("expected timeout error") } ctx, cancel := context.WithCancel(context.Background()) cancel() if _, err := s.ReadFile(ctx, "a1", ReadFilePayload{RequestID: "read-cancel"}); err == nil { t.Fatalf("expected cancel error") } serverConn.Close() if _, err := s.ReadFile(context.Background(), "a1", ReadFilePayload{RequestID: "read-send"}); err == nil { t.Fatalf("expected send error") } }