mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-13 15:28:38 +00:00
491 lines
15 KiB
Go
491 lines
15 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/ai/providers"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/ai/tools"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/recovery"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/unifiedresources"
|
|
)
|
|
|
|
type stubServiceProvider struct {
|
|
streamFn func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error
|
|
}
|
|
|
|
func (s *stubServiceProvider) Chat(ctx context.Context, req providers.ChatRequest) (*providers.ChatResponse, error) {
|
|
return &providers.ChatResponse{Content: "ok", Model: req.Model}, nil
|
|
}
|
|
|
|
func (s *stubServiceProvider) TestConnection(ctx context.Context) error { return nil }
|
|
|
|
func (s *stubServiceProvider) Name() string { return "stub" }
|
|
|
|
func (s *stubServiceProvider) ListModels(ctx context.Context) ([]providers.ModelInfo, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *stubServiceProvider) ChatStream(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
if s.streamFn != nil {
|
|
return s.streamFn(ctx, req, callback)
|
|
}
|
|
callback(providers.StreamEvent{
|
|
Type: "content",
|
|
Data: providers.ContentEvent{Text: "hello"},
|
|
})
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{InputTokens: 1, OutputTokens: 1},
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (s *stubServiceProvider) SupportsThinking(model string) bool { return false }
|
|
|
|
func TestService_ExecuteStream_Success(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
store, err := NewSessionStore(tmpDir)
|
|
if err != nil {
|
|
t.Fatalf("failed to create session store: %v", err)
|
|
}
|
|
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
provider := &stubServiceProvider{}
|
|
loop := NewAgenticLoop(provider, executor, "system")
|
|
|
|
svc := &Service{
|
|
cfg: &config.AIConfig{ChatModel: "openai:test"},
|
|
sessions: store,
|
|
executor: executor,
|
|
agenticLoop: loop,
|
|
provider: provider, // Required: per-request loops need a provider to create new instances
|
|
started: true,
|
|
}
|
|
|
|
var doneCount int
|
|
callback := func(event StreamEvent) {
|
|
if event.Type == "done" {
|
|
doneCount++
|
|
}
|
|
}
|
|
|
|
req := ExecuteRequest{
|
|
SessionID: "sess-1",
|
|
Prompt: "hello",
|
|
}
|
|
if err := svc.ExecuteStream(context.Background(), req, callback); err != nil {
|
|
t.Fatalf("ExecuteStream failed: %v", err)
|
|
}
|
|
if doneCount != 1 {
|
|
t.Fatalf("expected done callback once, got %d", doneCount)
|
|
}
|
|
|
|
messages, err := store.GetMessages("sess-1")
|
|
if err != nil {
|
|
t.Fatalf("GetMessages failed: %v", err)
|
|
}
|
|
if len(messages) < 2 {
|
|
t.Fatalf("expected at least 2 messages, got %d", len(messages))
|
|
}
|
|
}
|
|
|
|
func TestService_ExecuteStream_PrefetchMentionsAndOverrideModel(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
store, err := NewSessionStore(tmpDir)
|
|
if err != nil {
|
|
t.Fatalf("failed to create session store: %v", err)
|
|
}
|
|
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
provider := &stubServiceProvider{}
|
|
loop := NewAgenticLoop(provider, executor, "system")
|
|
|
|
state := models.StateSnapshot{
|
|
VMs: []models.VM{
|
|
{ID: "vm:node1:101", VMID: 101, Name: "vm1", Node: "node1", Status: "running"},
|
|
},
|
|
}
|
|
readState := unifiedresources.NewRegistry(nil)
|
|
readState.IngestSnapshot(state)
|
|
vmResources := readState.ListByType(unifiedresources.ResourceTypeVM)
|
|
if len(vmResources) != 1 {
|
|
t.Fatalf("expected one canonical VM resource, got %d", len(vmResources))
|
|
}
|
|
vmResourceID := vmResources[0].ID
|
|
|
|
var capturedModel string
|
|
svc := &Service{
|
|
cfg: &config.AIConfig{ChatModel: "openai:primary"},
|
|
sessions: store,
|
|
executor: executor,
|
|
agenticLoop: loop,
|
|
contextPrefetcher: NewContextPrefetcher(readState, nil),
|
|
provider: provider,
|
|
started: true,
|
|
providerFactory: func(modelStr string) (providers.StreamingProvider, error) {
|
|
capturedModel = modelStr
|
|
return provider, nil
|
|
},
|
|
}
|
|
|
|
autonomous := true
|
|
req := ExecuteRequest{
|
|
SessionID: "sess-2",
|
|
Prompt: "check @vm1",
|
|
Model: "openai:override",
|
|
Mentions: []StructuredMention{{ID: "vm:node1:101", Name: "vm1", Type: "vm", Node: "node1"}},
|
|
MaxTurns: 1,
|
|
AutonomousMode: &autonomous,
|
|
}
|
|
|
|
if err := svc.ExecuteStream(context.Background(), req, func(StreamEvent) {}); err != nil {
|
|
t.Fatalf("ExecuteStream failed: %v", err)
|
|
}
|
|
if capturedModel != "openai:override" {
|
|
t.Fatalf("expected override model to be used, got %q", capturedModel)
|
|
}
|
|
|
|
resolved := store.GetResolvedContext("sess-2")
|
|
if !resolved.WasRecentlyAccessed(vmResourceID, time.Minute) {
|
|
t.Fatal("expected explicit access to be recorded for structured mention")
|
|
}
|
|
}
|
|
|
|
func TestService_ExecuteStream_PrefetchMentionsMarksVMwareUnifiedResourceAccess(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
store, err := NewSessionStore(tmpDir)
|
|
if err != nil {
|
|
t.Fatalf("failed to create session store: %v", err)
|
|
}
|
|
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
provider := &stubServiceProvider{}
|
|
loop := NewAgenticLoop(provider, executor, "system")
|
|
|
|
now := time.Now().UTC()
|
|
rr := unifiedresources.NewRegistry(nil)
|
|
rr.IngestRecords(unifiedresources.SourceVMware, []unifiedresources.IngestRecord{{
|
|
SourceID: "vc-1:vm:vm-201",
|
|
Resource: unifiedresources.Resource{
|
|
ID: "vmware-vm-1",
|
|
Type: unifiedresources.ResourceTypeVM,
|
|
Name: "app-01",
|
|
Status: unifiedresources.StatusOnline,
|
|
LastSeen: now,
|
|
UpdatedAt: now,
|
|
ParentName: "esxi-01.lab.local",
|
|
VMware: &unifiedresources.VMwareData{
|
|
ConnectionID: "vc-1",
|
|
ConnectionName: "Lab VC",
|
|
ManagedObjectID: "vm-201",
|
|
EntityType: "vm",
|
|
RuntimeHostName: "esxi-01.lab.local",
|
|
},
|
|
},
|
|
Identity: unifiedresources.ResourceIdentity{Hostnames: []string{"app-01"}},
|
|
}})
|
|
vmResources := rr.ListByType(unifiedresources.ResourceTypeVM)
|
|
if len(vmResources) != 1 {
|
|
t.Fatalf("expected one VMware VM resource, got %d", len(vmResources))
|
|
}
|
|
vmResourceID := vmResources[0].ID
|
|
|
|
svc := &Service{
|
|
cfg: &config.AIConfig{ChatModel: "openai:primary"},
|
|
sessions: store,
|
|
executor: executor,
|
|
agenticLoop: loop,
|
|
contextPrefetcher: NewContextPrefetcher(rr, nil),
|
|
provider: provider,
|
|
started: true,
|
|
providerFactory: func(modelStr string) (providers.StreamingProvider, error) {
|
|
return provider, nil
|
|
},
|
|
}
|
|
|
|
autonomous := true
|
|
req := ExecuteRequest{
|
|
SessionID: "sess-vmware-mentions",
|
|
Prompt: "check @app-01",
|
|
Mentions: []StructuredMention{{ID: vmResourceID, Name: "app-01", Type: "vm"}},
|
|
MaxTurns: 1,
|
|
AutonomousMode: &autonomous,
|
|
}
|
|
|
|
if err := svc.ExecuteStream(context.Background(), req, func(StreamEvent) {}); err != nil {
|
|
t.Fatalf("ExecuteStream failed: %v", err)
|
|
}
|
|
|
|
resolved := store.GetResolvedContext("sess-vmware-mentions")
|
|
if !resolved.WasRecentlyAccessed(vmResourceID, time.Minute) {
|
|
t.Fatal("expected explicit access to be recorded for VMware structured mention")
|
|
}
|
|
}
|
|
|
|
func TestService_ExecuteStream_AgenticPulseStorageSnapshotsToleratesMalformedRecoveryMetadata(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
store, err := NewSessionStore(tmpDir)
|
|
if err != nil {
|
|
t.Fatalf("failed to create session store: %v", err)
|
|
}
|
|
|
|
completedAt := time.Date(2026, 2, 24, 10, 30, 0, 0, time.UTC)
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{
|
|
StateProvider: fakeStateProvider{},
|
|
ReadState: &fakeCanonicalReadState{},
|
|
RecoveryPointsProvider: &fakeRecoveryPointsProvider{points: []recovery.RecoveryPoint{{
|
|
ID: "pve-snapshot:snap-100-before-upgrade",
|
|
Provider: recovery.ProviderProxmoxPVE,
|
|
Kind: recovery.KindSnapshot,
|
|
Mode: recovery.ModeLocal,
|
|
Outcome: recovery.OutcomeSuccess,
|
|
SubjectRef: &recovery.ExternalRef{
|
|
Type: "vm",
|
|
Name: "100",
|
|
ID: "100",
|
|
Namespace: "pve1",
|
|
Class: "node1",
|
|
},
|
|
Display: &recovery.RecoveryPointDisplay{
|
|
SubjectLabel: "100",
|
|
ItemType: "vm",
|
|
ClusterLabel: "pve1",
|
|
NodeHostLabel: "node1",
|
|
EntityIDLabel: "100",
|
|
DetailsSummary: "before-upgrade",
|
|
},
|
|
CompletedAt: &completedAt,
|
|
}}},
|
|
})
|
|
|
|
turn := 0
|
|
provider := &stubServiceProvider{
|
|
streamFn: func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
turn++
|
|
switch turn {
|
|
case 1:
|
|
if !hasProviderTool(req.Tools, "pulse_storage") {
|
|
t.Fatalf("expected pulse_storage tool to be available, got %#v", req.Tools)
|
|
}
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{
|
|
StopReason: "tool_use",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call-snapshots",
|
|
Name: "pulse_storage",
|
|
Input: map[string]interface{}{
|
|
"type": "snapshots",
|
|
"guest_id": "100",
|
|
"instance": "pve1",
|
|
},
|
|
}},
|
|
},
|
|
})
|
|
return nil
|
|
case 2:
|
|
var toolResult string
|
|
for _, msg := range req.Messages {
|
|
if msg.ToolResult != nil && msg.ToolResult.ToolUseID == "call-snapshots" {
|
|
toolResult = msg.ToolResult.Content
|
|
break
|
|
}
|
|
}
|
|
if toolResult == "" {
|
|
t.Fatalf("expected pulse_storage tool result in second turn, got %#v", req.Messages)
|
|
}
|
|
if !strings.Contains(toolResult, "\"snapshot_name\":\"before-upgrade\"") {
|
|
t.Fatalf("expected canonical snapshot name in tool result, got %s", toolResult)
|
|
}
|
|
if !strings.Contains(toolResult, "\"instance\":\"pve1\"") || !strings.Contains(toolResult, "\"node\":\"node1\"") {
|
|
t.Fatalf("expected canonical placement labels in tool result, got %s", toolResult)
|
|
}
|
|
callback(providers.StreamEvent{
|
|
Type: "content",
|
|
Data: providers.ContentEvent{Text: "Recovered snapshot inventory"},
|
|
})
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{
|
|
InputTokens: 10,
|
|
OutputTokens: 12,
|
|
},
|
|
})
|
|
return nil
|
|
default:
|
|
t.Fatalf("unexpected extra provider turn %d", turn)
|
|
return nil
|
|
}
|
|
},
|
|
}
|
|
|
|
svc := &Service{
|
|
cfg: &config.AIConfig{ChatModel: "openai:test"},
|
|
sessions: store,
|
|
executor: executor,
|
|
agenticLoop: NewAgenticLoop(provider, executor, "system"),
|
|
provider: provider,
|
|
started: true,
|
|
}
|
|
|
|
req := ExecuteRequest{
|
|
SessionID: "sess-storage-snapshots",
|
|
Prompt: "List recovery snapshots for guest 100 on pve1.",
|
|
}
|
|
if err := svc.ExecuteStream(context.Background(), req, func(StreamEvent) {}); err != nil {
|
|
t.Fatalf("ExecuteStream failed: %v", err)
|
|
}
|
|
|
|
messages, err := store.GetMessages(req.SessionID)
|
|
if err != nil {
|
|
t.Fatalf("GetMessages failed: %v", err)
|
|
}
|
|
foundToolResult := false
|
|
for _, msg := range messages {
|
|
if msg.ToolResult != nil && strings.Contains(msg.ToolResult.Content, "\"snapshot_name\":\"before-upgrade\"") {
|
|
foundToolResult = true
|
|
break
|
|
}
|
|
}
|
|
if !foundToolResult {
|
|
t.Fatalf("expected stored tool result with canonical snapshot fallback, got %#v", messages)
|
|
}
|
|
}
|
|
|
|
func TestService_ExecuteStream_AgenticPulseStorageBackupTasksToleratesMalformedRecoveryMetadata(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
store, err := NewSessionStore(tmpDir)
|
|
if err != nil {
|
|
t.Fatalf("failed to create session store: %v", err)
|
|
}
|
|
|
|
startedAt := time.Date(2026, 2, 24, 11, 0, 0, 0, time.UTC)
|
|
completedAt := time.Date(2026, 2, 24, 11, 15, 0, 0, time.UTC)
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{
|
|
StateProvider: fakeStateProvider{},
|
|
ReadState: &fakeCanonicalReadState{},
|
|
RecoveryPointsProvider: &fakeRecoveryPointsProvider{points: []recovery.RecoveryPoint{{
|
|
ID: "pve-task:task-101-backup",
|
|
Provider: recovery.ProviderProxmoxPVE,
|
|
Kind: recovery.KindBackup,
|
|
Mode: recovery.ModeLocal,
|
|
Outcome: recovery.OutcomeSuccess,
|
|
SubjectRef: &recovery.ExternalRef{
|
|
Type: "vm",
|
|
Name: "101",
|
|
ID: "101",
|
|
Namespace: "pve1",
|
|
Class: "node1",
|
|
},
|
|
Display: &recovery.RecoveryPointDisplay{
|
|
SubjectLabel: "101",
|
|
ItemType: "vm",
|
|
ClusterLabel: "pve1",
|
|
NodeHostLabel: "node1",
|
|
EntityIDLabel: "101",
|
|
DetailsSummary: "completed successfully",
|
|
},
|
|
StartedAt: &startedAt,
|
|
CompletedAt: &completedAt,
|
|
}}},
|
|
})
|
|
|
|
turn := 0
|
|
provider := &stubServiceProvider{
|
|
streamFn: func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
turn++
|
|
switch turn {
|
|
case 1:
|
|
if !hasProviderTool(req.Tools, "pulse_storage") {
|
|
t.Fatalf("expected pulse_storage tool to be available, got %#v", req.Tools)
|
|
}
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{
|
|
StopReason: "tool_use",
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call-backup-tasks",
|
|
Name: "pulse_storage",
|
|
Input: map[string]interface{}{
|
|
"type": "backup_tasks",
|
|
"guest_id": "101",
|
|
"instance": "pve1",
|
|
"status": "OK",
|
|
},
|
|
}},
|
|
},
|
|
})
|
|
return nil
|
|
case 2:
|
|
var toolResult string
|
|
for _, msg := range req.Messages {
|
|
if msg.ToolResult != nil && msg.ToolResult.ToolUseID == "call-backup-tasks" {
|
|
toolResult = msg.ToolResult.Content
|
|
break
|
|
}
|
|
}
|
|
if toolResult == "" {
|
|
t.Fatalf("expected pulse_storage tool result in second turn, got %#v", req.Messages)
|
|
}
|
|
if !strings.Contains(toolResult, "\"status\":\"OK\"") || !strings.Contains(toolResult, "\"type\":\"vm\"") {
|
|
t.Fatalf("expected canonical backup task fields in tool result, got %s", toolResult)
|
|
}
|
|
if !strings.Contains(toolResult, "\"instance\":\"pve1\"") || !strings.Contains(toolResult, "\"node\":\"node1\"") {
|
|
t.Fatalf("expected canonical placement labels in tool result, got %s", toolResult)
|
|
}
|
|
callback(providers.StreamEvent{
|
|
Type: "content",
|
|
Data: providers.ContentEvent{Text: "Recovered backup task inventory"},
|
|
})
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{
|
|
InputTokens: 10,
|
|
OutputTokens: 12,
|
|
},
|
|
})
|
|
return nil
|
|
default:
|
|
t.Fatalf("unexpected extra provider turn %d", turn)
|
|
return nil
|
|
}
|
|
},
|
|
}
|
|
|
|
svc := &Service{
|
|
cfg: &config.AIConfig{ChatModel: "openai:test"},
|
|
sessions: store,
|
|
executor: executor,
|
|
agenticLoop: NewAgenticLoop(provider, executor, "system"),
|
|
provider: provider,
|
|
started: true,
|
|
}
|
|
|
|
req := ExecuteRequest{
|
|
SessionID: "sess-storage-backup-tasks",
|
|
Prompt: "List recovery backup tasks for guest 101 on pve1.",
|
|
}
|
|
if err := svc.ExecuteStream(context.Background(), req, func(StreamEvent) {}); err != nil {
|
|
t.Fatalf("ExecuteStream failed: %v", err)
|
|
}
|
|
|
|
messages, err := store.GetMessages(req.SessionID)
|
|
if err != nil {
|
|
t.Fatalf("GetMessages failed: %v", err)
|
|
}
|
|
foundToolResult := false
|
|
for _, msg := range messages {
|
|
if msg.ToolResult != nil && strings.Contains(msg.ToolResult.Content, "\"status\":\"OK\"") {
|
|
foundToolResult = true
|
|
break
|
|
}
|
|
}
|
|
if !foundToolResult {
|
|
t.Fatalf("expected stored tool result with canonical backup-task fallback, got %#v", messages)
|
|
}
|
|
}
|