fix(patrol): cap per-run tokens and reset patrol session history

This commit is contained in:
rcourtman 2026-02-24 11:29:47 +00:00
parent 82ccb662f9
commit 24f5b1cb31
8 changed files with 232 additions and 29 deletions

View file

@ -31,6 +31,8 @@ type AgenticLoop struct {
tools []providers.Tool
baseSystemPrompt string // Base prompt without mode context
maxTurns int
maxTotalTokens int // Optional hard cap for total tokens in one run (0 = disabled)
stopReason string
// Provider info for telemetry (e.g., "anthropic", "claude-3-sonnet")
providerName string
@ -56,6 +58,11 @@ type AgenticLoop struct {
budgetChecker func() error
}
const (
stopReasonNone = ""
stopReasonTokenLimit = "token_limit"
)
// NewAgenticLoop creates a new agentic loop
func NewAgenticLoop(provider providers.StreamingProvider, executor *tools.PulseToolExecutor, systemPrompt string) *AgenticLoop {
// Convert MCP tools to provider format
@ -94,6 +101,8 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
// before calling ExecuteWithTools, and this avoids races with concurrent sessions.
a.mu.Lock()
maxTurns := a.maxTurns
maxTotalTokens := a.maxTotalTokens
a.stopReason = stopReasonNone
a.aborted[sessionID] = false
a.mu.Unlock()
defer func() {
@ -390,6 +399,18 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
toolCalls = nil
}
// Per-run hard token cap (used by Patrol) — stop before executing further tool calls.
tokenLimitReached := maxTotalTokens > 0 && (a.totalInputTokens+a.totalOutputTokens) >= maxTotalTokens
if tokenLimitReached && len(toolCalls) > 0 {
log.Warn().
Int("token_total", a.totalInputTokens+a.totalOutputTokens).
Int("token_limit", maxTotalTokens).
Int("stripped_tool_calls", len(toolCalls)).
Str("session_id", sessionID).
Msg("[AgenticLoop] Token cap reached — stripping pending tool calls for graceful stop")
toolCalls = nil
}
// Check mid-run budget after each turn completes
if a.budgetChecker != nil {
if budgetErr := a.budgetChecker(); budgetErr != nil {
@ -402,6 +423,16 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
// Create assistant message
// Clean DeepSeek artifacts from the content before storing
cleanedContent := cleanDeepSeekArtifacts(contentBuilder.String())
if tokenLimitReached {
note := fmt.Sprintf("Analysis stopped early after reaching per-run token budget (%d tokens). Results above may be partial.", maxTotalTokens)
if cleanedContent == "" {
cleanedContent = note
} else {
cleanedContent = strings.TrimSpace(cleanedContent + "\n\n" + note)
}
jsonData, _ := json.Marshal(ContentData{Text: "\n\n" + note})
callback(StreamEvent{Type: "content", Data: jsonData})
}
assistantMsg := Message{
ID: uuid.New().String(),
Role: "assistant",
@ -440,6 +471,18 @@ func (a *AgenticLoop) executeWithTools(ctx context.Context, sessionID string, me
}
providerMessages = append(providerMessages, providerAssistant)
if tokenLimitReached {
a.mu.Lock()
a.stopReason = stopReasonTokenLimit
a.mu.Unlock()
log.Warn().
Int("token_total", a.totalInputTokens+a.totalOutputTokens).
Int("token_limit", maxTotalTokens).
Str("session_id", sessionID).
Msg("[AgenticLoop] Token cap reached — ending run gracefully")
return resultMessages, nil
}
// If no tool calls, we're done - but first check FSM and phantom execution
if len(toolCalls) == 0 {
// If the user explicitly requested a tool and the model didn't comply, retry once.
@ -1291,6 +1334,14 @@ func (a *AgenticLoop) SetMaxTurns(n int) {
a.mu.Unlock()
}
// SetMaxTotalTokens sets a hard cap on total tokens for a single execution.
// When reached, the loop stops gracefully and returns partial analysis.
func (a *AgenticLoop) SetMaxTotalTokens(n int) {
a.mu.Lock()
a.maxTotalTokens = n
a.mu.Unlock()
}
// SetProviderInfo sets the provider/model info for telemetry.
func (a *AgenticLoop) SetProviderInfo(provider, model string) {
a.mu.Lock()
@ -1321,6 +1372,13 @@ func (a *AgenticLoop) ResetTokenCounts() {
a.totalOutputTokens = 0
}
// GetStopReason returns why the most recent execution stopped, if applicable.
func (a *AgenticLoop) GetStopReason() string {
a.mu.Lock()
defer a.mu.Unlock()
return a.stopReason
}
// hasPhantomExecution detects when the model claims to have executed something
// but no actual tool calls were made. This catches models that "hallucinate"
// tool execution by writing about it instead of calling tools.

View file

@ -545,11 +545,12 @@ func (s *Service) ExecuteStream(ctx context.Context, req ExecuteRequest, callbac
// PatrolRequest represents a patrol execution request within the chat service
type PatrolRequest struct {
Prompt string `json:"prompt"`
SystemPrompt string `json:"system_prompt"`
SessionID string `json:"session_id,omitempty"`
UseCase string `json:"use_case"`
MaxTurns int `json:"max_turns,omitempty"`
Prompt string `json:"prompt"`
SystemPrompt string `json:"system_prompt"`
SessionID string `json:"session_id,omitempty"`
UseCase string `json:"use_case"`
MaxTurns int `json:"max_turns,omitempty"`
MaxTotalTokens int `json:"max_total_tokens,omitempty"`
}
// PatrolResponse contains the results of a patrol execution
@ -557,6 +558,7 @@ type PatrolResponse struct {
Content string `json:"content"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
StopReason string `json:"stop_reason,omitempty"`
}
// ExecutePatrolStream creates a temporary agentic loop for patrol execution.
@ -606,6 +608,9 @@ func (s *Service) ExecutePatrolStream(ctx context.Context, req PatrolRequest, ca
if req.MaxTurns > 0 {
tempLoop.SetMaxTurns(req.MaxTurns)
}
if req.MaxTotalTokens > 0 {
tempLoop.SetMaxTotalTokens(req.MaxTotalTokens)
}
// Set provider info for telemetry
parts := strings.SplitN(patrolModel, ":", 2)
@ -613,11 +618,21 @@ func (s *Service) ExecutePatrolStream(ctx context.Context, req PatrolRequest, ca
tempLoop.SetProviderInfo(parts[0], parts[1])
}
// Ensure patrol session exists
// Reset patrol session history before each run.
// Patrol prompts already include full seed context, so reusing prior run
// messages only bloats input tokens and can cause severe cost spikes.
sessionID := req.SessionID
if sessionID == "" {
sessionID = "patrol-main"
}
if _, getErr := sessions.Get(sessionID); getErr == nil {
if delErr := sessions.Delete(sessionID); delErr != nil {
log.Warn().
Err(delErr).
Str("session_id", sessionID).
Msg("Failed to reset patrol session history; continuing with existing session")
}
}
session, err := sessions.EnsureSession(sessionID)
if err != nil {
return nil, fmt.Errorf("failed to ensure patrol session: %w", err)
@ -707,6 +722,7 @@ func (s *Service) ExecutePatrolStream(ctx context.Context, req PatrolRequest, ca
Content: contentBuilder.String(),
InputTokens: tempLoop.GetTotalInputTokens(),
OutputTokens: tempLoop.GetTotalOutputTokens(),
StopReason: tempLoop.GetStopReason(),
}, nil
}

View file

@ -20,6 +20,7 @@ func (m mockKnowledgeStore) GetKnowledge(resourceID, category string) []tools.Kn
type mockStreamingProvider struct {
chatStreamFunc func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error
lastRequest providers.ChatRequest
requests []providers.ChatRequest
}
func (m *mockStreamingProvider) Chat(ctx context.Context, req providers.ChatRequest) (*providers.ChatResponse, error) {
@ -28,6 +29,7 @@ func (m *mockStreamingProvider) Chat(ctx context.Context, req providers.ChatRequ
func (m *mockStreamingProvider) ChatStream(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
m.lastRequest = req
m.requests = append(m.requests, req)
if m.chatStreamFunc != nil {
return m.chatStreamFunc(ctx, req, callback)
}
@ -146,6 +148,94 @@ func TestService_ExecutePatrolStream_Success(t *testing.T) {
}
}
func TestService_ExecutePatrolStream_ResetsSessionHistoryEachRun(t *testing.T) {
store, err := NewSessionStore(t.TempDir())
if err != nil {
t.Fatalf("failed to create session store: %v", err)
}
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
service := &Service{
started: true,
sessions: store,
executor: executor,
cfg: &config.AIConfig{PatrolModel: "mock:model"},
}
mockProvider := &mockStreamingProvider{
chatStreamFunc: func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "ok"}})
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{InputTokens: 1, OutputTokens: 1}})
return nil
},
}
service.providerFactory = func(modelStr string) (providers.StreamingProvider, error) {
return mockProvider, nil
}
for i := 0; i < 2; i++ {
if _, err := service.ExecutePatrolStream(context.Background(), PatrolRequest{
Prompt: "check status",
MaxTurns: 1,
SessionID: "patrol-main",
}, func(StreamEvent) {}); err != nil {
t.Fatalf("patrol run %d failed: %v", i+1, err)
}
}
if len(mockProvider.requests) != 2 {
t.Fatalf("expected 2 patrol requests, got %d", len(mockProvider.requests))
}
if len(mockProvider.requests[0].Messages) != 1 {
t.Fatalf("expected first run to include only current prompt message, got %d", len(mockProvider.requests[0].Messages))
}
if len(mockProvider.requests[1].Messages) != 1 {
t.Fatalf("expected second run to reset history and include only current prompt message, got %d", len(mockProvider.requests[1].Messages))
}
}
func TestService_ExecutePatrolStream_TokenCapStopsGracefully(t *testing.T) {
store, err := NewSessionStore(t.TempDir())
if err != nil {
t.Fatalf("failed to create session store: %v", err)
}
executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{})
service := &Service{
started: true,
sessions: store,
executor: executor,
cfg: &config.AIConfig{PatrolModel: "mock:model"},
}
mockProvider := &mockStreamingProvider{
chatStreamFunc: func(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error {
callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "initial analysis"}})
callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{InputTokens: 2, OutputTokens: 1}})
return nil
},
}
service.providerFactory = func(modelStr string) (providers.StreamingProvider, error) {
return mockProvider, nil
}
resp, err := service.ExecutePatrolStream(context.Background(), PatrolRequest{
Prompt: "check status",
MaxTurns: 5,
MaxTotalTokens: 1,
SessionID: "patrol-main",
}, func(StreamEvent) {})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil {
t.Fatalf("expected response")
}
if resp.StopReason != "token_limit" {
t.Fatalf("expected token_limit stop reason, got %q", resp.StopReason)
}
}
func TestService_ExecutePatrolStream_Errors(t *testing.T) {
store, err := NewSessionStore(t.TempDir())
if err != nil {

View file

@ -127,9 +127,11 @@ type PatrolRunRecord struct {
ErrorCount int `json:"error_count"`
Status string `json:"status"` // "healthy", "issues_found", "error"
// AI Analysis details
AIAnalysis string `json:"ai_analysis,omitempty"` // The AI's raw response/analysis
InputTokens int `json:"input_tokens,omitempty"` // Tokens sent to AI
OutputTokens int `json:"output_tokens,omitempty"` // Tokens received from AI
AIAnalysis string `json:"ai_analysis,omitempty"` // The AI's raw response/analysis
InputTokens int `json:"input_tokens,omitempty"` // Tokens sent to AI
OutputTokens int `json:"output_tokens,omitempty"` // Tokens received from AI
AnalysisStopReason string `json:"analysis_stop_reason,omitempty"` // Why analysis stopped early, if applicable
AnalysisStoppedEarly bool `json:"analysis_stopped_early,omitempty"` // True when patrol hit a hard per-run guardrail
// Tool call traces
ToolCalls []ToolCallRecord `json:"tool_calls,omitempty"`
ToolCallCount int `json:"tool_call_count"`

View file

@ -26,6 +26,8 @@ type AIAnalysisResult struct {
RejectedFindings int // Findings rejected by threshold validation
InputTokens int
OutputTokens int
StopReason string
StoppedEarly bool
ToolCalls []ToolCallRecord // Tool invocations during this analysis
ReportedIDs []string // Finding IDs reported (created/re-reported) this run
ResolvedIDs []string // Finding IDs explicitly resolved by LLM this run
@ -38,6 +40,7 @@ const (
patrolTurnsPer50Devices = 5
patrolQuickMinTurns = 10
patrolQuickMaxTurns = 30
patrolMaxTotalTokens = 250000
)
// CleanThinkingTokens removes model-specific thinking markers from AI responses.
@ -283,11 +286,12 @@ func (p *PatrolService) runAIAnalysis(ctx context.Context, state models.StateSna
var rawToolOutputs []string
chatResp, chatErr := cs.ExecutePatrolStream(ctx, PatrolExecuteRequest{
Prompt: seedContext,
SystemPrompt: p.getPatrolSystemPrompt(),
SessionID: "patrol-main",
UseCase: "patrol",
MaxTurns: maxTurns,
Prompt: seedContext,
SystemPrompt: p.getPatrolSystemPrompt(),
SessionID: "patrol-main",
UseCase: "patrol",
MaxTurns: maxTurns,
MaxTotalTokens: patrolMaxTotalTokens,
}, func(event ChatStreamEvent) {
switch event.Type {
case "content":
@ -536,6 +540,8 @@ func (p *PatrolService) runAIAnalysis(ctx context.Context, state models.StateSna
RejectedFindings: rejectedCount,
InputTokens: inputTokens,
OutputTokens: outputTokens,
StopReason: chatResp.StopReason,
StoppedEarly: chatResp.StopReason == "token_limit",
ToolCalls: collectedToolCalls,
ReportedIDs: adapter.getReportedFindingIDs(),
ResolvedIDs: adapter.getResolvedIDs(),
@ -714,11 +720,12 @@ func (p *PatrolService) runEvaluationPass(ctx context.Context, adapter *patrolFi
Msg("AI Patrol: Running evaluation pass for unmatched signals")
resp, err := cs.ExecutePatrolStream(ctx, PatrolExecuteRequest{
Prompt: userPrompt,
SystemPrompt: systemPrompt,
SessionID: "patrol-eval",
UseCase: "patrol",
MaxTurns: 5,
Prompt: userPrompt,
SystemPrompt: systemPrompt,
SessionID: "patrol-eval",
UseCase: "patrol",
MaxTurns: 5,
MaxTotalTokens: patrolMaxTotalTokens,
}, func(event ChatStreamEvent) {
// Minimal callback — we don't stream eval pass to the frontend
// but findings are still created via the adapter

View file

@ -472,6 +472,17 @@ func (p *PatrolService) runPatrolWithTrigger(ctx context.Context, trigger Trigge
findingsSummaryStr = fmt.Sprintf("Analysis incomplete (%d errors)", runStats.errors)
}
}
if runStats.aiAnalysis != nil && runStats.aiAnalysis.StoppedEarly {
const partialMsg = "Partial analysis (per-run token budget reached)"
if findingsSummaryStr == "All healthy" {
findingsSummaryStr = partialMsg
} else {
findingsSummaryStr = findingsSummaryStr + " • " + partialMsg
}
if status == "healthy" {
status = "issues_found"
}
}
// Create run record
runRecord := PatrolRunRecord{
@ -515,6 +526,8 @@ func (p *PatrolService) runPatrolWithTrigger(ctx context.Context, trigger Trigge
runRecord.AIAnalysis = runStats.aiAnalysis.Response
runRecord.InputTokens = runStats.aiAnalysis.InputTokens
runRecord.OutputTokens = runStats.aiAnalysis.OutputTokens
runRecord.AnalysisStopReason = runStats.aiAnalysis.StopReason
runRecord.AnalysisStoppedEarly = runStats.aiAnalysis.StoppedEarly
toolCalls := runStats.aiAnalysis.ToolCalls
if len(toolCalls) > MaxToolCallsPerRun {
toolCalls = toolCalls[:MaxToolCallsPerRun]
@ -779,6 +792,17 @@ func (p *PatrolService) runScopedPatrol(ctx context.Context, scope PatrolScope)
findingsSummaryStr = fmt.Sprintf("Analysis incomplete (%d errors)", runStats.errors)
}
}
if runStats.aiAnalysis != nil && runStats.aiAnalysis.StoppedEarly {
const partialMsg = "Partial analysis (per-run token budget reached)"
if findingsSummaryStr == "All healthy" {
findingsSummaryStr = partialMsg
} else {
findingsSummaryStr = findingsSummaryStr + " • " + partialMsg
}
if status == "healthy" {
status = "issues_found"
}
}
runRecord := PatrolRunRecord{
ID: fmt.Sprintf("%d", start.UnixNano()),
@ -815,6 +839,8 @@ func (p *PatrolService) runScopedPatrol(ctx context.Context, scope PatrolScope)
runRecord.AIAnalysis = runStats.aiAnalysis.Response
runRecord.InputTokens = runStats.aiAnalysis.InputTokens
runRecord.OutputTokens = runStats.aiAnalysis.OutputTokens
runRecord.AnalysisStopReason = runStats.aiAnalysis.StopReason
runRecord.AnalysisStoppedEarly = runStats.aiAnalysis.StoppedEarly
toolCalls := runStats.aiAnalysis.ToolCalls
if len(toolCalls) > MaxToolCallsPerRun {
toolCalls = toolCalls[:MaxToolCallsPerRun]

View file

@ -124,11 +124,12 @@ type ChatToolResult struct {
// PatrolExecuteRequest represents a patrol execution request via the chat service
type PatrolExecuteRequest struct {
Prompt string `json:"prompt"`
SystemPrompt string `json:"system_prompt"`
SessionID string `json:"session_id,omitempty"`
UseCase string `json:"use_case"` // "patrol" — for model selection
MaxTurns int `json:"max_turns,omitempty"`
Prompt string `json:"prompt"`
SystemPrompt string `json:"system_prompt"`
SessionID string `json:"session_id,omitempty"`
UseCase string `json:"use_case"` // "patrol" — for model selection
MaxTurns int `json:"max_turns,omitempty"`
MaxTotalTokens int `json:"max_total_tokens,omitempty"`
}
// PatrolStreamResponse contains the results of a patrol execution via the chat service
@ -136,6 +137,7 @@ type PatrolStreamResponse struct {
Content string `json:"content"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
StopReason string `json:"stop_reason,omitempty"`
}
// Service orchestrates AI interactions

View file

@ -34,11 +34,12 @@ func (a *chatServiceAdapter) ExecuteStream(ctx context.Context, req ai.ChatExecu
func (a *chatServiceAdapter) ExecutePatrolStream(ctx context.Context, req ai.PatrolExecuteRequest, callback ai.ChatStreamCallback) (*ai.PatrolStreamResponse, error) {
resp, err := a.svc.ExecutePatrolStream(ctx, chat.PatrolRequest{
Prompt: req.Prompt,
SystemPrompt: req.SystemPrompt,
SessionID: req.SessionID,
UseCase: req.UseCase,
MaxTurns: req.MaxTurns,
Prompt: req.Prompt,
SystemPrompt: req.SystemPrompt,
SessionID: req.SessionID,
UseCase: req.UseCase,
MaxTurns: req.MaxTurns,
MaxTotalTokens: req.MaxTotalTokens,
}, adaptCallback(callback))
if err != nil {
return nil, err
@ -47,6 +48,7 @@ func (a *chatServiceAdapter) ExecutePatrolStream(ctx context.Context, req ai.Pat
Content: resp.Content,
InputTokens: resp.InputTokens,
OutputTokens: resp.OutputTokens,
StopReason: resp.StopReason,
}, nil
}