mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-17 03:57:27 +00:00
refactor: use LLMCapability to decide on chat completion vs responses api
This commit is contained in:
parent
8fdeba74da
commit
80b2490ca9
23 changed files with 753 additions and 1241 deletions
|
|
@ -130,12 +130,21 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
|
||||
public ChatMessageResponseBody withResponse(@NotNull String response) {
|
||||
try {
|
||||
for (var item : new CompleteMessageParser().parse(response)) {
|
||||
var parser = new CompleteMessageParser();
|
||||
var segments = parser.parse(response);
|
||||
if (parser.getExtractedThought() != null && !parser.getExtractedThought().isBlank()) {
|
||||
processThinkingOutput(parser.getExtractedThought());
|
||||
}
|
||||
for (var item : segments) {
|
||||
processResponse(item, false);
|
||||
currentlyProcessedTextPane = null;
|
||||
currentlyProcessedEditorPanel = null;
|
||||
currentlyProcessedMermaidPanel = null;
|
||||
}
|
||||
var thoughtProcessPanel = getExistingThoughtProcessPanel();
|
||||
if (thoughtProcessPanel != null && !thoughtProcessPanel.isFinished()) {
|
||||
thoughtProcessPanel.setFinished();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
LOG.error("Something went wrong while processing input", e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,21 +13,33 @@ import ai.koog.agents.core.environment.ReceivedToolResult
|
|||
import ai.koog.agents.core.environment.result
|
||||
import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext
|
||||
import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext
|
||||
import ai.koog.agents.core.tools.ToolDescriptor
|
||||
import ai.koog.agents.core.tools.ToolRegistry
|
||||
import ai.koog.agents.ext.tool.ExitTool
|
||||
import ai.koog.agents.ext.tool.shell.ShellCommandConfirmation
|
||||
import ai.koog.agents.features.eventHandler.feature.handleEvents
|
||||
import ai.koog.agents.features.tokenizer.feature.MessageTokenizer
|
||||
import ai.koog.agents.features.tokenizer.feature.tokenizer
|
||||
import ai.koog.prompt.dsl.Prompt
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.executor.clients.anthropic.AnthropicParams
|
||||
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicThinking
|
||||
import ai.koog.prompt.executor.clients.LLMClient
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIResponsesParams
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.ReasoningEffort
|
||||
import ai.koog.prompt.executor.clients.openai.models.ReasoningConfig
|
||||
import ai.koog.prompt.executor.clients.openai.models.ReasoningSummary
|
||||
import ai.koog.prompt.executor.model.PromptExecutor
|
||||
import ai.koog.prompt.llm.LLMCapability
|
||||
import ai.koog.prompt.llm.LLMProvider
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.params.LLMParams
|
||||
import ai.koog.prompt.tokenizer.Tokenizer
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.agent.clients.CustomOpenAILLMClient
|
||||
import ee.carlrobert.codegpt.agent.clients.RetryingPromptExecutor
|
||||
import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
|
|
@ -46,6 +58,8 @@ import kotlin.time.Duration.Companion.seconds
|
|||
object AgentFactory {
|
||||
|
||||
private const val MAX_AGENT_ITERATIONS = 250
|
||||
private const val ANTHROPIC_MIN_THINKING_BUDGET = 512
|
||||
private const val ANTHROPIC_DEFAULT_THINKING_BUDGET = 2_048
|
||||
|
||||
fun createAgent(
|
||||
agentType: AgentType,
|
||||
|
|
@ -170,10 +184,6 @@ object AgentFactory {
|
|||
featureType: FeatureType = FeatureType.AGENT
|
||||
): PromptExecutor {
|
||||
val llmClient = LLMClientFactory.createClient(provider, featureType)
|
||||
return createRetryingExecutor(llmClient, events)
|
||||
}
|
||||
|
||||
private fun createRetryingExecutor(client: LLMClient, events: AgentEvents?): PromptExecutor {
|
||||
val policy = RetryingPromptExecutor.RetryPolicy(
|
||||
maxAttempts = 5,
|
||||
initialDelay = 1.seconds,
|
||||
|
|
@ -181,7 +191,104 @@ object AgentFactory {
|
|||
backoffMultiplier = 2.0,
|
||||
jitterFactor = 0.1
|
||||
)
|
||||
return RetryingPromptExecutor.fromClient(client, policy, events)
|
||||
return createRetryingExecutor(llmClient, policy, events)
|
||||
}
|
||||
|
||||
internal fun createRetryingExecutor(
|
||||
client: LLMClient,
|
||||
policy: RetryingPromptExecutor.RetryPolicy,
|
||||
events: AgentEvents?
|
||||
): PromptExecutor {
|
||||
val executor = RetryingPromptExecutor.fromClient(client, policy, events)
|
||||
return object : PromptExecutor {
|
||||
override fun executeStreaming(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
) = executor.executeStreaming(prompt.withReasoningParams(model), model, tools)
|
||||
|
||||
override suspend fun execute(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
) = executor.execute(prompt.withReasoningParams(model), model, tools)
|
||||
|
||||
override suspend fun moderate(prompt: Prompt, model: LLModel) =
|
||||
executor.moderate(prompt, model)
|
||||
|
||||
override suspend fun models() = executor.models()
|
||||
|
||||
override fun close() = executor.close()
|
||||
}
|
||||
}
|
||||
|
||||
private fun Prompt.withReasoningParams(model: LLModel): Prompt {
|
||||
val params = when (model.provider) {
|
||||
LLMProvider.OpenAI -> params.withOpenAIReasoning()
|
||||
CustomOpenAILLMClient.CustomOpenAI -> {
|
||||
if (model.supports(LLMCapability.OpenAIEndpoint.Responses)) {
|
||||
params.withOpenAIReasoning()
|
||||
} else {
|
||||
params
|
||||
}
|
||||
}
|
||||
LLMProvider.Anthropic -> params.withAnthropicReasoning()
|
||||
else -> params
|
||||
}
|
||||
return withParams(params)
|
||||
}
|
||||
|
||||
private fun LLMParams.withOpenAIReasoning(): LLMParams {
|
||||
val base = when (this) {
|
||||
is OpenAIResponsesParams -> this
|
||||
else -> OpenAIResponsesParams(
|
||||
temperature = temperature,
|
||||
maxTokens = maxTokens,
|
||||
numberOfChoices = numberOfChoices,
|
||||
speculation = speculation,
|
||||
schema = schema,
|
||||
toolChoice = toolChoice,
|
||||
user = user,
|
||||
additionalProperties = additionalProperties
|
||||
)
|
||||
}
|
||||
return base.copy(
|
||||
reasoning = base.reasoning ?: ReasoningConfig(
|
||||
effort = ReasoningEffort.MEDIUM,
|
||||
summary = ReasoningSummary.AUTO
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun LLMParams.withAnthropicReasoning(): LLMParams {
|
||||
val base = when (this) {
|
||||
is AnthropicParams -> this
|
||||
else -> AnthropicParams(
|
||||
temperature = temperature,
|
||||
maxTokens = maxTokens,
|
||||
numberOfChoices = numberOfChoices,
|
||||
speculation = speculation,
|
||||
schema = schema,
|
||||
toolChoice = toolChoice,
|
||||
user = user,
|
||||
additionalProperties = additionalProperties
|
||||
)
|
||||
}
|
||||
|
||||
if (base.thinking != null) return base
|
||||
|
||||
val thinkingBudget = resolveAnthropicThinkingBudget(base.maxTokens) ?: return base
|
||||
return base.copy(thinking = AnthropicThinking.Enabled(budgetTokens = thinkingBudget))
|
||||
}
|
||||
|
||||
private fun resolveAnthropicThinkingBudget(maxTokens: Int?): Int? {
|
||||
val limit = maxTokens ?: ANTHROPIC_DEFAULT_THINKING_BUDGET
|
||||
if (limit <= ANTHROPIC_MIN_THINKING_BUDGET) {
|
||||
return null
|
||||
}
|
||||
return (limit / 2)
|
||||
.coerceAtLeast(ANTHROPIC_MIN_THINKING_BUDGET)
|
||||
.coerceAtMost(ANTHROPIC_DEFAULT_THINKING_BUDGET)
|
||||
}
|
||||
|
||||
private fun createGeneralPurposeAgent(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import com.intellij.openapi.components.service
|
|||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStream
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
|
||||
import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION
|
||||
import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig
|
||||
import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider
|
||||
|
|
@ -35,6 +34,7 @@ import ee.carlrobert.codegpt.settings.hooks.HookManager
|
|||
import ee.carlrobert.codegpt.settings.models.ModelSettings
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest
|
||||
|
|
@ -89,7 +89,7 @@ object ProxyAIAgent {
|
|||
val modelSelection =
|
||||
service<ModelSettings>().getModelSelectionForFeature(FeatureType.AGENT)
|
||||
val skills = project.service<SkillDiscoveryService>().listSkills()
|
||||
val stream = shouldStreamAgentToolLoop(provider)
|
||||
val stream = shouldStreamAgentToolLoop(project, provider)
|
||||
val projectInstructions = loadProjectInstructions(project.basePath)
|
||||
val executor = AgentFactory.createExecutor(provider, events)
|
||||
val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
|
||||
|
|
@ -163,11 +163,20 @@ object ProxyAIAgent {
|
|||
val toolCallToUiId: MutableMap<String, String> = HashMap()
|
||||
val anonymousToolIds: ArrayDeque<String> = ArrayDeque()
|
||||
val frameAdapter = ReasoningFrameTextAdapter()
|
||||
var streamedReasoningForCurrentNode = false
|
||||
|
||||
onLLMStreamingFrameReceived { ctx ->
|
||||
if (!stream) return@onLLMStreamingFrameReceived
|
||||
|
||||
frameAdapter.consume(ctx.streamFrame).forEach { chunk ->
|
||||
val frameType = ctx.streamFrame::class.simpleName
|
||||
?: ctx.streamFrame::class.qualifiedName
|
||||
?: "unknown"
|
||||
val chunks = frameAdapter.consume(ctx.streamFrame)
|
||||
if (frameType.contains("Reasoning") && chunks.isNotEmpty()) {
|
||||
streamedReasoningForCurrentNode = true
|
||||
}
|
||||
|
||||
chunks.forEach { chunk ->
|
||||
if (chunk.isNotEmpty()) {
|
||||
events.onTextReceived(chunk)
|
||||
}
|
||||
|
|
@ -175,9 +184,22 @@ object ProxyAIAgent {
|
|||
}
|
||||
|
||||
onNodeExecutionCompleted { ctx ->
|
||||
if (stream) return@onNodeExecutionCompleted
|
||||
val output = (ctx.output as? List<*>) ?: emptyList<Any?>()
|
||||
if (stream) {
|
||||
if (!streamedReasoningForCurrentNode) {
|
||||
output.forEach { msg ->
|
||||
(msg as? Message.Reasoning)?.let {
|
||||
if (it.content.isNotBlank()) {
|
||||
events.onTextReceived("<think>${it.content}</think>")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
streamedReasoningForCurrentNode = false
|
||||
return@onNodeExecutionCompleted
|
||||
}
|
||||
|
||||
(ctx.output as? List<*>)?.forEach { msg ->
|
||||
output.forEach { msg ->
|
||||
(msg as? Message.Assistant)?.let {
|
||||
events.onTextReceived(it.content)
|
||||
}
|
||||
|
|
@ -268,10 +290,19 @@ object ProxyAIAgent {
|
|||
}
|
||||
|
||||
private fun shouldStreamAgentToolLoop(
|
||||
project: Project,
|
||||
provider: ServiceType,
|
||||
): Boolean {
|
||||
return when (provider) {
|
||||
ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(FeatureType.AGENT)
|
||||
ServiceType.CUSTOM_OPENAI -> {
|
||||
val selectedModel =
|
||||
service<ModelSettings>().getModelSelectionForFeature(FeatureType.AGENT)
|
||||
val selectedServiceId = selectedModel.serviceId
|
||||
val selectedService = service<CustomServicesSettings>().state.services
|
||||
.firstOrNull { it.id == selectedServiceId }
|
||||
selectedService?.chatCompletionSettings?.shouldStream() == true
|
||||
}
|
||||
|
||||
ServiceType.GOOGLE -> false
|
||||
else -> true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ import kotlinx.serialization.json.JsonElement
|
|||
|
||||
@Serializable
|
||||
class CustomOpenAIChatCompletionRequest(
|
||||
val messages: List<OpenAIMessage> = emptyList(),
|
||||
val messages: List<OpenAIMessage>? = null,
|
||||
val input: JsonElement? = null,
|
||||
val prompt: String? = null,
|
||||
override val model: String? = null,
|
||||
override val stream: Boolean? = null,
|
||||
|
|
|
|||
|
|
@ -1,93 +1,68 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.dsl.ModerationResult
|
||||
import ai.koog.agents.core.tools.ToolDescriptor
|
||||
import ai.koog.prompt.dsl.Prompt
|
||||
import ai.koog.prompt.executor.clients.ConnectionTimeoutConfig
|
||||
import ai.koog.prompt.executor.clients.LLMClient
|
||||
import ai.koog.prompt.executor.clients.LLMClientException
|
||||
import ai.koog.prompt.executor.clients.openai.base.AbstractOpenAILLMClient
|
||||
import ai.koog.prompt.executor.clients.openai.base.OpenAIBaseSettings
|
||||
import ai.koog.prompt.executor.clients.openai.base.OpenAICompatibleToolDescriptorSchemaGenerator
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIResponsesParams
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.*
|
||||
import ai.koog.prompt.llm.LLMCapability
|
||||
import ai.koog.prompt.llm.LLMProvider
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ai.koog.prompt.message.LLMChoice
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.message.ResponseMetaInfo
|
||||
import ai.koog.prompt.params.LLMParams
|
||||
import ai.koog.prompt.streaming.StreamFrame
|
||||
import ai.koog.prompt.streaming.buildStreamFrameFlow
|
||||
import com.fasterxml.jackson.core.type.TypeReference
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
|
||||
import ee.carlrobert.codegpt.codecompletions.InfillRequest
|
||||
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
|
||||
import ee.carlrobert.codegpt.settings.Placeholder
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceCodeCompletionSettingsState
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceCodeCompletionSettingsState
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicePlaceholders
|
||||
import ee.carlrobert.codegpt.completions.factory.ResponsesApiUtil
|
||||
import ee.carlrobert.codegpt.util.JsonMapper
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import io.ktor.client.*
|
||||
import io.ktor.client.plugins.*
|
||||
import io.ktor.client.plugins.api.createClientPlugin
|
||||
import io.ktor.client.request.*
|
||||
import io.ktor.http.HttpHeaders
|
||||
import org.apache.commons.text.StringEscapeUtils
|
||||
import io.ktor.http.ContentType
|
||||
import io.ktor.http.content.TextContent
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.KSerializer
|
||||
import kotlinx.serialization.builtins.ListSerializer
|
||||
import kotlinx.serialization.descriptors.elementNames
|
||||
import kotlinx.serialization.json.*
|
||||
import org.apache.commons.text.StringEscapeUtils
|
||||
import java.net.URI
|
||||
import java.util.UUID
|
||||
import kotlin.io.encoding.ExperimentalEncodingApi
|
||||
import kotlin.time.Clock
|
||||
|
||||
/**
|
||||
* Configuration settings for connecting to the CustomOpenAI API.
|
||||
* Implementation of [LLMClient] for OpenAI-compatible custom providers.
|
||||
*
|
||||
* @property baseUrl The base URL of the CustomOpenAI API. Default is "https://CustomOpenAI.ai/api/v1".
|
||||
* @property timeoutConfig Configuration for connection timeouts including request, connection, and socket timeouts.
|
||||
*/
|
||||
class CustomOpenAIClientSettings(
|
||||
baseUrl: String,
|
||||
chatCompletionsPath: String,
|
||||
timeoutConfig: ConnectionTimeoutConfig
|
||||
) : OpenAIBaseSettings(baseUrl, chatCompletionsPath, timeoutConfig)
|
||||
|
||||
/**
|
||||
* Implementation of [LLMClient] for CustomOpenAI API.
|
||||
* CustomOpenAI is an API that routes requests to multiple LLM providers.
|
||||
*
|
||||
* @param apiKey The API key for the CustomOpenAI API
|
||||
* @param settings The base URL and timeouts for the CustomOpenAI API, defaults to "https://CustomOpenAI.ai" and 900s
|
||||
* @param clock Clock instance used for tracking response metadata timestamps.
|
||||
* Chat-completions requests keep the custom body/placeholder behavior.
|
||||
* Responses requests follow Koog's [OpenAILLMClient] path and only add custom params.
|
||||
*/
|
||||
class CustomOpenAILLMClient(
|
||||
private val apiKey: String,
|
||||
private val settings: CustomOpenAIClientSettings,
|
||||
settings: OpenAIClientSettings,
|
||||
private val chatState: CustomServiceChatCompletionSettingsState? = null,
|
||||
private val codeCompletionState: CustomServiceCodeCompletionSettingsState? = null,
|
||||
private val baseClient: HttpClient = HttpClient(),
|
||||
clock: Clock = Clock.System
|
||||
) : AbstractOpenAILLMClient<CustomOpenAIChatCompletionResponse, CustomOpenAIChatCompletionStreamResponse>(
|
||||
apiKey,
|
||||
settings,
|
||||
baseClient,
|
||||
clock,
|
||||
staticLogger,
|
||||
OpenAICompatibleToolDescriptorSchemaGenerator()
|
||||
),
|
||||
CodeCompletionCapable {
|
||||
) : OpenAILLMClient(
|
||||
apiKey = apiKey,
|
||||
settings = settings,
|
||||
baseClient = baseClient,
|
||||
clock = clock,
|
||||
), CodeCompletionCapable {
|
||||
|
||||
data object CustomOpenAI : LLMProvider("custom-openai", "Custom OpenAI")
|
||||
|
||||
private val isResponsesApi: Boolean = ResponsesApiUtil.isResponsesApiUrl(chatState?.url)
|
||||
|
||||
companion object {
|
||||
private val staticLogger = KotlinLogging.logger { }
|
||||
|
||||
init {
|
||||
registerOpenAIJsonSchemaGenerators(CustomOpenAI)
|
||||
}
|
||||
|
|
@ -103,6 +78,7 @@ class CustomOpenAILLMClient(
|
|||
timeoutConfig = timeoutConfig
|
||||
)
|
||||
val clientWithCustomHeaders = baseClient.config {
|
||||
install(FlattenCustomOpenAIAdditionalPropertiesPlugin)
|
||||
defaultRequest {
|
||||
state.headers.forEach { (key, value) ->
|
||||
val normalizedKey = key.trim()
|
||||
|
|
@ -147,12 +123,17 @@ class CustomOpenAILLMClient(
|
|||
private fun createClientSettings(
|
||||
url: String,
|
||||
timeoutConfig: ConnectionTimeoutConfig
|
||||
): CustomOpenAIClientSettings {
|
||||
): OpenAIClientSettings {
|
||||
val uri = URI.create(url)
|
||||
val authority = uri.authority ?: uri.host
|
||||
return CustomOpenAIClientSettings(
|
||||
val path = buildString {
|
||||
append(uri.path)
|
||||
uri.query?.takeIf { it.isNotBlank() }?.let { append("?").append(it) }
|
||||
}
|
||||
return OpenAIClientSettings(
|
||||
baseUrl = "${uri.scheme}://${authority}",
|
||||
chatCompletionsPath = uri.path,
|
||||
chatCompletionsPath = path,
|
||||
responsesAPIPath = path,
|
||||
timeoutConfig = timeoutConfig
|
||||
)
|
||||
}
|
||||
|
|
@ -160,10 +141,67 @@ class CustomOpenAILLMClient(
|
|||
|
||||
override fun llmProvider(): LLMProvider = CustomOpenAI
|
||||
|
||||
override suspend fun getCodeCompletion(infillRequest: InfillRequest): String {
|
||||
val state = requireNotNull(codeCompletionState) {
|
||||
private fun LLModel.hasResponsesEndpointCapability(): Boolean =
|
||||
this.capabilities?.any { it == LLMCapability.OpenAIEndpoint.Responses } == true
|
||||
|
||||
private fun requireChatState(): CustomServiceChatCompletionSettingsState {
|
||||
return requireNotNull(chatState) {
|
||||
"Custom OpenAI chat request requested on a code-completion-only client"
|
||||
}
|
||||
}
|
||||
|
||||
private fun requireCodeCompletionState(): CustomServiceCodeCompletionSettingsState {
|
||||
return requireNotNull(codeCompletionState) {
|
||||
"Custom OpenAI code completion requested on a chat-only client"
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun execute(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): List<Message.Response> {
|
||||
val state = requireChatState()
|
||||
val finalPrompt = prompt.prepareForModel(model, state)
|
||||
return super.execute(finalPrompt, model, tools)
|
||||
}
|
||||
|
||||
override fun executeStreaming(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): Flow<StreamFrame> {
|
||||
val state = requireChatState()
|
||||
val finalPrompt = prompt.prepareForModel(model, state)
|
||||
return super.executeStreaming(finalPrompt, model, tools)
|
||||
}
|
||||
|
||||
override suspend fun executeMultipleChoices(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): List<LLMChoice> {
|
||||
val state = requireChatState()
|
||||
return super.executeMultipleChoices(prompt.prepareForModel(model, state), model, tools)
|
||||
}
|
||||
|
||||
private fun Prompt.prepareForModel(
|
||||
model: LLModel,
|
||||
state: CustomServiceChatCompletionSettingsState
|
||||
): Prompt {
|
||||
return when {
|
||||
model.hasResponsesEndpointCapability() ->
|
||||
withParams(params.toCustomOpenAIResponsesParams(state))
|
||||
|
||||
params is OpenAIResponsesParams ->
|
||||
withParams(params.toCustomOpenAIParams(state))
|
||||
|
||||
else -> this
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun getCodeCompletion(infillRequest: InfillRequest): String {
|
||||
val state = requireCodeCompletionState()
|
||||
val url = requireNotNull(state.url)
|
||||
val payload = postCompletionJson(
|
||||
client = baseClient,
|
||||
|
|
@ -190,14 +228,7 @@ class CustomOpenAILLMClient(
|
|||
params: LLMParams,
|
||||
stream: Boolean
|
||||
): String {
|
||||
val state = requireNotNull(chatState) {
|
||||
"Custom OpenAI chat request requested on a code-completion-only client"
|
||||
}
|
||||
|
||||
if (isResponsesApi) {
|
||||
return serializeResponsesApiRequest(state, messages, model, tools, toolChoice)
|
||||
}
|
||||
|
||||
val state = requireChatState()
|
||||
val customParams: CustomOpenAIParams = params.toCustomOpenAIParams(state)
|
||||
val streamRequest = state.shouldStream()
|
||||
val additionalProperties = buildCustomOpenAIAdditionalProperties(
|
||||
|
|
@ -238,44 +269,6 @@ class CustomOpenAILLMClient(
|
|||
return json.encodeToString(CustomOpenAIChatCompletionRequestSerializer, request)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the Responses API request body directly from the template body configuration.
|
||||
* The template body uses "input" instead of "messages" and "max_output_tokens" instead
|
||||
* of "max_tokens", so we process it as-is to produce the correct format.
|
||||
*/
|
||||
private fun serializeResponsesApiRequest(
|
||||
state: CustomServiceChatCompletionSettingsState,
|
||||
messages: List<OpenAIMessage>,
|
||||
model: LLModel,
|
||||
tools: List<OpenAITool>?,
|
||||
toolChoice: OpenAIToolChoice?
|
||||
): String {
|
||||
val streamRequest = state.shouldStream()
|
||||
val inputJson = messages.toResponsesApiItemsJson()
|
||||
val prompt = renderCustomOpenAIPrompt(messages, json)
|
||||
|
||||
return buildJsonObject {
|
||||
state.body.forEach { (key, value) ->
|
||||
put(
|
||||
key, transformCustomOpenAIBodyValue(
|
||||
key = key,
|
||||
value = value,
|
||||
streamRequest = streamRequest,
|
||||
messagesJson = inputJson,
|
||||
prompt = prompt,
|
||||
credential = apiKey,
|
||||
json = json
|
||||
)
|
||||
)
|
||||
}
|
||||
put("model", JsonPrimitive(model.id))
|
||||
if (!tools.isNullOrEmpty()) {
|
||||
put("tools", JsonArray(tools.map { it.toResponsesApiToolJson() }))
|
||||
}
|
||||
toolChoice?.toResponsesApiToolChoiceJson()?.let { put("tool_choice", it) }
|
||||
}.toString()
|
||||
}
|
||||
|
||||
private fun buildCodeCompletionRequestBody(
|
||||
state: CustomServiceCodeCompletionSettingsState,
|
||||
infillRequest: InfillRequest
|
||||
|
|
@ -285,16 +278,16 @@ class CustomOpenAILLMClient(
|
|||
mapOf(
|
||||
"role" to "system",
|
||||
"content" to (
|
||||
"You are a code completion assistant. Complete the code between the given prefix and suffix. " +
|
||||
"Return only the missing code that should be inserted, without any formatting, explanations, or markdown."
|
||||
)
|
||||
"You are a code completion assistant. Complete the code between the given prefix and suffix. " +
|
||||
"Return only the missing code that should be inserted, without any formatting, explanations, or markdown."
|
||||
)
|
||||
),
|
||||
mapOf(
|
||||
"role" to "user",
|
||||
"content" to (
|
||||
"<PREFIX>\n${infillRequest.prefix}\n</PREFIX>\n\n" +
|
||||
"<SUFFIX>\n${infillRequest.suffix}\n</SUFFIX>\n\nComplete:"
|
||||
)
|
||||
"<PREFIX>\n${infillRequest.prefix}\n</PREFIX>\n\n" +
|
||||
"<SUFFIX>\n${infillRequest.suffix}\n</SUFFIX>\n\nComplete:"
|
||||
)
|
||||
)
|
||||
)
|
||||
val transformedBody = state.body.entries.mapNotNull { (key, value) ->
|
||||
|
|
@ -416,283 +409,6 @@ class CustomOpenAILLMClient(
|
|||
?: replaced
|
||||
}
|
||||
}
|
||||
|
||||
override fun processProviderChatResponse(response: CustomOpenAIChatCompletionResponse): List<LLMChoice> {
|
||||
require(response.choices.isNotEmpty()) { "Empty choices in response" }
|
||||
return response.choices.map {
|
||||
it.message.toResponses(it.finishReason, createMetaInfo(response.usage))
|
||||
}
|
||||
}
|
||||
|
||||
override fun decodeStreamingResponse(data: String): CustomOpenAIChatCompletionStreamResponse {
|
||||
val payload = normalizeSsePayload(data)
|
||||
?: return CustomOpenAIChatCompletionStreamResponse(
|
||||
choices = emptyList(),
|
||||
created = 0,
|
||||
id = "",
|
||||
model = ""
|
||||
)
|
||||
if (!isResponsesApi) {
|
||||
return json.decodeFromString(payload)
|
||||
}
|
||||
return adaptResponsesApiStreamEvent(payload)
|
||||
}
|
||||
|
||||
override fun decodeResponse(data: String): CustomOpenAIChatCompletionResponse {
|
||||
if (!isResponsesApi) {
|
||||
return json.decodeFromString(data)
|
||||
}
|
||||
return adaptResponsesApiResponse(data)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adapts a Responses API SSE event into a [CustomOpenAIChatCompletionStreamResponse].
|
||||
* Maps Responses API event types to the chat completions delta format that
|
||||
* [processStreamingResponse] already knows how to handle.
|
||||
*/
|
||||
private fun adaptResponsesApiStreamEvent(data: String): CustomOpenAIChatCompletionStreamResponse {
|
||||
val event = json.parseToJsonElement(data).jsonObject
|
||||
val type = event["type"]?.jsonPrimitive?.contentOrNull ?: ""
|
||||
|
||||
val choice = when (type) {
|
||||
"response.output_text.delta" -> {
|
||||
val delta = event["delta"]?.jsonPrimitive?.contentOrNull ?: ""
|
||||
CustomOpenAIStreamChoice(
|
||||
delta = CustomOpenAIStreamDelta(content = delta)
|
||||
)
|
||||
}
|
||||
|
||||
"response.function_call_arguments.delta",
|
||||
"response.output_item.added" -> null
|
||||
|
||||
"response.completed" -> {
|
||||
val responseObject = event["response"]?.jsonObject
|
||||
val toolCalls = responseObject?.get("output")
|
||||
?.jsonArray
|
||||
?.mapIndexedNotNull { index, item ->
|
||||
val itemObject = item.jsonObject
|
||||
if (itemObject["type"]?.jsonPrimitive?.contentOrNull != "function_call") {
|
||||
return@mapIndexedNotNull null
|
||||
}
|
||||
|
||||
val rawArguments = when (val arguments = itemObject["arguments"]) {
|
||||
is JsonPrimitive -> arguments.contentOrNull
|
||||
is JsonObject -> arguments.toString()
|
||||
else -> null
|
||||
}
|
||||
CustomOpenAIToolCall(
|
||||
id = itemObject["call_id"]?.jsonPrimitive?.contentOrNull,
|
||||
index = index,
|
||||
function = CustomOpenAIFunction(
|
||||
name = itemObject["name"]?.jsonPrimitive?.contentOrNull,
|
||||
arguments = normalizeToolArgumentsJson(rawArguments) ?: rawArguments.orEmpty()
|
||||
)
|
||||
)
|
||||
}
|
||||
.orEmpty()
|
||||
CustomOpenAIStreamChoice(
|
||||
finishReason = "stop",
|
||||
delta = CustomOpenAIStreamDelta(
|
||||
toolCalls = toolCalls.takeIf { it.isNotEmpty() }
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
|
||||
return CustomOpenAIChatCompletionStreamResponse(
|
||||
choices = listOfNotNull(choice),
|
||||
created = 0,
|
||||
id = "",
|
||||
model = ""
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adapts a non-streaming Responses API response into a [CustomOpenAIChatCompletionResponse].
|
||||
* Builds a synthetic chat-completions-format JSON and deserializes it, avoiding direct
|
||||
* construction of Koog internal types.
|
||||
*/
|
||||
private fun adaptResponsesApiResponse(data: String): CustomOpenAIChatCompletionResponse {
|
||||
val response = json.parseToJsonElement(data).jsonObject
|
||||
val output = response["output"]?.jsonArray ?: JsonArray(emptyList())
|
||||
|
||||
val textContent = StringBuilder()
|
||||
val toolCallsJson = mutableListOf<JsonElement>()
|
||||
|
||||
for (item in output) {
|
||||
val itemObj = item.jsonObject
|
||||
when (itemObj["type"]?.jsonPrimitive?.contentOrNull) {
|
||||
"message" -> {
|
||||
itemObj["content"]?.jsonArray?.forEach { contentPart ->
|
||||
val partObj = contentPart.jsonObject
|
||||
if (partObj["type"]?.jsonPrimitive?.contentOrNull == "output_text") {
|
||||
textContent.append(partObj["text"]?.jsonPrimitive?.contentOrNull ?: "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"function_call" -> {
|
||||
val rawArguments = when (val arguments = itemObj["arguments"]) {
|
||||
is JsonPrimitive -> arguments.contentOrNull
|
||||
is JsonObject -> arguments.toString()
|
||||
else -> null
|
||||
}
|
||||
toolCallsJson.add(buildJsonObject {
|
||||
put("id", itemObj["call_id"] ?: JsonPrimitive(""))
|
||||
put("type", JsonPrimitive("function"))
|
||||
putJsonObject("function") {
|
||||
put("name", itemObj["name"] ?: JsonPrimitive(""))
|
||||
put(
|
||||
"arguments",
|
||||
JsonPrimitive(normalizeToolArgumentsJson(rawArguments) ?: rawArguments ?: "{}")
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val syntheticJson = buildJsonObject {
|
||||
put("id", response["id"] ?: JsonPrimitive(""))
|
||||
put("created", JsonPrimitive(0))
|
||||
put("model", response["model"] ?: JsonPrimitive(""))
|
||||
put("object", JsonPrimitive("chat.completion"))
|
||||
putJsonArray("choices") {
|
||||
addJsonObject {
|
||||
put("finish_reason", JsonPrimitive("stop"))
|
||||
putJsonObject("message") {
|
||||
put("role", JsonPrimitive("assistant"))
|
||||
if (textContent.isNotEmpty()) {
|
||||
put("content", JsonPrimitive(textContent.toString()))
|
||||
}
|
||||
if (toolCallsJson.isNotEmpty()) {
|
||||
put("tool_calls", JsonArray(toolCallsJson))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response["usage"]?.let { put("usage", parseResponsesApiUsageJson(it)) }
|
||||
}
|
||||
|
||||
return json.decodeFromString(syntheticJson.toString())
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts Responses API usage JSON (input_tokens/output_tokens) to
|
||||
* OpenAI-compatible usage JSON (prompt_tokens/completion_tokens/total_tokens).
|
||||
*/
|
||||
private fun parseResponsesApiUsageJson(usageElement: JsonElement): JsonElement {
|
||||
val usageObj = usageElement.jsonObject
|
||||
val inputTokens = usageObj["input_tokens"]?.jsonPrimitive?.intOrNull ?: 0
|
||||
val outputTokens = usageObj["output_tokens"]?.jsonPrimitive?.intOrNull ?: 0
|
||||
return buildJsonObject {
|
||||
put("prompt_tokens", JsonPrimitive(inputTokens))
|
||||
put("completion_tokens", JsonPrimitive(outputTokens))
|
||||
put("total_tokens", JsonPrimitive(inputTokens + outputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
override fun processStreamingResponse(
|
||||
response: Flow<CustomOpenAIChatCompletionStreamResponse>
|
||||
): Flow<StreamFrame> = buildStreamFrameFlow {
|
||||
var finishReason: String? = null
|
||||
var metaInfo: ResponseMetaInfo? = null
|
||||
|
||||
response.collect { chunk ->
|
||||
chunk.choices.firstOrNull()?.let { choice ->
|
||||
choice.delta.content?.let { emitTextDelta(it) }
|
||||
|
||||
choice.delta.toolCalls?.forEach { openAIToolCall ->
|
||||
val index = openAIToolCall.index ?: 0
|
||||
val id = openAIToolCall.id.orEmpty()
|
||||
val functionName = openAIToolCall.function.name.orEmpty()
|
||||
val functionArgs = openAIToolCall.function.arguments.orEmpty()
|
||||
emitToolCallDelta(id, functionName, functionArgs, index)
|
||||
}
|
||||
|
||||
choice.finishReason?.let { finishReason = it }
|
||||
}
|
||||
|
||||
chunk.usage?.let { metaInfo = createMetaInfo(it) }
|
||||
}
|
||||
|
||||
emitEnd(finishReason, metaInfo)
|
||||
}
|
||||
|
||||
override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult {
|
||||
throw UnsupportedOperationException("Moderation not supported.")
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalEncodingApi::class)
|
||||
private fun OpenAIMessage.toResponses(
|
||||
finishReason: String?,
|
||||
metaInfo: ResponseMetaInfo
|
||||
): List<Message.Response> {
|
||||
val contentText = content?.text()
|
||||
|
||||
if (this is OpenAIMessage.Assistant) {
|
||||
val assistantToolCalls = toolCalls
|
||||
if (!assistantToolCalls.isNullOrEmpty()) {
|
||||
return buildList {
|
||||
contentText?.let {
|
||||
add(
|
||||
Message.Assistant(
|
||||
content = it,
|
||||
finishReason = finishReason,
|
||||
metaInfo = metaInfo
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
assistantToolCalls.forEach { toolCall ->
|
||||
val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}"
|
||||
add(
|
||||
Message.Tool.Call(
|
||||
id = toolCall.id,
|
||||
tool = toolCall.function.name,
|
||||
content = arguments,
|
||||
metaInfo = metaInfo
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val reasoning = reasoningContent
|
||||
if (reasoning != null && contentText != null) {
|
||||
return listOf(
|
||||
Message.Reasoning(
|
||||
content = reasoning,
|
||||
metaInfo = metaInfo
|
||||
),
|
||||
Message.Assistant(
|
||||
content = contentText,
|
||||
finishReason = finishReason,
|
||||
metaInfo = metaInfo
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (contentText != null) {
|
||||
return listOf(
|
||||
Message.Assistant(
|
||||
content = contentText,
|
||||
finishReason = finishReason,
|
||||
metaInfo = metaInfo
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
val exception = LLMClientException(
|
||||
clientName,
|
||||
"Unexpected response: no tool calls and no content"
|
||||
)
|
||||
logger.error(exception) { exception.message }
|
||||
throw exception
|
||||
}
|
||||
}
|
||||
|
||||
internal fun buildCustomOpenAIAdditionalProperties(
|
||||
|
|
@ -722,38 +438,22 @@ internal fun transformCustomOpenAIBodyValue(
|
|||
messages: List<OpenAIMessage>,
|
||||
credential: String,
|
||||
json: Json
|
||||
): JsonElement {
|
||||
return transformCustomOpenAIBodyValue(
|
||||
key = key,
|
||||
value = value,
|
||||
streamRequest = streamRequest,
|
||||
messagesJson = json.encodeToJsonElement(
|
||||
ListSerializer(OpenAIMessage.serializer()),
|
||||
messages
|
||||
),
|
||||
prompt = renderCustomOpenAIPrompt(messages, json),
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
}
|
||||
|
||||
private fun transformCustomOpenAIBodyValue(
|
||||
key: String?,
|
||||
value: Any?,
|
||||
streamRequest: Boolean,
|
||||
messagesJson: JsonElement,
|
||||
prompt: String,
|
||||
credential: String,
|
||||
json: Json
|
||||
): JsonElement {
|
||||
return when (value) {
|
||||
null -> JsonNull
|
||||
is JsonElement -> value
|
||||
is String -> when {
|
||||
!streamRequest && key == "stream" -> JsonPrimitive(false)
|
||||
CustomServicePlaceholders.isMessages(value) -> messagesJson
|
||||
CustomServicePlaceholders.isMessages(value) -> json.parseToJsonElement(
|
||||
json.encodeToString(ListSerializer(OpenAIMessage.serializer()), messages)
|
||||
)
|
||||
|
||||
CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(prompt)
|
||||
CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(
|
||||
renderCustomOpenAIPrompt(
|
||||
messages,
|
||||
json
|
||||
)
|
||||
)
|
||||
|
||||
value.contains($$"$CUSTOM_SERVICE_API_KEY") -> {
|
||||
JsonPrimitive(value.replace($$"$CUSTOM_SERVICE_API_KEY", credential))
|
||||
|
|
@ -772,8 +472,7 @@ private fun transformCustomOpenAIBodyValue(
|
|||
key = nestedKey.toString(),
|
||||
value = nestedValue,
|
||||
streamRequest = streamRequest,
|
||||
messagesJson = messagesJson,
|
||||
prompt = prompt,
|
||||
messages = messages,
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -786,8 +485,7 @@ private fun transformCustomOpenAIBodyValue(
|
|||
key = null,
|
||||
value = item,
|
||||
streamRequest = streamRequest,
|
||||
messagesJson = messagesJson,
|
||||
prompt = prompt,
|
||||
messages = messages,
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -800,8 +498,7 @@ private fun transformCustomOpenAIBodyValue(
|
|||
key = null,
|
||||
value = item,
|
||||
streamRequest = streamRequest,
|
||||
messagesJson = messagesJson,
|
||||
prompt = prompt,
|
||||
messages = messages,
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -819,172 +516,48 @@ internal fun renderCustomOpenAIPrompt(messages: List<OpenAIMessage>, json: Json)
|
|||
}
|
||||
}
|
||||
|
||||
private fun List<OpenAIMessage>.toResponsesApiItemsJson(): JsonArray {
|
||||
return JsonArray(
|
||||
buildList {
|
||||
for (message in this@toResponsesApiItemsJson) {
|
||||
addAll(message.toResponsesApiItemsJson())
|
||||
}
|
||||
}
|
||||
)
|
||||
private val customOpenAIRequestTransformJson = Json {
|
||||
ignoreUnknownKeys = true
|
||||
isLenient = true
|
||||
explicitNulls = false
|
||||
}
|
||||
|
||||
private fun OpenAIMessage.toResponsesApiItemsJson(): List<JsonObject> {
|
||||
return when (this) {
|
||||
is OpenAIMessage.System, is OpenAIMessage.Developer -> listOf(
|
||||
buildResponsesApiInputMessageItemJson(
|
||||
role = "developer",
|
||||
content = content.toResponsesApiInputContentJson()
|
||||
)
|
||||
private val FlattenCustomOpenAIAdditionalPropertiesPlugin = createClientPlugin(
|
||||
"FlattenCustomOpenAIAdditionalProperties"
|
||||
) {
|
||||
transformRequestBody { _, content, _ ->
|
||||
val requestBody = content as? String ?: return@transformRequestBody null
|
||||
val flattened = flattenSerializedAdditionalProperties(
|
||||
requestBody,
|
||||
customOpenAIRequestTransformJson
|
||||
)
|
||||
|
||||
is OpenAIMessage.User -> listOf(
|
||||
buildResponsesApiInputMessageItemJson(
|
||||
role = "user",
|
||||
content = content.toResponsesApiInputContentJson()
|
||||
)
|
||||
)
|
||||
|
||||
is OpenAIMessage.Assistant -> buildList {
|
||||
reasoningContent
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { reasoning ->
|
||||
add(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("reasoning"))
|
||||
put("id", JsonPrimitive(UUID.randomUUID().toString()))
|
||||
putJsonArray("summary") {
|
||||
addJsonObject {
|
||||
put("type", JsonPrimitive("summary_text"))
|
||||
put("text", JsonPrimitive(reasoning))
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
content?.text()
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { assistantText ->
|
||||
add(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("message"))
|
||||
put("role", JsonPrimitive("assistant"))
|
||||
putJsonArray("content") {
|
||||
addJsonObject {
|
||||
put("type", JsonPrimitive("output_text"))
|
||||
put("text", JsonPrimitive(assistantText))
|
||||
put("annotations", JsonArray(emptyList()))
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
toolCalls.orEmpty().forEach { toolCall ->
|
||||
val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}"
|
||||
add(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("function_call"))
|
||||
put("arguments", JsonPrimitive(arguments))
|
||||
put("call_id", JsonPrimitive(toolCall.id))
|
||||
put("name", JsonPrimitive(toolCall.function.name))
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
is OpenAIMessage.Tool -> listOf(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("function_call_output"))
|
||||
put("call_id", JsonPrimitive(toolCallId))
|
||||
put("output", JsonPrimitive(content?.text().orEmpty()))
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildResponsesApiInputMessageItemJson(
|
||||
role: String,
|
||||
content: JsonArray
|
||||
): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("type", JsonPrimitive("message"))
|
||||
put("role", JsonPrimitive(role))
|
||||
put("content", content)
|
||||
}
|
||||
}
|
||||
|
||||
private fun Content?.toResponsesApiInputContentJson(): JsonArray {
|
||||
if (this == null) {
|
||||
return JsonArray(emptyList())
|
||||
}
|
||||
|
||||
return when (this) {
|
||||
is Content.Parts -> JsonArray(value.mapNotNull { it.toResponsesApiInputContentJson() })
|
||||
else -> JsonArray(
|
||||
listOf(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("input_text"))
|
||||
put("text", JsonPrimitive(text()))
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun OpenAIContentPart.toResponsesApiInputContentJson(): JsonObject? {
|
||||
return when (this) {
|
||||
is OpenAIContentPart.Text -> buildJsonObject {
|
||||
put("type", JsonPrimitive("input_text"))
|
||||
put("text", JsonPrimitive(text))
|
||||
}
|
||||
|
||||
is OpenAIContentPart.Image -> buildJsonObject {
|
||||
put("type", JsonPrimitive("input_image"))
|
||||
imageUrl.detail?.let { put("detail", JsonPrimitive(it)) }
|
||||
put("imageUrl", JsonPrimitive(imageUrl.url))
|
||||
}
|
||||
|
||||
is OpenAIContentPart.File -> buildJsonObject {
|
||||
put("type", JsonPrimitive("input_file"))
|
||||
file.fileData?.let { put("fileData", JsonPrimitive(it)) }
|
||||
file.fileId?.let { put("fileId", JsonPrimitive(it)) }
|
||||
file.filename?.let { put("filename", JsonPrimitive(it)) }
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
internal fun OpenAITool.toResponsesApiToolJson(): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("type", JsonPrimitive("function"))
|
||||
put("name", JsonPrimitive(function.name))
|
||||
put(
|
||||
"parameters",
|
||||
function.parameters ?: buildJsonObject {
|
||||
put("type", JsonPrimitive("object"))
|
||||
putJsonObject("properties") {}
|
||||
putJsonArray("required") {}
|
||||
}
|
||||
)
|
||||
function.strict?.let { put("strict", JsonPrimitive(it)) }
|
||||
function.description?.takeIf { it.isNotBlank() }?.let {
|
||||
put("description", JsonPrimitive(it))
|
||||
if (flattened == requestBody) {
|
||||
null
|
||||
} else {
|
||||
TextContent(flattened, ContentType.Application.Json)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun OpenAIToolChoice.toResponsesApiToolChoiceJson(): JsonElement {
|
||||
return when (this) {
|
||||
is OpenAIToolChoice.Function -> buildJsonObject {
|
||||
put("type", JsonPrimitive("function"))
|
||||
put("name", JsonPrimitive(function.name))
|
||||
}
|
||||
internal fun flattenSerializedAdditionalProperties(
|
||||
requestBody: String,
|
||||
json: Json
|
||||
): String {
|
||||
val payload = runCatching { json.parseToJsonElement(requestBody) }.getOrNull()?.jsonObject
|
||||
?: return requestBody
|
||||
val additionalProperties = payload["additional_properties"] as? JsonObject
|
||||
?: return requestBody
|
||||
|
||||
is OpenAIToolChoice.Mode -> JsonPrimitive(value)
|
||||
val flattened = buildJsonObject {
|
||||
payload.entries
|
||||
.filterNot { (key, _) -> key == "additional_properties" }
|
||||
.forEach { (key, value) -> put(key, value) }
|
||||
additionalProperties.entries
|
||||
.filterNot { (key, _) -> payload.containsKey(key) }
|
||||
.forEach { (key, value) -> put(key, value) }
|
||||
}
|
||||
|
||||
return flattened.toString()
|
||||
}
|
||||
|
||||
internal object CustomOpenAIChatCompletionRequestSerializer :
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIResponsesParams
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.ReasoningEffort
|
||||
import ai.koog.prompt.executor.clients.openai.models.ReasoningConfig
|
||||
import ai.koog.prompt.executor.clients.openai.models.ReasoningSummary
|
||||
import ai.koog.prompt.params.LLMParams
|
||||
import ai.koog.prompt.params.LLMParams.ToolChoice
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
|
|
@ -19,7 +23,6 @@ internal fun LLMParams.toCustomOpenAIParams(state: CustomServiceChatCompletionSe
|
|||
additionalProperties = mergedAdditionalProperties,
|
||||
temperature = body.findValue("temperature").asDouble() ?: temperature,
|
||||
maxTokens = body.findValue("maxTokens", "max_tokens").asInt() ?: maxTokens,
|
||||
speculation = body.findValue("speculation").asString() ?: speculation,
|
||||
toolChoice = body.findValue("toolChoice", "tool_choice").asToolChoice() ?: toolChoice,
|
||||
frequencyPenalty = body.findValue("frequencyPenalty", "frequency_penalty").asDouble()
|
||||
?: current?.frequencyPenalty,
|
||||
|
|
@ -33,6 +36,37 @@ internal fun LLMParams.toCustomOpenAIParams(state: CustomServiceChatCompletionSe
|
|||
)
|
||||
}
|
||||
|
||||
internal fun LLMParams.toCustomOpenAIResponsesParams(
|
||||
state: CustomServiceChatCompletionSettingsState
|
||||
): OpenAIResponsesParams {
|
||||
val body = state.body
|
||||
val currentReasoning = (this as? OpenAIResponsesParams)?.reasoning
|
||||
val bodyReasoning = body.findValue("reasoning").asReasoningConfig()
|
||||
|
||||
val mergedAdditionalProperties = buildMap {
|
||||
body
|
||||
.filterKeys { it !in CUSTOM_OPENAI_RESPONSES_RESERVED_BODY_KEYS }
|
||||
.forEach { (key, value) -> put(key, value.toJsonElement()) }
|
||||
}.takeIf { it.isNotEmpty() }
|
||||
|
||||
return OpenAIResponsesParams(
|
||||
additionalProperties = mergedAdditionalProperties,
|
||||
temperature = body.findValue("temperature").asDouble() ?: temperature,
|
||||
maxTokens = body.findValue("maxOutputTokens", "max_output_tokens").asInt() ?: maxTokens,
|
||||
numberOfChoices = numberOfChoices,
|
||||
reasoning = when {
|
||||
bodyReasoning == null -> currentReasoning
|
||||
currentReasoning == null -> bodyReasoning
|
||||
else -> ReasoningConfig(
|
||||
effort = bodyReasoning.effort ?: currentReasoning.effort,
|
||||
summary = bodyReasoning.summary ?: currentReasoning.summary
|
||||
)
|
||||
},
|
||||
schema = schema,
|
||||
toolChoice = body.findValue("toolChoice", "tool_choice").asToolChoice() ?: toolChoice,
|
||||
)
|
||||
}
|
||||
|
||||
internal fun CustomServiceChatCompletionSettingsState.shouldStream(): Boolean {
|
||||
return body.findValue("stream").asBoolean() == true
|
||||
}
|
||||
|
|
@ -53,6 +87,18 @@ internal val CUSTOM_OPENAI_RESERVED_BODY_KEYS = setOf(
|
|||
"topP", "top_p",
|
||||
)
|
||||
|
||||
internal val CUSTOM_OPENAI_RESPONSES_RESERVED_BODY_KEYS = setOf(
|
||||
"input",
|
||||
"maxOutputTokens", "max_output_tokens",
|
||||
"messages",
|
||||
"model",
|
||||
"reasoning",
|
||||
"stream",
|
||||
"temperature",
|
||||
"toolChoice", "tool_choice",
|
||||
"tools",
|
||||
)
|
||||
|
||||
private fun Map<String, Any>.findValue(vararg keys: String): Any? {
|
||||
keys.forEach { key ->
|
||||
if (containsKey(key)) {
|
||||
|
|
@ -95,23 +141,6 @@ private fun Any?.asBoolean(): Boolean? = when (this) {
|
|||
else -> null
|
||||
}
|
||||
|
||||
private fun Any?.asJsonElementMap(): Map<String, JsonElement>? {
|
||||
if (this !is Map<*, *>) return null
|
||||
|
||||
return buildMap {
|
||||
this@asJsonElementMap.forEach { (key, value) ->
|
||||
key?.toString()
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { put(it, value.toJsonElement()) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun Any?.asString(): String? = when (this) {
|
||||
is String -> this
|
||||
else -> null
|
||||
}
|
||||
|
||||
private fun Any?.asStopList(): List<String>? = asStringList(allowEmpty = false)
|
||||
|
||||
private fun Any?.asStringList(allowEmpty: Boolean = false): List<String>? = when (this) {
|
||||
|
|
@ -155,6 +184,47 @@ private fun Any?.asToolChoice(): ToolChoice? = when (this) {
|
|||
else -> null
|
||||
}
|
||||
|
||||
private fun Any?.asReasoningConfig(): ReasoningConfig? {
|
||||
val value = this as? Map<*, *> ?: return null
|
||||
val effort = value["effort"].asReasoningEffort()
|
||||
val summary = value["summary"].asReasoningSummary()
|
||||
|
||||
if (effort == null && summary == null) {
|
||||
return null
|
||||
}
|
||||
|
||||
return ReasoningConfig(
|
||||
effort = effort,
|
||||
summary = summary
|
||||
)
|
||||
}
|
||||
|
||||
private fun Any?.asReasoningEffort(): ReasoningEffort? = when (this) {
|
||||
is ReasoningEffort -> this
|
||||
is String -> when (trim().lowercase()) {
|
||||
"none" -> ReasoningEffort.NONE
|
||||
"minimal" -> ReasoningEffort.MINIMAL
|
||||
"low" -> ReasoningEffort.LOW
|
||||
"medium" -> ReasoningEffort.MEDIUM
|
||||
"high" -> ReasoningEffort.HIGH
|
||||
else -> null
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
|
||||
private fun Any?.asReasoningSummary(): ReasoningSummary? = when (this) {
|
||||
is ReasoningSummary -> this
|
||||
is String -> when (trim().lowercase()) {
|
||||
"auto" -> ReasoningSummary.AUTO
|
||||
"concise" -> ReasoningSummary.CONCISE
|
||||
"detailed" -> ReasoningSummary.DETAILED
|
||||
else -> null
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
|
||||
private fun Any?.toJsonElement(): JsonElement = when (this) {
|
||||
null -> JsonNull
|
||||
is JsonElement -> this
|
||||
|
|
|
|||
|
|
@ -180,11 +180,8 @@ public class ProxyAILLMClient(
|
|||
}
|
||||
}
|
||||
|
||||
override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse {
|
||||
val payload = normalizeSsePayload(data)
|
||||
?: return ProxyAIChatCompletionStreamResponse()
|
||||
return json.decodeFromString(payload)
|
||||
}
|
||||
override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse =
|
||||
json.decodeFromString(data)
|
||||
|
||||
override fun decodeResponse(data: String): ProxyAIChatCompletionResponse =
|
||||
json.decodeFromString(data)
|
||||
|
|
|
|||
|
|
@ -5,23 +5,13 @@ import ai.koog.http.client.KoogHttpClientException
|
|||
import ai.koog.prompt.dsl.ModerationResult
|
||||
import ai.koog.prompt.dsl.Prompt
|
||||
import ai.koog.prompt.executor.clients.LLMClient
|
||||
import ai.koog.prompt.executor.clients.anthropic.AnthropicParams
|
||||
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicThinking
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIResponsesParams
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.ReasoningEffort
|
||||
import ai.koog.prompt.executor.clients.openai.models.ReasoningConfig
|
||||
import ai.koog.prompt.executor.clients.openai.models.ReasoningSummary
|
||||
import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
|
||||
import ai.koog.prompt.executor.model.PromptExecutor
|
||||
import ai.koog.prompt.llm.LLMProvider
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.params.LLMParams
|
||||
import ai.koog.prompt.streaming.StreamFrame
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import java.io.IOException
|
||||
import java.net.SocketTimeoutException
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.TimeoutCancellationException
|
||||
import kotlinx.coroutines.delay
|
||||
|
|
@ -29,6 +19,8 @@ import kotlinx.coroutines.flow.Flow
|
|||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import java.io.IOException
|
||||
import java.net.SocketTimeoutException
|
||||
import kotlin.math.pow
|
||||
import kotlin.random.Random
|
||||
import kotlin.time.Duration
|
||||
|
|
@ -44,8 +36,6 @@ class RetryingPromptExecutor(
|
|||
) : PromptExecutor {
|
||||
|
||||
companion object {
|
||||
private const val ANTHROPIC_MIN_THINKING_BUDGET = 1_024
|
||||
private const val ANTHROPIC_DEFAULT_THINKING_BUDGET = 2_048
|
||||
private val logger = KotlinLogging.logger { }
|
||||
private val RETRYABLE_HTTP_STATUS_CODES = setOf(408, 409, 425, 429, 500, 502, 503, 504)
|
||||
|
||||
|
|
@ -88,7 +78,7 @@ class RetryingPromptExecutor(
|
|||
private fun Throwable.hasTimeoutMessage(): Boolean {
|
||||
val message = message ?: return false
|
||||
return message.contains("timed out", ignoreCase = true)
|
||||
|| message.contains("timeout", ignoreCase = true)
|
||||
|| message.contains("timeout", ignoreCase = true)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,13 +90,12 @@ class RetryingPromptExecutor(
|
|||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): Flow<StreamFrame> {
|
||||
val refinedPrompt = prompt.withReasoningParams(model)
|
||||
var hasReceivedData: Boolean
|
||||
|
||||
fun createStream(attemptNum: Int): Flow<StreamFrame> {
|
||||
hasReceivedData = false
|
||||
|
||||
return delegate.executeStreaming(refinedPrompt, model, tools)
|
||||
return delegate.executeStreaming(prompt, model, tools)
|
||||
.onEach { hasReceivedData = true }
|
||||
.catch { error ->
|
||||
val retryable = isRetryableFailure(error)
|
||||
|
|
@ -187,14 +176,13 @@ class RetryingPromptExecutor(
|
|||
model: LLModel,
|
||||
tools: List<ToolDescriptor>
|
||||
): List<Message.Response> {
|
||||
val promptWithReasoning = prompt.withReasoningParams(model)
|
||||
var attempt = 1
|
||||
var delay = retryPolicy.initialDelay
|
||||
var lastError: Throwable? = null
|
||||
|
||||
while (attempt <= retryPolicy.maxAttempts) {
|
||||
try {
|
||||
return delegate.execute(promptWithReasoning, model, tools)
|
||||
return delegate.execute(prompt, model, tools)
|
||||
} catch (t: Throwable) {
|
||||
lastError = t
|
||||
val retryable = isRetryableFailure(t)
|
||||
|
|
@ -218,66 +206,4 @@ class RetryingPromptExecutor(
|
|||
override fun close() {
|
||||
(delegate as? AutoCloseable)?.close()
|
||||
}
|
||||
|
||||
private fun Prompt.withReasoningParams(model: LLModel): Prompt {
|
||||
val params = when (model.provider) {
|
||||
LLMProvider.OpenAI -> params.withOpenAIReasoning()
|
||||
LLMProvider.Anthropic -> params.withAnthropicReasoning()
|
||||
else -> params
|
||||
}
|
||||
return withParams(params)
|
||||
}
|
||||
|
||||
private fun LLMParams.withOpenAIReasoning(): LLMParams {
|
||||
val base = when (this) {
|
||||
is OpenAIResponsesParams -> this
|
||||
else -> OpenAIResponsesParams(
|
||||
temperature = temperature,
|
||||
maxTokens = maxTokens,
|
||||
numberOfChoices = numberOfChoices,
|
||||
speculation = speculation,
|
||||
schema = schema,
|
||||
toolChoice = toolChoice,
|
||||
user = user,
|
||||
additionalProperties = additionalProperties
|
||||
)
|
||||
}
|
||||
|
||||
val reasoning = base.reasoning ?: ReasoningConfig(
|
||||
effort = ReasoningEffort.MEDIUM,
|
||||
summary = ReasoningSummary.AUTO
|
||||
)
|
||||
return base.copy(reasoning = reasoning)
|
||||
}
|
||||
|
||||
private fun LLMParams.withAnthropicReasoning(): LLMParams {
|
||||
val base = when (this) {
|
||||
is AnthropicParams -> this
|
||||
else -> AnthropicParams(
|
||||
temperature = temperature,
|
||||
maxTokens = maxTokens,
|
||||
numberOfChoices = numberOfChoices,
|
||||
speculation = speculation,
|
||||
schema = schema,
|
||||
toolChoice = toolChoice,
|
||||
user = user,
|
||||
additionalProperties = additionalProperties
|
||||
)
|
||||
}
|
||||
|
||||
if (base.thinking != null) return base
|
||||
|
||||
val thinkingBudget = resolveAnthropicThinkingBudget(base.maxTokens) ?: return base
|
||||
return base.copy(thinking = AnthropicThinking.Enabled(budgetTokens = thinkingBudget))
|
||||
}
|
||||
|
||||
private fun resolveAnthropicThinkingBudget(maxTokens: Int?): Int? {
|
||||
val limit = maxTokens ?: ANTHROPIC_DEFAULT_THINKING_BUDGET
|
||||
if (limit <= ANTHROPIC_MIN_THINKING_BUDGET) {
|
||||
return null
|
||||
}
|
||||
return (limit / 2)
|
||||
.coerceAtLeast(ANTHROPIC_MIN_THINKING_BUDGET)
|
||||
.coerceAtMost(ANTHROPIC_DEFAULT_THINKING_BUDGET)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
internal fun normalizeSsePayload(rawData: String): String? {
|
||||
val normalized = rawData
|
||||
.replace("\r\n", "\n")
|
||||
.replace('\r', '\n')
|
||||
.trim()
|
||||
if (normalized.isEmpty()) {
|
||||
return null
|
||||
}
|
||||
|
||||
val dataLines = normalized.lineSequence()
|
||||
.mapNotNull { line ->
|
||||
val markerIndex = line.indexOf("data:")
|
||||
if (markerIndex >= 0) {
|
||||
line.substring(markerIndex + "data:".length).trimStart()
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
.filter { it.isNotEmpty() }
|
||||
.toList()
|
||||
|
||||
val payload = if (dataLines.isNotEmpty()) {
|
||||
dataLines.joinToString("\n")
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
|
||||
return payload.takeUnless { it.isBlank() || it == "[DONE]" }
|
||||
}
|
||||
|
|
@ -200,7 +200,7 @@ private suspend fun AIAgentLLMWriteSession.requestResponses(
|
|||
}
|
||||
val appendableResponses = appendableResponses(responses, model.provider)
|
||||
appendableResponses.forEach(appendResponse)
|
||||
return if (stream) responses else appendableResponses
|
||||
return responses
|
||||
}
|
||||
|
||||
private suspend fun AIAgentLLMWriteSession.requestAndPublish(
|
||||
|
|
|
|||
|
|
@ -8,19 +8,21 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents
|
|||
import ai.koog.agents.features.tokenizer.feature.MessageTokenizer
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.tokenizer.Tokenizer
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import ee.carlrobert.codegpt.agent.MessageWithContext
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStream
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
|
||||
import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION
|
||||
import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig
|
||||
import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider
|
||||
import ee.carlrobert.codegpt.mcp.McpTool
|
||||
import ee.carlrobert.codegpt.mcp.McpToolAliasResolver
|
||||
import ee.carlrobert.codegpt.mcp.McpToolCallHandler
|
||||
import ee.carlrobert.codegpt.settings.models.ModelSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import ee.carlrobert.codegpt.util.ReasoningFrameTextAdapter
|
||||
import kotlinx.coroutines.*
|
||||
import java.util.*
|
||||
|
|
@ -49,7 +51,7 @@ internal object AgentCompletionRunner : CompletionRunner {
|
|||
val jobRef = AtomicReference<Job?>()
|
||||
val messageBuilder = StringBuilder()
|
||||
val project = request.callParameters.project!!
|
||||
val stream = shouldStreamAgentToolLoop(request)
|
||||
val stream = shouldStreamAgentToolLoop(project, request)
|
||||
val toolCallHandler = project.let { McpToolCallHandler.getInstance(it) }
|
||||
val toolRegistry = createChatToolRegistry(
|
||||
callParameters = request.callParameters,
|
||||
|
|
@ -174,13 +176,23 @@ internal object AgentCompletionRunner : CompletionRunner {
|
|||
}
|
||||
|
||||
internal fun shouldStreamAgentToolLoop(
|
||||
project: Project,
|
||||
request: CompletionRunnerRequest.Chat,
|
||||
): Boolean {
|
||||
val provider = request.serviceType
|
||||
return when (provider) {
|
||||
ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(
|
||||
request.callParameters.featureType
|
||||
)
|
||||
ServiceType.CUSTOM_OPENAI -> {
|
||||
val selectedServiceId = service<ModelSettings>()
|
||||
.getModelSelectionForFeature(request.callParameters.featureType)
|
||||
.serviceId
|
||||
service<CustomServicesSettings>()
|
||||
.state.services
|
||||
.firstOrNull { it.id == selectedServiceId }
|
||||
?.chatCompletionSettings
|
||||
?.shouldStream()
|
||||
?: false
|
||||
}
|
||||
|
||||
ServiceType.GOOGLE -> false
|
||||
else -> true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,19 +3,24 @@ package ee.carlrobert.codegpt.completions
|
|||
import ai.koog.agents.core.tools.ToolDescriptor
|
||||
import ai.koog.prompt.dsl.Prompt
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.executor.model.PromptExecutor
|
||||
import ai.koog.prompt.llm.LLMCapability
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ee.carlrobert.codegpt.agent.AgentFactory
|
||||
import ee.carlrobert.codegpt.agent.clients.CustomOpenAILLMClient
|
||||
import ee.carlrobert.codegpt.agent.clients.HttpClientProvider
|
||||
import ee.carlrobert.codegpt.agent.clients.RetryingPromptExecutor
|
||||
import ee.carlrobert.codegpt.completions.factory.ResponsesApiUtil
|
||||
import ee.carlrobert.codegpt.settings.models.ModelSelection
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import javax.swing.JPanel
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
object CompletionRequestService {
|
||||
|
||||
|
|
@ -84,34 +89,53 @@ object CompletionRequestService {
|
|||
modelId: String?,
|
||||
eventListener: CompletionStreamEventListener
|
||||
): CancellableRequest {
|
||||
val client = CustomOpenAILLMClient.fromSettingsState(
|
||||
apiKey.orEmpty(),
|
||||
settings,
|
||||
HttpClientProvider.createHttpClient()
|
||||
)
|
||||
val retryPolicy = RetryingPromptExecutor.RetryPolicy(
|
||||
maxAttempts = 2,
|
||||
initialDelay = 1.seconds,
|
||||
maxDelay = 4.seconds,
|
||||
backoffMultiplier = 2.0,
|
||||
jitterFactor = 0.1
|
||||
)
|
||||
val request = CompletionRunnerRequest.Streaming(
|
||||
executor = RetryingPromptExecutor.fromClient(client, retryPolicy, null),
|
||||
model = LLModel(
|
||||
id = modelId?.takeIf { it.isNotBlank() } ?: "gpt-4.1-mini",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
val model = LLModel(
|
||||
id = modelId?.takeIf { it.isNotBlank() } ?: "gpt-4.1-mini",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = listOf(
|
||||
if (ResponsesApiUtil.isResponsesApiUrl(settings.url)) {
|
||||
LLMCapability.OpenAIEndpoint.Responses
|
||||
} else {
|
||||
LLMCapability.OpenAIEndpoint.Completions
|
||||
}
|
||||
),
|
||||
prompt = prompt("custom-service-test-connection") {
|
||||
user("Test connection")
|
||||
},
|
||||
eventListener = eventListener,
|
||||
mode = StreamingMode.SINGLE_RESPONSE,
|
||||
cancellationResultBuilder = { StringBuilder() }
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
)
|
||||
return CompletionRunnerFactory.create(request).run(request)
|
||||
val testPrompt = prompt("custom-service-test-connection") {
|
||||
user("Test connection")
|
||||
}
|
||||
|
||||
val job = scope.launch {
|
||||
val client = CustomOpenAILLMClient.fromSettingsState(
|
||||
apiKey.orEmpty(),
|
||||
settings,
|
||||
HttpClientProvider.createHttpClient()
|
||||
)
|
||||
val messageBuilder = StringBuilder()
|
||||
eventListener.onOpen()
|
||||
try {
|
||||
val responses = client.execute(testPrompt, model, emptyList())
|
||||
val text = CompletionTextExtractor.extract(responses)
|
||||
if (text.isNotBlank()) {
|
||||
messageBuilder.append(text)
|
||||
eventListener.onMessage(text)
|
||||
}
|
||||
eventListener.onComplete(StringBuilder(messageBuilder))
|
||||
} catch (_: CancellationException) {
|
||||
eventListener.onCancelled(StringBuilder(messageBuilder))
|
||||
} catch (exception: Throwable) {
|
||||
eventListener.onError(
|
||||
CompletionError(exception.message ?: "Failed to complete request"),
|
||||
exception
|
||||
)
|
||||
} finally {
|
||||
runCatching { client.close() }
|
||||
scope.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
return CancellableRequest { job.cancel() }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
|
|||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
|
||||
import java.net.URI
|
||||
import javax.swing.Icon
|
||||
|
||||
interface ModelProvider {
|
||||
|
|
@ -703,6 +704,16 @@ private class CustomOpenAIModelProvider : ModelProvider {
|
|||
FeatureType.LOOKUP,
|
||||
)
|
||||
|
||||
private fun isResponsesApiUrl(url: String?): Boolean {
|
||||
if (url.isNullOrBlank()) return false
|
||||
return try {
|
||||
val path = URI(url.trim()).path?.trimEnd('/') ?: return false
|
||||
path.endsWith("/responses")
|
||||
} catch (_: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildModels(
|
||||
modelExtractor: (CustomServiceSettingsState) -> String?
|
||||
): List<ModelSelection> {
|
||||
|
|
@ -713,13 +724,30 @@ private class CustomOpenAIModelProvider : ModelProvider {
|
|||
val modelName =
|
||||
modelExtractor(svc)?.takeIf { it.isNotBlank() } ?: return@mapNotNull null
|
||||
val displayName = formatCustomModelDisplayName(serviceName, modelName)
|
||||
val additionalOpenAICapability =
|
||||
if (isResponsesApiUrl(svc.chatCompletionSettings.url)) {
|
||||
LLMCapability.OpenAIEndpoint.Responses
|
||||
} else {
|
||||
LLMCapability.OpenAIEndpoint.Completions
|
||||
}
|
||||
|
||||
ModelSelection(
|
||||
provider = serviceType,
|
||||
llmModel = virtualModel(
|
||||
CustomOpenAILLMClient.CustomOpenAI,
|
||||
modelName,
|
||||
OPENAI_CAPABILITIES
|
||||
listOf(
|
||||
LLMCapability.Temperature,
|
||||
LLMCapability.Schema.JSON.Basic,
|
||||
LLMCapability.Schema.JSON.Standard,
|
||||
LLMCapability.Speculation,
|
||||
LLMCapability.Tools,
|
||||
LLMCapability.ToolChoice,
|
||||
LLMCapability.Vision.Image,
|
||||
LLMCapability.Document,
|
||||
LLMCapability.Completion,
|
||||
LLMCapability.MultipleChoices,
|
||||
) + additionalOpenAICapability
|
||||
),
|
||||
displayName = displayName,
|
||||
serviceId = serviceId,
|
||||
|
|
|
|||
|
|
@ -164,18 +164,14 @@ class CustomServiceForm(
|
|||
headers = template.chatCompletionTemplate.headers
|
||||
body = template.chatCompletionTemplate.body
|
||||
}
|
||||
if (template.codeCompletionTemplate != null) {
|
||||
template.codeCompletionTemplate?.let {
|
||||
codeCompletionsForm.run {
|
||||
url = template.codeCompletionTemplate.url
|
||||
headers = template.codeCompletionTemplate.headers
|
||||
body = template.codeCompletionTemplate.body
|
||||
url = it.url
|
||||
headers = it.headers
|
||||
body = it.body
|
||||
parseResponseAsChatCompletions =
|
||||
template.codeCompletionTemplate.parseResponseAsChatCompletions
|
||||
it.parseResponseAsChatCompletions
|
||||
}
|
||||
tabbedPane.setEnabledAt(1, true)
|
||||
} else {
|
||||
tabbedPane.selectedIndex = 0
|
||||
tabbedPane.setEnabledAt(1, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,127 +5,109 @@ enum class CustomServiceChatCompletionTemplate(
|
|||
val headers: MutableMap<String, String>,
|
||||
val body: MutableMap<String, Any>
|
||||
) {
|
||||
ANYSCALE(
|
||||
"https://api.endpoints.anyscale.com/v1/chat/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
)
|
||||
),
|
||||
AZURE(
|
||||
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version=2023-05-15",
|
||||
getDefaultHeaders("api-key", "\$CUSTOM_SERVICE_API_KEY"),
|
||||
getDefaultBodyParams(emptyMap())
|
||||
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version=2024-10-21",
|
||||
getDefaultHeaders("api-key", $$"$CUSTOM_SERVICE_API_KEY"),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
)
|
||||
),
|
||||
DEEP_INFRA(
|
||||
"https://api.deepinfra.com/v1/openai/chat/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "meta-llama/Llama-2-70b-chat-hf",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "deepseek-ai/DeepSeek-V3.2",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
FIREWORKS(
|
||||
"https://api.fireworks.ai/inference/v1/chat/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "accounts/fireworks/models/deepseek-r1-basic",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "accounts/fireworks/models/deepseek-v3p1",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
GROQ(
|
||||
"https://api.groq.com/openai/v1/chat/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "codellama-34b",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "openai/gpt-oss-20b",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
OPENAI(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "gpt-5",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
)
|
||||
),
|
||||
PERPLEXITY(
|
||||
"https://api.perplexity.ai/chat/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "codellama",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
)
|
||||
),
|
||||
TOGETHER(
|
||||
"https://api.together.xyz/v1/chat/completions",
|
||||
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
)
|
||||
),
|
||||
OLLAMA(
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
getDefaultHeaders(),
|
||||
getDefaultBodyParams(mapOf("model" to "codellama"))
|
||||
),
|
||||
LLAMA_CPP(
|
||||
"http://localhost:8080/v1/chat/completions",
|
||||
getDefaultHeaders(),
|
||||
getDefaultBodyParams(emptyMap())
|
||||
),
|
||||
MISTRAL_AI(
|
||||
"https://api.mistral.ai/v1/chat/completions",
|
||||
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "open-mistral-7b",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
)
|
||||
),
|
||||
OPEN_ROUTER(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
getDefaultHeaders(
|
||||
mapOf(
|
||||
"Authorization" to "Bearer \$CUSTOM_SERVICE_API_KEY",
|
||||
"HTTP-Referer" to "https://tryproxy.io",
|
||||
"X-Title" to "ProxyAI"
|
||||
)
|
||||
),
|
||||
getDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "meta-llama/llama-3.1-8b-instruct:free",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
getDefaultHeaders("Authorization", $$"Bearer $CUSTOM_SERVICE_API_KEY"),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "gpt-4.1",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
),
|
||||
OPENAI_RESPONSES(
|
||||
"https://api.openai.com/v1/responses",
|
||||
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
|
||||
getResponsesApiDefaultBodyParams(
|
||||
mapOf(
|
||||
"model" to "gpt-4.1",
|
||||
"max_output_tokens" to 8192
|
||||
)
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "gpt-5.3-codex",
|
||||
"max_output_tokens" to 32_000
|
||||
)
|
||||
);
|
||||
),
|
||||
TOGETHER(
|
||||
"https://api.together.xyz/v1/chat/completions",
|
||||
getDefaultHeaders("Authorization", $$"Bearer $CUSTOM_SERVICE_API_KEY"),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "zai-org/GLM-5",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
OLLAMA(
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
getDefaultHeaders(),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "gpt-oss:20b",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
LLAMA_CPP(
|
||||
"http://localhost:8080/v1/chat/completions",
|
||||
getDefaultHeaders(),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "gpt-oss:20b",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
MISTRAL_AI(
|
||||
"https://api.mistral.ai/v1/chat/completions",
|
||||
getDefaultHeaders("Authorization", $$"Bearer $CUSTOM_SERVICE_API_KEY"),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "mistral-large-2512",
|
||||
"max_tokens" to 32_000
|
||||
)
|
||||
),
|
||||
OPENROUTER(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
getDefaultHeaders(
|
||||
mapOf(
|
||||
"Authorization" to $$"Bearer $CUSTOM_SERVICE_API_KEY",
|
||||
"HTTP-Referer" to "https://tryproxy.io",
|
||||
"X-OpenRouter-Title" to "ProxyAI"
|
||||
)
|
||||
),
|
||||
mutableMapOf(
|
||||
"stream" to true,
|
||||
"model" to "moonshotai/kimi-k2.5",
|
||||
"max_tokens" to 8192
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
private fun getDefaultHeadersWithAuthentication(): MutableMap<String, String> {
|
||||
|
|
@ -148,23 +130,3 @@ private fun getDefaultHeaders(additionalHeaders: Map<String, String>): MutableMa
|
|||
defaultHeaders.putAll(additionalHeaders)
|
||||
return defaultHeaders
|
||||
}
|
||||
|
||||
private fun getDefaultBodyParams(additionalParams: Map<String, Any>): MutableMap<String, Any> {
|
||||
val defaultParams = mutableMapOf<String, Any>(
|
||||
"stream" to true,
|
||||
"messages" to "\$OPENAI_MESSAGES",
|
||||
"temperature" to 0.1
|
||||
)
|
||||
defaultParams.putAll(additionalParams)
|
||||
return defaultParams
|
||||
}
|
||||
|
||||
private fun getResponsesApiDefaultBodyParams(additionalParams: Map<String, Any>): MutableMap<String, Any> {
|
||||
val defaultParams = mutableMapOf<String, Any>(
|
||||
"stream" to true,
|
||||
"input" to "\$OPENAI_MESSAGES",
|
||||
"temperature" to 0.1
|
||||
)
|
||||
defaultParams.putAll(additionalParams)
|
||||
return defaultParams
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,24 +7,24 @@ enum class CustomServiceCodeCompletionTemplate(
|
|||
val parseResponseAsChatCompletions: Boolean = false
|
||||
) {
|
||||
ANYSCALE(
|
||||
"https://api.endpoints.anyscale.com/v1/completions",
|
||||
"https://{your-service-endpoint}/v1/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(mapOf("model" to "codellama/CodeLlama-70b-Instruct-hf"))
|
||||
getDefaultBodyParams(mapOf("model" to "{your-model-id}"))
|
||||
),
|
||||
AZURE(
|
||||
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/completions?api-version=2023-05-15",
|
||||
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/completions?api-version=2024-10-21",
|
||||
getDefaultHeaders("api-key", "\$CUSTOM_SERVICE_API_KEY"),
|
||||
getDefaultBodyParams(emptyMap())
|
||||
),
|
||||
DEEP_INFRA(
|
||||
"https://api.deepinfra.com/v1/inference/codellama/CodeLlama-70b-Instruct-hf",
|
||||
"https://api.deepinfra.com/v1/openai/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
mutableMapOf("input" to "\$FIM_PROMPT")
|
||||
getDefaultBodyParams(mapOf("model" to "deepseek-ai/DeepSeek-V3.2"))
|
||||
),
|
||||
FIREWORKS(
|
||||
"https://api.fireworks.ai/inference/v1/completions",
|
||||
getDefaultHeadersWithAuthentication(),
|
||||
getDefaultBodyParams(mapOf("model" to "accounts/fireworks/models/qwen2p5-coder-32b-instruct"))
|
||||
getDefaultBodyParams(mapOf("model" to "accounts/fireworks/models/kimi-k2-instruct-0905"))
|
||||
),
|
||||
OPENAI(
|
||||
"https://api.openai.com/v1/completions",
|
||||
|
|
@ -45,7 +45,7 @@ enum class CustomServiceCodeCompletionTemplate(
|
|||
"stream" to true,
|
||||
"prompt" to "\$PREFIX",
|
||||
"suffix" to "\$SUFFIX",
|
||||
"model" to "codestral-latest",
|
||||
"model" to "codestral-2508",
|
||||
"temperature" to 0.7,
|
||||
"max_tokens" to 1024
|
||||
),
|
||||
|
|
@ -54,7 +54,7 @@ enum class CustomServiceCodeCompletionTemplate(
|
|||
TOGETHER(
|
||||
"https://api.together.xyz/v1/completions",
|
||||
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
|
||||
getDefaultBodyParams(mapOf("model" to "codellama/CodeLlama-70b-hf"))
|
||||
getDefaultBodyParams(mapOf("model" to "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8"))
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -84,4 +84,4 @@ private fun getDefaultBodyParams(additionalParams: Map<String, Any>): MutableMap
|
|||
)
|
||||
defaultParams.putAll(additionalParams)
|
||||
return defaultParams
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,12 +6,6 @@ enum class CustomServiceTemplate(
|
|||
val chatCompletionTemplate: CustomServiceChatCompletionTemplate,
|
||||
val codeCompletionTemplate: CustomServiceCodeCompletionTemplate? = null
|
||||
) {
|
||||
ANYSCALE(
|
||||
"Anyscale",
|
||||
"https://docs.endpoints.anyscale.com/",
|
||||
CustomServiceChatCompletionTemplate.ANYSCALE,
|
||||
CustomServiceCodeCompletionTemplate.ANYSCALE,
|
||||
),
|
||||
AZURE(
|
||||
"Azure OpenAI",
|
||||
"https://learn.microsoft.com/en-us/azure/ai-services/openai/reference",
|
||||
|
|
@ -26,25 +20,26 @@ enum class CustomServiceTemplate(
|
|||
),
|
||||
FIREWORKS(
|
||||
"Fireworks",
|
||||
"https://readme.fireworks.ai/reference/createchatcompletion",
|
||||
"https://docs.fireworks.ai/api-reference/post-chatcompletions",
|
||||
CustomServiceChatCompletionTemplate.FIREWORKS,
|
||||
CustomServiceCodeCompletionTemplate.FIREWORKS
|
||||
),
|
||||
GROQ(
|
||||
"Groq",
|
||||
"https://docs.api.groq.com/md/openai.oas.html",
|
||||
"https://console.groq.com/docs/openai",
|
||||
CustomServiceChatCompletionTemplate.GROQ
|
||||
),
|
||||
OPENAI(
|
||||
"OpenAI",
|
||||
"https://platform.openai.com/docs/api-reference/chat",
|
||||
"OpenAI (Chat Completions API)",
|
||||
"https://platform.openai.com/docs/api-reference/chat/create",
|
||||
CustomServiceChatCompletionTemplate.OPENAI,
|
||||
CustomServiceCodeCompletionTemplate.OPENAI
|
||||
),
|
||||
PERPLEXITY(
|
||||
"Perplexity AI",
|
||||
"https://docs.perplexity.ai/reference/post_chat_completions",
|
||||
CustomServiceChatCompletionTemplate.PERPLEXITY
|
||||
OPENAI_RESPONSES(
|
||||
"OpenAI (Responses API)",
|
||||
"https://platform.openai.com/docs/api-reference/responses/create",
|
||||
CustomServiceChatCompletionTemplate.OPENAI_RESPONSES,
|
||||
CustomServiceCodeCompletionTemplate.OPENAI
|
||||
),
|
||||
TOGETHER(
|
||||
"Together AI",
|
||||
|
|
@ -54,32 +49,27 @@ enum class CustomServiceTemplate(
|
|||
),
|
||||
OLLAMA(
|
||||
"Ollama",
|
||||
"https://github.com/ollama/ollama/blob/main/docs/openai.md",
|
||||
"https://docs.ollama.com/api/openai-compatibility",
|
||||
CustomServiceChatCompletionTemplate.OLLAMA
|
||||
),
|
||||
LLAMA_CPP(
|
||||
"LLaMA C/C++",
|
||||
"https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md",
|
||||
"llama.cpp Server",
|
||||
"https://github.com/ggml-org/llama.cpp/blob/master/tools/server/README.md",
|
||||
CustomServiceChatCompletionTemplate.LLAMA_CPP
|
||||
),
|
||||
MISTRAL_AI(
|
||||
"Mistral AI",
|
||||
"https://docs.mistral.ai/getting-started/quickstart",
|
||||
"https://docs.mistral.ai/capabilities/completion/usage",
|
||||
CustomServiceChatCompletionTemplate.MISTRAL_AI,
|
||||
CustomServiceCodeCompletionTemplate.MISTRAL_AI
|
||||
),
|
||||
OPEN_ROUTER(
|
||||
"OpenRouter",
|
||||
"https://openrouter.ai/docs#quick-start",
|
||||
CustomServiceChatCompletionTemplate.OPEN_ROUTER
|
||||
),
|
||||
OPENAI_RESPONSES(
|
||||
"OpenAI (Responses API)",
|
||||
"https://platform.openai.com/docs/api-reference/responses",
|
||||
CustomServiceChatCompletionTemplate.OPENAI_RESPONSES
|
||||
"https://openrouter.ai/docs/quickstart",
|
||||
CustomServiceChatCompletionTemplate.OPENROUTER
|
||||
);
|
||||
|
||||
override fun toString(): String {
|
||||
return providerName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,11 +17,13 @@ class CompleteMessageParser : MessageParser {
|
|||
Pattern.compile("<<<<<<< SEARCH\\n(.*?)(?:\\n=======\\n(.*?))?$", Pattern.DOTALL)
|
||||
|
||||
private const val THINK_OPEN_TAG = "<think>"
|
||||
private const val THINK_CLOSE_TAG = "</think>\n\n"
|
||||
private const val THINK_CLOSE_TAG = "</think>"
|
||||
private const val CODE_HEADER_GROUP_INDEX = 2
|
||||
private const val CODE_CONTENT_GROUP_INDEX = 3
|
||||
private const val SEARCH_CONTENT_GROUP_INDEX = 1
|
||||
private const val REPLACE_CONTENT_GROUP_INDEX = 2
|
||||
private val THINKING_BLOCK_PATTERN =
|
||||
Regex("""^<think>(.*?)</think>\s*""", setOf(RegexOption.DOT_MATCHES_ALL))
|
||||
|
||||
private val TOLERANT_SEARCH_START =
|
||||
Regex("""^\s*<{3,}(\s*SEARCH.*)?$""", RegexOption.IGNORE_CASE)
|
||||
|
|
@ -54,15 +56,10 @@ class CompleteMessageParser : MessageParser {
|
|||
private fun extractThoughtIfPresent(input: String): String {
|
||||
extractedThought = null
|
||||
|
||||
if (!input.startsWith(THINK_OPEN_TAG)) {
|
||||
return input
|
||||
}
|
||||
|
||||
val closeTagIndex = input.indexOf(THINK_CLOSE_TAG)
|
||||
return if (closeTagIndex != -1) {
|
||||
val thoughtStartIndex = THINK_OPEN_TAG.length
|
||||
extractedThought = input.substring(thoughtStartIndex, closeTagIndex).trim()
|
||||
input.substring(closeTagIndex + THINK_CLOSE_TAG.length)
|
||||
val match = THINKING_BLOCK_PATTERN.find(input)
|
||||
return if (match != null) {
|
||||
extractedThought = match.groupValues[1].trim()
|
||||
input.substring(match.range.last + 1)
|
||||
} else {
|
||||
input
|
||||
}
|
||||
|
|
|
|||
|
|
@ -466,7 +466,7 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
"item",
|
||||
jsonMap(
|
||||
e("type", "function_call"),
|
||||
e("id", "fc_1"),
|
||||
e("id", callId),
|
||||
e("call_id", callId),
|
||||
e("name", toolName),
|
||||
e("arguments", "")
|
||||
|
|
@ -477,12 +477,28 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
),
|
||||
jsonMapResponse(
|
||||
e("type", "response.function_call_arguments.delta"),
|
||||
e("item_id", "fc_1"),
|
||||
e("item_id", callId),
|
||||
e("output_index", 0),
|
||||
e("delta", arguments),
|
||||
e("call_id", callId),
|
||||
e("sequence_number", 2)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("type", "response.output_item.done"),
|
||||
e(
|
||||
"item",
|
||||
jsonMap(
|
||||
e("type", "function_call"),
|
||||
e("id", callId),
|
||||
e("call_id", callId),
|
||||
e("name", toolName),
|
||||
e("arguments", arguments),
|
||||
e("status", "completed")
|
||||
)
|
||||
),
|
||||
e("output_index", 0),
|
||||
e("sequence_number", 3)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("type", "response.completed"),
|
||||
e(
|
||||
|
|
@ -497,7 +513,7 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "function_call"),
|
||||
e("id", "fc_1"),
|
||||
e("id", callId),
|
||||
e("call_id", callId),
|
||||
e("name", toolName),
|
||||
e("arguments", arguments),
|
||||
|
|
@ -510,7 +526,7 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
e("text", jsonMap())
|
||||
)
|
||||
),
|
||||
e("sequence_number", 3)
|
||||
e("sequence_number", 4)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
@ -826,6 +842,8 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
jsonArray(
|
||||
jsonMap(
|
||||
e("finishReason", "stop"),
|
||||
e("finish_reason", "stop"),
|
||||
e("index", 0),
|
||||
e(
|
||||
"message",
|
||||
jsonMap(
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ package ee.carlrobert.codegpt.agent
|
|||
import ai.koog.agents.core.agent.AIAgent
|
||||
import ai.koog.agents.core.agent.AIAgentService
|
||||
import ai.koog.agents.core.agent.config.AIAgentConfig
|
||||
import ai.koog.agents.core.tools.ToolDescriptor
|
||||
import ai.koog.agents.core.tools.ToolRegistry
|
||||
import ai.koog.agents.snapshot.feature.AgentCheckpointData
|
||||
import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider
|
||||
import ai.koog.prompt.dsl.ModerationResult
|
||||
import ai.koog.prompt.dsl.Prompt
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.executor.model.PromptExecutor
|
||||
|
|
@ -327,7 +329,7 @@ private object NoopPromptExecutor : PromptExecutor {
|
|||
override suspend fun execute(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ai.koog.agents.core.tools.ToolDescriptor>
|
||||
tools: List<ToolDescriptor>
|
||||
): List<Message.Response> {
|
||||
throw UnsupportedOperationException("NoopPromptExecutor should not be used in tests")
|
||||
}
|
||||
|
|
@ -335,7 +337,7 @@ private object NoopPromptExecutor : PromptExecutor {
|
|||
override fun executeStreaming(
|
||||
prompt: Prompt,
|
||||
model: LLModel,
|
||||
tools: List<ai.koog.agents.core.tools.ToolDescriptor>
|
||||
tools: List<ToolDescriptor>
|
||||
): Flow<StreamFrame> {
|
||||
throw UnsupportedOperationException("NoopPromptExecutor should not be used in tests")
|
||||
}
|
||||
|
|
@ -343,7 +345,7 @@ private object NoopPromptExecutor : PromptExecutor {
|
|||
override suspend fun moderate(
|
||||
prompt: Prompt,
|
||||
model: LLModel
|
||||
): ai.koog.prompt.dsl.ModerationResult {
|
||||
): ModerationResult {
|
||||
throw UnsupportedOperationException("NoopPromptExecutor should not be used in tests")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,219 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.Content
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ai.koog.prompt.params.LLMParams
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.buildJsonArray
|
||||
import kotlinx.serialization.json.buildJsonObject
|
||||
import kotlinx.serialization.json.boolean
|
||||
import kotlinx.serialization.json.jsonArray
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import kotlinx.serialization.json.put
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class CustomOpenAIResponsesApiSerializationTest {
|
||||
|
||||
private val json = Json {
|
||||
ignoreUnknownKeys = true
|
||||
encodeDefaults = true
|
||||
explicitNulls = false
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api request should encode tools with top level name`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val serializeMethod = client.javaClass.getDeclaredMethod(
|
||||
"serializeProviderChatRequest",
|
||||
List::class.java,
|
||||
LLModel::class.java,
|
||||
List::class.java,
|
||||
OpenAIToolChoice::class.java,
|
||||
LLMParams::class.java,
|
||||
Boolean::class.javaPrimitiveType
|
||||
)
|
||||
serializeMethod.isAccessible = true
|
||||
|
||||
val payload = serializeMethod.invoke(
|
||||
client,
|
||||
listOf(OpenAIMessage.User(content = Content.Text("hello"))),
|
||||
LLModel(
|
||||
id = "gpt-test",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
),
|
||||
listOf(
|
||||
OpenAITool(
|
||||
OpenAIToolFunction(
|
||||
name = "Diagnostics",
|
||||
description = "Read IDE diagnostics",
|
||||
parameters = buildJsonObject {
|
||||
put("type", "object")
|
||||
put("properties", buildJsonObject {})
|
||||
put("required", buildJsonArray {})
|
||||
},
|
||||
strict = true
|
||||
)
|
||||
)
|
||||
),
|
||||
OpenAIToolChoice.Function(OpenAIToolChoice.FunctionName("Diagnostics")),
|
||||
CustomOpenAIParams(),
|
||||
true
|
||||
) as String
|
||||
|
||||
val request = json.parseToJsonElement(payload).jsonObject
|
||||
val tool = request.getValue("tools").jsonArray.single().jsonObject
|
||||
val toolChoice = request.getValue("tool_choice").jsonObject
|
||||
|
||||
assertThat(tool.getValue("type").jsonPrimitive.content).isEqualTo("function")
|
||||
assertThat(tool.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics")
|
||||
assertThat(tool.getValue("description").jsonPrimitive.content).isEqualTo("Read IDE diagnostics")
|
||||
assertThat(tool.getValue("strict").jsonPrimitive.boolean).isTrue()
|
||||
assertThat(tool).doesNotContainKey("function")
|
||||
assertThat(toolChoice.getValue("type").jsonPrimitive.content).isEqualTo("function")
|
||||
assertThat(toolChoice.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api request should encode mode tool choice as lowercase string`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val serializeMethod = client.javaClass.getDeclaredMethod(
|
||||
"serializeProviderChatRequest",
|
||||
List::class.java,
|
||||
LLModel::class.java,
|
||||
List::class.java,
|
||||
OpenAIToolChoice::class.java,
|
||||
LLMParams::class.java,
|
||||
Boolean::class.javaPrimitiveType
|
||||
)
|
||||
serializeMethod.isAccessible = true
|
||||
|
||||
val payload = serializeMethod.invoke(
|
||||
client,
|
||||
listOf(OpenAIMessage.User(content = Content.Text("hello"))),
|
||||
LLModel(
|
||||
id = "gpt-test",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
),
|
||||
emptyList<OpenAITool>(),
|
||||
openAIToolChoiceMode("required"),
|
||||
CustomOpenAIParams(),
|
||||
true
|
||||
) as String
|
||||
|
||||
val request = json.parseToJsonElement(payload).jsonObject
|
||||
|
||||
assertThat(request.getValue("tool_choice").jsonPrimitive.content).isEqualTo("required")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api request should encode agent tool history as input items`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val serializeMethod = client.javaClass.getDeclaredMethod(
|
||||
"serializeProviderChatRequest",
|
||||
List::class.java,
|
||||
LLModel::class.java,
|
||||
List::class.java,
|
||||
OpenAIToolChoice::class.java,
|
||||
LLMParams::class.java,
|
||||
Boolean::class.javaPrimitiveType
|
||||
)
|
||||
serializeMethod.isAccessible = true
|
||||
|
||||
val payload = serializeMethod.invoke(
|
||||
client,
|
||||
listOf(
|
||||
OpenAIMessage.System(content = Content.Text("System instructions")),
|
||||
OpenAIMessage.User(content = Content.Text("Read the fixture")),
|
||||
OpenAIMessage.Assistant(
|
||||
content = Content.Text("Calling Read"),
|
||||
toolCalls = listOf(
|
||||
OpenAIToolCall(
|
||||
"call_read",
|
||||
OpenAIFunction("Read", "\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"")
|
||||
)
|
||||
)
|
||||
),
|
||||
OpenAIMessage.Tool(
|
||||
content = Content.Text("1\tfixture contents"),
|
||||
toolCallId = "call_read"
|
||||
)
|
||||
),
|
||||
LLModel(
|
||||
id = "gpt-test",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
),
|
||||
emptyList<OpenAITool>(),
|
||||
null,
|
||||
CustomOpenAIParams(),
|
||||
true
|
||||
) as String
|
||||
|
||||
val request = json.parseToJsonElement(payload).jsonObject
|
||||
val input = request.getValue("input").jsonArray
|
||||
|
||||
assertThat(input.map { it.jsonObject.getValue("type").jsonPrimitive.content }).containsExactly(
|
||||
"message",
|
||||
"message",
|
||||
"message",
|
||||
"function_call",
|
||||
"function_call_output"
|
||||
)
|
||||
assertThat(input[0].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("developer")
|
||||
assertThat(input[1].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("user")
|
||||
assertThat(input[2].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("assistant")
|
||||
assertThat(
|
||||
input[2].jsonObject.getValue("content").jsonArray.single().jsonObject.getValue("text").jsonPrimitive.content
|
||||
).isEqualTo("Calling Read")
|
||||
assertThat(input[3].jsonObject.getValue("name").jsonPrimitive.content).isEqualTo("Read")
|
||||
assertThat(input[3].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read")
|
||||
assertThat(input[3].jsonObject.getValue("arguments").jsonPrimitive.content)
|
||||
.isEqualTo("""{"file_path":"/tmp/fixture.txt"}""")
|
||||
assertThat(input[4].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read")
|
||||
assertThat(input[4].jsonObject.getValue("output").jsonPrimitive.content)
|
||||
.isEqualTo("1\tfixture contents")
|
||||
}
|
||||
|
||||
private fun openAIToolChoiceMode(value: String): OpenAIToolChoice {
|
||||
val modeClass = Class.forName(
|
||||
"ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice\$Mode"
|
||||
)
|
||||
val boxMethod = modeClass.getDeclaredMethod("box-impl", String::class.java)
|
||||
return boxMethod.invoke(null, value) as OpenAIToolChoice
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.llm.LLMCapability
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
import testsupport.http.ResponseEntity
|
||||
import testsupport.http.exchange.BasicHttpExchange
|
||||
import testsupport.json.JSONUtil.e
|
||||
import testsupport.json.JSONUtil.jsonArray
|
||||
import testsupport.json.JSONUtil.jsonMap
|
||||
import testsupport.json.JSONUtil.jsonMapResponse
|
||||
|
||||
class CustomOpenAIResponsesSerializationIntegrationTest : IntegrationTest() {
|
||||
|
||||
fun testResponsesRequestsFlattenAdditionalBodyParameters() {
|
||||
runBlocking {
|
||||
val settings = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = System.getProperty("customOpenAI.baseUrl") + "/v1/responses"
|
||||
headers.clear()
|
||||
body.clear()
|
||||
body["model"] = "custom-responses-model"
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
body["custom_config"] = mapOf("tier" to "gold")
|
||||
body["extra_flag"] = "enabled"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("TEST_API_KEY", settings)
|
||||
val model = LLModel(
|
||||
id = "custom-responses-model",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = listOf(LLMCapability.OpenAIEndpoint.Responses),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
)
|
||||
val prompt = prompt("custom-openai-responses-serialization") {
|
||||
user("Test flattened params")
|
||||
}
|
||||
|
||||
expectCustomOpenAI(BasicHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/responses")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.body)
|
||||
.containsEntry("extra_flag", "enabled")
|
||||
.containsKey("custom_config")
|
||||
.doesNotContainKey("additional_properties")
|
||||
assertThat(request.body["custom_config"])
|
||||
.isEqualTo(mapOf("tier" to "gold"))
|
||||
ResponseEntity(
|
||||
openAiResponsesResponse(
|
||||
model = "custom-responses-model",
|
||||
text = "Hello with flattened params"
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
val response = client.execute(prompt, model, emptyList())
|
||||
|
||||
assertThat(response)
|
||||
.singleElement()
|
||||
.extracting("content")
|
||||
.isEqualTo("Hello with flattened params")
|
||||
}
|
||||
}
|
||||
|
||||
private fun openAiResponsesResponse(
|
||||
model: String,
|
||||
text: String
|
||||
): String {
|
||||
return jsonMapResponse(
|
||||
e("id", "resp-openai-test"),
|
||||
e("object", "response"),
|
||||
e("created_at", 1),
|
||||
e("model", model),
|
||||
e(
|
||||
"output",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "message"),
|
||||
e("id", "msg_1"),
|
||||
e("role", "assistant"),
|
||||
e("status", "completed"),
|
||||
e(
|
||||
"content",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "output_text"),
|
||||
e("text", text),
|
||||
e("annotations", jsonArray())
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
e("parallel_tool_calls", true),
|
||||
e("status", "completed"),
|
||||
e("text", jsonMap()),
|
||||
e(
|
||||
"usage",
|
||||
jsonMap(
|
||||
e("input_tokens", 1),
|
||||
e("input_tokens_details", jsonMap("cached_tokens", 0)),
|
||||
e("output_tokens", 1),
|
||||
e("output_tokens_details", jsonMap("reasoning_tokens", 0)),
|
||||
e("total_tokens", 2)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.streaming.StreamFrame
|
||||
import ai.koog.prompt.streaming.toMessageResponses
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.asFlow
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class StreamingPayloadNormalizerTest {
|
||||
|
||||
@Test
|
||||
fun `should unwrap sse data payload`() {
|
||||
assertThat(normalizeSsePayload("data: {\"id\":\"chunk-1\"}"))
|
||||
.isEqualTo("{\"id\":\"chunk-1\"}")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should ignore done sentinel`() {
|
||||
assertThat(normalizeSsePayload("data: [DONE]")).isNull()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should unwrap multiline sse event frames`() {
|
||||
assertThat(
|
||||
normalizeSsePayload("event: response.created\ndata: {\"type\":\"response.created\"}\n")
|
||||
).isEqualTo("{\"type\":\"response.created\"}")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should unwrap compact sse frames where event and data share a line`() {
|
||||
assertThat(
|
||||
normalizeSsePayload("event: response.created data: {\"type\":\"response.created\"}")
|
||||
).isEqualTo("{\"type\":\"response.created\"}")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `custom openai client should decode sse wrapped chat completion chunks`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/chat/completions"
|
||||
body["stream"] = true
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java)
|
||||
decodeMethod.isAccessible = true
|
||||
|
||||
val response = decodeMethod.invoke(
|
||||
client,
|
||||
"data: {\"choices\":[],\"created\":0,\"id\":\"chunk-1\",\"model\":\"test-model\"}"
|
||||
) as CustomOpenAIChatCompletionStreamResponse
|
||||
|
||||
assertThat(response.id).isEqualTo("chunk-1")
|
||||
assertThat(response.model).isEqualTo("test-model")
|
||||
assertThat(response.choices).isEmpty()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api streaming tool call should collapse into one named tool call`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java)
|
||||
decodeMethod.isAccessible = true
|
||||
val processMethod = client.javaClass.getDeclaredMethod("processStreamingResponse", Flow::class.java)
|
||||
processMethod.isAccessible = true
|
||||
|
||||
val chunks = listOf(
|
||||
"""{"type":"response.output_item.added","item":{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":""},"output_index":0,"sequence_number":1}""",
|
||||
"""{"type":"response.function_call_arguments.delta","item_id":"fc_1","output_index":0,"delta":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","call_id":"call_diag","sequence_number":2}""",
|
||||
"""{"type":"response.completed","response":{"id":"resp-tool","object":"response","created_at":1,"model":"test-model","output":[{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","status":"completed"}],"parallel_tool_calls":true,"status":"completed","text":{}},"sequence_number":3}"""
|
||||
).map { payload ->
|
||||
decodeMethod.invoke(client, payload) as CustomOpenAIChatCompletionStreamResponse
|
||||
}
|
||||
|
||||
val responses = runBlocking {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val frames = processMethod.invoke(client, chunks.asFlow()) as Flow<StreamFrame>
|
||||
frames.toList().toMessageResponses()
|
||||
}
|
||||
val toolCall = responses.filterIsInstance<Message.Tool.Call>().single()
|
||||
|
||||
assertThat(toolCall.id).isEqualTo("call_diag")
|
||||
assertThat(toolCall.tool).isEqualTo("Diagnostics")
|
||||
assertThat(toolCall.content).isEqualTo("""{"file_path":"/tmp/mainwindow.cpp","filter":"all"}""")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue