diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt index 0a9bf0de..e1539ae7 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt @@ -21,6 +21,7 @@ interface AgentEvents { ) { } + fun onRetry(attempt: Int, maxAttempts: Int, reason: String? = null) {} fun onQueuedMessagesResolved() fun onTokenUsageAvailable(tokenUsage: Long) {} fun onTokenUsageUpdated(tokenUsage: TokenUsage) {} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt index b1985e5b..e3ce1845 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt @@ -3,10 +3,12 @@ package ee.carlrobert.codegpt.agent import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.GraphAIAgent import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy -import ai.koog.agents.core.dsl.extension.* -import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession +import ai.koog.agents.core.dsl.extension.nodeExecuteMultipleTools +import ai.koog.agents.core.dsl.extension.onMultipleAssistantMessages +import ai.koog.agents.core.dsl.extension.onMultipleToolCalls import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.environment.result import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext @@ -18,19 +20,11 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents import ai.koog.agents.features.tokenizer.feature.MessageTokenizer import ai.koog.agents.features.tokenizer.feature.tokenizer import ai.koog.prompt.dsl.prompt -import ai.koog.prompt.executor.clients.ConnectionTimeoutConfig import ai.koog.prompt.executor.clients.LLMClient -import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient -import ai.koog.prompt.executor.clients.google.GoogleClientSettings import ai.koog.prompt.executor.clients.google.GoogleLLMClient -import ai.koog.prompt.executor.clients.mistralai.MistralAIClientSettings import ai.koog.prompt.executor.clients.mistralai.MistralAILLMClient -import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings import ai.koog.prompt.executor.clients.openai.OpenAILLMClient -import ai.koog.prompt.executor.clients.retry.RetryConfig -import ai.koog.prompt.executor.clients.retry.RetryingLLMClient -import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.executor.ollama.client.OllamaClient import ai.koog.prompt.message.Message @@ -39,8 +33,9 @@ import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import ee.carlrobert.codegpt.EncodingManager import ee.carlrobert.codegpt.agent.clients.CustomOpenAILLMClient -import ee.carlrobert.codegpt.agent.clients.ProxyAIClientSettings +import ee.carlrobert.codegpt.agent.clients.InceptionAILLMClient import ee.carlrobert.codegpt.agent.clients.ProxyAILLMClient +import ee.carlrobert.codegpt.agent.clients.RetryingPromptExecutor import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot import ee.carlrobert.codegpt.agent.tools.* import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey @@ -52,8 +47,8 @@ import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent import java.time.LocalDate -import kotlin.time.Duration.Companion.seconds import java.util.concurrent.atomic.AtomicLong +import kotlin.time.Duration.Companion.seconds object AgentFactory { @@ -68,7 +63,8 @@ object AgentFactory { extraBehavior: String? = null, toolOverrides: Set? = null, onCreditsAvailable: ((AgentCreditsEvent) -> Unit)? = null, - tokenCounter: AtomicLong? = null + tokenCounter: AtomicLong? = null, + events: AgentEvents? = null ): AIAgent { val installHandler = buildUsageAwareInstallHandler( provider, @@ -77,7 +73,7 @@ object AgentFactory { onAgentToolCallCompleted, onCreditsAvailable ) - val executor = createExecutor(provider) + val executor = createExecutor(provider, events) return when (agentType) { AgentType.GENERAL_PURPOSE -> createGeneralPurposeAgent( provider, @@ -166,53 +162,30 @@ object AgentFactory { ) } - fun createExecutor(provider: ServiceType): PromptExecutor { - val timeoutConfig = ConnectionTimeoutConfig( - connectTimeoutMillis = 30_000, - socketTimeoutMillis = 60_000, - ) + fun createExecutor(provider: ServiceType, events: AgentEvents? = null): PromptExecutor { return when (provider) { ServiceType.OPENAI -> { val apiKey = getCredential(CredentialKey.OpenaiApiKey) ?: "" - createRetryingExecutor( - OpenAILLMClient( - apiKey, OpenAIClientSettings(timeoutConfig = timeoutConfig) - ) - ) + createRetryingExecutor(OpenAILLMClient(apiKey), events) } ServiceType.ANTHROPIC -> { val apiKey = getCredential(CredentialKey.AnthropicApiKey) ?: "" - createRetryingExecutor( - AnthropicLLMClient( - apiKey, - AnthropicClientSettings(timeoutConfig = timeoutConfig) - ) - ) + createRetryingExecutor(AnthropicLLMClient(apiKey), events) } ServiceType.GOOGLE -> { val apiKey = getCredential(CredentialKey.GoogleApiKey) ?: "" - createRetryingExecutor( - GoogleLLMClient( - apiKey, - GoogleClientSettings(timeoutConfig = timeoutConfig) - ) - ) + createRetryingExecutor(GoogleLLMClient(apiKey), events) } ServiceType.OLLAMA -> { - createRetryingExecutor(OllamaClient(timeoutConfig = timeoutConfig)) + createRetryingExecutor(OllamaClient(), events) } ServiceType.MISTRAL -> { val apiKey = getCredential(CredentialKey.MistralApiKey) ?: "" - createRetryingExecutor( - MistralAILLMClient( - apiKey, - MistralAIClientSettings(timeoutConfig = timeoutConfig) - ) - ) + createRetryingExecutor(MistralAILLMClient(apiKey), events) } ServiceType.CUSTOM_OPENAI -> { @@ -227,34 +200,34 @@ object AgentFactory { CustomOpenAILLMClient.fromSettingsState( apiKey, state.chatCompletionSettings, - timeoutConfig - ) + ), + events ) } ServiceType.PROXYAI -> { val apiKey = getCredential(CredentialKey.CodeGptApiKey) ?: "" - createRetryingExecutor( - ProxyAILLMClient( - apiKey, - ProxyAIClientSettings(timeoutConfig = timeoutConfig) - ) - ) + createRetryingExecutor(ProxyAILLMClient(apiKey), events) + } + + ServiceType.INCEPTION -> { + val apiKey = getCredential(CredentialKey.InceptionApiKey) ?: "" + createRetryingExecutor(InceptionAILLMClient(apiKey), events) } else -> throw UnsupportedOperationException("Provider not supported: $provider") } } - private fun createRetryingExecutor(client: LLMClient): PromptExecutor { - val retryConfig = RetryConfig( + private fun createRetryingExecutor(client: LLMClient, events: AgentEvents?): PromptExecutor { + val policy = RetryingPromptExecutor.RetryPolicy( maxAttempts = 5, initialDelay = 1.seconds, maxDelay = 30.seconds, backoffMultiplier = 2.0, jitterFactor = 0.1 ) - return SingleLLMPromptExecutor(RetryingLLMClient(client, retryConfig)) + return RetryingPromptExecutor.fromClient(client, policy, events) } private fun createGeneralPurposeAgent( @@ -442,7 +415,12 @@ object AgentFactory { appendPrompt { user(input) } tokenCounter?.addAndGet(tokenizer().tokenCountFor(prompt).toLong()) val responses = - requestResponses(executor, config, { appendPrompt { message(it) } }, tokenCounter) + requestResponses( + executor, + config, + { appendPrompt { message(it) } }, + tokenCounter + ) responses } } @@ -454,7 +432,12 @@ object AgentFactory { } tokenCounter?.addAndGet(tokenizer().tokenCountFor(prompt).toLong()) val responses = - requestResponses(executor, config, { appendPrompt { message(it) } }, tokenCounter) + requestResponses( + executor, + config, + { appendPrompt { message(it) } }, + tokenCounter + ) responses } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt index 73bf5419..5a6a6eaf 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt @@ -79,11 +79,9 @@ object ProxyAIAgent { ): AIAgent { val modelSelection = service().getModelSelectionForFeature(FeatureType.AGENT) - val stream = - provider == ServiceType.ANTHROPIC || provider == ServiceType.OPENAI || provider == ServiceType.PROXYAI + val stream = provider != ServiceType.CUSTOM_OPENAI val projectInstructions = searchForInstructions(project.basePath) - - val executor = AgentFactory.createExecutor(provider) + val executor = AgentFactory.createExecutor(provider, events) val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() } val toolRegistry = createToolRegistry(project, events, sessionId) val agentModel = service().getAgentModel() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/RetryingPromptExecutor.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/RetryingPromptExecutor.kt new file mode 100644 index 00000000..e2bc6f1e --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/RetryingPromptExecutor.kt @@ -0,0 +1,175 @@ +package ee.carlrobert.codegpt.agent.clients + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.prompt.dsl.ModerationResult +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.clients.LLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import ai.koog.prompt.streaming.StreamFrame +import ee.carlrobert.codegpt.agent.AgentEvents +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.onEach +import kotlin.math.pow +import kotlin.random.Random +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds + +/** + * Wraps a [PromptExecutor] and retries the LLM call on network/timeouts. + */ +class RetryingPromptExecutor( + private val delegate: PromptExecutor, + private val retryPolicy: RetryPolicy, + private val events: AgentEvents? +) : PromptExecutor { + + companion object { + private val logger = KotlinLogging.logger { } + + fun fromClient( + client: LLMClient, + retryPolicy: RetryPolicy, + events: AgentEvents? + ): PromptExecutor { + return RetryingPromptExecutor( + delegate = SingleLLMPromptExecutor(client), + retryPolicy = retryPolicy, + events = events + ) + } + } + + override fun executeStreaming( + prompt: Prompt, + model: LLModel, + tools: List + ): Flow { + var attempt = 1 + var hasReceivedData: Boolean + + fun createStream(attemptNum: Int): Flow { + hasReceivedData = false + + return delegate.executeStreaming(prompt, model, tools) + .onEach { hasReceivedData = true } + .catch { error -> + val shouldRetry = !hasReceivedData && attemptNum < retryPolicy.maxAttempts + logger.warn { + "Stream error: ${error.javaClass.simpleName}, " + + "hasReceivedData=$hasReceivedData, " + + "attempt=$attemptNum/${retryPolicy.maxAttempts}, " + + "willRetry=$shouldRetry" + } + + if (shouldRetry) { + attempt++ + logger.warn { "Retrying streaming request (attempt $attempt/${retryPolicy.maxAttempts})" } + events?.onRetry( + attempt, + retryPolicy.maxAttempts, + error.javaClass.simpleName + ) + + val delayMs = if (attemptNum == 1) { + retryPolicy.initialDelay.inWholeMilliseconds + } else { + val exponential = (retryPolicy.initialDelay.inWholeMilliseconds * + retryPolicy.backoffMultiplier.pow(attemptNum.toDouble())).toLong() + exponential.coerceAtMost(retryPolicy.maxDelay.inWholeMilliseconds) + } + val jitterMs = (delayMs * retryPolicy.jitterFactor).toLong() + val jitteredDelay = + delayMs + (if (jitterMs > 0) Random.nextInt(jitterMs.toInt()) else 0) + + delay(jitteredDelay) + throw CancellationException("Retrying streaming request", error) + } else { + logger.error { "Streaming failed: $error" } + throw error + } + } + } + + val retryLoop = flow { + var currentAttempt = attempt + while (currentAttempt <= retryPolicy.maxAttempts) { + try { + createStream(currentAttempt).collect { emit(it) } + break + } catch (ce: CancellationException) { + if (ce.message == "Retrying streaming request" && currentAttempt < retryPolicy.maxAttempts) { + currentAttempt++ + continue + } else { + throw ce + } + } + } + } + + return retryLoop + } + + override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult = + delegate.moderate(prompt, model) + + override suspend fun models(): List = delegate.models() + + data class RetryPolicy( + val maxAttempts: Int, + val initialDelay: Duration, + val maxDelay: Duration, + val backoffMultiplier: Double, + val jitterFactor: Double, + ) + + override suspend fun execute( + prompt: Prompt, + model: LLModel, + tools: List + ): List { + var attempt = 1 + var delay = retryPolicy.initialDelay + var lastError: Throwable? = null + + while (attempt <= retryPolicy.maxAttempts) { + try { + if (attempt > 1) { + events?.onRetry( + attempt, + retryPolicy.maxAttempts, + lastError?.javaClass?.simpleName + ) + } + return delegate.execute(prompt, model, tools) + } catch (t: Throwable) { + lastError = t + if (attempt >= retryPolicy.maxAttempts) throw t + + events?.onRetry(attempt + 1, retryPolicy.maxAttempts, t.javaClass.simpleName) + + val jitter = (delay.inWholeMilliseconds * retryPolicy.jitterFactor).toLong() + val jitteredMs = + delay.inWholeMilliseconds + (if (jitter > 0) (0..jitter).random() else 0) + delay(jitteredMs) + + val nextMs = (delay.inWholeMilliseconds * retryPolicy.backoffMultiplier).toLong() + delay = nextMs.milliseconds.coerceAtMost(retryPolicy.maxDelay) + attempt++ + } + } + throw lastError ?: IllegalStateException("Retry loop ended without result") + } + + override fun close() { + (delegate as? AutoCloseable)?.close() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt index f43b07c6..304cdef3 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt @@ -457,6 +457,17 @@ class AgentEventHandler( .onCreditsChanged(event) } + override fun onRetry(attempt: Int, maxAttempts: Int, reason: String?) { + val suffix = "($attempt/$maxAttempts)" + val base = when { + reason?.contains("timeout", ignoreCase = true) == true -> "Request timed out, retrying" + else -> "Retrying" + } + runInEdt { + onShowLoading("$base $suffix") + } + } + override fun onHistoryCompressionStateChanged(isCompressing: Boolean) { val key = if (isCompressing) "toolwindow.chat.compressingHistory" else "toolwindow.chat.loading"