mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 03:20:11 +00:00
fix(ai): discovery transient error handling, agentic loop detection, and read-only classification
- Discovery: classify transient errors (429, timeout, connection refused, etc.) and return IsError:true so models stop retrying rate-limited calls - Agentic loop: detect identical tool calls repeated >3 times and block with LOOP_DETECTED error, forcing the model to try a different approach - OpenAI provider: skip tool_choice for DeepSeek Reasoner which doesn't support it - Read-only classifier: fix curl -I case sensitivity (uppercase flags lowered), add iostat/vmstat/mpstat/sar/lxc-ls/lxc-info/nc -z to allowlist, fix 2>&1 false positive in input redirect detection
This commit is contained in:
parent
f0a356c016
commit
e85ec858fd
6 changed files with 535 additions and 30 deletions
|
|
@ -39,6 +39,9 @@ type AgenticLoop struct {
|
|||
|
||||
// Per-session FSMs for workflow enforcement (set before each execution)
|
||||
sessionFSM *SessionFSM
|
||||
|
||||
// Budget checker called after each turn to enforce token spending limits
|
||||
budgetChecker func() error
|
||||
}
|
||||
|
||||
// NewAgenticLoop creates a new agentic loop
|
||||
|
|
@ -98,6 +101,13 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
toolsSucceededThisEpisode := false // Track if any tool executed successfully this episode
|
||||
preferredToolName := ""
|
||||
preferredToolRetried := false
|
||||
singleToolRequested := isSingleToolRequest(providerMessages)
|
||||
singleToolEnforced := false
|
||||
|
||||
// Loop detection: track identical tool calls (name + serialized input).
|
||||
// After maxIdenticalCalls identical invocations, the next one is blocked.
|
||||
const maxIdenticalCalls = 3
|
||||
recentCallCounts := make(map[string]int)
|
||||
|
||||
for turn < a.maxTurns {
|
||||
// Check if aborted
|
||||
|
|
@ -149,6 +159,9 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
}
|
||||
if preferredToolName != "" {
|
||||
req.ToolChoice = &providers.ToolChoice{Type: providers.ToolChoiceTool, Name: preferredToolName}
|
||||
if singleToolRequested {
|
||||
singleToolEnforced = true
|
||||
}
|
||||
log.Debug().
|
||||
Str("session_id", sessionID).
|
||||
Str("tool", preferredToolName).
|
||||
|
|
@ -268,6 +281,15 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
return resultMessages, fmt.Errorf("provider error: %w", err)
|
||||
}
|
||||
|
||||
// Check mid-run budget after each turn completes
|
||||
if a.budgetChecker != nil {
|
||||
if budgetErr := a.budgetChecker(); budgetErr != nil {
|
||||
log.Warn().Err(budgetErr).Int("turn", turn).Str("session_id", sessionID).
|
||||
Msg("[AgenticLoop] Budget exceeded mid-run, stopping")
|
||||
return resultMessages, fmt.Errorf("budget exceeded: %w", budgetErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Create assistant message
|
||||
assistantMsg := Message{
|
||||
ID: uuid.New().String(),
|
||||
|
|
@ -419,6 +441,7 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
// Clear preferred tool once the model has used any tool.
|
||||
preferredToolName = ""
|
||||
}
|
||||
firstToolResultText := ""
|
||||
for _, tc := range toolCalls {
|
||||
// Check for abort
|
||||
a.mu.Lock()
|
||||
|
|
@ -501,6 +524,49 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
}
|
||||
}
|
||||
|
||||
// === LOOP DETECTION: Block identical repeated tool calls ===
|
||||
callKey := toolCallKey(tc.Name, tc.Input)
|
||||
recentCallCounts[callKey]++
|
||||
if recentCallCounts[callKey] > maxIdenticalCalls {
|
||||
log.Warn().
|
||||
Str("tool", tc.Name).
|
||||
Int("count", recentCallCounts[callKey]).
|
||||
Str("session_id", sessionID).
|
||||
Msg("[AgenticLoop] LOOP_DETECTED: blocking repeated identical tool call")
|
||||
|
||||
loopMsg := fmt.Sprintf("LOOP_DETECTED: You have called %s with the same arguments %d times. This call is blocked. Try a different tool or approach.", tc.Name, recentCallCounts[callKey])
|
||||
|
||||
jsonData, _ := json.Marshal(ToolEndData{
|
||||
ID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Input: "",
|
||||
Output: loopMsg,
|
||||
Success: false,
|
||||
})
|
||||
callback(StreamEvent{Type: "tool_end", Data: jsonData})
|
||||
|
||||
toolResultMsg := Message{
|
||||
ID: uuid.New().String(),
|
||||
Role: "user",
|
||||
Timestamp: time.Now(),
|
||||
ToolResult: &ToolResult{
|
||||
ToolUseID: tc.ID,
|
||||
Content: loopMsg,
|
||||
IsError: true,
|
||||
},
|
||||
}
|
||||
resultMessages = append(resultMessages, toolResultMsg)
|
||||
providerMessages = append(providerMessages, providers.Message{
|
||||
Role: "user",
|
||||
ToolResult: &providers.ToolResult{
|
||||
ToolUseID: tc.ID,
|
||||
Content: loopMsg,
|
||||
IsError: true,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
result, err := a.executor.ExecuteTool(ctx, tc.Name, tc.Input)
|
||||
|
||||
|
|
@ -521,6 +587,10 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
}
|
||||
}
|
||||
|
||||
if firstToolResultText == "" {
|
||||
firstToolResultText = resultText
|
||||
}
|
||||
|
||||
// Track pending recovery for strict resolution blocks
|
||||
// (FSM blocks are tracked above; strict resolution blocks come from the executor)
|
||||
if isError && fsm != nil && strings.Contains(resultText, "STRICT_RESOLUTION") {
|
||||
|
|
@ -770,6 +840,26 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
|
|||
})
|
||||
}
|
||||
|
||||
if singleToolEnforced && len(toolCalls) > 0 {
|
||||
summary := firstToolResultText
|
||||
if strings.TrimSpace(summary) == "" {
|
||||
if preferredToolName != "" {
|
||||
summary = fmt.Sprintf("Tool %s completed.", preferredToolName)
|
||||
} else {
|
||||
summary = "Tool call completed."
|
||||
}
|
||||
}
|
||||
if len(resultMessages) > 0 {
|
||||
lastIdx := len(resultMessages) - 1
|
||||
if resultMessages[lastIdx].Role == "assistant" && strings.TrimSpace(resultMessages[lastIdx].Content) == "" {
|
||||
resultMessages[lastIdx].Content = summary
|
||||
}
|
||||
}
|
||||
jsonData, _ := json.Marshal(ContentData{Text: summary})
|
||||
callback(StreamEvent{Type: "content", Data: jsonData})
|
||||
return resultMessages, nil
|
||||
}
|
||||
|
||||
turn++
|
||||
}
|
||||
|
||||
|
|
@ -800,6 +890,13 @@ func (a *AgenticLoop) SetSessionFSM(fsm *SessionFSM) {
|
|||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetMaxTurns overrides the maximum number of agentic turns for this loop.
|
||||
func (a *AgenticLoop) SetMaxTurns(n int) {
|
||||
a.mu.Lock()
|
||||
a.maxTurns = n
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetProviderInfo sets the provider/model info for telemetry.
|
||||
func (a *AgenticLoop) SetProviderInfo(provider, model string) {
|
||||
a.mu.Lock()
|
||||
|
|
@ -808,6 +905,12 @@ func (a *AgenticLoop) SetProviderInfo(provider, model string) {
|
|||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetBudgetChecker sets a function called after each agentic turn to enforce
|
||||
// token spending limits. If the checker returns an error, the loop stops.
|
||||
func (a *AgenticLoop) SetBudgetChecker(fn func() error) {
|
||||
a.budgetChecker = fn
|
||||
}
|
||||
|
||||
// GetTotalInputTokens returns the accumulated input tokens across all turns.
|
||||
func (a *AgenticLoop) GetTotalInputTokens() int {
|
||||
return a.totalInputTokens
|
||||
|
|
@ -973,6 +1076,16 @@ func tryAutoRecovery(result tools.CallToolResult, tc providers.ToolCall, executo
|
|||
return "", false
|
||||
}
|
||||
|
||||
// toolCallKey returns a string key for a tool call (name + serialized input)
|
||||
// used to detect repeated identical calls in the agentic loop.
|
||||
func toolCallKey(name string, input map[string]interface{}) string {
|
||||
inputBytes, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return name
|
||||
}
|
||||
return name + ":" + string(inputBytes)
|
||||
}
|
||||
|
||||
// getCommandFromInput extracts the command from tool input for logging.
|
||||
func getCommandFromInput(input map[string]interface{}) string {
|
||||
if cmd, ok := input["command"].(string); ok {
|
||||
|
|
@ -1165,6 +1278,45 @@ func getPreferredTool(messages []providers.Message, tools []providers.Tool) stri
|
|||
return ""
|
||||
}
|
||||
|
||||
// isSingleToolRequest detects user instructions to use exactly one tool call.
|
||||
func isSingleToolRequest(messages []providers.Message) bool {
|
||||
var lastUserContent string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" && messages[i].ToolResult == nil {
|
||||
lastUserContent = strings.ToLower(messages[i].Content)
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastUserContent == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
patterns := []string{
|
||||
"only that tool once",
|
||||
"only this tool once",
|
||||
"call only that tool once",
|
||||
"call only this tool once",
|
||||
"call only that tool",
|
||||
"call only this tool",
|
||||
"call only one tool",
|
||||
"only one tool",
|
||||
"single tool",
|
||||
"use only that tool",
|
||||
"use only this tool",
|
||||
"do not call any other tools",
|
||||
"don't call any other tools",
|
||||
"no other tools",
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(lastUserContent, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getSystemPrompt builds the full system prompt including the current mode context.
|
||||
// This is called at request time so the prompt reflects the current mode.
|
||||
func (a *AgenticLoop) getSystemPrompt() string {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ package chat
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -375,6 +377,165 @@ func TestHasPhantomExecution(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestToolCallKey(t *testing.T) {
|
||||
t.Run("same name and input produce same key", func(t *testing.T) {
|
||||
input := map[string]interface{}{"action": "get", "resource_type": "lxc", "host_id": "node1"}
|
||||
k1 := toolCallKey("pulse_discovery", input)
|
||||
k2 := toolCallKey("pulse_discovery", input)
|
||||
assert.Equal(t, k1, k2)
|
||||
})
|
||||
|
||||
t.Run("different input produces different key", func(t *testing.T) {
|
||||
input1 := map[string]interface{}{"action": "get", "resource_id": "100"}
|
||||
input2 := map[string]interface{}{"action": "get", "resource_id": "200"}
|
||||
k1 := toolCallKey("pulse_discovery", input1)
|
||||
k2 := toolCallKey("pulse_discovery", input2)
|
||||
assert.NotEqual(t, k1, k2)
|
||||
})
|
||||
|
||||
t.Run("different tool name produces different key", func(t *testing.T) {
|
||||
input := map[string]interface{}{"action": "get"}
|
||||
k1 := toolCallKey("pulse_discovery", input)
|
||||
k2 := toolCallKey("pulse_query", input)
|
||||
assert.NotEqual(t, k1, k2)
|
||||
})
|
||||
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
k := toolCallKey("pulse_discovery", nil)
|
||||
assert.Contains(t, k, "pulse_discovery")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoopDetection(t *testing.T) {
|
||||
// Simulate the loop detection logic from executeWithTools
|
||||
const maxIdenticalCalls = 3
|
||||
|
||||
t.Run("allows up to maxIdenticalCalls", func(t *testing.T) {
|
||||
recentCallCounts := make(map[string]int)
|
||||
input := map[string]interface{}{"action": "get", "resource_type": "lxc"}
|
||||
key := toolCallKey("pulse_discovery", input)
|
||||
|
||||
for i := 0; i < maxIdenticalCalls; i++ {
|
||||
recentCallCounts[key]++
|
||||
assert.LessOrEqual(t, recentCallCounts[key], maxIdenticalCalls,
|
||||
"call %d should be allowed", i+1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blocks call exceeding maxIdenticalCalls", func(t *testing.T) {
|
||||
recentCallCounts := make(map[string]int)
|
||||
input := map[string]interface{}{"action": "get", "resource_type": "lxc"}
|
||||
key := toolCallKey("pulse_discovery", input)
|
||||
|
||||
// Simulate 3 allowed calls
|
||||
for i := 0; i < maxIdenticalCalls; i++ {
|
||||
recentCallCounts[key]++
|
||||
}
|
||||
|
||||
// 4th call should be blocked
|
||||
recentCallCounts[key]++
|
||||
assert.Greater(t, recentCallCounts[key], maxIdenticalCalls,
|
||||
"4th identical call should exceed limit")
|
||||
})
|
||||
|
||||
t.Run("different calls tracked independently", func(t *testing.T) {
|
||||
recentCallCounts := make(map[string]int)
|
||||
input1 := map[string]interface{}{"action": "get", "resource_id": "100"}
|
||||
input2 := map[string]interface{}{"action": "get", "resource_id": "200"}
|
||||
key1 := toolCallKey("pulse_discovery", input1)
|
||||
key2 := toolCallKey("pulse_discovery", input2)
|
||||
|
||||
// Call key1 three times
|
||||
for i := 0; i < maxIdenticalCalls; i++ {
|
||||
recentCallCounts[key1]++
|
||||
}
|
||||
|
||||
// key2 should still be fine
|
||||
recentCallCounts[key2]++
|
||||
assert.Equal(t, 1, recentCallCounts[key2])
|
||||
assert.Equal(t, maxIdenticalCalls, recentCallCounts[key1])
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoopDetectionIntegration(t *testing.T) {
|
||||
// Integration test: run the agentic loop with a provider that keeps
|
||||
// calling the same tool, and verify the 4th identical call is blocked.
|
||||
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
|
||||
mockProvider := &MockProvider{}
|
||||
loop := NewAgenticLoop(mockProvider, executor, "You are a helper")
|
||||
ctx := context.Background()
|
||||
sessionID := "loop-detect-session"
|
||||
messages := []Message{{Role: "user", Content: "discover lxc 100"}}
|
||||
|
||||
callCount := 0
|
||||
// The provider will keep requesting the same tool call up to 5 times
|
||||
mockProvider.On("ChatStream", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||
callback := args.Get(2).(providers.StreamCallback)
|
||||
callCount++
|
||||
|
||||
// Check if we got a LOOP_DETECTED error in the messages — if so, stop calling tools
|
||||
req := args.Get(1).(providers.ChatRequest)
|
||||
for _, msg := range req.Messages {
|
||||
if msg.ToolResult != nil && strings.Contains(msg.ToolResult.Content, "LOOP_DETECTED") {
|
||||
// Model should stop — emit content and no tool calls
|
||||
callback(providers.StreamEvent{
|
||||
Type: "content",
|
||||
Data: providers.ContentEvent{Text: "I'll try a different approach."},
|
||||
})
|
||||
callback(providers.StreamEvent{
|
||||
Type: "done",
|
||||
Data: providers.DoneEvent{},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Keep calling the same tool
|
||||
callback(providers.StreamEvent{
|
||||
Type: "tool_start",
|
||||
Data: providers.ToolStartEvent{ID: fmt.Sprintf("call_%d", callCount), Name: "pulse_discovery"},
|
||||
})
|
||||
callback(providers.StreamEvent{
|
||||
Type: "done",
|
||||
Data: providers.DoneEvent{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: fmt.Sprintf("call_%d", callCount),
|
||||
Name: "pulse_discovery",
|
||||
Input: map[string]interface{}{
|
||||
"action": "get",
|
||||
"resource_type": "lxc",
|
||||
"resource_id": "100",
|
||||
"host_id": "node1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
var events []StreamEvent
|
||||
results, err := loop.Execute(ctx, sessionID, messages, func(event StreamEvent) {
|
||||
events = append(events, event)
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify LOOP_DETECTED appears in at least one tool result
|
||||
foundLoopDetected := false
|
||||
for _, msg := range results {
|
||||
if msg.ToolResult != nil && strings.Contains(msg.ToolResult.Content, "LOOP_DETECTED") {
|
||||
foundLoopDetected = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundLoopDetected, "expected LOOP_DETECTED in tool results")
|
||||
|
||||
// The loop should have stopped (model returned content after seeing LOOP_DETECTED)
|
||||
// Total calls: 4 tool-calling turns (3 allowed + 1 blocked) + 1 final content turn = 5
|
||||
assert.LessOrEqual(t, callCount, 6, "loop should terminate after detection")
|
||||
}
|
||||
|
||||
func TestTruncateForLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -308,9 +308,10 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
|||
},
|
||||
})
|
||||
}
|
||||
if len(openaiReq.Tools) > 0 {
|
||||
if len(openaiReq.Tools) > 0 && !c.isDeepSeekReasoner() {
|
||||
// Map ToolChoice to OpenAI format
|
||||
// OpenAI uses "required" instead of Anthropic's "any"
|
||||
// DeepSeek Reasoner does not support tool_choice — it decides tool use via reasoning
|
||||
openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
|
@ -380,6 +381,7 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
|||
if err := json.Unmarshal(respBody, &errResp); err == nil && errResp.Error.Message != "" {
|
||||
errMsg = errResp.Error.Message
|
||||
}
|
||||
errMsg = appendRateLimitInfo(errMsg, resp)
|
||||
lastErr = fmt.Errorf("API error (%d): %s", resp.StatusCode, errMsg)
|
||||
continue
|
||||
}
|
||||
|
|
@ -388,9 +390,11 @@ func (c *OpenAIClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse
|
|||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp openaiError
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil && errResp.Error.Message != "" {
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
errMsg := appendRateLimitInfo(errResp.Error.Message, resp)
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
||||
errMsg := appendRateLimitInfo(string(respBody), resp)
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, errMsg)
|
||||
}
|
||||
|
||||
// Success - break out of retry loop
|
||||
|
|
@ -641,8 +645,9 @@ func (c *OpenAIClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
|||
},
|
||||
})
|
||||
}
|
||||
if len(openaiReq.Tools) > 0 {
|
||||
if len(openaiReq.Tools) > 0 && !c.isDeepSeekReasoner() {
|
||||
// Map ToolChoice to OpenAI format (same as non-streaming)
|
||||
// DeepSeek Reasoner does not support tool_choice — it decides tool use via reasoning
|
||||
openaiReq.ToolChoice = convertToolChoiceToOpenAI(req.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
|
@ -671,9 +676,11 @@ func (c *OpenAIClient) ChatStream(ctx context.Context, req ChatRequest, callback
|
|||
respBody, _ := io.ReadAll(resp.Body)
|
||||
var errResp openaiError
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil && errResp.Error.Message != "" {
|
||||
return fmt.Errorf("API error (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
errMsg := appendRateLimitInfo(errResp.Error.Message, resp)
|
||||
return fmt.Errorf("API error (%d): %s", resp.StatusCode, errMsg)
|
||||
}
|
||||
return fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
||||
errMsg := appendRateLimitInfo(string(respBody), resp)
|
||||
return fmt.Errorf("API error (%d): %s", resp.StatusCode, errMsg)
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
|
|
@ -837,7 +844,8 @@ func (c *OpenAIClient) ListModels(ctx context.Context) ([]ModelInfo, error) {
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(body))
|
||||
errMsg := appendRateLimitInfo(string(body), resp)
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, errMsg)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ func (e *PulseToolExecutor) executeGetDiscovery(ctx context.Context, args map[st
|
|||
|
||||
if resourceType == "lxc" {
|
||||
for _, c := range state.Containers {
|
||||
if strings.EqualFold(c.Name, resourceID) && c.Node == hostID {
|
||||
if strings.EqualFold(c.Name, resourceID) && nodeMatchesHostID(c.Node, hostID) {
|
||||
resourceID = fmt.Sprintf("%d", c.VMID)
|
||||
resolved = true
|
||||
break
|
||||
|
|
@ -261,7 +261,7 @@ func (e *PulseToolExecutor) executeGetDiscovery(ctx context.Context, args map[st
|
|||
}
|
||||
} else if resourceType == "vm" {
|
||||
for _, vm := range state.VMs {
|
||||
if strings.EqualFold(vm.Name, resourceID) && vm.Node == hostID {
|
||||
if strings.EqualFold(vm.Name, resourceID) && nodeMatchesHostID(vm.Node, hostID) {
|
||||
resourceID = fmt.Sprintf("%d", vm.VMID)
|
||||
resolved = true
|
||||
break
|
||||
|
|
@ -288,7 +288,19 @@ func (e *PulseToolExecutor) executeGetDiscovery(ctx context.Context, args map[st
|
|||
if discovery == nil {
|
||||
discovery, err = e.discoveryProvider.TriggerDiscovery(ctx, resourceType, hostID, resourceID)
|
||||
if err != nil {
|
||||
// Even on failure, provide cli_access so AI can investigate manually
|
||||
// Distinguish transient errors (rate limits, timeouts) from genuine not-found.
|
||||
// Transient errors must surface as IsError so the model stops retrying.
|
||||
if isTransientError(err) {
|
||||
return CallToolResult{
|
||||
Content: []Content{{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Discovery temporarily unavailable: %v. Do NOT retry this call. Use pulse_control or a different approach to investigate the resource.", err),
|
||||
}},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Genuine failure (e.g. resource doesn't exist) — keep existing behavior
|
||||
return NewJSONResult(map[string]interface{}{
|
||||
"found": false,
|
||||
"resource_type": resourceType,
|
||||
|
|
@ -408,6 +420,56 @@ func (e *PulseToolExecutor) executeGetDiscovery(ctx context.Context, args map[st
|
|||
return NewJSONResult(response), nil
|
||||
}
|
||||
|
||||
// isTransientError checks whether an error is a transient API/infrastructure error
|
||||
// (rate limit, timeout, temporary unavailability) rather than a genuine "not found".
|
||||
// When true, the caller should return IsError:true so the model doesn't retry.
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
transientPatterns := []string{
|
||||
"429",
|
||||
"503",
|
||||
"rate_limit",
|
||||
"rate limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"timeout",
|
||||
"context deadline exceeded",
|
||||
"failed after", // "failed after N retries"
|
||||
"temporarily", // "temporarily unavailable"
|
||||
"server overloaded",
|
||||
"service unavailable",
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"broken pipe",
|
||||
"i/o timeout",
|
||||
"network unreachable",
|
||||
}
|
||||
|
||||
for _, pattern := range transientPatterns {
|
||||
if strings.Contains(msg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// nodeMatchesHostID checks if a node name matches a host_id which may be
|
||||
// a plain node name ("delly") or a composite instance-node ID ("homelab-delly").
|
||||
func nodeMatchesHostID(nodeName, hostID string) bool {
|
||||
if strings.EqualFold(nodeName, hostID) {
|
||||
return true
|
||||
}
|
||||
// Check if hostID is a composite "instance-node" format ending with the node name
|
||||
if strings.HasSuffix(strings.ToLower(hostID), "-"+strings.ToLower(nodeName)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *PulseToolExecutor) executeListDiscoveries(_ context.Context, args map[string]interface{}) (CallToolResult, error) {
|
||||
if e.discoveryProvider == nil {
|
||||
return NewTextResult("Discovery service not available."), nil
|
||||
|
|
|
|||
59
internal/ai/tools/tools_discovery_test.go
Normal file
59
internal/ai/tools/tools_discovery_test.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsTransientError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
// Transient errors — should return true
|
||||
{"nil error", nil, false},
|
||||
{"rate limit 429", errors.New("API returned status 429"), true},
|
||||
{"503 service unavailable", errors.New("HTTP 503 Service Unavailable"), true},
|
||||
{"rate_limit underscore", errors.New("rate_limit: too many requests"), true},
|
||||
{"rate limit space", errors.New("rate limit exceeded"), true},
|
||||
{"ratelimit single word", errors.New("ratelimit error from provider"), true},
|
||||
{"too many requests", errors.New("too many requests, slow down"), true},
|
||||
{"timeout", errors.New("request timeout after 30s"), true},
|
||||
{"context deadline", errors.New("context deadline exceeded"), true},
|
||||
{"failed after retries", errors.New("failed after 3 retries"), true},
|
||||
{"temporarily unavailable", errors.New("service temporarily unavailable"), true},
|
||||
{"server overloaded", errors.New("server overloaded, try later"), true},
|
||||
{"service unavailable text", errors.New("the service is service unavailable"), true},
|
||||
{"connection refused", errors.New("dial tcp: connection refused"), true},
|
||||
{"connection reset", errors.New("connection reset by peer"), true},
|
||||
{"broken pipe", errors.New("write: broken pipe"), true},
|
||||
{"i/o timeout", errors.New("i/o timeout"), true},
|
||||
{"network unreachable", errors.New("network unreachable"), true},
|
||||
|
||||
// Anthropic-style rate limit
|
||||
{"anthropic rate limit", errors.New("Error: 429 {\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\"}}"), true},
|
||||
// OpenAI-style
|
||||
{"openai rate limit", errors.New("Rate limit reached for gpt-4"), true},
|
||||
// Gemini-style
|
||||
{"gemini quota", errors.New("429 Too Many Requests"), true},
|
||||
|
||||
// Non-transient errors — should return false
|
||||
{"resource not found", errors.New("resource not found"), false},
|
||||
{"permission denied", errors.New("permission denied"), false},
|
||||
{"invalid argument", errors.New("invalid resource_type: foo"), false},
|
||||
{"generic error", errors.New("something went wrong"), false},
|
||||
{"empty error", errors.New(""), false},
|
||||
{"auth error", errors.New("authentication failed"), false},
|
||||
{"not found", errors.New("404 not found"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isTransientError(tt.err)
|
||||
assert.Equal(t, tt.expected, result, "error: %v", tt.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -339,11 +339,12 @@ func checkMutationCapabilityGuards(command, cmdLower string) string {
|
|||
|
||||
// Input redirection means we can't inspect the content.
|
||||
// This catches: < (redirect), << (heredoc), <<< (here-string)
|
||||
// But NOT stderr-to-stdout redirections like 2>&1 which are harmless.
|
||||
// Examples blocked:
|
||||
// sqlite3 db < script.sql
|
||||
// psql <<EOF ... EOF
|
||||
// sqlite3 db <<< "SELECT ..."
|
||||
if strings.Contains(command, "<") {
|
||||
if hasInputRedirect(command) {
|
||||
return "input redirection prevents content inspection"
|
||||
}
|
||||
|
||||
|
|
@ -929,6 +930,23 @@ func isInteractiveREPL(cmdLower string) bool {
|
|||
}
|
||||
|
||||
// hasStdoutRedirect checks for dangerous output redirects while allowing safe stderr redirects.
|
||||
// hasInputRedirect checks for input redirection (<, <<, <<<) while
|
||||
// allowing harmless stderr-to-stdout redirections (2>&1, 2<&1).
|
||||
func hasInputRedirect(command string) bool {
|
||||
if !strings.Contains(command, "<") {
|
||||
return false
|
||||
}
|
||||
// Remove safe fd-merge patterns before checking for input redirects.
|
||||
// 2>&1 = merge stderr into stdout; 2<&1 = merge stderr into stdout (rare)
|
||||
// These contain '<' but aren't input redirection.
|
||||
cmd := command
|
||||
cmd = strings.ReplaceAll(cmd, "2>&1", "")
|
||||
cmd = strings.ReplaceAll(cmd, "2<&1", "")
|
||||
cmd = strings.ReplaceAll(cmd, "1>&2", "")
|
||||
cmd = strings.ReplaceAll(cmd, "&>", "") // bash shorthand for > ... 2>&1
|
||||
return strings.Contains(cmd, "<")
|
||||
}
|
||||
|
||||
func hasStdoutRedirect(command string) bool {
|
||||
if !strings.Contains(command, ">") {
|
||||
return false
|
||||
|
|
@ -994,35 +1012,78 @@ func isReadOnlyByConstruction(cmdLower string) bool {
|
|||
readOnlyCommands := []string{
|
||||
"cat", "head", "tail",
|
||||
"ls", "ll", "dir",
|
||||
"ps", "free", "df", "du",
|
||||
"ps", "free", "df", "du", "iostat", "vmstat", "mpstat", "sar",
|
||||
"grep", "awk", "sed", "find", "locate", "which", "whereis",
|
||||
"journalctl", "dmesg",
|
||||
"uname", "hostname", "whoami", "id", "groups",
|
||||
"date", "uptime", "env", "printenv", "locale",
|
||||
"netstat", "ss", "ifconfig", "route",
|
||||
"netstat", "ss", "ifconfig", "route", "arp",
|
||||
"ping", "traceroute", "tracepath", "nslookup", "dig", "host",
|
||||
"file", "stat", "wc", "sort", "uniq", "cut", "tr",
|
||||
"lsof", "fuser",
|
||||
"getent", "nproc", "lscpu", "lsmem", "lsblk", "blkid",
|
||||
"lxc-ls", "lxc-info",
|
||||
"zcat", "zgrep", "bzcat", "xzcat",
|
||||
"md5sum", "sha256sum", "sha1sum",
|
||||
"test",
|
||||
// Process inspection
|
||||
"pgrep", "pidof", "pstree",
|
||||
// Login/session info
|
||||
"last", "lastlog", "who", "w",
|
||||
// Hardware inspection
|
||||
"lspci", "lsusb", "dmidecode", "hwinfo", "inxi",
|
||||
"sensors", "hddtemp", "smartctl", "nvme",
|
||||
// Media inspection tools
|
||||
"ffprobe", "mediainfo", "exiftool",
|
||||
// Proxmox version
|
||||
"pveversion",
|
||||
}
|
||||
|
||||
// Multi-word patterns that must appear at the start
|
||||
multiWordPatterns := []string{
|
||||
"curl -s", "curl --silent", "curl -I", "curl --head",
|
||||
// Curl read-only variants (various flag combinations)
|
||||
"curl -s", "curl --silent", "curl -i", "curl --head",
|
||||
"curl -k", "curl --insecure",
|
||||
"curl -sk", "curl -ks", "curl -ki", "curl -ik",
|
||||
"curl http", "curl https",
|
||||
"wget -q", "wget --spider",
|
||||
"docker ps", "docker logs", "docker inspect", "docker stats", "docker images", "docker info",
|
||||
"systemctl status", "systemctl is-active", "systemctl is-enabled", "systemctl list", "systemctl show",
|
||||
"pct list", "pct status",
|
||||
"qm list", "qm status",
|
||||
"ip addr", "ip route", "ip link",
|
||||
// Docker read-only
|
||||
"docker ps", "docker logs", "docker inspect", "docker stats",
|
||||
"docker images", "docker info", "docker version",
|
||||
"docker top", "docker port",
|
||||
"docker network ls", "docker network inspect",
|
||||
"docker volume ls", "docker volume inspect",
|
||||
"docker-compose ps", "docker compose ps",
|
||||
// Systemd read-only
|
||||
"systemctl status", "systemctl is-active", "systemctl is-enabled",
|
||||
"systemctl list", "systemctl show",
|
||||
"service status", "service --status-all",
|
||||
// Proxmox read-only
|
||||
"pct list", "pct status", "pct config", "pct df",
|
||||
"qm list", "qm status", "qm config", "qm guest cmd",
|
||||
"pvesh get",
|
||||
"pvecm status", "pvecm nodes",
|
||||
"pvesm status", "pvesm list",
|
||||
// ZFS/ZPool read-only
|
||||
"zpool status", "zpool list", "zpool get",
|
||||
"zfs list", "zfs get",
|
||||
// RAID inspection
|
||||
"mdadm --detail", "mdadm -D",
|
||||
// Network: ip with optional protocol flags (-4, -6)
|
||||
"ip addr", "ip route", "ip link", "ip neigh", "ip neighbor",
|
||||
"ip -4 addr", "ip -4 route", "ip -4 link", "ip -4 neigh", "ip -4 neighbor",
|
||||
"ip -6 addr", "ip -6 route", "ip -6 link", "ip -6 neigh", "ip -6 neighbor",
|
||||
"ip a", "ip r", "ip n",
|
||||
// Package info (read-only)
|
||||
"apt list", "apt show", "apt-cache",
|
||||
"dpkg -l", "dpkg --list", "dpkg -s",
|
||||
"rpm -q", "rpm -qa",
|
||||
"yum list", "dnf list",
|
||||
// Kubectl read-only commands
|
||||
"kubectl get", "kubectl describe", "kubectl logs", "kubectl top", "kubectl cluster-info",
|
||||
"kubectl api-resources", "kubectl api-versions", "kubectl version", "kubectl config view",
|
||||
// Network connectivity check (scan-only, no data sent)
|
||||
"nc -z", "nc -vz", "nc -zv",
|
||||
// Timeout wrapper (makes any command bounded)
|
||||
"timeout ",
|
||||
}
|
||||
|
|
@ -1064,8 +1125,8 @@ func matchesWritePatterns(cmdLower string) string {
|
|||
"shutdown": "system shutdown", "reboot": "system reboot", "poweroff": "system poweroff", "halt": "system halt",
|
||||
"systemctl restart": "service restart", "systemctl stop": "service stop", "systemctl start": "service start",
|
||||
"systemctl enable": "service enable", "systemctl disable": "service disable",
|
||||
"service ": "service control", "init ": "init control",
|
||||
"apt ": "package management", "apt-get ": "package management", "yum ": "package management",
|
||||
"init ": "init control",
|
||||
"apt ": "package management", "apt-get ": "package management", "yum ": "package management",
|
||||
"dnf ": "package management", "pacman ": "package management", "apk ": "package management", "brew ": "package management",
|
||||
"pip install": "package install", "pip uninstall": "package uninstall",
|
||||
"npm install": "package install", "npm uninstall": "package uninstall", "cargo install": "package install",
|
||||
|
|
@ -1090,6 +1151,12 @@ func matchesWritePatterns(cmdLower string) string {
|
|||
}
|
||||
}
|
||||
|
||||
// Command-start-only patterns: these must be the first word to avoid matching
|
||||
// substrings in arguments (e.g., "pve-daily-utils.service" contains "service").
|
||||
if strings.HasPrefix(cmdLower, "service ") {
|
||||
return "service control"
|
||||
}
|
||||
|
||||
// Medium-risk patterns
|
||||
mediumRiskPatterns := map[string]string{
|
||||
"mv ": "file move", "cp ": "file copy",
|
||||
|
|
@ -2001,7 +2068,6 @@ func (e *PulseToolExecutor) executeListInfrastructure(_ context.Context, args ma
|
|||
response.Nodes = append(response.Nodes, NodeSummary{
|
||||
Name: node.Name,
|
||||
Status: node.Status,
|
||||
ID: node.ID,
|
||||
AgentConnected: connectedAgentHostnames[node.Name],
|
||||
})
|
||||
count++
|
||||
|
|
@ -2220,7 +2286,6 @@ func (e *PulseToolExecutor) executeGetTopology(_ context.Context, args map[strin
|
|||
hasAgent := connectedAgentHostnames[node.Name]
|
||||
nodeMap[node.Name] = &ProxmoxNodeTopology{
|
||||
Name: node.Name,
|
||||
ID: node.ID,
|
||||
Status: node.Status,
|
||||
AgentConnected: hasAgent,
|
||||
CanExecute: hasAgent && controlEnabled,
|
||||
|
|
@ -2230,7 +2295,7 @@ func (e *PulseToolExecutor) executeGetTopology(_ context.Context, args map[strin
|
|||
}
|
||||
}
|
||||
|
||||
ensureNode := func(name, id, status string) *ProxmoxNodeTopology {
|
||||
ensureNode := func(name, status string) *ProxmoxNodeTopology {
|
||||
if !includeProxmox || summaryOnly {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -2243,7 +2308,6 @@ func (e *PulseToolExecutor) executeGetTopology(_ context.Context, args map[strin
|
|||
hasAgent := connectedAgentHostnames[name]
|
||||
nodeMap[name] = &ProxmoxNodeTopology{
|
||||
Name: name,
|
||||
ID: id,
|
||||
Status: status,
|
||||
AgentConnected: hasAgent,
|
||||
CanExecute: hasAgent && controlEnabled,
|
||||
|
|
@ -2259,7 +2323,7 @@ func (e *PulseToolExecutor) executeGetTopology(_ context.Context, args map[strin
|
|||
summary.RunningVMs++
|
||||
}
|
||||
|
||||
nodeTopology := ensureNode(vm.Node, "", "unknown")
|
||||
nodeTopology := ensureNode(vm.Node, "unknown")
|
||||
if nodeTopology == nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -2284,7 +2348,7 @@ func (e *PulseToolExecutor) executeGetTopology(_ context.Context, args map[strin
|
|||
summary.RunningLXC++
|
||||
}
|
||||
|
||||
nodeTopology := ensureNode(ct.Node, "", "unknown")
|
||||
nodeTopology := ensureNode(ct.Node, "unknown")
|
||||
if nodeTopology == nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -2937,7 +3001,6 @@ func (e *PulseToolExecutor) executeSearchResources(_ context.Context, args map[s
|
|||
}
|
||||
addMatch(ResourceMatch{
|
||||
Type: "node",
|
||||
ID: node.ID,
|
||||
Name: node.Name,
|
||||
Status: node.Status,
|
||||
AgentConnected: connectedAgentHostnames[node.Name],
|
||||
|
|
@ -3072,9 +3135,9 @@ func (e *PulseToolExecutor) executeSearchResources(_ context.Context, args map[s
|
|||
case "node":
|
||||
reg = ResourceRegistration{
|
||||
Kind: "node",
|
||||
ProviderUID: match.ID, // Node ID is the provider UID
|
||||
ProviderUID: match.Name, // Node name is the identifier used for routing
|
||||
Name: match.Name,
|
||||
Aliases: []string{match.Name, match.ID},
|
||||
Aliases: []string{match.Name},
|
||||
HostName: match.Name,
|
||||
LocationChain: []string{"node:" + match.Name},
|
||||
Executors: []ExecutorRegistration{{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue