Pulse/internal/ai/tools/control_run_command_test.go
2026-03-19 14:56:02 +00:00

378 lines
13 KiB
Go

package tools
import (
"context"
"encoding/json"
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/agentexec"
"github.com/rcourtman/pulse-go-rewrite/internal/ai/approval"
"github.com/rcourtman/pulse-go-rewrite/internal/models"
"github.com/rcourtman/pulse-go-rewrite/internal/unifiedresources"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func mustParseJSONMap(t *testing.T, text string) map[string]interface{} {
t.Helper()
var out map[string]interface{}
require.NoError(t, json.Unmarshal([]byte(text), &out))
return out
}
func mustParseApprovalPayload(t *testing.T, text string) map[string]interface{} {
t.Helper()
require.True(t, strings.HasPrefix(text, "APPROVAL_REQUIRED: "))
return mustParseJSONMap(t, strings.TrimPrefix(text, "APPROVAL_REQUIRED: "))
}
func TestPulseToolExecutor_ExecuteRunCommand(t *testing.T) {
ctx := context.Background()
t.Run("MissingCommand", func(t *testing.T) {
exec := NewPulseToolExecutor(ExecutorConfig{})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{})
assert.NoError(t, err)
assert.True(t, result.IsError)
assert.Contains(t, result.Content[0].Text, "command is required")
})
t.Run("PolicyBlocked", func(t *testing.T) {
policy := &mockCommandPolicy{}
policy.On("Evaluate", "rm -rf /").Return(agentexec.PolicyBlock).Once()
exec := NewPulseToolExecutor(ExecutorConfig{Policy: policy})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "rm -rf /",
})
assert.NoError(t, err)
assert.False(t, result.IsError)
assert.Contains(t, result.Content[0].Text, "POLICY_BLOCKED")
policy.AssertExpectations(t)
})
t.Run("TargetHostRequired", func(t *testing.T) {
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "a1", Hostname: "node1"},
{AgentID: "a2", Hostname: "node2"},
}}
exec := NewPulseToolExecutor(ExecutorConfig{AgentServer: agentSrv})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "ls",
})
assert.NoError(t, err)
assert.Contains(t, result.Content[0].Text, "Multiple agents are connected")
})
t.Run("ControlledRequiresApproval", func(t *testing.T) {
approval.SetStore(nil)
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "tower"},
}}
exec := NewPulseToolExecutor(ExecutorConfig{
AgentServer: agentSrv,
ControlLevel: ControlLevelControlled,
})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "ls",
"target_host": "tower",
})
assert.NoError(t, err)
assert.Contains(t, result.Content[0].Text, "APPROVAL_REQUIRED")
})
t.Run("ControlledApprovalUsesResolvedRoutingTarget", func(t *testing.T) {
store, err := approval.NewStore(approval.StoreConfig{
DataDir: t.TempDir(),
DisablePersistence: true,
})
require.NoError(t, err)
approval.SetStore(store)
defer approval.SetStore(nil)
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "tower"},
}}
exec := NewPulseToolExecutor(ExecutorConfig{
AgentServer: agentSrv,
ControlLevel: ControlLevelControlled,
})
// Session context target must not influence command approval binding.
exec.SetContext("host", "session-target", false)
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "uptime",
"target_host": "tower",
})
require.NoError(t, err)
payload := mustParseApprovalPayload(t, result.Content[0].Text)
approvalID, _ := payload["approval_id"].(string)
require.NotEmpty(t, approvalID)
req, found := store.GetApproval(approvalID)
require.True(t, found)
assert.Equal(t, "agent", req.TargetType)
assert.Equal(t, "agent-1", req.TargetID)
assert.Equal(t, "tower", req.TargetName)
})
t.Run("ControlledConsumesApprovedCommandWithResolvedRoutingTarget", func(t *testing.T) {
store, err := approval.NewStore(approval.StoreConfig{
DataDir: t.TempDir(),
DisablePersistence: true,
})
require.NoError(t, err)
approval.SetStore(store)
defer approval.SetStore(nil)
req := &approval.ApprovalRequest{
ID: "approval-1",
Command: "uptime",
TargetType: "agent",
TargetID: "agent-1",
}
require.NoError(t, store.CreateApproval(req))
_, err = store.Approve("approval-1", "tester")
require.NoError(t, err)
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "tower"},
}}
agentSrv.On("GetConnectedAgents").Return([]agentexec.ConnectedAgent{
{AgentID: "agent-1", Hostname: "tower"},
}).Maybe()
agentSrv.On("ExecuteCommand", mock.Anything, "agent-1", mock.MatchedBy(func(payload agentexec.ExecuteCommandPayload) bool {
return payload.Command == "uptime" && payload.TargetType == "agent" && payload.TargetID == ""
})).Return(&agentexec.CommandResultPayload{
Stdout: "ok",
ExitCode: 0,
}, nil).Once()
exec := NewPulseToolExecutor(ExecutorConfig{
AgentServer: agentSrv,
ControlLevel: ControlLevelControlled,
})
exec.SetContext("host", "different-session-target", false)
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "uptime",
"target_host": "tower",
"_approval_id": "approval-1",
})
require.NoError(t, err)
resp := mustParseJSONMap(t, result.Content[0].Text)
assert.Equal(t, true, resp["success"])
assert.Equal(t, float64(0), resp["exit_code"])
agentSrv.AssertExpectations(t)
consumed, found := store.GetApproval("approval-1")
require.True(t, found)
assert.True(t, consumed.Consumed)
})
t.Run("ExecuteSuccess", func(t *testing.T) {
store := unifiedresources.NewMemoryStore()
agentSrv := &mockAgentServer{}
agentSrv.On("GetConnectedAgents").Return([]agentexec.ConnectedAgent{
{AgentID: "agent1", Hostname: "node1"},
}).Twice()
agentSrv.On("ExecuteCommand", mock.Anything, "agent1", mock.MatchedBy(func(payload agentexec.ExecuteCommandPayload) bool {
// For direct agent targets, TargetID is empty - resolveTargetForCommand returns "" for agent type
return payload.Command == "uptime" && payload.TargetType == "agent" && payload.TargetID == ""
})).Return(&agentexec.CommandResultPayload{
Stdout: "ok",
ExitCode: 0,
}, nil).Once()
exec := NewPulseToolExecutor(ExecutorConfig{
AgentServer: agentSrv,
ActionAuditStore: store,
})
exec.SetContext("host", "host1", false)
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "uptime",
"run_on_host": true,
})
assert.NoError(t, err)
resp := mustParseJSONMap(t, result.Content[0].Text)
assert.Equal(t, true, resp["success"])
assert.Equal(t, float64(0), resp["exit_code"])
assert.Contains(t, resp["output"].(string), "ok")
if v, ok := resp["verification"].(map[string]interface{}); ok {
assert.Equal(t, true, v["ok"])
}
agentSrv.AssertExpectations(t)
audits, err := store.GetActionAudits("", time.Time{}, 10)
require.NoError(t, err)
require.Len(t, audits, 1)
assert.Equal(t, "pulse_control", audits[0].Request.CapabilityName)
assert.Contains(t, audits[0].Plan.Message, "run command \"uptime\"")
assert.Empty(t, audits[0].Plan.ResourceVersion)
assert.Empty(t, audits[0].Plan.PolicyVersion)
assert.NotEmpty(t, audits[0].Plan.PlanHash)
planJSON, err := json.Marshal(audits[0].Plan)
require.NoError(t, err)
_ = mustParseJSONMap(t, string(planJSON))
events, err := store.GetActionLifecycleEvents(audits[0].ID, time.Time{}, 10)
require.NoError(t, err)
require.Len(t, events, 3)
assert.Equal(t, unifiedresources.ActionStatePlanned, events[2].State)
assert.Equal(t, unifiedresources.ActionStateExecuting, events[1].State)
assert.Equal(t, unifiedresources.ActionStateCompleted, events[0].State)
})
}
func TestPulseToolExecutor_RunCommandLXCRouting(t *testing.T) {
ctx := context.Background()
t.Run("LXCCommandRoutedCorrectly", func(t *testing.T) {
// Test that commands targeting LXCs are routed with correct target type/ID
// The agent handles sh -c wrapping, so tool just sends raw command
agents := []agentexec.ConnectedAgent{{AgentID: "proxmox-agent", Hostname: "pve-node"}}
mockAgent := &mockAgentServer{}
mockAgent.On("GetConnectedAgents").Return(agents)
mockAgent.On("ExecuteCommand", mock.Anything, "proxmox-agent", mock.MatchedBy(func(cmd agentexec.ExecuteCommandPayload) bool {
// Tool sends raw command, agent will wrap in sh -c
return cmd.TargetType == "container" &&
cmd.TargetID == "108" &&
cmd.Command == "grep pattern /var/log/*.log"
})).Return(&agentexec.CommandResultPayload{
ExitCode: 0,
Stdout: "matched line",
}, nil)
state := models.StateSnapshot{
Containers: []models.Container{
{VMID: 108, Name: "jellyfin", Node: "pve-node"},
},
}
exec := NewPulseToolExecutor(ExecutorConfig{
StateProvider: &mockStateProvider{state: state},
AgentServer: mockAgent,
ControlLevel: ControlLevelAutonomous,
})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "grep pattern /var/log/*.log",
"target_host": "jellyfin",
})
require.NoError(t, err)
resp := mustParseJSONMap(t, result.Content[0].Text)
assert.Equal(t, true, resp["success"])
assert.Equal(t, "jellyfin", resp["target_host"])
mockAgent.AssertExpectations(t)
})
t.Run("VMCommandRoutedCorrectly", func(t *testing.T) {
// Test that commands targeting VMs are routed with correct target type/ID
agents := []agentexec.ConnectedAgent{{AgentID: "proxmox-agent", Hostname: "pve-node"}}
mockAgent := &mockAgentServer{}
mockAgent.On("GetConnectedAgents").Return(agents)
mockAgent.On("ExecuteCommand", mock.Anything, "proxmox-agent", mock.MatchedBy(func(cmd agentexec.ExecuteCommandPayload) bool {
return cmd.TargetType == "vm" &&
cmd.TargetID == "100" &&
cmd.Command == "ls /tmp/*.txt"
})).Return(&agentexec.CommandResultPayload{
ExitCode: 0,
Stdout: "result",
}, nil)
state := models.StateSnapshot{
VMs: []models.VM{
{VMID: 100, Name: "test-vm", Node: "pve-node"},
},
}
exec := NewPulseToolExecutor(ExecutorConfig{
StateProvider: &mockStateProvider{state: state},
AgentServer: mockAgent,
ControlLevel: ControlLevelAutonomous,
})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "ls /tmp/*.txt",
"target_host": "test-vm",
})
require.NoError(t, err)
resp := mustParseJSONMap(t, result.Content[0].Text)
assert.Equal(t, true, resp["success"])
assert.Equal(t, "test-vm", resp["target_host"])
mockAgent.AssertExpectations(t)
})
t.Run("DirectHostRoutedCorrectly", func(t *testing.T) {
// Direct host commands are canonicalized to target type "agent"
agents := []agentexec.ConnectedAgent{{AgentID: "agent", Hostname: "tower"}}
mockAgent := &mockAgentServer{}
mockAgent.On("GetConnectedAgents").Return(agents)
mockAgent.On("ExecuteCommand", mock.Anything, "agent", mock.MatchedBy(func(cmd agentexec.ExecuteCommandPayload) bool {
return cmd.TargetType == "agent" &&
cmd.Command == "ls /tmp/*.txt"
})).Return(&agentexec.CommandResultPayload{
ExitCode: 0,
Stdout: "files",
}, nil)
exec := NewPulseToolExecutor(ExecutorConfig{
StateProvider: &mockStateProvider{state: models.StateSnapshot{}},
AgentServer: mockAgent,
ControlLevel: ControlLevelAutonomous,
})
result, err := exec.executeRunCommand(ctx, map[string]interface{}{
"command": "ls /tmp/*.txt",
"target_host": "tower",
})
require.NoError(t, err)
resp := mustParseJSONMap(t, result.Content[0].Text)
assert.Equal(t, true, resp["success"])
assert.Equal(t, "tower", resp["target_host"])
mockAgent.AssertExpectations(t)
})
}
func TestPulseToolExecutor_FindAgentForCommand(t *testing.T) {
t.Run("NoAgentServer", func(t *testing.T) {
exec := NewPulseToolExecutor(ExecutorConfig{})
assert.Empty(t, exec.findAgentForCommand(false, ""))
})
t.Run("NoAgents", func(t *testing.T) {
agentSrv := &mockAgentServer{}
exec := NewPulseToolExecutor(ExecutorConfig{AgentServer: agentSrv})
assert.Empty(t, exec.findAgentForCommand(false, ""))
})
t.Run("TargetHostMatches", func(t *testing.T) {
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "a1", Hostname: "node1"},
}}
exec := NewPulseToolExecutor(ExecutorConfig{AgentServer: agentSrv})
assert.Equal(t, "a1", exec.findAgentForCommand(false, "a1"))
})
t.Run("MultipleAgentsNoTarget", func(t *testing.T) {
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "a1", Hostname: "node1"},
{AgentID: "a2", Hostname: "node2"},
}}
exec := NewPulseToolExecutor(ExecutorConfig{AgentServer: agentSrv})
assert.Empty(t, exec.findAgentForCommand(false, ""))
})
t.Run("SingleAgentNoTarget", func(t *testing.T) {
agentSrv := &mockAgentServer{agents: []agentexec.ConnectedAgent{
{AgentID: "a1", Hostname: "node1"},
}}
exec := NewPulseToolExecutor(ExecutorConfig{AgentServer: agentSrv})
assert.Equal(t, "a1", exec.findAgentForCommand(false, ""))
})
}