From 1c37947bc210f55437def9cf83ea34811458759e Mon Sep 17 00:00:00 2001 From: Roman Gromov <35900229+sapphirepro@users.noreply.github.com> Date: Fri, 13 Mar 2026 20:35:31 +0100 Subject: [PATCH 1/3] Added sorting to Custom OpenAI models, so that it's better organized now --- .../service/custom/form/CustomServiceForm.kt | 32 ++++++++++++++- .../settings/models/ModelSettingsTest.kt | 40 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceForm.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceForm.kt index 280ce738..f809f3aa 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceForm.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceForm.kt @@ -327,6 +327,8 @@ class CustomServiceForm( .setPreferredSize(Dimension(220, 0)) .setAddAction { handleAddAction() } .setRemoveAction { handleRemoveAction() } + .setMoveUpAction { handleMoveAction(-1) } + .setMoveDownAction { handleMoveAction(1) } .setRemoveActionUpdater { formState.value.services.size > 1 } @@ -347,7 +349,6 @@ class CustomServiceForm( handleDuplicateAction() } }) - .disableUpDownActions() private fun handleRemoveAction() { val prevSelectedIndex = customProvidersJBList.selectedIndex @@ -373,6 +374,35 @@ class CustomServiceForm( lastSelectedIndex = -1 } + private fun handleMoveAction(offset: Int) { + val selectedIndex = customProvidersJBList.selectedIndex + if (selectedIndex == -1) { + return + } + + val targetIndex = selectedIndex + offset + if (targetIndex !in formState.value.services.indices) { + return + } + + if (lastSelectedIndex != -1 && lastSelectedIndex < formState.value.services.size) { + updateStateFromForm(lastSelectedIndex) + } + + val movedId = formState.value.services[selectedIndex].id + selectedServiceId = movedId + pendingSelectedId = movedId + + formState.update { state -> + val reordered = state.services.toMutableList() + val moved = reordered.removeAt(selectedIndex) + reordered.add(targetIndex, moved) + state.copy(services = reordered.toImmutableList()) + } + + lastSelectedIndex = -1 + } + private fun handleDuplicateAction() { if (lastSelectedIndex >= 0 && lastSelectedIndex < formState.value.services.size) { updateStateFromForm(lastSelectedIndex) diff --git a/src/test/kotlin/ee/carlrobert/codegpt/settings/models/ModelSettingsTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/settings/models/ModelSettingsTest.kt index 246822d3..10e365c7 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/settings/models/ModelSettingsTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/settings/models/ModelSettingsTest.kt @@ -6,6 +6,8 @@ import com.intellij.util.messages.MessageBusConnection import ee.carlrobert.codegpt.settings.service.FeatureType import ee.carlrobert.codegpt.settings.service.ModelChangeNotifier import ee.carlrobert.codegpt.settings.service.ServiceType +import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState +import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings import org.assertj.core.api.Assertions.assertThat import testsupport.IntegrationTest import java.util.concurrent.atomic.AtomicReference @@ -202,6 +204,30 @@ class ModelSettingsTest : IntegrationTest() { assertThat(result).isEqualTo(ServiceType.OPENAI) } + fun `test custom openai models keep configured order for chat and agent`() { + val customServicesSettings = service() + val first = createCustomService(name = "First", model = "chat-first") + val second = createCustomService(name = "Second", model = "chat-second") + + customServicesSettings.state.services.clear() + customServicesSettings.state.services.add(first) + customServicesSettings.state.services.add(second) + + assertThat(customModelNames(FeatureType.CHAT)) + .containsExactly("First (chat-first)", "Second (chat-second)") + assertThat(customModelNames(FeatureType.AGENT)) + .containsExactly("First (chat-first)", "Second (chat-second)") + + customServicesSettings.state.services.clear() + customServicesSettings.state.services.add(second) + customServicesSettings.state.services.add(first) + + assertThat(customModelNames(FeatureType.CHAT)) + .containsExactly("Second (chat-second)", "First (chat-first)") + assertThat(customModelNames(FeatureType.AGENT)) + .containsExactly("Second (chat-second)", "First (chat-first)") + } + fun `test migrateMissingProviderInformation updates missing providers`() { val state = ModelSettingsState() val detailsState = ModelDetailsState() @@ -227,4 +253,18 @@ class ModelSettingsTest : IntegrationTest() { assertThat(modelSettings.getStoredModelForFeature(FeatureType.CHAT)).isEqualTo("unknown-model") assertThat(modelSettings.getStoredProviderForFeature(FeatureType.CHAT)).isNull() } + + private fun customModelNames(featureType: FeatureType): List { + return modelSettings.getAvailableModels(featureType) + .filter { it.provider == ServiceType.CUSTOM_OPENAI } + .map { it.displayName } + } + + private fun createCustomService(name: String, model: String): CustomServiceSettingsState { + return CustomServiceSettingsState().apply { + this.name = name + chatCompletionSettings.body.clear() + chatCompletionSettings.body["model"] = model + } + } } From 23241ec98bbf95b289571a9b5f1a5f671c0ce91b Mon Sep 17 00:00:00 2001 From: Roman Gromov <35900229+sapphirepro@users.noreply.github.com> Date: Fri, 13 Mar 2026 23:10:18 +0100 Subject: [PATCH 2/3] Fixes Fixed diagnostics being broken + added diagnostics as tool for agent. Now it shows relevant files and can filter errors separately. Fixed Custom OpenAI providers being broken in Agent mode. None worked previously, was streaming bug issues for both --- .../carlrobert/codegpt/agent/AgentFactory.kt | 7 + .../carlrobert/codegpt/agent/ProxyAIAgent.kt | 21 +- .../carlrobert/codegpt/agent/SubagentTools.kt | 1 + .../ee/carlrobert/codegpt/agent/ToolSpecs.kt | 8 + .../agent/clients/CustomOpenAILLMClient.kt | 49 +++- .../clients/CustomOpenAIStreamingSupport.kt | 14 ++ .../codegpt/agent/clients/ProxyAILLMClient.kt | 7 +- .../clients/StreamingPayloadNormalizer.kt | 31 +++ .../codegpt/agent/tools/DiagnosticsTool.kt | 123 ++++++++++ .../completions/AgentCompletionRunner.kt | 20 +- .../diagnostics/ProjectDiagnosticsService.kt | 231 ++++++++++++++++++ .../codegpt/inlineedit/InlineEditInlay.kt | 15 +- .../codegpt/ui/textarea/SearchManager.kt | 6 +- .../ui/textarea/TagProcessorFactory.kt | 129 +--------- .../ui/textarea/header/tag/TagDetails.kt | 9 +- .../lookup/action/DiagnosticsActionItem.kt | 41 ---- .../action/DiagnosticsContextSupport.kt | 24 ++ .../action/DiagnosticsFilterActionItem.kt | 57 +++++ .../lookup/group/DiagnosticsGroupItem.kt | 27 ++ .../agent/AgentProviderIntegrationTest.kt | 77 +++++- .../codegpt/agent/DiagnosticsToolTest.kt | 26 ++ ...stomOpenAIResponsesApiSerializationTest.kt | 91 +++++++ .../clients/StreamingPayloadNormalizerTest.kt | 53 ++++ .../ui/textarea/DiagnosticsIntegrationTest.kt | 80 ++++++ 24 files changed, 937 insertions(+), 210 deletions(-) create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIStreamingSupport.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizer.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/agent/tools/DiagnosticsTool.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/diagnostics/ProjectDiagnosticsService.kt delete mode 100644 src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsActionItem.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsContextSupport.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsFilterActionItem.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/group/DiagnosticsGroupItem.kt create mode 100644 src/test/kotlin/ee/carlrobert/codegpt/agent/DiagnosticsToolTest.kt create mode 100644 src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt create mode 100644 src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt create mode 100644 src/test/kotlin/ee/carlrobert/codegpt/ui/textarea/DiagnosticsIntegrationTest.kt diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt index 1f60313b..a65b23bb 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentFactory.kt @@ -534,6 +534,13 @@ object AgentFactory { hookManager = hookManager, ) ) + if (SubagentTool.DIAGNOSTICS in selected) tool( + DiagnosticsTool( + project = project, + sessionId = sessionId, + hookManager = hookManager, + ) + ) if (SubagentTool.WEB_SEARCH in selected) tool( WebSearchTool( workingDirectory = project.basePath ?: System.getProperty("user.dir"), diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt index fd915050..1fff7523 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt @@ -23,6 +23,7 @@ import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import ee.carlrobert.codegpt.EncodingManager import ee.carlrobert.codegpt.agent.clients.shouldStream +import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider @@ -34,7 +35,6 @@ import ee.carlrobert.codegpt.settings.hooks.HookManager import ee.carlrobert.codegpt.settings.models.ModelSettings import ee.carlrobert.codegpt.settings.service.FeatureType import ee.carlrobert.codegpt.settings.service.ServiceType -import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest @@ -89,7 +89,7 @@ object ProxyAIAgent { val modelSelection = service().getModelSelectionForFeature(FeatureType.AGENT) val skills = project.service().listSkills() - val stream = shouldStreamAgentToolLoop(project, provider) + val stream = shouldStreamAgentToolLoop(provider) val projectInstructions = loadProjectInstructions(project.basePath) val executor = AgentFactory.createExecutor(provider, events) val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() } @@ -268,18 +268,10 @@ object ProxyAIAgent { } private fun shouldStreamAgentToolLoop( - project: Project, provider: ServiceType, ): Boolean { return when (provider) { - ServiceType.CUSTOM_OPENAI -> { - val selectedServiceId = - project.service().getStoredModelForFeature(FeatureType.AGENT) - project.service().state.services - .firstOrNull { it.id == selectedServiceId }?.chatCompletionSettings?.shouldStream() - ?: false - } - + ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(FeatureType.AGENT) ServiceType.GOOGLE -> false else -> true } @@ -323,6 +315,13 @@ object ProxyAIAgent { hookManager = hookManager, ) ) + tool( + DiagnosticsTool( + project = project, + sessionId = sessionId, + hookManager = hookManager, + ) + ) tool( WebSearchTool( workingDirectory = workingDirectory, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/SubagentTools.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/SubagentTools.kt index af405f00..f451338d 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/SubagentTools.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/SubagentTools.kt @@ -4,6 +4,7 @@ enum class SubagentTool(val id: String, val displayName: String, val isWrite: Bo READ("read", "Read", false), TODO_WRITE("todowrite", "TodoWrite", false), INTELLIJ_SEARCH("intellijsearch", "IntelliJSearch", false), + DIAGNOSTICS("diagnostics", "Diagnostics", false), WEB_SEARCH("websearch", "WebSearch", false), WEB_FETCH("webfetch", "WebFetch", false), MCP("MCP", "MCP", false), diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt index fad8028c..eb36695e 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt @@ -14,6 +14,7 @@ enum class ToolName(val id: String, val aliases: Set = emptySet()) { BASH_OUTPUT("BashOutput"), KILL_SHELL("KillShell"), INTELLIJ_SEARCH("IntelliJSearch"), + DIAGNOSTICS("Diagnostics"), WEB_SEARCH("WebSearch"), WEB_FETCH("WebFetch"), MCP("MCP"), @@ -94,6 +95,13 @@ object ToolSpecs { IntelliJSearchTool.Result.serializer() ) ) + register( + ToolSpec( + ToolName.DIAGNOSTICS, + DiagnosticsTool.Args.serializer(), + DiagnosticsTool.Result.serializer() + ) + ) register( ToolSpec( ToolName.WEB_SEARCH, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt index b7574151..7e00e5b3 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt @@ -193,7 +193,7 @@ class CustomOpenAILLMClient( } if (isResponsesApi) { - return serializeResponsesApiRequest(state, messages, model, tools) + return serializeResponsesApiRequest(state, messages, model, tools, toolChoice) } val customParams: CustomOpenAIParams = params.toCustomOpenAIParams(state) @@ -245,7 +245,8 @@ class CustomOpenAILLMClient( state: CustomServiceChatCompletionSettingsState, messages: List, model: LLModel, - tools: List? + tools: List?, + toolChoice: OpenAIToolChoice? ): String { val streamRequest = state.shouldStream() @@ -264,8 +265,9 @@ class CustomOpenAILLMClient( } put("model", JsonPrimitive(model.id)) if (!tools.isNullOrEmpty()) { - put("tools", json.encodeToJsonElement(ListSerializer(OpenAITool.serializer()), tools)) + put("tools", JsonArray(tools.map { it.toResponsesApiToolJson() })) } + toolChoice?.toResponsesApiToolChoiceJson()?.let { put("tool_choice", it) } }.toString() } @@ -418,10 +420,17 @@ class CustomOpenAILLMClient( } override fun decodeStreamingResponse(data: String): CustomOpenAIChatCompletionStreamResponse { + val payload = normalizeSsePayload(data) + ?: return CustomOpenAIChatCompletionStreamResponse( + choices = emptyList(), + created = 0, + id = "", + model = "" + ) if (!isResponsesApi) { - return json.decodeFromString(data) + return json.decodeFromString(payload) } - return adaptResponsesApiStreamEvent(data) + return adaptResponsesApiStreamEvent(payload) } override fun decodeResponse(data: String): CustomOpenAIChatCompletionResponse { @@ -784,6 +793,36 @@ internal fun renderCustomOpenAIPrompt(messages: List, json: Json) } } +internal fun OpenAITool.toResponsesApiToolJson(): JsonObject { + return buildJsonObject { + put("type", JsonPrimitive("function")) + put("name", JsonPrimitive(function.name)) + put( + "parameters", + function.parameters ?: buildJsonObject { + put("type", JsonPrimitive("object")) + putJsonObject("properties") {} + putJsonArray("required") {} + } + ) + function.strict?.let { put("strict", JsonPrimitive(it)) } + function.description?.takeIf { it.isNotBlank() }?.let { + put("description", JsonPrimitive(it)) + } + } +} + +internal fun OpenAIToolChoice.toResponsesApiToolChoiceJson(): JsonElement { + return when (this) { + is OpenAIToolChoice.Function -> buildJsonObject { + put("type", JsonPrimitive("function")) + put("name", JsonPrimitive(function.name)) + } + + else -> JsonPrimitive(toString()) + } +} + internal object CustomOpenAIChatCompletionRequestSerializer : CustomOpenAIAdditionalPropertiesFlatteningSerializer(CustomOpenAIChatCompletionRequest.serializer()) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIStreamingSupport.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIStreamingSupport.kt new file mode 100644 index 00000000..a710df3a --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIStreamingSupport.kt @@ -0,0 +1,14 @@ +package ee.carlrobert.codegpt.agent.clients + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.settings.service.FeatureType +import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings + +internal fun shouldStreamCustomOpenAI(featureType: FeatureType): Boolean { + return runCatching { + service() + .customServiceStateForFeatureType(featureType) + .chatCompletionSettings + .shouldStream() + }.getOrDefault(false) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/ProxyAILLMClient.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/ProxyAILLMClient.kt index 4b27250e..1e5837c1 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/ProxyAILLMClient.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/ProxyAILLMClient.kt @@ -180,8 +180,11 @@ public class ProxyAILLMClient( } } - override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse = - json.decodeFromString(data) + override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse { + val payload = normalizeSsePayload(data) + ?: return ProxyAIChatCompletionStreamResponse() + return json.decodeFromString(payload) + } override fun decodeResponse(data: String): ProxyAIChatCompletionResponse = json.decodeFromString(data) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizer.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizer.kt new file mode 100644 index 00000000..4ffd1216 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizer.kt @@ -0,0 +1,31 @@ +package ee.carlrobert.codegpt.agent.clients + +internal fun normalizeSsePayload(rawData: String): String? { + val normalized = rawData + .replace("\r\n", "\n") + .replace('\r', '\n') + .trim() + if (normalized.isEmpty()) { + return null + } + + val dataLines = normalized.lineSequence() + .mapNotNull { line -> + val markerIndex = line.indexOf("data:") + if (markerIndex >= 0) { + line.substring(markerIndex + "data:".length).trimStart() + } else { + null + } + } + .filter { it.isNotEmpty() } + .toList() + + val payload = if (dataLines.isNotEmpty()) { + dataLines.joinToString("\n") + } else { + normalized + } + + return payload.takeUnless { it.isBlank() || it == "[DONE]" } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/DiagnosticsTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/DiagnosticsTool.kt new file mode 100644 index 00000000..60969da5 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/DiagnosticsTool.kt @@ -0,0 +1,123 @@ +package ee.carlrobert.codegpt.agent.tools + +import ai.koog.agents.core.tools.annotations.LLMDescription +import com.intellij.openapi.components.service +import com.intellij.openapi.project.Project +import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter +import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService +import ee.carlrobert.codegpt.settings.ProxyAISettingsService +import ee.carlrobert.codegpt.settings.ToolPermissionPolicy +import ee.carlrobert.codegpt.settings.hooks.HookManager +import ee.carlrobert.codegpt.tokens.truncateToolResult +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +class DiagnosticsTool( + private val project: Project, + private val sessionId: String, + private val hookManager: HookManager, +) : BaseTool( + workingDirectory = project.basePath ?: System.getProperty("user.dir"), + argsSerializer = Args.serializer(), + resultSerializer = Result.serializer(), + name = "Diagnostics", + description = """ + Reads the IDE's current diagnostics for a specific file. + + Use this tool when you need compiler or inspection diagnostics for one file. + - The file_path parameter must be an absolute path. + - filter='errors_only' returns only errors. + - filter='all' returns errors, warnings, weak warnings, and info diagnostics. + - Results reflect diagnostics currently available in the IDE for that file. + """.trimIndent(), + argsClass = Args::class, + resultClass = Result::class, + hookManager = hookManager, + sessionId = sessionId, +) { + + @Serializable + data class Args( + @property:LLMDescription( + "The absolute path to the file to inspect. Must be an absolute path." + ) + @SerialName("file_path") + val filePath: String, + @property:LLMDescription( + "Diagnostics filter: 'errors_only' for errors only, or 'all' for all diagnostics." + ) + val filter: DiagnosticsFilter = DiagnosticsFilter.ERRORS_ONLY + ) + + @Serializable + data class Result( + @SerialName("file_path") + val filePath: String, + val filter: DiagnosticsFilter, + @SerialName("diagnostic_count") + val diagnosticCount: Int = 0, + val output: String = "", + val error: String? = null + ) + + override suspend fun doExecute(args: Args): Result { + val settingsService = project.service() + val decision = settingsService.evaluateToolPermission(this, args.filePath) + if (decision == ToolPermissionPolicy.Decision.DENY) { + return Result( + filePath = args.filePath, + filter = args.filter, + error = "Access denied by permissions.deny for Diagnostics" + ) + } + if (settingsService.hasAllowRulesForTool("Diagnostics") + && decision != ToolPermissionPolicy.Decision.ALLOW + ) { + return Result( + filePath = args.filePath, + filter = args.filter, + error = "Access denied by permissions.allow for Diagnostics" + ) + } + if (settingsService.isPathIgnored(args.filePath)) { + return Result( + filePath = args.filePath, + filter = args.filter, + error = "File not found: ${args.filePath}" + ) + } + + val diagnosticsService = project.service() + val virtualFile = diagnosticsService.findVirtualFile(args.filePath) + ?: return Result( + filePath = args.filePath, + filter = args.filter, + error = "File not found: ${args.filePath}" + ) + + val report = diagnosticsService.collect(virtualFile, args.filter) + return Result( + filePath = args.filePath, + filter = args.filter, + diagnosticCount = report.diagnosticCount, + output = report.content.ifBlank { args.filter.emptyMessage() }, + error = report.error + ) + } + + override fun createDeniedResult(originalArgs: Args, deniedReason: String): Result { + return Result( + filePath = originalArgs.filePath, + filter = originalArgs.filter, + error = deniedReason + ) + } + + override fun encodeResultToString(result: Result): String { + if (result.error != null) { + return "Failed to read diagnostics for '${result.filePath}': ${result.error}" + .truncateToolResult() + } + return result.output.truncateToolResult() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt index 8d05ccb4..0c6902cb 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt @@ -8,21 +8,19 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents import ai.koog.agents.features.tokenizer.feature.MessageTokenizer import ai.koog.prompt.dsl.prompt import ai.koog.prompt.tokenizer.Tokenizer -import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import ee.carlrobert.codegpt.EncodingManager import ee.carlrobert.codegpt.agent.AgentEvents import ee.carlrobert.codegpt.agent.MessageWithContext import ee.carlrobert.codegpt.agent.clients.shouldStream +import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider import ee.carlrobert.codegpt.mcp.McpTool import ee.carlrobert.codegpt.mcp.McpToolAliasResolver import ee.carlrobert.codegpt.mcp.McpToolCallHandler -import ee.carlrobert.codegpt.settings.models.ModelSettings import ee.carlrobert.codegpt.settings.service.ServiceType -import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings import ee.carlrobert.codegpt.util.ReasoningFrameTextAdapter import kotlinx.coroutines.* import java.util.* @@ -51,7 +49,7 @@ internal object AgentCompletionRunner : CompletionRunner { val jobRef = AtomicReference() val messageBuilder = StringBuilder() val project = request.callParameters.project!! - val stream = shouldStreamAgentToolLoop(project, request) + val stream = shouldStreamAgentToolLoop(request) val toolCallHandler = project.let { McpToolCallHandler.getInstance(it) } val toolRegistry = createChatToolRegistry( callParameters = request.callParameters, @@ -176,21 +174,13 @@ internal object AgentCompletionRunner : CompletionRunner { } internal fun shouldStreamAgentToolLoop( - project: Project, request: CompletionRunnerRequest.Chat, ): Boolean { val provider = request.serviceType return when (provider) { - ServiceType.CUSTOM_OPENAI -> { - val selectedServiceId = project.service() - .getStoredModelForFeature(request.callParameters.featureType) - project.service().state.services - .firstOrNull { it.id == selectedServiceId } - ?.chatCompletionSettings - ?.shouldStream() - ?: false - } - + ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI( + request.callParameters.featureType + ) ServiceType.GOOGLE -> false else -> true } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/diagnostics/ProjectDiagnosticsService.kt b/src/main/kotlin/ee/carlrobert/codegpt/diagnostics/ProjectDiagnosticsService.kt new file mode 100644 index 00000000..8652bb5b --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/diagnostics/ProjectDiagnosticsService.kt @@ -0,0 +1,231 @@ +package ee.carlrobert.codegpt.diagnostics + +import com.intellij.codeInsight.daemon.impl.DaemonCodeAnalyzerImpl +import com.intellij.codeInsight.daemon.impl.HighlightInfo +import com.intellij.lang.annotation.HighlightSeverity +import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.components.Service +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.project.DumbService +import com.intellij.openapi.project.Project +import com.intellij.openapi.util.text.StringUtil +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiManager +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +enum class DiagnosticsFilter(val displayName: String) { + @SerialName("errors_only") + ERRORS_ONLY("Errors only"), + + @SerialName("all") + ALL("All"); + + fun includes(severity: HighlightSeverity): Boolean { + return when (this) { + ERRORS_ONLY -> severity == HighlightSeverity.ERROR + ALL -> severity == HighlightSeverity.ERROR || + severity == HighlightSeverity.WARNING || + severity == HighlightSeverity.WEAK_WARNING || + severity == HighlightSeverity.INFORMATION + } + } + + fun minimumSeverity(): HighlightSeverity { + return when (this) { + ERRORS_ONLY -> HighlightSeverity.ERROR + ALL -> HighlightSeverity.INFORMATION + } + } + + fun emptyMessage(): String { + return when (this) { + ERRORS_ONLY -> "No errors found." + ALL -> "No diagnostics found." + } + } +} + +data class DiagnosticsReport( + val filePath: String, + val filter: DiagnosticsFilter, + val content: String = "", + val diagnosticCount: Int = 0, + val error: String? = null +) { + val hasDiagnostics: Boolean + get() = error == null && diagnosticCount > 0 +} + +@Service(Service.Level.PROJECT) +class ProjectDiagnosticsService( + private val project: Project +) { + + fun findVirtualFile(filePath: String): VirtualFile? { + val normalizedPath = filePath.replace('\\', '/') + val fileSystem = LocalFileSystem.getInstance() + return fileSystem.findFileByPath(normalizedPath) + ?: fileSystem.refreshAndFindFileByPath(normalizedPath) + } + + fun collect( + virtualFile: VirtualFile, + filter: DiagnosticsFilter = DiagnosticsFilter.ALL + ): DiagnosticsReport { + return try { + var result = DiagnosticsReport( + filePath = virtualFile.path, + filter = filter + ) + + ApplicationManager.getApplication().invokeAndWait { + result = ApplicationManager.getApplication().runWriteAction { + DumbService.getInstance(project).runReadActionInSmartMode { + val document = FileDocumentManager.getInstance().getDocument(virtualFile) + ?: return@runReadActionInSmartMode DiagnosticsReport( + filePath = virtualFile.path, + filter = filter, + error = "No document found for file." + ) + + PsiDocumentManager.getInstance(project).commitDocument(document) + + val psiFile = PsiManager.getInstance(project).findFile(virtualFile) + ?: return@runReadActionInSmartMode DiagnosticsReport( + filePath = virtualFile.path, + filter = filter, + error = "No PSI file found for: ${virtualFile.path}" + ) + + val rangeHighlights = DaemonCodeAnalyzerImpl.getHighlights( + document, + filter.minimumSeverity(), + project + ) + + val fileLevel = getFileLevelHighlights(psiFile) + val highlights = (rangeHighlights.asSequence() + fileLevel.asSequence()) + .filter { filter.includes(it.severity) } + .mapNotNull { highlight -> + extractMessage(highlight)?.let { message -> + DiagnosticEntry(highlight, message) + } + } + .distinctBy { Triple(it.message, it.highlight.startOffset, it.highlight.severity) } + .sortedWith( + compareBy( + { severityOrder(it.highlight.severity) }, + { it.highlight.startOffset.coerceAtLeast(0) } + ) + ) + .toList() + + if (highlights.isEmpty()) { + return@runReadActionInSmartMode DiagnosticsReport( + filePath = virtualFile.path, + filter = filter + ) + } + + val maxItems = 200 + val overflow = (highlights.size - maxItems).coerceAtLeast(0) + val shown = highlights.take(maxItems) + + val content = buildString { + append("File: ${virtualFile.name}\n") + append("Path: ${virtualFile.path}\n") + append("Filter: ${filter.displayName}\n\n") + + shown.forEach { entry -> + val info = entry.highlight + val startOffset = info.startOffset.coerceIn(0, document.textLength) + val lineColText = + if (info.startOffset >= 0 && document.textLength > 0) { + val line = document.getLineNumber(startOffset) + 1 + val col = + startOffset - document.getLineStartOffset(line - 1) + 1 + "line $line, col $col" + } else { + "file-level" + } + + append("- [${severityLabel(info.severity)}] $lineColText: ${entry.message}\n") + } + + if (overflow > 0) { + append("... ($overflow more not shown)\n") + } + } + + DiagnosticsReport( + filePath = virtualFile.path, + filter = filter, + content = content, + diagnosticCount = highlights.size + ) + } + } + } + + result + } catch (e: Exception) { + DiagnosticsReport( + filePath = virtualFile.path, + filter = filter, + error = "Error retrieving diagnostics: ${e.message}" + ) + } + } + + private fun getFileLevelHighlights(psiFile: com.intellij.psi.PsiFile): List { + return try { + val method = DaemonCodeAnalyzerImpl::class.java.methods.firstOrNull { + it.name == "getFileLevelHighlights" && it.parameterCount == 2 + } + if (method != null) { + @Suppress("UNCHECKED_CAST") + method.invoke(null, project, psiFile) as? List ?: emptyList() + } else { + emptyList() + } + } catch (_: Throwable) { + emptyList() + } + } + + private fun extractMessage(info: HighlightInfo): String? { + val rawMessage = info.description ?: info.toolTip ?: "" + return StringUtil.removeHtmlTags(rawMessage, false) + .trim() + .takeIf { it.isNotBlank() } + } + + private fun severityLabel(severity: HighlightSeverity): String { + return when (severity) { + HighlightSeverity.ERROR -> "ERROR" + HighlightSeverity.WARNING -> "WARNING" + HighlightSeverity.WEAK_WARNING -> "WEAK_WARNING" + HighlightSeverity.INFORMATION -> "INFO" + else -> severity.toString() + } + } + + private fun severityOrder(severity: HighlightSeverity): Int { + return when (severity) { + HighlightSeverity.ERROR -> 0 + HighlightSeverity.WARNING -> 1 + HighlightSeverity.WEAK_WARNING -> 2 + HighlightSeverity.INFORMATION -> 3 + else -> 4 + } + } + + private data class DiagnosticEntry( + val highlight: HighlightInfo, + val message: String + ) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/inlineedit/InlineEditInlay.kt b/src/main/kotlin/ee/carlrobert/codegpt/inlineedit/InlineEditInlay.kt index 03946785..4d050e45 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/inlineedit/InlineEditInlay.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/inlineedit/InlineEditInlay.kt @@ -795,13 +795,18 @@ class InlineEditInlay(private var editor: Editor) : Disposable { private fun collectDiagnosticsInfo(): String? { val tags: Set = tagManager.getTags() - val diagnosticsTag = - tags.firstOrNull { it.selected && it is DiagnosticsTagDetails } as? DiagnosticsTagDetails - ?: return null + val diagnosticsTags = tags + .filter { it.selected && it is DiagnosticsTagDetails } + .filterIsInstance() + if (diagnosticsTags.isEmpty()) { + return null + } - val processor = TagProcessorFactory.getProcessor(project, diagnosticsTag) val stringBuilder = StringBuilder() - processor.process(Message("", ""), stringBuilder) + diagnosticsTags.forEach { diagnosticsTag -> + val processor = TagProcessorFactory.getProcessor(project, diagnosticsTag) + processor.process(Message("", ""), stringBuilder) + } return stringBuilder.toString().takeIf { it.isNotBlank() } } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/SearchManager.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/SearchManager.kt index ab02bcd8..6cd7dd31 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/SearchManager.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/SearchManager.kt @@ -8,7 +8,6 @@ import ee.carlrobert.codegpt.settings.service.FeatureType import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager import ee.carlrobert.codegpt.ui.textarea.lookup.LookupActionItem import ee.carlrobert.codegpt.ui.textarea.lookup.LookupGroupItem -import ee.carlrobert.codegpt.ui.textarea.lookup.action.DiagnosticsActionItem import ee.carlrobert.codegpt.ui.textarea.lookup.action.ImageActionItem import ee.carlrobert.codegpt.ui.textarea.lookup.action.WebActionItem import ee.carlrobert.codegpt.ui.textarea.lookup.action.files.IncludeOpenFilesActionItem @@ -44,7 +43,7 @@ class SearchManager( FoldersGroupItem(project, tagManager), if (GitFeatureAvailability.isAvailable) GitGroupItem(project) else null, HistoryGroupItem(), - DiagnosticsActionItem(tagManager) + DiagnosticsGroupItem(tagManager) ).filter { it.enabled } private fun getAgentGroups() = listOfNotNull( @@ -52,6 +51,7 @@ class SearchManager( FoldersGroupItem(project, tagManager), if (GitFeatureAvailability.isAvailable) GitGroupItem(project) else null, MCPGroupItem(tagManager, FeatureType.AGENT), + DiagnosticsGroupItem(tagManager), ImageActionItem(project, tagManager) ).filter { it.enabled } @@ -62,7 +62,7 @@ class SearchManager( HistoryGroupItem(), PersonasGroupItem(tagManager), MCPGroupItem(tagManager, featureType ?: FeatureType.CHAT), - DiagnosticsActionItem(tagManager), + DiagnosticsGroupItem(tagManager), WebActionItem(tagManager), ImageActionItem(project, tagManager) ).filter { it.enabled } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/TagProcessorFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/TagProcessorFactory.kt index e71541ad..b8955fbe 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/TagProcessorFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/TagProcessorFactory.kt @@ -1,24 +1,16 @@ package ee.carlrobert.codegpt.ui.textarea -import com.intellij.codeInsight.daemon.impl.DaemonCodeAnalyzerImpl -import com.intellij.codeInsight.daemon.impl.HighlightInfo -import com.intellij.lang.annotation.HighlightSeverity -import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.application.runReadAction import com.intellij.openapi.components.service -import com.intellij.openapi.fileEditor.FileDocumentManager import com.intellij.openapi.progress.ProgressManager -import com.intellij.openapi.project.DumbService import com.intellij.openapi.project.Project -import com.intellij.openapi.util.text.StringUtil import com.intellij.openapi.vfs.VirtualFile -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiManager import ee.carlrobert.codegpt.EncodingManager import ee.carlrobert.codegpt.completions.CompletionRequestUtil import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.ConversationsState import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService import ee.carlrobert.codegpt.settings.ProxyAISettingsService import ee.carlrobert.codegpt.ui.textarea.header.tag.* import ee.carlrobert.codegpt.ui.textarea.lookup.action.HistoryActionItem @@ -270,118 +262,17 @@ class DiagnosticsTagProcessor( private val project: Project, private val tagDetails: DiagnosticsTagDetails, ) : TagProcessor { + private val diagnosticsService = project.service() + override fun process(message: Message, promptBuilder: StringBuilder) { + val diagnostics = diagnosticsService.collect(tagDetails.virtualFile, tagDetails.filter) + if (diagnostics.content.isBlank() && diagnostics.error == null) { + return + } + promptBuilder - .append("\n## Current File Problems\n") - .append(getDiagnosticsString(project, tagDetails.virtualFile)) + .append("\n## ${tagDetails.virtualFile.name} Problems (${tagDetails.filter.displayName})\n") + .append(diagnostics.error ?: diagnostics.content) .append("\n") } - - private fun getDiagnosticsString(project: Project, virtualFile: VirtualFile): String { - return try { - var result = "" - ApplicationManager.getApplication().invokeAndWait { - result = ApplicationManager.getApplication().runWriteAction { - DumbService.getInstance(project).runReadActionInSmartMode { - val document = FileDocumentManager.getInstance().getDocument(virtualFile) - ?: return@runReadActionInSmartMode "No document found for file" - - PsiDocumentManager.getInstance(project).commitDocument(document) - - val psiManager = PsiManager.getInstance(project) - val psiFile = psiManager.findFile(virtualFile) - ?: return@runReadActionInSmartMode "No PSI file found for: ${virtualFile.path}" - - val rangeHighlights = - DaemonCodeAnalyzerImpl.getHighlights( - document, - HighlightSeverity.WEAK_WARNING, - project - ) - // TODO: Find a better solution - val fileLevel: List = try { - val method = DaemonCodeAnalyzerImpl::class.java.methods.firstOrNull { - it.name == "getFileLevelHighlights" && it.parameterCount == 2 - } - if (method != null) { - @Suppress("UNCHECKED_CAST") - method.invoke(null, project, psiFile) as? List - ?: emptyList() - } else { - emptyList() - } - } catch (_: Throwable) { - emptyList() - } - - val highlights = (rangeHighlights.asSequence() + fileLevel.asSequence()) - .distinctBy { Triple(it.description, it.startOffset, it.severity) } - .sortedWith( - compareBy( - { severityOrder(it.severity) }, - { it.startOffset.coerceAtLeast(0) } - ) - ) - .toList() - - if (highlights.isEmpty()) { - return@runReadActionInSmartMode "" - } - - val maxItems = 200 - val overflow = (highlights.size - maxItems).coerceAtLeast(0) - val shown = highlights.take(maxItems) - - buildString { - append("File: ${virtualFile.name}\n") - append("Path: ${virtualFile.path}\n\n") - - shown.forEach { info -> - val startOffset = info.startOffset.coerceIn(0, document.textLength) - val lineColText = - if (info.startOffset >= 0 && document.textLength > 0) { - val line = document.getLineNumber(startOffset) + 1 - val col = - startOffset - document.getLineStartOffset(line - 1) + 1 - "line $line, col $col" - } else { - "file-level" - } - - val rawMessage = info.description ?: info.toolTip ?: "" - val message = StringUtil.removeHtmlTags(rawMessage, false).trim() - - val severityLabel = when (info.severity) { - HighlightSeverity.ERROR -> "ERROR" - HighlightSeverity.WARNING -> "WARNING" - HighlightSeverity.WEAK_WARNING -> "WEAK_WARNING" - HighlightSeverity.INFORMATION -> "INFO" - else -> info.severity.toString() - } - - append("- [$severityLabel] $lineColText: $message\n") - } - - if (overflow > 0) { - append("... ($overflow more not shown)\n") - } - } - } - } - } - result - } catch (e: Exception) { - "Error retrieving diagnostics: ${e.message}" - } - } - - private fun severityOrder(severity: HighlightSeverity): Int { - return when (severity) { - HighlightSeverity.ERROR -> 0 - HighlightSeverity.WARNING -> 1 - HighlightSeverity.WEAK_WARNING -> 2 - HighlightSeverity.INFORMATION -> 3 - else -> 4 - } - } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/header/tag/TagDetails.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/header/tag/TagDetails.kt index 5ae52c3f..a899dd32 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/header/tag/TagDetails.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/header/tag/TagDetails.kt @@ -5,6 +5,7 @@ import com.intellij.openapi.editor.SelectionModel import com.intellij.openapi.vfs.VirtualFile import com.intellij.ui.JBColor import ee.carlrobert.codegpt.Icons +import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter import ee.carlrobert.codegpt.mcp.ConnectionStatus import ee.carlrobert.codegpt.mcp.McpResource import ee.carlrobert.codegpt.mcp.McpTool @@ -304,7 +305,9 @@ class CodeAnalyzeTagDetails : TagDetails("Code Analyze", AllIcons.Actions.Depend override fun getTooltipText(): String? = null } -data class DiagnosticsTagDetails(val virtualFile: VirtualFile) : - TagDetails("${virtualFile.name} Problems", AllIcons.General.InspectionsEye) { - override fun getTooltipText(): String = virtualFile.path +data class DiagnosticsTagDetails( + val virtualFile: VirtualFile, + val filter: DiagnosticsFilter = DiagnosticsFilter.ALL +) : TagDetails("${virtualFile.name} Problems (${filter.displayName})", AllIcons.General.InspectionsEye) { + override fun getTooltipText(): String = "${virtualFile.path} (${filter.displayName})" } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsActionItem.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsActionItem.kt deleted file mode 100644 index 08b473c7..00000000 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsActionItem.kt +++ /dev/null @@ -1,41 +0,0 @@ -package ee.carlrobert.codegpt.ui.textarea.lookup.action - -import com.intellij.icons.AllIcons -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.VirtualFile -import ee.carlrobert.codegpt.ui.textarea.UserInputPanel -import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails -import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails -import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails -import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager -import ee.carlrobert.codegpt.util.EditorUtil - -class DiagnosticsActionItem( - private val tagManager: TagManager -) : AbstractLookupActionItem() { - - override val displayName: String = "Diagnostics" - override val icon = AllIcons.General.InspectionsEye - override val enabled: Boolean - get() = tagManager.getTags().none { it is DiagnosticsTagDetails } && - tagManager.getTags().any { it is FileTagDetails || it is EditorTagDetails } - - override fun execute(project: Project, userInputPanel: UserInputPanel) { - val virtualFile = findVirtualFile(project) - virtualFile?.let { file -> - userInputPanel.addTag(DiagnosticsTagDetails(file)) - } - } - - private fun findVirtualFile(project: Project): VirtualFile? { - val existingFile = tagManager.getTags() - .firstNotNullOfOrNull { tag -> - when (tag) { - is FileTagDetails -> tag.virtualFile - is EditorTagDetails -> tag.virtualFile - else -> null - } - } - return existingFile ?: EditorUtil.getSelectedEditor(project)?.virtualFile - } -} \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsContextSupport.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsContextSupport.kt new file mode 100644 index 00000000..e782de80 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsContextSupport.kt @@ -0,0 +1,24 @@ +package ee.carlrobert.codegpt.ui.textarea.lookup.action + +import com.intellij.openapi.vfs.VirtualFile +import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorSelectionTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.SelectionTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails + +internal fun selectedContextFiles(tags: Collection): List { + return tags.asSequence() + .filter { it.selected } + .mapNotNull { tag -> + when (tag) { + is FileTagDetails -> tag.virtualFile + is EditorTagDetails -> tag.virtualFile + is SelectionTagDetails -> tag.virtualFile + is EditorSelectionTagDetails -> tag.virtualFile + else -> null + } + } + .distinctBy { it.path } + .toList() +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsFilterActionItem.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsFilterActionItem.kt new file mode 100644 index 00000000..767f8f43 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/DiagnosticsFilterActionItem.kt @@ -0,0 +1,57 @@ +package ee.carlrobert.codegpt.ui.textarea.lookup.action + +import com.intellij.icons.AllIcons +import com.intellij.notification.NotificationType +import com.intellij.openapi.components.service +import com.intellij.openapi.project.Project +import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter +import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService +import ee.carlrobert.codegpt.ui.OverlayUtil +import ee.carlrobert.codegpt.ui.textarea.UserInputPanel +import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager + +class DiagnosticsFilterActionItem( + private val tagManager: TagManager, + private val filter: DiagnosticsFilter +) : AbstractLookupActionItem() { + + override val displayName: String = filter.displayName + override val icon = AllIcons.General.InspectionsEye + + override fun execute(project: Project, userInputPanel: UserInputPanel) { + val diagnosticsService = project.service() + val files = selectedContextFiles(userInputPanel.getSelectedTags()) + var matched = false + + files.forEach { virtualFile -> + val diagnostics = diagnosticsService.collect(virtualFile, filter) + if (!diagnostics.hasDiagnostics) { + return@forEach + } + matched = true + + val newTag = DiagnosticsTagDetails(virtualFile, filter) + val existing = tagManager.getTags() + .filterIsInstance() + .firstOrNull { it.virtualFile == virtualFile } + + when { + existing == null -> { + userInputPanel.addTag(newTag) + } + + existing != newTag -> { + tagManager.updateTag(existing, newTag) + } + } + } + + if (!matched) { + OverlayUtil.showNotification( + filter.emptyMessage().removeSuffix(".") + " in selected context files.", + NotificationType.INFORMATION + ) + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/group/DiagnosticsGroupItem.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/group/DiagnosticsGroupItem.kt new file mode 100644 index 00000000..4d816f77 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/group/DiagnosticsGroupItem.kt @@ -0,0 +1,27 @@ +package ee.carlrobert.codegpt.ui.textarea.lookup.group + +import com.intellij.icons.AllIcons +import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter +import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager +import ee.carlrobert.codegpt.ui.textarea.lookup.LookupActionItem +import ee.carlrobert.codegpt.ui.textarea.lookup.action.DiagnosticsFilterActionItem +import ee.carlrobert.codegpt.ui.textarea.lookup.action.selectedContextFiles + +class DiagnosticsGroupItem( + private val tagManager: TagManager +) : AbstractLookupGroupItem() { + + override val displayName: String = "Diagnostics" + override val icon = AllIcons.General.InspectionsEye + override val enabled: Boolean + get() = selectedContextFiles(tagManager.getTags()).isNotEmpty() + + override suspend fun getLookupItems(searchText: String): List { + return listOf( + DiagnosticsFilterActionItem(tagManager, DiagnosticsFilter.ERRORS_ONLY), + DiagnosticsFilterActionItem(tagManager, DiagnosticsFilter.ALL) + ).filter { + searchText.isEmpty() || it.displayName.contains(searchText, ignoreCase = true) + } + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt index 3c1b6063..c7b0a8e9 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider import com.intellij.openapi.components.service import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.* import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential +import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI import ee.carlrobert.codegpt.settings.models.ModelCatalog import ee.carlrobert.codegpt.settings.models.ModelSettings import ee.carlrobert.codegpt.settings.service.FeatureType @@ -174,6 +175,60 @@ class AgentProviderIntegrationTest : IntegrationTest() { assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI") } + fun testCustomOpenAIAgentStreamsWhenStoredSelectionUsesModelId() { + val customService = configureCustomOpenAIService(stream = true) + service().setModel( + FeatureType.AGENT, + "custom-agent-model", + ServiceType.CUSTOM_OPENAI + ) + assertThat(shouldStreamCustomOpenAI(FeatureType.AGENT)).isTrue() + expectCustomOpenAI(StreamHttpExchange { request -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.body["stream"]).isEqualTo(true) + assertThat(extractPromptText(request)).contains("Say hello from streamed Custom OpenAI") + chatCompletionChunks("custom-agent-model", "Hello from streamed Custom OpenAI") + }) + + val result = runAgent(ServiceType.CUSTOM_OPENAI, "Say hello from streamed Custom OpenAI") + + assertThat(customService.id).isNotBlank() + assertThat(result.output).isEqualTo("Hello from streamed Custom OpenAI") + assertThat(result.events.text.toString()).isEqualTo("Hello from streamed Custom OpenAI") + } + + fun testCustomOpenAIResponsesAgentStreamsWhenStoredSelectionUsesModelId() { + val customService = configureCustomOpenAIService( + path = "/v1/responses", + model = "custom-responses-model", + stream = true, + useResponsesApiBody = true + ) + service().setModel( + FeatureType.AGENT, + "custom-responses-model", + ServiceType.CUSTOM_OPENAI + ) + assertThat(shouldStreamCustomOpenAI(FeatureType.AGENT)).isTrue() + expectCustomOpenAI(StreamHttpExchange { request -> + assertThat(request.uri.path).isEqualTo("/v1/responses") + assertThat(request.method).isEqualTo("POST") + assertThat(request.body["stream"]).isEqualTo(true) + assertThat(extractPromptText(request)).contains("Say hello from Custom OpenAI Responses") + openAiResponsesChunks( + model = "custom-responses-model", + text = "Hello from Custom OpenAI Responses" + ) + }) + + val result = runAgent(ServiceType.CUSTOM_OPENAI, "Say hello from Custom OpenAI Responses") + + assertThat(customService.id).isNotBlank() + assertThat(result.output).isEqualTo("Hello from Custom OpenAI Responses") + assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI Responses") + } + private fun runAgent( provider: ServiceType, userMessage: String @@ -246,9 +301,10 @@ class AgentProviderIntegrationTest : IntegrationTest() { ) } - private fun openAiResponsesChunks(): List { - val model = "gpt-4.1-mini" - val text = "Hello from OpenAI" + private fun openAiResponsesChunks( + model: String = "gpt-4.1-mini", + text: String = "Hello from OpenAI" + ): List { val chunks = text.chunked(4) return chunks.mapIndexed { index, chunk -> jsonMapResponse( @@ -573,15 +629,24 @@ class AgentProviderIntegrationTest : IntegrationTest() { ) } - private fun configureCustomOpenAIService(): CustomServiceSettingsState { + private fun configureCustomOpenAIService( + path: String = "/v1/chat/completions", + model: String = "custom-agent-model", + stream: Boolean = false, + useResponsesApiBody: Boolean = false + ): CustomServiceSettingsState { val settings = service() val serviceState = CustomServiceSettingsState().apply { name = "Agent Test Custom Service" chatCompletionSettings.url = - System.getProperty("customOpenAI.baseUrl") + "/v1/chat/completions" + System.getProperty("customOpenAI.baseUrl") + path chatCompletionSettings.headers.clear() chatCompletionSettings.body.clear() - chatCompletionSettings.body["model"] = "custom-agent-model" + chatCompletionSettings.body["model"] = model + chatCompletionSettings.body["stream"] = stream + if (useResponsesApiBody) { + chatCompletionSettings.body["input"] = "\$OPENAI_MESSAGES" + } } settings.state.services.clear() settings.state.services.add(serviceState) diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/DiagnosticsToolTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/DiagnosticsToolTest.kt new file mode 100644 index 00000000..d7c91280 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/DiagnosticsToolTest.kt @@ -0,0 +1,26 @@ +package ee.carlrobert.codegpt.agent + +import ee.carlrobert.codegpt.agent.tools.DiagnosticsTool +import ee.carlrobert.codegpt.settings.hooks.HookManager +import kotlinx.coroutines.runBlocking +import org.assertj.core.api.Assertions.assertThat +import testsupport.IntegrationTest +import java.io.File + +class DiagnosticsToolTest : IntegrationTest() { + + fun `test diagnostics tool should be registered in tool specs`() { + assertThat(ToolSpecs.find("Diagnostics")).isNotNull() + } + + fun `test diagnostics tool should return missing file error`() { + val missingPath = File(project.basePath, "does-not-exist.txt").absolutePath + val tool = DiagnosticsTool(project, "test-session-id", HookManager(project)) + + val result = runBlocking { + tool.execute(DiagnosticsTool.Args(missingPath)) + } + + assertThat(result.error).isEqualTo("File not found: $missingPath") + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt new file mode 100644 index 00000000..93223fdd --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt @@ -0,0 +1,91 @@ +package ee.carlrobert.codegpt.agent.clients + +import ai.koog.prompt.executor.clients.openai.base.models.Content +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage +import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.params.LLMParams +import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.boolean +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.put +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class CustomOpenAIResponsesApiSerializationTest { + + private val json = Json { + ignoreUnknownKeys = true + encodeDefaults = true + explicitNulls = false + } + + @Test + fun `responses api request should encode tools with top level name`() { + val state = CustomServiceChatCompletionSettingsState().apply { + url = "https://example.com/v1/responses" + body.clear() + body["stream"] = true + body["input"] = "\$OPENAI_MESSAGES" + } + val client = CustomOpenAILLMClient.fromSettingsState("test-key", state) + val serializeMethod = client.javaClass.getDeclaredMethod( + "serializeProviderChatRequest", + List::class.java, + LLModel::class.java, + List::class.java, + OpenAIToolChoice::class.java, + LLMParams::class.java, + Boolean::class.javaPrimitiveType + ) + serializeMethod.isAccessible = true + + val payload = serializeMethod.invoke( + client, + listOf(OpenAIMessage.User(content = Content.Text("hello"))), + LLModel( + id = "gpt-test", + provider = CustomOpenAILLMClient.CustomOpenAI, + capabilities = emptyList(), + contextLength = 128_000, + maxOutputTokens = 4_096 + ), + listOf( + OpenAITool( + OpenAIToolFunction( + name = "Diagnostics", + description = "Read IDE diagnostics", + parameters = buildJsonObject { + put("type", "object") + put("properties", buildJsonObject {}) + put("required", buildJsonArray {}) + }, + strict = true + ) + ) + ), + OpenAIToolChoice.Function(OpenAIToolChoice.FunctionName("Diagnostics")), + CustomOpenAIParams(), + true + ) as String + + val request = json.parseToJsonElement(payload).jsonObject + val tool = request.getValue("tools").jsonArray.single().jsonObject + val toolChoice = request.getValue("tool_choice").jsonObject + + assertThat(tool.getValue("type").jsonPrimitive.content).isEqualTo("function") + assertThat(tool.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics") + assertThat(tool.getValue("description").jsonPrimitive.content).isEqualTo("Read IDE diagnostics") + assertThat(tool.getValue("strict").jsonPrimitive.boolean).isTrue() + assertThat(tool).doesNotContainKey("function") + assertThat(toolChoice.getValue("type").jsonPrimitive.content).isEqualTo("function") + assertThat(toolChoice.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics") + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt new file mode 100644 index 00000000..9a256075 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt @@ -0,0 +1,53 @@ +package ee.carlrobert.codegpt.agent.clients + +import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class StreamingPayloadNormalizerTest { + + @Test + fun `should unwrap sse data payload`() { + assertThat(normalizeSsePayload("data: {\"id\":\"chunk-1\"}")) + .isEqualTo("{\"id\":\"chunk-1\"}") + } + + @Test + fun `should ignore done sentinel`() { + assertThat(normalizeSsePayload("data: [DONE]")).isNull() + } + + @Test + fun `should unwrap multiline sse event frames`() { + assertThat( + normalizeSsePayload("event: response.created\ndata: {\"type\":\"response.created\"}\n") + ).isEqualTo("{\"type\":\"response.created\"}") + } + + @Test + fun `should unwrap compact sse frames where event and data share a line`() { + assertThat( + normalizeSsePayload("event: response.created data: {\"type\":\"response.created\"}") + ).isEqualTo("{\"type\":\"response.created\"}") + } + + @Test + fun `custom openai client should decode sse wrapped chat completion chunks`() { + val state = CustomServiceChatCompletionSettingsState().apply { + url = "https://example.com/v1/chat/completions" + body["stream"] = true + } + val client = CustomOpenAILLMClient.fromSettingsState("test-key", state) + val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java) + decodeMethod.isAccessible = true + + val response = decodeMethod.invoke( + client, + "data: {\"choices\":[],\"created\":0,\"id\":\"chunk-1\",\"model\":\"test-model\"}" + ) as CustomOpenAIChatCompletionStreamResponse + + assertThat(response.id).isEqualTo("chunk-1") + assertThat(response.model).isEqualTo("test-model") + assertThat(response.choices).isEmpty() + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/ui/textarea/DiagnosticsIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/ui/textarea/DiagnosticsIntegrationTest.kt new file mode 100644 index 00000000..fec344b1 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/ui/textarea/DiagnosticsIntegrationTest.kt @@ -0,0 +1,80 @@ +package ee.carlrobert.codegpt.ui.textarea + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter +import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService +import ee.carlrobert.codegpt.settings.service.FeatureType +import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails +import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager +import ee.carlrobert.codegpt.ui.textarea.lookup.action.selectedContextFiles +import ee.carlrobert.codegpt.ui.textarea.lookup.group.DiagnosticsGroupItem +import kotlinx.coroutines.runBlocking +import org.assertj.core.api.Assertions.assertThat +import testsupport.IntegrationTest + +class DiagnosticsIntegrationTest : IntegrationTest() { + + fun `test diagnostics service should return errors for broken file`() { + val brokenFile = myFixture.configureByText( + "Broken.java", + "class Broken { void test() { int value = ; } }" + ).virtualFile + myFixture.doHighlighting() + + val report = project.service() + .collect(brokenFile, DiagnosticsFilter.ERRORS_ONLY) + + assertThat(report.hasDiagnostics).isTrue() + assertThat(report.content).contains("[ERROR]") + assertThat(report.content).contains("Broken.java") + } + + fun `test diagnostics service should return empty for clean file`() { + val cleanFile = myFixture.configureByText( + "Clean.java", + "class Clean { void test() { int value = 1; } }" + ).virtualFile + myFixture.doHighlighting() + + val report = project.service() + .collect(cleanFile, DiagnosticsFilter.ERRORS_ONLY) + + assertThat(report.hasDiagnostics).isFalse() + assertThat(report.content).isBlank() + assertThat(report.error).isNull() + } + + fun `test selectedContextFiles should use selected file context only`() { + val firstFile = myFixture.configureByText("First.java", "class First {}").virtualFile + val secondFile = myFixture.configureByText("Second.java", "class Second {}").virtualFile + val unselectedEditorTag = EditorTagDetails(secondFile).apply { selected = false } + + val contextFiles = selectedContextFiles( + listOf( + FileTagDetails(firstFile), + DiagnosticsTagDetails(firstFile, DiagnosticsFilter.ALL), + EditorTagDetails(secondFile), + unselectedEditorTag + ) + ) + + assertThat(contextFiles.map { it.path }) + .containsExactly(firstFile.path, secondFile.path) + } + + fun `test diagnostics group should be available in agent mode with file context`() { + val file = myFixture.configureByText("AgentFile.java", "class AgentFile {}").virtualFile + val tagManager = TagManager().apply { + addTag(FileTagDetails(file)) + } + + val groups = SearchManager(project, tagManager, FeatureType.AGENT).getDefaultGroups() + val diagnosticsGroup = groups.filterIsInstance().single() + val actions = runBlocking { diagnosticsGroup.getLookupItems("") } + + assertThat(diagnosticsGroup.enabled).isTrue() + assertThat(actions.map { it.displayName }).containsExactly("Errors only", "All") + } +} From 93ff7cda5a6a52150db445bdc1f6ed51b15104b1 Mon Sep 17 00:00:00 2001 From: Roman Gromov <35900229+sapphirepro@users.noreply.github.com> Date: Sat, 14 Mar 2026 00:49:44 +0100 Subject: [PATCH 3/3] Fixed Custom OpenAI (Responses) tool calls Fixed tool calls that were previously broken and crashing --- .../codegpt/agent/ToolArgumentsJson.kt | 47 +++ .../ee/carlrobert/codegpt/agent/ToolSpecs.kt | 6 +- .../agent/clients/CustomOpenAILLMClient.kt | 270 ++++++++++++++---- .../strategy/SingleRunStrategyProvider.kt | 16 +- .../agent/AgentProviderIntegrationTest.kt | 254 ++++++++++++++++ ...stomOpenAIResponsesApiSerializationTest.kt | 128 +++++++++ .../clients/StreamingPayloadNormalizerTest.kt | 41 +++ .../strategy/SingleRunStrategyProviderTest.kt | 8 + 8 files changed, 705 insertions(+), 65 deletions(-) create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/agent/ToolArgumentsJson.kt diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolArgumentsJson.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolArgumentsJson.kt new file mode 100644 index 00000000..aecc5a34 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolArgumentsJson.kt @@ -0,0 +1,47 @@ +package ee.carlrobert.codegpt.agent + +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.contentOrNull + +private val toolArgumentsJson = Json { + ignoreUnknownKeys = true +} + +internal fun normalizeToolArgumentsJson(rawArgs: String?): String? { + val payload = rawArgs?.trim().orEmpty() + if (payload.isBlank()) { + return null + } + + val element = runCatching { + toolArgumentsJson.parseToJsonElement(payload) + }.getOrNull() ?: return null + + return normalizeToolArgumentsElement(element)?.toString() +} + +private tailrec fun normalizeToolArgumentsElement(element: JsonElement): JsonObject? { + return when (element) { + is JsonObject -> element + is JsonPrimitive -> { + if (!element.isString) { + null + } else { + val nestedPayload = element.contentOrNull?.trim().orEmpty() + if (nestedPayload.isBlank()) { + null + } else { + val nested = runCatching { + toolArgumentsJson.parseToJsonElement(nestedPayload) + }.getOrNull() ?: return null + normalizeToolArgumentsElement(nested) + } + } + } + + else -> null + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt index eb36695e..507aa686 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt @@ -201,8 +201,12 @@ object ToolSpecs { if (serializer == null || payload.isBlank()) { return null } + val typedSerializer = serializer as KSerializer return runCatching { - json.decodeFromString(serializer as KSerializer, payload) + json.decodeFromString(typedSerializer, payload) + }.recoverCatching { + val normalized = normalizeToolArgumentsJson(payload) ?: throw it + json.decodeFromString(typedSerializer, normalized) }.getOrNull() } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt index 7e00e5b3..10ed114e 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt @@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.type.TypeReference import com.fasterxml.jackson.databind.ObjectMapper import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate import ee.carlrobert.codegpt.codecompletions.InfillRequest +import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson import ee.carlrobert.codegpt.settings.Placeholder import ee.carlrobert.codegpt.settings.service.custom.CustomServiceCodeCompletionSettingsState import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState @@ -40,6 +41,7 @@ import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.descriptors.elementNames import kotlinx.serialization.json.* import java.net.URI +import java.util.UUID import kotlin.io.encoding.ExperimentalEncodingApi import kotlin.time.Clock @@ -249,6 +251,8 @@ class CustomOpenAILLMClient( toolChoice: OpenAIToolChoice? ): String { val streamRequest = state.shouldStream() + val inputJson = messages.toResponsesApiItemsJson() + val prompt = renderCustomOpenAIPrompt(messages, json) return buildJsonObject { state.body.forEach { (key, value) -> @@ -257,7 +261,8 @@ class CustomOpenAILLMClient( key = key, value = value, streamRequest = streamRequest, - messages = messages, + messagesJson = inputJson, + prompt = prompt, credential = apiKey, json = json ) @@ -457,48 +462,42 @@ class CustomOpenAILLMClient( ) } - "response.function_call_arguments.delta" -> { - val argsDelta = event["delta"]?.jsonPrimitive?.contentOrNull ?: "" - val callId = event["call_id"]?.jsonPrimitive?.contentOrNull ?: "" - val index = event["output_index"]?.jsonPrimitive?.intOrNull ?: 0 - CustomOpenAIStreamChoice( - delta = CustomOpenAIStreamDelta( - toolCalls = listOf( - CustomOpenAIToolCall( - id = callId, - index = index, - function = CustomOpenAIFunction(arguments = argsDelta) + "response.function_call_arguments.delta", + "response.output_item.added" -> null + + "response.completed" -> { + val responseObject = event["response"]?.jsonObject + val toolCalls = responseObject?.get("output") + ?.jsonArray + ?.mapIndexedNotNull { index, item -> + val itemObject = item.jsonObject + if (itemObject["type"]?.jsonPrimitive?.contentOrNull != "function_call") { + return@mapIndexedNotNull null + } + + val rawArguments = when (val arguments = itemObject["arguments"]) { + is JsonPrimitive -> arguments.contentOrNull + is JsonObject -> arguments.toString() + else -> null + } + CustomOpenAIToolCall( + id = itemObject["call_id"]?.jsonPrimitive?.contentOrNull, + index = index, + function = CustomOpenAIFunction( + name = itemObject["name"]?.jsonPrimitive?.contentOrNull, + arguments = normalizeToolArgumentsJson(rawArguments) ?: rawArguments.orEmpty() ) ) + } + .orEmpty() + CustomOpenAIStreamChoice( + finishReason = "stop", + delta = CustomOpenAIStreamDelta( + toolCalls = toolCalls.takeIf { it.isNotEmpty() } ) ) } - "response.output_item.added" -> { - val item = event["item"]?.jsonObject - if (item?.get("type")?.jsonPrimitive?.contentOrNull == "function_call") { - val callId = item["call_id"]?.jsonPrimitive?.contentOrNull ?: "" - val name = item["name"]?.jsonPrimitive?.contentOrNull ?: "" - val index = event["output_index"]?.jsonPrimitive?.intOrNull ?: 0 - CustomOpenAIStreamChoice( - delta = CustomOpenAIStreamDelta( - toolCalls = listOf( - CustomOpenAIToolCall( - id = callId, - index = index, - function = CustomOpenAIFunction(name = name) - ) - ) - ) - ) - } else null - } - - "response.completed" -> CustomOpenAIStreamChoice( - finishReason = "stop", - delta = CustomOpenAIStreamDelta() - ) - else -> null } @@ -535,12 +534,20 @@ class CustomOpenAILLMClient( } "function_call" -> { + val rawArguments = when (val arguments = itemObj["arguments"]) { + is JsonPrimitive -> arguments.contentOrNull + is JsonObject -> arguments.toString() + else -> null + } toolCallsJson.add(buildJsonObject { put("id", itemObj["call_id"] ?: JsonPrimitive("")) put("type", JsonPrimitive("function")) putJsonObject("function") { put("name", itemObj["name"] ?: JsonPrimitive("")) - put("arguments", itemObj["arguments"] ?: JsonPrimitive("{}")) + put( + "arguments", + JsonPrimitive(normalizeToolArgumentsJson(rawArguments) ?: rawArguments ?: "{}") + ) } }) } @@ -640,12 +647,12 @@ class CustomOpenAILLMClient( } assistantToolCalls.forEach { toolCall -> + val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}" add( Message.Tool.Call( id = toolCall.id, tool = toolCall.function.name, - content = toolCall.function.arguments.takeIf { it.isNotEmpty() } - ?: "{}", + content = arguments, metaInfo = metaInfo ) ) @@ -715,22 +722,38 @@ internal fun transformCustomOpenAIBodyValue( messages: List, credential: String, json: Json +): JsonElement { + return transformCustomOpenAIBodyValue( + key = key, + value = value, + streamRequest = streamRequest, + messagesJson = json.encodeToJsonElement( + ListSerializer(OpenAIMessage.serializer()), + messages + ), + prompt = renderCustomOpenAIPrompt(messages, json), + credential = credential, + json = json + ) +} + +private fun transformCustomOpenAIBodyValue( + key: String?, + value: Any?, + streamRequest: Boolean, + messagesJson: JsonElement, + prompt: String, + credential: String, + json: Json ): JsonElement { return when (value) { null -> JsonNull is JsonElement -> value is String -> when { !streamRequest && key == "stream" -> JsonPrimitive(false) - CustomServicePlaceholders.isMessages(value) -> json.parseToJsonElement( - json.encodeToString(ListSerializer(OpenAIMessage.serializer()), messages) - ) + CustomServicePlaceholders.isMessages(value) -> messagesJson - CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive( - renderCustomOpenAIPrompt( - messages, - json - ) - ) + CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(prompt) value.contains($$"$CUSTOM_SERVICE_API_KEY") -> { JsonPrimitive(value.replace($$"$CUSTOM_SERVICE_API_KEY", credential)) @@ -749,7 +772,8 @@ internal fun transformCustomOpenAIBodyValue( key = nestedKey.toString(), value = nestedValue, streamRequest = streamRequest, - messages = messages, + messagesJson = messagesJson, + prompt = prompt, credential = credential, json = json ) @@ -762,7 +786,8 @@ internal fun transformCustomOpenAIBodyValue( key = null, value = item, streamRequest = streamRequest, - messages = messages, + messagesJson = messagesJson, + prompt = prompt, credential = credential, json = json ) @@ -775,7 +800,8 @@ internal fun transformCustomOpenAIBodyValue( key = null, value = item, streamRequest = streamRequest, - messages = messages, + messagesJson = messagesJson, + prompt = prompt, credential = credential, json = json ) @@ -793,6 +819,144 @@ internal fun renderCustomOpenAIPrompt(messages: List, json: Json) } } +private fun List.toResponsesApiItemsJson(): JsonArray { + return JsonArray( + buildList { + for (message in this@toResponsesApiItemsJson) { + addAll(message.toResponsesApiItemsJson()) + } + } + ) +} + +private fun OpenAIMessage.toResponsesApiItemsJson(): List { + return when (this) { + is OpenAIMessage.System, is OpenAIMessage.Developer -> listOf( + buildResponsesApiInputMessageItemJson( + role = "developer", + content = content.toResponsesApiInputContentJson() + ) + ) + + is OpenAIMessage.User -> listOf( + buildResponsesApiInputMessageItemJson( + role = "user", + content = content.toResponsesApiInputContentJson() + ) + ) + + is OpenAIMessage.Assistant -> buildList { + reasoningContent + ?.takeIf { it.isNotBlank() } + ?.let { reasoning -> + add( + buildJsonObject { + put("type", JsonPrimitive("reasoning")) + put("id", JsonPrimitive(UUID.randomUUID().toString())) + putJsonArray("summary") { + addJsonObject { + put("type", JsonPrimitive("summary_text")) + put("text", JsonPrimitive(reasoning)) + } + } + } + ) + } + + content?.text() + ?.takeIf { it.isNotBlank() } + ?.let { assistantText -> + add( + buildJsonObject { + put("type", JsonPrimitive("message")) + put("role", JsonPrimitive("assistant")) + putJsonArray("content") { + addJsonObject { + put("type", JsonPrimitive("output_text")) + put("text", JsonPrimitive(assistantText)) + put("annotations", JsonArray(emptyList())) + } + } + } + ) + } + + toolCalls.orEmpty().forEach { toolCall -> + val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}" + add( + buildJsonObject { + put("type", JsonPrimitive("function_call")) + put("arguments", JsonPrimitive(arguments)) + put("call_id", JsonPrimitive(toolCall.id)) + put("name", JsonPrimitive(toolCall.function.name)) + } + ) + } + } + + is OpenAIMessage.Tool -> listOf( + buildJsonObject { + put("type", JsonPrimitive("function_call_output")) + put("call_id", JsonPrimitive(toolCallId)) + put("output", JsonPrimitive(content?.text().orEmpty())) + } + ) + } +} + +private fun buildResponsesApiInputMessageItemJson( + role: String, + content: JsonArray +): JsonObject { + return buildJsonObject { + put("type", JsonPrimitive("message")) + put("role", JsonPrimitive(role)) + put("content", content) + } +} + +private fun Content?.toResponsesApiInputContentJson(): JsonArray { + if (this == null) { + return JsonArray(emptyList()) + } + + return when (this) { + is Content.Parts -> JsonArray(value.mapNotNull { it.toResponsesApiInputContentJson() }) + else -> JsonArray( + listOf( + buildJsonObject { + put("type", JsonPrimitive("input_text")) + put("text", JsonPrimitive(text())) + } + ) + ) + } +} + +private fun OpenAIContentPart.toResponsesApiInputContentJson(): JsonObject? { + return when (this) { + is OpenAIContentPart.Text -> buildJsonObject { + put("type", JsonPrimitive("input_text")) + put("text", JsonPrimitive(text)) + } + + is OpenAIContentPart.Image -> buildJsonObject { + put("type", JsonPrimitive("input_image")) + imageUrl.detail?.let { put("detail", JsonPrimitive(it)) } + put("imageUrl", JsonPrimitive(imageUrl.url)) + } + + is OpenAIContentPart.File -> buildJsonObject { + put("type", JsonPrimitive("input_file")) + file.fileData?.let { put("fileData", JsonPrimitive(it)) } + file.fileId?.let { put("fileId", JsonPrimitive(it)) } + file.filename?.let { put("filename", JsonPrimitive(it)) } + } + + else -> null + } +} + internal fun OpenAITool.toResponsesApiToolJson(): JsonObject { return buildJsonObject { put("type", JsonPrimitive("function")) @@ -819,7 +983,7 @@ internal fun OpenAIToolChoice.toResponsesApiToolChoiceJson(): JsonElement { put("name", JsonPrimitive(function.name)) } - else -> JsonPrimitive(toString()) + is OpenAIToolChoice.Mode -> JsonPrimitive(value) } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProvider.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProvider.kt index d52bfe83..0d50488a 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProvider.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProvider.kt @@ -23,15 +23,13 @@ import com.intellij.openapi.vfs.LocalFileSystem import ee.carlrobert.codegpt.ReferencedFile import ee.carlrobert.codegpt.agent.AgentEvents import ee.carlrobert.codegpt.agent.MessageWithContext +import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot import ee.carlrobert.codegpt.completions.CompletionRequestUtil import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent import ee.carlrobert.codegpt.ui.textarea.TagProcessorFactory import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails -import kotlinx.serialization.SerializationException -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.jsonObject import java.util.* import ee.carlrobert.codegpt.conversations.message.Message as ChatMessage @@ -186,15 +184,11 @@ private suspend fun AIAgentLLMWriteSession.requestResponses( .toMessageResponses() .map { if (it is Message.Tool.Call) { - try { - // validate json - Json.parseToJsonElement(it.content).jsonObject + val normalizedArgs = normalizeToolArgumentsJson(it.content) ?: "{}" + if (normalizedArgs == it.content) { it - } catch (_: SerializationException) { - // allows agent to retry the request - it.copy(parts = listOf(it.parts[0].copy(text = "{}"))) - } catch (e: Exception) { - throw e + } else { + it.copy(parts = listOf(it.parts[0].copy(text = normalizedArgs))) } } else { it diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt index c7b0a8e9..a1b3cc37 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt @@ -229,6 +229,90 @@ class AgentProviderIntegrationTest : IntegrationTest() { assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI Responses") } + fun testCustomOpenAIResponsesAgentSerializesToolHistoryAsResponsesInput() { + val fixture = createReadFixture("Custom Responses fixture") + configureCustomOpenAIService( + path = "/v1/responses", + model = "custom-responses-model", + stream = false, + useResponsesApiBody = true + ) + expectCustomOpenAI(BasicHttpExchange { request -> + assertThat(request.uri.path).isEqualTo("/v1/responses") + assertThat(request.method).isEqualTo("POST") + assertThat(extractPromptText(request)).contains("Read the fixture and repeat its contents") + ResponseEntity( + openAiResponsesToolCallResponse( + model = "custom-responses-model", + toolName = "Read", + callId = "call_custom_responses_read", + arguments = """{"file_path":"${fixture.path}"}""" + ) + ) + }) + expectCustomOpenAI(BasicHttpExchange { request -> + assertThat(request.uri.path).isEqualTo("/v1/responses") + assertThat(request.method).isEqualTo("POST") + assertThatResponsesToolHistory( + request = request, + toolName = "Read", + callId = "call_custom_responses_read", + fixture = fixture + ) + ResponseEntity( + openAiResponsesResponse( + model = "custom-responses-model", + text = "Custom Responses read: ${fixture.contents}" + ) + ) + }) + + val result = runAgent(ServiceType.CUSTOM_OPENAI, "Read the fixture and repeat its contents") + + assertThat(result.output).isEqualTo("Custom Responses read: ${fixture.contents}") + assertThat(result.events.text.toString()).isEqualTo("Custom Responses read: ${fixture.contents}") + } + + fun testCustomOpenAIResponsesStreamingAgentCompletesToolLoop() { + val fixture = createReadFixture("Custom Responses streaming fixture") + configureCustomOpenAIService( + path = "/v1/responses", + model = "custom-responses-model", + stream = true, + useResponsesApiBody = true + ) + expectCustomOpenAI(StreamHttpExchange { request -> + assertThat(request.uri.path).isEqualTo("/v1/responses") + assertThat(request.method).isEqualTo("POST") + assertThat(extractPromptText(request)).contains("Read the fixture and repeat its contents") + openAiResponsesToolCallChunks( + model = "custom-responses-model", + toolName = "Read", + callId = "call_custom_responses_stream_read", + arguments = """{"file_path":"${fixture.path}"}""" + ) + }) + expectCustomOpenAI(StreamHttpExchange { request -> + assertThat(request.uri.path).isEqualTo("/v1/responses") + assertThat(request.method).isEqualTo("POST") + assertThatResponsesToolHistory( + request = request, + toolName = "Read", + callId = "call_custom_responses_stream_read", + fixture = fixture + ) + openAiResponsesChunks( + model = "custom-responses-model", + text = "Custom Responses streaming read: ${fixture.contents}" + ) + }) + + val result = runAgent(ServiceType.CUSTOM_OPENAI, "Read the fixture and repeat its contents") + + assertThat(result.output).isEqualTo("Custom Responses streaming read: ${fixture.contents}") + assertThat(result.events.text.toString()).isEqualTo("Custom Responses streaming read: ${fixture.contents}") + } + private fun runAgent( provider: ServiceType, userMessage: String @@ -369,6 +453,154 @@ class AgentProviderIntegrationTest : IntegrationTest() { ) } + private fun openAiResponsesToolCallChunks( + model: String, + toolName: String, + callId: String, + arguments: String + ): List { + return listOf( + jsonMapResponse( + e("type", "response.output_item.added"), + e( + "item", + jsonMap( + e("type", "function_call"), + e("id", "fc_1"), + e("call_id", callId), + e("name", toolName), + e("arguments", "") + ) + ), + e("output_index", 0), + e("sequence_number", 1) + ), + jsonMapResponse( + e("type", "response.function_call_arguments.delta"), + e("item_id", "fc_1"), + e("output_index", 0), + e("delta", arguments), + e("call_id", callId), + e("sequence_number", 2) + ), + jsonMapResponse( + e("type", "response.completed"), + e( + "response", + jsonMap( + e("id", "resp-tool-call"), + e("object", "response"), + e("created_at", 1), + e("model", model), + e( + "output", + jsonArray( + jsonMap( + e("type", "function_call"), + e("id", "fc_1"), + e("call_id", callId), + e("name", toolName), + e("arguments", arguments), + e("status", "completed") + ) + ) + ), + e("parallel_tool_calls", true), + e("status", "completed"), + e("text", jsonMap()) + ) + ), + e("sequence_number", 3) + ) + ) + } + + private fun openAiResponsesToolCallResponse( + model: String, + toolName: String, + callId: String, + arguments: String + ): String { + return jsonMapResponse( + e("id", "resp-tool-call"), + e("object", "response"), + e("created_at", 1), + e("model", model), + e( + "output", + jsonArray( + jsonMap( + e("type", "function_call"), + e("id", "fc_1"), + e("call_id", callId), + e("name", toolName), + e("arguments", arguments), + e("status", "completed") + ) + ) + ), + e("parallel_tool_calls", true), + e("status", "completed"), + e("text", jsonMap()), + e( + "usage", + jsonMap( + e("input_tokens", 1), + e("input_tokens_details", jsonMap("cached_tokens", 0)), + e("output_tokens", 1), + e("output_tokens_details", jsonMap("reasoning_tokens", 0)), + e("total_tokens", 2) + ) + ) + ) + } + + private fun openAiResponsesResponse( + model: String, + text: String + ): String { + return jsonMapResponse( + e("id", "resp-openai-test"), + e("object", "response"), + e("created_at", 1), + e("model", model), + e( + "output", + jsonArray( + jsonMap( + e("type", "message"), + e("id", "msg_1"), + e("role", "assistant"), + e("status", "completed"), + e( + "content", + jsonArray( + jsonMap( + e("type", "output_text"), + e("text", text), + e("annotations", jsonArray()) + ) + ) + ) + ) + ) + ), + e("parallel_tool_calls", true), + e("status", "completed"), + e("text", jsonMap()), + e( + "usage", + jsonMap( + e("input_tokens", 1), + e("input_tokens_details", jsonMap("cached_tokens", 0)), + e("output_tokens", 1), + e("output_tokens_details", jsonMap("reasoning_tokens", 0)), + e("total_tokens", 2) + ) + ) + ) + } + private fun anthropicTextChunks(text: String): List { return listOf( jsonMapResponse( @@ -681,6 +913,28 @@ class AgentProviderIntegrationTest : IntegrationTest() { } } + private fun assertThatResponsesToolHistory( + request: RequestEntity, + toolName: String, + callId: String, + fixture: ReadFixture + ) { + val input = request.body["input"] as? List<*> ?: error("Expected responses input list") + val items = input.mapNotNull { it as? Map<*, *> } + + assertThat(items.any { item -> + item["type"] == "function_call" && + item["name"] == toolName && + item["call_id"] == callId && + item["arguments"] == """{"file_path":"${fixture.path}"}""" + }).isTrue() + assertThat(items.any { item -> + item["type"] == "function_call_output" && + item["call_id"] == callId && + (item["output"] as? String).orEmpty().contains(fixture.contents) + }).isTrue() + } + private fun extractGooglePromptText(request: RequestEntity): String { val contents = request.body["contents"] as? List<*> ?: return "" return contents.joinToString("\n") { content -> diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt index 93223fdd..164886fe 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesApiSerializationTest.kt @@ -1,7 +1,9 @@ package ee.carlrobert.codegpt.agent.clients import ai.koog.prompt.executor.clients.openai.base.models.Content +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction @@ -88,4 +90,130 @@ class CustomOpenAIResponsesApiSerializationTest { assertThat(toolChoice.getValue("type").jsonPrimitive.content).isEqualTo("function") assertThat(toolChoice.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics") } + + @Test + fun `responses api request should encode mode tool choice as lowercase string`() { + val state = CustomServiceChatCompletionSettingsState().apply { + url = "https://example.com/v1/responses" + body.clear() + body["stream"] = true + body["input"] = "\$OPENAI_MESSAGES" + } + val client = CustomOpenAILLMClient.fromSettingsState("test-key", state) + val serializeMethod = client.javaClass.getDeclaredMethod( + "serializeProviderChatRequest", + List::class.java, + LLModel::class.java, + List::class.java, + OpenAIToolChoice::class.java, + LLMParams::class.java, + Boolean::class.javaPrimitiveType + ) + serializeMethod.isAccessible = true + + val payload = serializeMethod.invoke( + client, + listOf(OpenAIMessage.User(content = Content.Text("hello"))), + LLModel( + id = "gpt-test", + provider = CustomOpenAILLMClient.CustomOpenAI, + capabilities = emptyList(), + contextLength = 128_000, + maxOutputTokens = 4_096 + ), + emptyList(), + openAIToolChoiceMode("required"), + CustomOpenAIParams(), + true + ) as String + + val request = json.parseToJsonElement(payload).jsonObject + + assertThat(request.getValue("tool_choice").jsonPrimitive.content).isEqualTo("required") + } + + @Test + fun `responses api request should encode agent tool history as input items`() { + val state = CustomServiceChatCompletionSettingsState().apply { + url = "https://example.com/v1/responses" + body.clear() + body["stream"] = true + body["input"] = "\$OPENAI_MESSAGES" + } + val client = CustomOpenAILLMClient.fromSettingsState("test-key", state) + val serializeMethod = client.javaClass.getDeclaredMethod( + "serializeProviderChatRequest", + List::class.java, + LLModel::class.java, + List::class.java, + OpenAIToolChoice::class.java, + LLMParams::class.java, + Boolean::class.javaPrimitiveType + ) + serializeMethod.isAccessible = true + + val payload = serializeMethod.invoke( + client, + listOf( + OpenAIMessage.System(content = Content.Text("System instructions")), + OpenAIMessage.User(content = Content.Text("Read the fixture")), + OpenAIMessage.Assistant( + content = Content.Text("Calling Read"), + toolCalls = listOf( + OpenAIToolCall( + "call_read", + OpenAIFunction("Read", "\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"") + ) + ) + ), + OpenAIMessage.Tool( + content = Content.Text("1\tfixture contents"), + toolCallId = "call_read" + ) + ), + LLModel( + id = "gpt-test", + provider = CustomOpenAILLMClient.CustomOpenAI, + capabilities = emptyList(), + contextLength = 128_000, + maxOutputTokens = 4_096 + ), + emptyList(), + null, + CustomOpenAIParams(), + true + ) as String + + val request = json.parseToJsonElement(payload).jsonObject + val input = request.getValue("input").jsonArray + + assertThat(input.map { it.jsonObject.getValue("type").jsonPrimitive.content }).containsExactly( + "message", + "message", + "message", + "function_call", + "function_call_output" + ) + assertThat(input[0].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("developer") + assertThat(input[1].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("user") + assertThat(input[2].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("assistant") + assertThat( + input[2].jsonObject.getValue("content").jsonArray.single().jsonObject.getValue("text").jsonPrimitive.content + ).isEqualTo("Calling Read") + assertThat(input[3].jsonObject.getValue("name").jsonPrimitive.content).isEqualTo("Read") + assertThat(input[3].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read") + assertThat(input[3].jsonObject.getValue("arguments").jsonPrimitive.content) + .isEqualTo("""{"file_path":"/tmp/fixture.txt"}""") + assertThat(input[4].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read") + assertThat(input[4].jsonObject.getValue("output").jsonPrimitive.content) + .isEqualTo("1\tfixture contents") + } + + private fun openAIToolChoiceMode(value: String): OpenAIToolChoice { + val modeClass = Class.forName( + "ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice\$Mode" + ) + val boxMethod = modeClass.getDeclaredMethod("box-impl", String::class.java) + return boxMethod.invoke(null, value) as OpenAIToolChoice + } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt index 9a256075..2ce230c5 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/StreamingPayloadNormalizerTest.kt @@ -1,6 +1,13 @@ package ee.carlrobert.codegpt.agent.clients +import ai.koog.prompt.message.Message +import ai.koog.prompt.streaming.StreamFrame +import ai.koog.prompt.streaming.toMessageResponses import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat import org.junit.Test @@ -50,4 +57,38 @@ class StreamingPayloadNormalizerTest { assertThat(response.model).isEqualTo("test-model") assertThat(response.choices).isEmpty() } + + @Test + fun `responses api streaming tool call should collapse into one named tool call`() { + val state = CustomServiceChatCompletionSettingsState().apply { + url = "https://example.com/v1/responses" + body.clear() + body["stream"] = true + body["input"] = "\$OPENAI_MESSAGES" + } + val client = CustomOpenAILLMClient.fromSettingsState("test-key", state) + val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java) + decodeMethod.isAccessible = true + val processMethod = client.javaClass.getDeclaredMethod("processStreamingResponse", Flow::class.java) + processMethod.isAccessible = true + + val chunks = listOf( + """{"type":"response.output_item.added","item":{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":""},"output_index":0,"sequence_number":1}""", + """{"type":"response.function_call_arguments.delta","item_id":"fc_1","output_index":0,"delta":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","call_id":"call_diag","sequence_number":2}""", + """{"type":"response.completed","response":{"id":"resp-tool","object":"response","created_at":1,"model":"test-model","output":[{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","status":"completed"}],"parallel_tool_calls":true,"status":"completed","text":{}},"sequence_number":3}""" + ).map { payload -> + decodeMethod.invoke(client, payload) as CustomOpenAIChatCompletionStreamResponse + } + + val responses = runBlocking { + @Suppress("UNCHECKED_CAST") + val frames = processMethod.invoke(client, chunks.asFlow()) as Flow + frames.toList().toMessageResponses() + } + val toolCall = responses.filterIsInstance().single() + + assertThat(toolCall.id).isEqualTo("call_diag") + assertThat(toolCall.tool).isEqualTo("Diagnostics") + assertThat(toolCall.content).isEqualTo("""{"file_path":"/tmp/mainwindow.cpp","filter":"all"}""") + } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProviderTest.kt index b6037155..b9e11352 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/strategy/SingleRunStrategyProviderTest.kt @@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.agent.strategy import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.message.Message import ai.koog.prompt.message.ResponseMetaInfo +import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson import org.assertj.core.api.Assertions.assertThat import kotlin.test.Test @@ -94,4 +95,11 @@ class SingleRunStrategyProviderTest { assertThat(toolCall.tool).isEqualTo("Read") assertThat(toolCall.content).isEqualTo("""{"path":"build.gradle.kts"}""") } + + @Test + fun `normalize tool arguments unwraps string encoded json object`() { + val actual = normalizeToolArgumentsJson("\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"") + + assertThat(actual).isEqualTo("""{"file_path":"/tmp/fixture.txt"}""") + } }