mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 16:28:46 +00:00
fix: retry logic for streaming requests
This commit is contained in:
parent
6ad5f4e523
commit
ebea17ff6b
5 changed files with 228 additions and 60 deletions
|
|
@ -21,6 +21,7 @@ interface AgentEvents {
|
|||
) {
|
||||
}
|
||||
|
||||
fun onRetry(attempt: Int, maxAttempts: Int, reason: String? = null) {}
|
||||
fun onQueuedMessagesResolved()
|
||||
fun onTokenUsageAvailable(tokenUsage: Long) {}
|
||||
fun onTokenUsageUpdated(tokenUsage: TokenUsage) {}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ package ee.carlrobert.codegpt.agent
|
|||
import ai.koog.agents.core.agent.AIAgent
|
||||
import ai.koog.agents.core.agent.GraphAIAgent
|
||||
import ai.koog.agents.core.agent.config.AIAgentConfig
|
||||
import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession
|
||||
import ai.koog.agents.core.dsl.builder.forwardTo
|
||||
import ai.koog.agents.core.dsl.builder.strategy
|
||||
import ai.koog.agents.core.dsl.extension.*
|
||||
import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession
|
||||
import ai.koog.agents.core.dsl.extension.nodeExecuteMultipleTools
|
||||
import ai.koog.agents.core.dsl.extension.onMultipleAssistantMessages
|
||||
import ai.koog.agents.core.dsl.extension.onMultipleToolCalls
|
||||
import ai.koog.agents.core.environment.ReceivedToolResult
|
||||
import ai.koog.agents.core.environment.result
|
||||
import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext
|
||||
|
|
@ -18,19 +20,11 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents
|
|||
import ai.koog.agents.features.tokenizer.feature.MessageTokenizer
|
||||
import ai.koog.agents.features.tokenizer.feature.tokenizer
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.executor.clients.ConnectionTimeoutConfig
|
||||
import ai.koog.prompt.executor.clients.LLMClient
|
||||
import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings
|
||||
import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient
|
||||
import ai.koog.prompt.executor.clients.google.GoogleClientSettings
|
||||
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
|
||||
import ai.koog.prompt.executor.clients.mistralai.MistralAIClientSettings
|
||||
import ai.koog.prompt.executor.clients.mistralai.MistralAILLMClient
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
|
||||
import ai.koog.prompt.executor.clients.retry.RetryConfig
|
||||
import ai.koog.prompt.executor.clients.retry.RetryingLLMClient
|
||||
import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
|
||||
import ai.koog.prompt.executor.model.PromptExecutor
|
||||
import ai.koog.prompt.executor.ollama.client.OllamaClient
|
||||
import ai.koog.prompt.message.Message
|
||||
|
|
@ -39,8 +33,9 @@ import com.intellij.openapi.components.service
|
|||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.agent.clients.CustomOpenAILLMClient
|
||||
import ee.carlrobert.codegpt.agent.clients.ProxyAIClientSettings
|
||||
import ee.carlrobert.codegpt.agent.clients.InceptionAILLMClient
|
||||
import ee.carlrobert.codegpt.agent.clients.ProxyAILLMClient
|
||||
import ee.carlrobert.codegpt.agent.clients.RetryingPromptExecutor
|
||||
import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
|
||||
|
|
@ -52,8 +47,8 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
|
|||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
|
||||
import java.time.LocalDate
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
object AgentFactory {
|
||||
|
||||
|
|
@ -68,7 +63,8 @@ object AgentFactory {
|
|||
extraBehavior: String? = null,
|
||||
toolOverrides: Set<SubagentTool>? = null,
|
||||
onCreditsAvailable: ((AgentCreditsEvent) -> Unit)? = null,
|
||||
tokenCounter: AtomicLong? = null
|
||||
tokenCounter: AtomicLong? = null,
|
||||
events: AgentEvents? = null
|
||||
): AIAgent<String, String> {
|
||||
val installHandler = buildUsageAwareInstallHandler(
|
||||
provider,
|
||||
|
|
@ -77,7 +73,7 @@ object AgentFactory {
|
|||
onAgentToolCallCompleted,
|
||||
onCreditsAvailable
|
||||
)
|
||||
val executor = createExecutor(provider)
|
||||
val executor = createExecutor(provider, events)
|
||||
return when (agentType) {
|
||||
AgentType.GENERAL_PURPOSE -> createGeneralPurposeAgent(
|
||||
provider,
|
||||
|
|
@ -166,53 +162,30 @@ object AgentFactory {
|
|||
)
|
||||
}
|
||||
|
||||
fun createExecutor(provider: ServiceType): PromptExecutor {
|
||||
val timeoutConfig = ConnectionTimeoutConfig(
|
||||
connectTimeoutMillis = 30_000,
|
||||
socketTimeoutMillis = 60_000,
|
||||
)
|
||||
fun createExecutor(provider: ServiceType, events: AgentEvents? = null): PromptExecutor {
|
||||
return when (provider) {
|
||||
ServiceType.OPENAI -> {
|
||||
val apiKey = getCredential(CredentialKey.OpenaiApiKey) ?: ""
|
||||
createRetryingExecutor(
|
||||
OpenAILLMClient(
|
||||
apiKey, OpenAIClientSettings(timeoutConfig = timeoutConfig)
|
||||
)
|
||||
)
|
||||
createRetryingExecutor(OpenAILLMClient(apiKey), events)
|
||||
}
|
||||
|
||||
ServiceType.ANTHROPIC -> {
|
||||
val apiKey = getCredential(CredentialKey.AnthropicApiKey) ?: ""
|
||||
createRetryingExecutor(
|
||||
AnthropicLLMClient(
|
||||
apiKey,
|
||||
AnthropicClientSettings(timeoutConfig = timeoutConfig)
|
||||
)
|
||||
)
|
||||
createRetryingExecutor(AnthropicLLMClient(apiKey), events)
|
||||
}
|
||||
|
||||
ServiceType.GOOGLE -> {
|
||||
val apiKey = getCredential(CredentialKey.GoogleApiKey) ?: ""
|
||||
createRetryingExecutor(
|
||||
GoogleLLMClient(
|
||||
apiKey,
|
||||
GoogleClientSettings(timeoutConfig = timeoutConfig)
|
||||
)
|
||||
)
|
||||
createRetryingExecutor(GoogleLLMClient(apiKey), events)
|
||||
}
|
||||
|
||||
ServiceType.OLLAMA -> {
|
||||
createRetryingExecutor(OllamaClient(timeoutConfig = timeoutConfig))
|
||||
createRetryingExecutor(OllamaClient(), events)
|
||||
}
|
||||
|
||||
ServiceType.MISTRAL -> {
|
||||
val apiKey = getCredential(CredentialKey.MistralApiKey) ?: ""
|
||||
createRetryingExecutor(
|
||||
MistralAILLMClient(
|
||||
apiKey,
|
||||
MistralAIClientSettings(timeoutConfig = timeoutConfig)
|
||||
)
|
||||
)
|
||||
createRetryingExecutor(MistralAILLMClient(apiKey), events)
|
||||
}
|
||||
|
||||
ServiceType.CUSTOM_OPENAI -> {
|
||||
|
|
@ -227,34 +200,34 @@ object AgentFactory {
|
|||
CustomOpenAILLMClient.fromSettingsState(
|
||||
apiKey,
|
||||
state.chatCompletionSettings,
|
||||
timeoutConfig
|
||||
)
|
||||
),
|
||||
events
|
||||
)
|
||||
}
|
||||
|
||||
ServiceType.PROXYAI -> {
|
||||
val apiKey = getCredential(CredentialKey.CodeGptApiKey) ?: ""
|
||||
createRetryingExecutor(
|
||||
ProxyAILLMClient(
|
||||
apiKey,
|
||||
ProxyAIClientSettings(timeoutConfig = timeoutConfig)
|
||||
)
|
||||
)
|
||||
createRetryingExecutor(ProxyAILLMClient(apiKey), events)
|
||||
}
|
||||
|
||||
ServiceType.INCEPTION -> {
|
||||
val apiKey = getCredential(CredentialKey.InceptionApiKey) ?: ""
|
||||
createRetryingExecutor(InceptionAILLMClient(apiKey), events)
|
||||
}
|
||||
|
||||
else -> throw UnsupportedOperationException("Provider not supported: $provider")
|
||||
}
|
||||
}
|
||||
|
||||
private fun createRetryingExecutor(client: LLMClient): PromptExecutor {
|
||||
val retryConfig = RetryConfig(
|
||||
private fun createRetryingExecutor(client: LLMClient, events: AgentEvents?): PromptExecutor {
|
||||
val policy = RetryingPromptExecutor.RetryPolicy(
|
||||
maxAttempts = 5,
|
||||
initialDelay = 1.seconds,
|
||||
maxDelay = 30.seconds,
|
||||
backoffMultiplier = 2.0,
|
||||
jitterFactor = 0.1
|
||||
)
|
||||
return SingleLLMPromptExecutor(RetryingLLMClient(client, retryConfig))
|
||||
return RetryingPromptExecutor.fromClient(client, policy, events)
|
||||
}
|
||||
|
||||
private fun createGeneralPurposeAgent(
|
||||
|
|
@ -442,7 +415,12 @@ object AgentFactory {
|
|||
appendPrompt { user(input) }
|
||||
tokenCounter?.addAndGet(tokenizer().tokenCountFor(prompt).toLong())
|
||||
val responses =
|
||||
requestResponses(executor, config, { appendPrompt { message(it) } }, tokenCounter)
|
||||
requestResponses(
|
||||
executor,
|
||||
config,
|
||||
{ appendPrompt { message(it) } },
|
||||
tokenCounter
|
||||
)
|
||||
responses
|
||||
}
|
||||
}
|
||||
|
|
@ -454,7 +432,12 @@ object AgentFactory {
|
|||
}
|
||||
tokenCounter?.addAndGet(tokenizer().tokenCountFor(prompt).toLong())
|
||||
val responses =
|
||||
requestResponses(executor, config, { appendPrompt { message(it) } }, tokenCounter)
|
||||
requestResponses(
|
||||
executor,
|
||||
config,
|
||||
{ appendPrompt { message(it) } },
|
||||
tokenCounter
|
||||
)
|
||||
responses
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,11 +79,9 @@ object ProxyAIAgent {
|
|||
): AIAgent<MessageWithContext, String> {
|
||||
val modelSelection =
|
||||
service<ModelSelectionService>().getModelSelectionForFeature(FeatureType.AGENT)
|
||||
val stream =
|
||||
provider == ServiceType.ANTHROPIC || provider == ServiceType.OPENAI || provider == ServiceType.PROXYAI
|
||||
val stream = provider != ServiceType.CUSTOM_OPENAI
|
||||
val projectInstructions = searchForInstructions(project.basePath)
|
||||
|
||||
val executor = AgentFactory.createExecutor(provider)
|
||||
val executor = AgentFactory.createExecutor(provider, events)
|
||||
val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
|
||||
val toolRegistry = createToolRegistry(project, events, sessionId)
|
||||
val agentModel = service<ModelSelectionService>().getAgentModel()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,175 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.agents.core.tools.ToolDescriptor
|
||||
import ai.koog.prompt.dsl.ModerationResult
|
||||
import ai.koog.prompt.dsl.Prompt
|
||||
import ai.koog.prompt.executor.clients.LLMClient
|
||||
import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
|
||||
import ai.koog.prompt.executor.model.PromptExecutor
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.streaming.StreamFrame
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlin.math.pow
|
||||
import kotlin.random.Random
|
||||
import kotlin.time.Duration
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
|
||||
/**
|
||||
* Wraps a [PromptExecutor] and retries the LLM call on network/timeouts.
|
||||
*/
|
||||
class RetryingPromptExecutor(
|
||||
private val delegate: PromptExecutor,
|
||||
private val retryPolicy: RetryPolicy,
|
||||
private val events: AgentEvents?
|
||||
) : PromptExecutor {
|
||||
|
||||
companion object {
|
||||
private val logger = KotlinLogging.logger { }
|
||||
|
||||
fun fromClient(
|
||||
client: LLMClient,
|
||||
retryPolicy: RetryPolicy,
|
||||
events: AgentEvents?
|
||||
): PromptExecutor {
|
||||
return RetryingPromptExecutor(
|
||||
delegate = SingleLLMPromptExecutor(client),
|
||||
retryPolicy = retryPolicy,
|
||||
events = events
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun executeStreaming(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): Flow<StreamFrame> {
|
||||
var attempt = 1
|
||||
var hasReceivedData: Boolean
|
||||
|
||||
fun createStream(attemptNum: Int): Flow<StreamFrame> {
|
||||
hasReceivedData = false
|
||||
|
||||
return delegate.executeStreaming(prompt, model, tools)
|
||||
.onEach { hasReceivedData = true }
|
||||
.catch { error ->
|
||||
val shouldRetry = !hasReceivedData && attemptNum < retryPolicy.maxAttempts
|
||||
logger.warn {
|
||||
"Stream error: ${error.javaClass.simpleName}, " +
|
||||
"hasReceivedData=$hasReceivedData, " +
|
||||
"attempt=$attemptNum/${retryPolicy.maxAttempts}, " +
|
||||
"willRetry=$shouldRetry"
|
||||
}
|
||||
|
||||
if (shouldRetry) {
|
||||
attempt++
|
||||
logger.warn { "Retrying streaming request (attempt $attempt/${retryPolicy.maxAttempts})" }
|
||||
events?.onRetry(
|
||||
attempt,
|
||||
retryPolicy.maxAttempts,
|
||||
error.javaClass.simpleName
|
||||
)
|
||||
|
||||
val delayMs = if (attemptNum == 1) {
|
||||
retryPolicy.initialDelay.inWholeMilliseconds
|
||||
} else {
|
||||
val exponential = (retryPolicy.initialDelay.inWholeMilliseconds *
|
||||
retryPolicy.backoffMultiplier.pow(attemptNum.toDouble())).toLong()
|
||||
exponential.coerceAtMost(retryPolicy.maxDelay.inWholeMilliseconds)
|
||||
}
|
||||
val jitterMs = (delayMs * retryPolicy.jitterFactor).toLong()
|
||||
val jitteredDelay =
|
||||
delayMs + (if (jitterMs > 0) Random.nextInt(jitterMs.toInt()) else 0)
|
||||
|
||||
delay(jitteredDelay)
|
||||
throw CancellationException("Retrying streaming request", error)
|
||||
} else {
|
||||
logger.error { "Streaming failed: $error" }
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val retryLoop = flow {
|
||||
var currentAttempt = attempt
|
||||
while (currentAttempt <= retryPolicy.maxAttempts) {
|
||||
try {
|
||||
createStream(currentAttempt).collect { emit(it) }
|
||||
break
|
||||
} catch (ce: CancellationException) {
|
||||
if (ce.message == "Retrying streaming request" && currentAttempt < retryPolicy.maxAttempts) {
|
||||
currentAttempt++
|
||||
continue
|
||||
} else {
|
||||
throw ce
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return retryLoop
|
||||
}
|
||||
|
||||
override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult =
|
||||
delegate.moderate(prompt, model)
|
||||
|
||||
override suspend fun models(): List<String> = delegate.models()
|
||||
|
||||
data class RetryPolicy(
|
||||
val maxAttempts: Int,
|
||||
val initialDelay: Duration,
|
||||
val maxDelay: Duration,
|
||||
val backoffMultiplier: Double,
|
||||
val jitterFactor: Double,
|
||||
)
|
||||
|
||||
override suspend fun execute(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): List<Message.Response> {
|
||||
var attempt = 1
|
||||
var delay = retryPolicy.initialDelay
|
||||
var lastError: Throwable? = null
|
||||
|
||||
while (attempt <= retryPolicy.maxAttempts) {
|
||||
try {
|
||||
if (attempt > 1) {
|
||||
events?.onRetry(
|
||||
attempt,
|
||||
retryPolicy.maxAttempts,
|
||||
lastError?.javaClass?.simpleName
|
||||
)
|
||||
}
|
||||
return delegate.execute(prompt, model, tools)
|
||||
} catch (t: Throwable) {
|
||||
lastError = t
|
||||
if (attempt >= retryPolicy.maxAttempts) throw t
|
||||
|
||||
events?.onRetry(attempt + 1, retryPolicy.maxAttempts, t.javaClass.simpleName)
|
||||
|
||||
val jitter = (delay.inWholeMilliseconds * retryPolicy.jitterFactor).toLong()
|
||||
val jitteredMs =
|
||||
delay.inWholeMilliseconds + (if (jitter > 0) (0..jitter).random() else 0)
|
||||
delay(jitteredMs)
|
||||
|
||||
val nextMs = (delay.inWholeMilliseconds * retryPolicy.backoffMultiplier).toLong()
|
||||
delay = nextMs.milliseconds.coerceAtMost(retryPolicy.maxDelay)
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
throw lastError ?: IllegalStateException("Retry loop ended without result")
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
(delegate as? AutoCloseable)?.close()
|
||||
}
|
||||
}
|
||||
|
|
@ -457,6 +457,17 @@ class AgentEventHandler(
|
|||
.onCreditsChanged(event)
|
||||
}
|
||||
|
||||
override fun onRetry(attempt: Int, maxAttempts: Int, reason: String?) {
|
||||
val suffix = "($attempt/$maxAttempts)"
|
||||
val base = when {
|
||||
reason?.contains("timeout", ignoreCase = true) == true -> "Request timed out, retrying"
|
||||
else -> "Retrying"
|
||||
}
|
||||
runInEdt {
|
||||
onShowLoading("$base $suffix")
|
||||
}
|
||||
}
|
||||
|
||||
override fun onHistoryCompressionStateChanged(isCompressing: Boolean) {
|
||||
val key =
|
||||
if (isCompressing) "toolwindow.chat.compressingHistory" else "toolwindow.chat.loading"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue