mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 17:19:57 +00:00
556 lines
20 KiB
Go
556 lines
20 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/ai/providers"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/ai/tools"
|
|
)
|
|
|
|
func TestAgenticLoop_Setters(t *testing.T) {
|
|
loop := &AgenticLoop{}
|
|
loop.SetMaxTurns(7)
|
|
if loop.maxTurns != 7 {
|
|
t.Fatalf("expected maxTurns=7, got %d", loop.maxTurns)
|
|
}
|
|
|
|
loop.SetProviderInfo("provider", "model")
|
|
if loop.providerName != "provider" || loop.modelName != "model" {
|
|
t.Fatalf("expected provider/model to be set")
|
|
}
|
|
|
|
called := false
|
|
loop.SetBudgetChecker(func() error {
|
|
called = true
|
|
return nil
|
|
})
|
|
if loop.budgetChecker == nil {
|
|
t.Fatalf("expected budgetChecker to be set")
|
|
}
|
|
_ = loop.budgetChecker()
|
|
if !called {
|
|
t.Fatalf("expected budgetChecker to be invoked")
|
|
}
|
|
}
|
|
|
|
func TestPruneMessagesForModel_Stateless(t *testing.T) {
|
|
prev := StatelessContext
|
|
StatelessContext = true
|
|
defer func() { StatelessContext = prev }()
|
|
|
|
messages := []Message{
|
|
{Role: "user", Content: "first"},
|
|
{Role: "assistant", Content: "ok"},
|
|
{Role: "user", Content: "second"},
|
|
{Role: "assistant", Content: "done"},
|
|
}
|
|
|
|
pruned := pruneMessagesForModel(messages)
|
|
if len(pruned) != 1 {
|
|
t.Fatalf("expected 1 message, got %d", len(pruned))
|
|
}
|
|
if pruned[0].Content != "second" {
|
|
t.Fatalf("expected last user message to be kept")
|
|
}
|
|
}
|
|
|
|
func TestPruneMessagesForModel_SkipsOrphanedToolResults(t *testing.T) {
|
|
// Build a message list longer than MaxContextMessagesLimit so pruning occurs.
|
|
messages := make([]Message, 0, MaxContextMessagesLimit+2)
|
|
messages = append(messages, Message{Role: "user", Content: "a"})
|
|
messages = append(messages, Message{Role: "assistant", Content: "b"})
|
|
// This tool result should be dropped if it becomes the first item after pruning.
|
|
messages = append(messages, Message{Role: "assistant", ToolResult: &ToolResult{Content: "tool"}})
|
|
for i := 0; i < MaxContextMessagesLimit; i++ {
|
|
messages = append(messages, Message{Role: "user", Content: "msg"})
|
|
}
|
|
|
|
pruned := pruneMessagesForModel(messages)
|
|
if len(pruned) == 0 {
|
|
t.Fatalf("expected pruned messages")
|
|
}
|
|
if pruned[0].ToolResult != nil {
|
|
t.Fatalf("expected leading tool result to be pruned")
|
|
}
|
|
}
|
|
|
|
func TestPruneMessagesForModel_SkipsAssistantWithToolCalls(t *testing.T) {
|
|
// Ensure assistant tool calls at the start of the pruned window are skipped.
|
|
messages := make([]Message, 0, MaxContextMessagesLimit+3)
|
|
messages = append(messages, Message{Role: "user", Content: "seed"})
|
|
messages = append(messages, Message{Role: "assistant", Content: "seed"})
|
|
messages = append(messages, Message{Role: "assistant", ToolCalls: []ToolCall{{Name: "pulse_query"}}})
|
|
messages = append(messages, Message{Role: "assistant", ToolResult: &ToolResult{Content: "result"}})
|
|
for i := 0; i < MaxContextMessagesLimit; i++ {
|
|
messages = append(messages, Message{Role: "user", Content: "msg"})
|
|
}
|
|
|
|
pruned := pruneMessagesForModel(messages)
|
|
if len(pruned) == 0 {
|
|
t.Fatalf("expected pruned messages")
|
|
}
|
|
if pruned[0].Role == "assistant" && len(pruned[0].ToolCalls) > 0 {
|
|
t.Fatalf("expected assistant tool-call message to be pruned")
|
|
}
|
|
if pruned[0].ToolResult != nil {
|
|
t.Fatalf("expected tool result following pruned tool call to be removed")
|
|
}
|
|
}
|
|
|
|
type stubStreamingProvider struct {
|
|
lastRequest providers.ChatRequest
|
|
chatStream func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error
|
|
}
|
|
|
|
func (s *stubStreamingProvider) Chat(ctx context.Context, req providers.ChatRequest) (*providers.ChatResponse, error) {
|
|
return &providers.ChatResponse{Content: "ok"}, nil
|
|
}
|
|
|
|
func (s *stubStreamingProvider) ChatStream(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
s.lastRequest = req
|
|
if s.chatStream != nil {
|
|
return s.chatStream(ctx, req, callback)
|
|
}
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "summary"}})
|
|
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{InputTokens: 1, OutputTokens: 1}})
|
|
return nil
|
|
}
|
|
|
|
func (s *stubStreamingProvider) SupportsThinking(model string) bool { return false }
|
|
func (s *stubStreamingProvider) TestConnection(ctx context.Context) error { return nil }
|
|
func (s *stubStreamingProvider) Name() string { return "stub" }
|
|
func (s *stubStreamingProvider) ListModels(ctx context.Context) ([]providers.ModelInfo, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func TestEnsureFinalTextResponse(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
loop := &AgenticLoop{provider: provider, baseSystemPrompt: "prompt"}
|
|
|
|
result := loop.ensureFinalTextResponse(
|
|
context.Background(),
|
|
"session-1",
|
|
[]Message{{Role: "assistant", Content: ""}},
|
|
[]providers.Message{{Role: "assistant", Content: ""}},
|
|
func(event StreamEvent) {},
|
|
)
|
|
if len(result) != 2 {
|
|
t.Fatalf("expected summary message to be appended")
|
|
}
|
|
if provider.lastRequest.ToolChoice == nil || provider.lastRequest.ToolChoice.Type != providers.ToolChoiceNone {
|
|
t.Fatalf("expected summary call to enforce text-only tool choice")
|
|
}
|
|
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
return errors.New("boom")
|
|
}
|
|
result2 := loop.ensureFinalTextResponse(
|
|
context.Background(),
|
|
"session-2",
|
|
[]Message{
|
|
{Role: "assistant", Content: ""},
|
|
{Role: "user", ToolResult: &ToolResult{ToolUseID: "pulse_query_0", Content: "{\"nodes\":1}", IsError: false}},
|
|
},
|
|
[]providers.Message{{Role: "assistant", Content: ""}},
|
|
func(event StreamEvent) {},
|
|
)
|
|
if len(result2) != 3 {
|
|
t.Fatalf("expected fallback summary message when provider errors")
|
|
}
|
|
if !strings.Contains(result2[len(result2)-1].Content, "automatic summary") {
|
|
t.Fatalf("expected deterministic fallback summary, got %q", result2[len(result2)-1].Content)
|
|
}
|
|
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{InputTokens: 1, OutputTokens: 1}})
|
|
return nil
|
|
}
|
|
result3 := loop.ensureFinalTextResponse(
|
|
context.Background(),
|
|
"session-3",
|
|
[]Message{
|
|
{Role: "assistant", Content: ""},
|
|
{Role: "user", ToolResult: &ToolResult{ToolUseID: "pulse_metrics_1", Content: "cpu ok", IsError: false}},
|
|
},
|
|
[]providers.Message{{Role: "assistant", Content: ""}},
|
|
func(event StreamEvent) {},
|
|
)
|
|
if len(result3) != 3 {
|
|
t.Fatalf("expected fallback summary when provider returns empty content")
|
|
}
|
|
if !strings.Contains(result3[len(result3)-1].Content, "Latest successful result snippet") {
|
|
t.Fatalf("expected fallback summary to include tool snippet")
|
|
}
|
|
}
|
|
|
|
func TestBuildAutomaticFallbackSummary(t *testing.T) {
|
|
summary := buildAutomaticFallbackSummary([]Message{
|
|
{Role: "user", ToolResult: &ToolResult{ToolUseID: "pulse_query_0", Content: "nodes ok", IsError: false}},
|
|
{Role: "user", ToolResult: &ToolResult{ToolUseID: "pulse_query_1", Content: "containers ok", IsError: false}},
|
|
{Role: "user", ToolResult: &ToolResult{ToolUseID: "pulse_read_0", Content: "read failed", IsError: true}},
|
|
})
|
|
if !strings.Contains(summary, "2 successful check(s)") {
|
|
t.Fatalf("unexpected fallback summary: %q", summary)
|
|
}
|
|
if !strings.Contains(summary, "pulse_query") {
|
|
t.Fatalf("expected tool name in fallback summary")
|
|
}
|
|
}
|
|
|
|
func TestExecuteToolSafely_RecoversPanic(t *testing.T) {
|
|
exec := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
exec.RegisterTool(tools.RegisteredTool{
|
|
Definition: tools.Tool{
|
|
Name: "panic_tool",
|
|
InputSchema: tools.InputSchema{
|
|
Type: "object",
|
|
Properties: map[string]tools.PropertySchema{},
|
|
},
|
|
},
|
|
Handler: func(_ context.Context, _ *tools.PulseToolExecutor, _ map[string]interface{}) (tools.CallToolResult, error) {
|
|
panic("boom")
|
|
},
|
|
})
|
|
|
|
loop := &AgenticLoop{executor: exec}
|
|
result, err := loop.executeToolSafely(context.Background(), "panic_tool", map[string]interface{}{})
|
|
if err == nil {
|
|
t.Fatalf("expected panic recovery error")
|
|
}
|
|
if !result.IsError {
|
|
t.Fatalf("expected error result after panic recovery")
|
|
}
|
|
if len(result.Content) == 0 || !strings.Contains(result.Content[0].Text, "tool panic in panic_tool") {
|
|
t.Fatalf("unexpected panic recovery result: %+v", result)
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_RetriesProviderStreamBeforeEvents(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
|
|
callCount := 0
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callCount++
|
|
if callCount == 1 {
|
|
return errors.New("connection reset by peer")
|
|
}
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "hello"}})
|
|
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{}})
|
|
return nil
|
|
}
|
|
|
|
results, err := loop.Execute(context.Background(), "retry-before-events", []Message{{Role: "user", Content: "hi"}}, func(event StreamEvent) {})
|
|
if err != nil {
|
|
t.Fatalf("expected retry to recover stream failure, got error: %v", err)
|
|
}
|
|
if callCount != 2 {
|
|
t.Fatalf("expected 2 provider attempts, got %d", callCount)
|
|
}
|
|
if len(results) != 1 || results[0].Content != "hello" {
|
|
t.Fatalf("unexpected results: %+v", results)
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_DoesNotRetryAfterPartialEvents(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
|
|
callCount := 0
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callCount++
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "partial"}})
|
|
return errors.New("connection reset by peer")
|
|
}
|
|
|
|
_, err := loop.Execute(context.Background(), "no-retry-partial", []Message{{Role: "user", Content: "hi"}}, func(event StreamEvent) {})
|
|
if err == nil {
|
|
t.Fatalf("expected provider error when stream fails after partial output")
|
|
}
|
|
if callCount != 1 {
|
|
t.Fatalf("expected no retry after partial output, got %d attempts", callCount)
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_EmitsFallbackErrorEventOnTransportFailure(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "partial"}})
|
|
return errors.New("connection reset by peer")
|
|
}
|
|
|
|
var events []StreamEvent
|
|
_, err := loop.Execute(context.Background(), "emit-fallback-error", []Message{{Role: "user", Content: "hi"}}, func(event StreamEvent) {
|
|
events = append(events, event)
|
|
})
|
|
if err == nil {
|
|
t.Fatalf("expected provider error when stream fails after partial output")
|
|
}
|
|
|
|
var foundError bool
|
|
for _, event := range events {
|
|
if event.Type != "error" {
|
|
continue
|
|
}
|
|
foundError = true
|
|
var data ErrorData
|
|
if decodeErr := json.Unmarshal(event.Data, &data); decodeErr != nil {
|
|
t.Fatalf("failed to decode error event payload: %v", decodeErr)
|
|
}
|
|
if strings.TrimSpace(data.Message) == "" {
|
|
t.Fatalf("expected non-empty fallback error message")
|
|
}
|
|
}
|
|
if !foundError {
|
|
t.Fatalf("expected fallback error event to be emitted")
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_IgnoresErrorAfterDoneEvent(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
|
|
callCount := 0
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callCount++
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "complete"}})
|
|
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{}})
|
|
return errors.New("EOF")
|
|
}
|
|
|
|
results, err := loop.Execute(context.Background(), "ignore-after-done", []Message{{Role: "user", Content: "hi"}}, func(event StreamEvent) {})
|
|
if err != nil {
|
|
t.Fatalf("expected post-done error to be ignored, got: %v", err)
|
|
}
|
|
if callCount != 1 {
|
|
t.Fatalf("expected single provider attempt, got %d", callCount)
|
|
}
|
|
if len(results) != 1 || results[0].Content != "complete" {
|
|
t.Fatalf("unexpected results: %+v", results)
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_RetriesOnErrorEventBeforeVisibleOutput(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
|
|
callCount := 0
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callCount++
|
|
if callCount == 1 {
|
|
callback(providers.StreamEvent{Type: "error", Data: providers.ErrorEvent{Message: "connection reset by peer"}})
|
|
return nil
|
|
}
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "recovered"}})
|
|
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{}})
|
|
return nil
|
|
}
|
|
|
|
results, err := loop.Execute(context.Background(), "retry-error-event", []Message{{Role: "user", Content: "hi"}}, func(event StreamEvent) {})
|
|
if err != nil {
|
|
t.Fatalf("expected recovery after transient error event, got: %v", err)
|
|
}
|
|
if callCount != 2 {
|
|
t.Fatalf("expected 2 provider attempts, got %d", callCount)
|
|
}
|
|
if len(results) != 1 || results[0].Content != "recovered" {
|
|
t.Fatalf("unexpected results: %+v", results)
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_DoesNotRetryErrorEventAfterVisibleOutput(t *testing.T) {
|
|
provider := &stubStreamingProvider{}
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
|
|
callCount := 0
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callCount++
|
|
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "partial"}})
|
|
callback(providers.StreamEvent{Type: "error", Data: providers.ErrorEvent{Message: "connection reset by peer"}})
|
|
return nil
|
|
}
|
|
|
|
_, err := loop.Execute(context.Background(), "no-retry-error-after-content", []Message{{Role: "user", Content: "hi"}}, func(event StreamEvent) {})
|
|
if err == nil {
|
|
t.Fatalf("expected error when stream emits error after visible output")
|
|
}
|
|
if callCount != 1 {
|
|
t.Fatalf("expected no retry after visible output, got %d attempts", callCount)
|
|
}
|
|
}
|
|
|
|
func TestTryAutoRecoveryAndCommandExtraction(t *testing.T) {
|
|
result := tools.CallToolResult{
|
|
Content: []tools.Content{{Type: "text", Text: `{"error":{"code":"READ_ONLY_VIOLATION","details":{"auto_recoverable":true,"suggested_rewrite":"uptime","category":"strict"}}}`}},
|
|
IsError: true,
|
|
}
|
|
plan, attempted := tryAutoRecovery(result, providers.ToolCall{
|
|
Name: "pulse_read",
|
|
Input: map[string]interface{}{"action": "exec"},
|
|
}, nil, context.Background())
|
|
if attempted || plan == nil {
|
|
t.Fatalf("expected recovery plan and no prior attempt")
|
|
}
|
|
if plan.ToolName != "pulse_read" || plan.Input["command"] != "uptime" {
|
|
t.Fatalf("unexpected legacy recovery plan: %+v", plan)
|
|
}
|
|
if plan.ErrorCode != "READ_ONLY_VIOLATION" {
|
|
t.Fatalf("expected READ_ONLY_VIOLATION error code, got %+v", plan)
|
|
}
|
|
|
|
structured := tools.CallToolResult{
|
|
Content: []tools.Content{{Type: "text", Text: `{"error":{"code":"ACTION_NOT_ALLOWED","details":{"auto_recoverable":true,"suggested_tool":"pulse_query","suggested_arguments":{"action":"get","resource_type":"vm","resource_id":"app-01"}}}}`}},
|
|
IsError: true,
|
|
}
|
|
plan, _ = tryAutoRecovery(structured, providers.ToolCall{Input: map[string]interface{}{}}, nil, context.Background())
|
|
if plan == nil {
|
|
t.Fatalf("expected structured recovery plan")
|
|
}
|
|
if plan.ToolName != "pulse_query" || plan.Input["action"] != "get" || plan.Input["resource_id"] != "app-01" {
|
|
t.Fatalf("unexpected structured recovery plan: %+v", plan)
|
|
}
|
|
if plan.ErrorCode != "ACTION_NOT_ALLOWED" {
|
|
t.Fatalf("expected ACTION_NOT_ALLOWED error code, got %+v", plan)
|
|
}
|
|
|
|
retryTool := tools.CallToolResult{
|
|
Content: []tools.Content{{Type: "text", Text: `{"error":5,"auto_recoverable":true,"suggested_tool":"pulse_query","suggested_arguments":{"action":"list"}}`}},
|
|
IsError: true,
|
|
}
|
|
plan, _ = tryAutoRecovery(retryTool, providers.ToolCall{Input: map[string]interface{}{}}, nil, context.Background())
|
|
if plan == nil || plan.ToolName != "pulse_query" || plan.Input["action"] != "list" {
|
|
t.Fatalf("expected alternate-format structured recovery plan, got %+v", plan)
|
|
}
|
|
|
|
plan, attempted = tryAutoRecovery(result, providers.ToolCall{Input: map[string]interface{}{"_auto_recovery_attempt": true}}, nil, context.Background())
|
|
if !attempted || plan != nil {
|
|
t.Fatalf("expected auto recovery to be skipped when already attempted")
|
|
}
|
|
|
|
if cmd := getCommandFromInput(map[string]interface{}{"command": "ls"}); cmd != "ls" {
|
|
t.Fatalf("expected command to be extracted")
|
|
}
|
|
if cmd := getCommandFromInput(map[string]interface{}{}); cmd != "<unknown>" {
|
|
t.Fatalf("expected fallback command string")
|
|
}
|
|
}
|
|
|
|
func TestAgenticLoop_AutoRecoveryExecutesStructuredToolCall(t *testing.T) {
|
|
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
|
failCalls := 0
|
|
recoveryCalls := 0
|
|
|
|
executor.RegisterTool(tools.RegisteredTool{
|
|
Definition: tools.Tool{
|
|
Name: "fail_tool",
|
|
InputSchema: tools.InputSchema{
|
|
Type: "object",
|
|
Properties: map[string]tools.PropertySchema{},
|
|
},
|
|
},
|
|
Handler: func(ctx context.Context, exec *tools.PulseToolExecutor, args map[string]interface{}) (tools.CallToolResult, error) {
|
|
failCalls++
|
|
return tools.NewToolResponseResult(tools.NewToolBlockedError(
|
|
tools.ErrCodeActionNotAllowed,
|
|
"blocked",
|
|
map[string]interface{}{
|
|
"auto_recoverable": true,
|
|
"suggested_tool": "recovery_tool",
|
|
"suggested_arguments": map[string]interface{}{
|
|
"value": "recovered through query",
|
|
},
|
|
},
|
|
)), nil
|
|
},
|
|
})
|
|
executor.RegisterTool(tools.RegisteredTool{
|
|
Definition: tools.Tool{
|
|
Name: "recovery_tool",
|
|
InputSchema: tools.InputSchema{
|
|
Type: "object",
|
|
Properties: map[string]tools.PropertySchema{
|
|
"value": {Type: "string"},
|
|
},
|
|
},
|
|
},
|
|
Handler: func(ctx context.Context, exec *tools.PulseToolExecutor, args map[string]interface{}) (tools.CallToolResult, error) {
|
|
recoveryCalls++
|
|
value, _ := args["value"].(string)
|
|
return tools.NewTextResult(value), nil
|
|
},
|
|
})
|
|
|
|
provider := &stubStreamingProvider{}
|
|
loop := NewAgenticLoop(provider, executor, "prompt")
|
|
callCount := 0
|
|
provider.chatStream = func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
|
|
callCount++
|
|
switch callCount {
|
|
case 1:
|
|
callback(providers.StreamEvent{
|
|
Type: "tool_start",
|
|
Data: providers.ToolStartEvent{ID: "call_1", Name: "fail_tool"},
|
|
})
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{
|
|
ToolCalls: []providers.ToolCall{{
|
|
ID: "call_1",
|
|
Name: "fail_tool",
|
|
Input: map[string]interface{}{},
|
|
}},
|
|
},
|
|
})
|
|
case 2:
|
|
if len(req.Messages) != 3 {
|
|
t.Fatalf("expected assistant tool call plus tool result, got %d messages", len(req.Messages))
|
|
}
|
|
if req.Messages[2].ToolResult == nil || req.Messages[2].ToolResult.IsError {
|
|
t.Fatalf("expected recovered tool result, got %+v", req.Messages[2].ToolResult)
|
|
}
|
|
if !strings.Contains(req.Messages[2].ToolResult.Content, "recovered through query") {
|
|
t.Fatalf("expected recovered output in provider follow-up, got %+v", req.Messages[2].ToolResult)
|
|
}
|
|
callback(providers.StreamEvent{
|
|
Type: "content",
|
|
Data: providers.ContentEvent{Text: "Recovered."},
|
|
})
|
|
callback(providers.StreamEvent{
|
|
Type: "done",
|
|
Data: providers.DoneEvent{},
|
|
})
|
|
default:
|
|
t.Fatalf("unexpected provider call %d", callCount)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
results, err := loop.Execute(context.Background(), "structured-auto-recovery", []Message{{Role: "user", Content: "help"}}, func(event StreamEvent) {})
|
|
if err != nil {
|
|
t.Fatalf("expected successful auto-recovery, got %v", err)
|
|
}
|
|
if failCalls != 1 || recoveryCalls != 1 {
|
|
t.Fatalf("expected one fail call and one recovery call, got fail=%d recovery=%d", failCalls, recoveryCalls)
|
|
}
|
|
if len(results) != 3 {
|
|
t.Fatalf("expected assistant tool call, recovered tool result, and final response, got %+v", results)
|
|
}
|
|
if results[1].ToolResult == nil || !strings.Contains(results[1].ToolResult.Content, "recovered through query") {
|
|
t.Fatalf("expected recovered tool result in transcript, got %+v", results[1])
|
|
}
|
|
if results[2].Content != "Recovered." {
|
|
t.Fatalf("unexpected final response: %+v", results[2])
|
|
}
|
|
}
|