fix: retry logic for streaming requests

This commit is contained in:
Carl-Robert Linnupuu 2026-01-25 17:51:17 +00:00
parent 6ad5f4e523
commit ebea17ff6b
5 changed files with 228 additions and 60 deletions

View file

@ -21,6 +21,7 @@ interface AgentEvents {
) {
}
fun onRetry(attempt: Int, maxAttempts: Int, reason: String? = null) {}
fun onQueuedMessagesResolved()
fun onTokenUsageAvailable(tokenUsage: Long) {}
fun onTokenUsageUpdated(tokenUsage: TokenUsage) {}

View file

@ -3,10 +3,12 @@ package ee.carlrobert.codegpt.agent
import ai.koog.agents.core.agent.AIAgent
import ai.koog.agents.core.agent.GraphAIAgent
import ai.koog.agents.core.agent.config.AIAgentConfig
import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession
import ai.koog.agents.core.dsl.builder.forwardTo
import ai.koog.agents.core.dsl.builder.strategy
import ai.koog.agents.core.dsl.extension.*
import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession
import ai.koog.agents.core.dsl.extension.nodeExecuteMultipleTools
import ai.koog.agents.core.dsl.extension.onMultipleAssistantMessages
import ai.koog.agents.core.dsl.extension.onMultipleToolCalls
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.result
import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext
@ -18,19 +20,11 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents
import ai.koog.agents.features.tokenizer.feature.MessageTokenizer
import ai.koog.agents.features.tokenizer.feature.tokenizer
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.clients.ConnectionTimeoutConfig
import ai.koog.prompt.executor.clients.LLMClient
import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings
import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient
import ai.koog.prompt.executor.clients.google.GoogleClientSettings
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
import ai.koog.prompt.executor.clients.mistralai.MistralAIClientSettings
import ai.koog.prompt.executor.clients.mistralai.MistralAILLMClient
import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
import ai.koog.prompt.executor.clients.retry.RetryConfig
import ai.koog.prompt.executor.clients.retry.RetryingLLMClient
import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
import ai.koog.prompt.executor.model.PromptExecutor
import ai.koog.prompt.executor.ollama.client.OllamaClient
import ai.koog.prompt.message.Message
@ -39,8 +33,9 @@ import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.agent.clients.CustomOpenAILLMClient
import ee.carlrobert.codegpt.agent.clients.ProxyAIClientSettings
import ee.carlrobert.codegpt.agent.clients.InceptionAILLMClient
import ee.carlrobert.codegpt.agent.clients.ProxyAILLMClient
import ee.carlrobert.codegpt.agent.clients.RetryingPromptExecutor
import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot
import ee.carlrobert.codegpt.agent.tools.*
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
@ -52,8 +47,8 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
import java.time.LocalDate
import kotlin.time.Duration.Companion.seconds
import java.util.concurrent.atomic.AtomicLong
import kotlin.time.Duration.Companion.seconds
object AgentFactory {
@ -68,7 +63,8 @@ object AgentFactory {
extraBehavior: String? = null,
toolOverrides: Set<SubagentTool>? = null,
onCreditsAvailable: ((AgentCreditsEvent) -> Unit)? = null,
tokenCounter: AtomicLong? = null
tokenCounter: AtomicLong? = null,
events: AgentEvents? = null
): AIAgent<String, String> {
val installHandler = buildUsageAwareInstallHandler(
provider,
@ -77,7 +73,7 @@ object AgentFactory {
onAgentToolCallCompleted,
onCreditsAvailable
)
val executor = createExecutor(provider)
val executor = createExecutor(provider, events)
return when (agentType) {
AgentType.GENERAL_PURPOSE -> createGeneralPurposeAgent(
provider,
@ -166,53 +162,30 @@ object AgentFactory {
)
}
fun createExecutor(provider: ServiceType): PromptExecutor {
val timeoutConfig = ConnectionTimeoutConfig(
connectTimeoutMillis = 30_000,
socketTimeoutMillis = 60_000,
)
fun createExecutor(provider: ServiceType, events: AgentEvents? = null): PromptExecutor {
return when (provider) {
ServiceType.OPENAI -> {
val apiKey = getCredential(CredentialKey.OpenaiApiKey) ?: ""
createRetryingExecutor(
OpenAILLMClient(
apiKey, OpenAIClientSettings(timeoutConfig = timeoutConfig)
)
)
createRetryingExecutor(OpenAILLMClient(apiKey), events)
}
ServiceType.ANTHROPIC -> {
val apiKey = getCredential(CredentialKey.AnthropicApiKey) ?: ""
createRetryingExecutor(
AnthropicLLMClient(
apiKey,
AnthropicClientSettings(timeoutConfig = timeoutConfig)
)
)
createRetryingExecutor(AnthropicLLMClient(apiKey), events)
}
ServiceType.GOOGLE -> {
val apiKey = getCredential(CredentialKey.GoogleApiKey) ?: ""
createRetryingExecutor(
GoogleLLMClient(
apiKey,
GoogleClientSettings(timeoutConfig = timeoutConfig)
)
)
createRetryingExecutor(GoogleLLMClient(apiKey), events)
}
ServiceType.OLLAMA -> {
createRetryingExecutor(OllamaClient(timeoutConfig = timeoutConfig))
createRetryingExecutor(OllamaClient(), events)
}
ServiceType.MISTRAL -> {
val apiKey = getCredential(CredentialKey.MistralApiKey) ?: ""
createRetryingExecutor(
MistralAILLMClient(
apiKey,
MistralAIClientSettings(timeoutConfig = timeoutConfig)
)
)
createRetryingExecutor(MistralAILLMClient(apiKey), events)
}
ServiceType.CUSTOM_OPENAI -> {
@ -227,34 +200,34 @@ object AgentFactory {
CustomOpenAILLMClient.fromSettingsState(
apiKey,
state.chatCompletionSettings,
timeoutConfig
)
),
events
)
}
ServiceType.PROXYAI -> {
val apiKey = getCredential(CredentialKey.CodeGptApiKey) ?: ""
createRetryingExecutor(
ProxyAILLMClient(
apiKey,
ProxyAIClientSettings(timeoutConfig = timeoutConfig)
)
)
createRetryingExecutor(ProxyAILLMClient(apiKey), events)
}
ServiceType.INCEPTION -> {
val apiKey = getCredential(CredentialKey.InceptionApiKey) ?: ""
createRetryingExecutor(InceptionAILLMClient(apiKey), events)
}
else -> throw UnsupportedOperationException("Provider not supported: $provider")
}
}
private fun createRetryingExecutor(client: LLMClient): PromptExecutor {
val retryConfig = RetryConfig(
private fun createRetryingExecutor(client: LLMClient, events: AgentEvents?): PromptExecutor {
val policy = RetryingPromptExecutor.RetryPolicy(
maxAttempts = 5,
initialDelay = 1.seconds,
maxDelay = 30.seconds,
backoffMultiplier = 2.0,
jitterFactor = 0.1
)
return SingleLLMPromptExecutor(RetryingLLMClient(client, retryConfig))
return RetryingPromptExecutor.fromClient(client, policy, events)
}
private fun createGeneralPurposeAgent(
@ -442,7 +415,12 @@ object AgentFactory {
appendPrompt { user(input) }
tokenCounter?.addAndGet(tokenizer().tokenCountFor(prompt).toLong())
val responses =
requestResponses(executor, config, { appendPrompt { message(it) } }, tokenCounter)
requestResponses(
executor,
config,
{ appendPrompt { message(it) } },
tokenCounter
)
responses
}
}
@ -454,7 +432,12 @@ object AgentFactory {
}
tokenCounter?.addAndGet(tokenizer().tokenCountFor(prompt).toLong())
val responses =
requestResponses(executor, config, { appendPrompt { message(it) } }, tokenCounter)
requestResponses(
executor,
config,
{ appendPrompt { message(it) } },
tokenCounter
)
responses
}
}

View file

@ -79,11 +79,9 @@ object ProxyAIAgent {
): AIAgent<MessageWithContext, String> {
val modelSelection =
service<ModelSelectionService>().getModelSelectionForFeature(FeatureType.AGENT)
val stream =
provider == ServiceType.ANTHROPIC || provider == ServiceType.OPENAI || provider == ServiceType.PROXYAI
val stream = provider != ServiceType.CUSTOM_OPENAI
val projectInstructions = searchForInstructions(project.basePath)
val executor = AgentFactory.createExecutor(provider)
val executor = AgentFactory.createExecutor(provider, events)
val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
val toolRegistry = createToolRegistry(project, events, sessionId)
val agentModel = service<ModelSelectionService>().getAgentModel()

View file

@ -0,0 +1,175 @@
package ee.carlrobert.codegpt.agent.clients
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.prompt.dsl.ModerationResult
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.executor.clients.LLMClient
import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
import ai.koog.prompt.executor.model.PromptExecutor
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.message.Message
import ai.koog.prompt.streaming.StreamFrame
import ee.carlrobert.codegpt.agent.AgentEvents
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.onEach
import kotlin.math.pow
import kotlin.random.Random
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
/**
* Wraps a [PromptExecutor] and retries the LLM call on network/timeouts.
*/
class RetryingPromptExecutor(
private val delegate: PromptExecutor,
private val retryPolicy: RetryPolicy,
private val events: AgentEvents?
) : PromptExecutor {
companion object {
private val logger = KotlinLogging.logger { }
fun fromClient(
client: LLMClient,
retryPolicy: RetryPolicy,
events: AgentEvents?
): PromptExecutor {
return RetryingPromptExecutor(
delegate = SingleLLMPromptExecutor(client),
retryPolicy = retryPolicy,
events = events
)
}
}
override fun executeStreaming(
prompt: Prompt,
model: LLModel,
tools: List<ToolDescriptor>
): Flow<StreamFrame> {
var attempt = 1
var hasReceivedData: Boolean
fun createStream(attemptNum: Int): Flow<StreamFrame> {
hasReceivedData = false
return delegate.executeStreaming(prompt, model, tools)
.onEach { hasReceivedData = true }
.catch { error ->
val shouldRetry = !hasReceivedData && attemptNum < retryPolicy.maxAttempts
logger.warn {
"Stream error: ${error.javaClass.simpleName}, " +
"hasReceivedData=$hasReceivedData, " +
"attempt=$attemptNum/${retryPolicy.maxAttempts}, " +
"willRetry=$shouldRetry"
}
if (shouldRetry) {
attempt++
logger.warn { "Retrying streaming request (attempt $attempt/${retryPolicy.maxAttempts})" }
events?.onRetry(
attempt,
retryPolicy.maxAttempts,
error.javaClass.simpleName
)
val delayMs = if (attemptNum == 1) {
retryPolicy.initialDelay.inWholeMilliseconds
} else {
val exponential = (retryPolicy.initialDelay.inWholeMilliseconds *
retryPolicy.backoffMultiplier.pow(attemptNum.toDouble())).toLong()
exponential.coerceAtMost(retryPolicy.maxDelay.inWholeMilliseconds)
}
val jitterMs = (delayMs * retryPolicy.jitterFactor).toLong()
val jitteredDelay =
delayMs + (if (jitterMs > 0) Random.nextInt(jitterMs.toInt()) else 0)
delay(jitteredDelay)
throw CancellationException("Retrying streaming request", error)
} else {
logger.error { "Streaming failed: $error" }
throw error
}
}
}
val retryLoop = flow {
var currentAttempt = attempt
while (currentAttempt <= retryPolicy.maxAttempts) {
try {
createStream(currentAttempt).collect { emit(it) }
break
} catch (ce: CancellationException) {
if (ce.message == "Retrying streaming request" && currentAttempt < retryPolicy.maxAttempts) {
currentAttempt++
continue
} else {
throw ce
}
}
}
}
return retryLoop
}
override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult =
delegate.moderate(prompt, model)
override suspend fun models(): List<String> = delegate.models()
data class RetryPolicy(
val maxAttempts: Int,
val initialDelay: Duration,
val maxDelay: Duration,
val backoffMultiplier: Double,
val jitterFactor: Double,
)
override suspend fun execute(
prompt: Prompt,
model: LLModel,
tools: List<ToolDescriptor>
): List<Message.Response> {
var attempt = 1
var delay = retryPolicy.initialDelay
var lastError: Throwable? = null
while (attempt <= retryPolicy.maxAttempts) {
try {
if (attempt > 1) {
events?.onRetry(
attempt,
retryPolicy.maxAttempts,
lastError?.javaClass?.simpleName
)
}
return delegate.execute(prompt, model, tools)
} catch (t: Throwable) {
lastError = t
if (attempt >= retryPolicy.maxAttempts) throw t
events?.onRetry(attempt + 1, retryPolicy.maxAttempts, t.javaClass.simpleName)
val jitter = (delay.inWholeMilliseconds * retryPolicy.jitterFactor).toLong()
val jitteredMs =
delay.inWholeMilliseconds + (if (jitter > 0) (0..jitter).random() else 0)
delay(jitteredMs)
val nextMs = (delay.inWholeMilliseconds * retryPolicy.backoffMultiplier).toLong()
delay = nextMs.milliseconds.coerceAtMost(retryPolicy.maxDelay)
attempt++
}
}
throw lastError ?: IllegalStateException("Retry loop ended without result")
}
override fun close() {
(delegate as? AutoCloseable)?.close()
}
}

View file

@ -457,6 +457,17 @@ class AgentEventHandler(
.onCreditsChanged(event)
}
override fun onRetry(attempt: Int, maxAttempts: Int, reason: String?) {
val suffix = "($attempt/$maxAttempts)"
val base = when {
reason?.contains("timeout", ignoreCase = true) == true -> "Request timed out, retrying"
else -> "Retrying"
}
runInEdt {
onShowLoading("$base $suffix")
}
}
override fun onHistoryCompressionStateChanged(isCompressing: Boolean) {
val key =
if (isCompressing) "toolwindow.chat.compressingHistory" else "toolwindow.chat.loading"