diff --git a/build.gradle.kts b/build.gradle.kts index 4dd4f88d..7eee328e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -80,6 +80,10 @@ dependencies { implementation(libs.jackson.datatype.jsr310) implementation(libs.jackson.module.kotlin) implementation(libs.kotlin.stdlib) + implementation(libs.acp.sdk) { + exclude(group = "org.jetbrains.kotlinx", module = "kotlinx-coroutines-core") + exclude(group = "org.jetbrains.kotlinx", module = "kotlinx-coroutines-core-jvm") + } implementation(libs.flexmark.all) { // vulnerable transitive dependency diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 22e7dd46..37a9538b 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,5 +1,6 @@ [versions] analytics = "3.1.2" +acp = "0.17.0" assertj = "3.27.3" changelog = "2.2.1" checkstyle = "10.15.0" @@ -23,6 +24,7 @@ protobuf-plugin = "0.9.4" [libraries] analytics = { module = "com.rudderstack.sdk.java.analytics:analytics", version.ref = "analytics" } +acp-sdk = { module = "com.agentclientprotocol:acp", version.ref = "acp" } assertj-core = { module = "org.assertj:assertj-core", version.ref = "assertj" } commons-text = { module = "org.apache.commons:commons-text", version.ref = "commons-text" } flexmark-all = { module = "com.vladsch.flexmark:flexmark-all", version.ref = "flexmark" } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java index 030cda63..a558d6c0 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java @@ -85,6 +85,7 @@ public class ChatMessageResponseBody extends JPanel { private final ResponseBodyProgressPanel progressPanel = new ResponseBodyProgressPanel(); private final JPanel contentPanel = new JPanel(new VerticalFlowLayout(VerticalFlowLayout.TOP, 0, 4, true, false)); + private final StringBuilder accumulatedThinking = new StringBuilder(); private ResponseEditorPanel currentlyProcessedEditorPanel; private MermaidResponsePanel currentlyProcessedMermaidPanel; @@ -170,6 +171,20 @@ public class ChatMessageResponseBody extends JPanel { } } + public void appendThinking(String partialThinking) { + if (partialThinking.isEmpty()) { + return; + } + + if (accumulatedThinking.length() > 0) { + accumulatedThinking.append('\n'); + } + accumulatedThinking.append(partialThinking); + processThinkingOutput(accumulatedThinking.toString()); + revalidate(); + repaint(); + } + public void displayMissingCredential() { String message = "API key not provided. Open Settings to set one."; displayErrorMessage(message, e -> { @@ -254,6 +269,7 @@ public class ChatMessageResponseBody extends JPanel { public void clear() { contentPanel.removeAll(); streamOutputParser.clear(); + accumulatedThinking.setLength(0); // Reset for the next incoming message prepareProcessingText(true); diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt index edebca6d..2019bac7 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt @@ -1,5 +1,6 @@ package ee.carlrobert.codegpt.agent +import com.agentclientprotocol.model.PlanEntry import ee.carlrobert.codegpt.agent.history.CheckpointRef import ee.carlrobert.codegpt.agent.tools.AskUserQuestionTool import ee.carlrobert.codegpt.settings.service.ServiceType @@ -9,6 +10,8 @@ import java.util.UUID interface AgentEvents { fun onTextReceived(text: String) {} + fun onThinkingReceived(text: String) {} + fun onPlanUpdated(entries: List) {} fun onAgentCompleted(agentId: String) {} fun onToolStarting(id: String, toolName: String, args: Any?) {} fun onToolCompleted(id: String?, toolName: String, result: Any?) {} @@ -23,8 +26,13 @@ interface AgentEvents { fun onRetry(attempt: Int, maxAttempts: Int) {} fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {} - fun onQueuedMessagesResolved() + fun onQueuedMessagesResolved(message: MessageWithContext? = null) fun onTokenUsageAvailable(tokenUsage: Long) {} + fun onUsageAvailable(event: AgentUsageEvent) { + onTokenUsageAvailable(event.usedTokens) + } + fun onRuntimeOptionsUpdated() {} + fun onSessionInfoUpdated(title: String?, updatedAt: String? = null) {} fun onCreditsAvailable(event: AgentCreditsEvent) {} fun onAgentException(provider: ServiceType, throwable: Throwable) {} fun onHistoryCompressionStateChanged(isCompressing: Boolean) {} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt index c1583682..4b4f4351 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt @@ -85,7 +85,11 @@ class AgentService(private val project: Project) { val queuedMessageProcessed = _queuedMessageProcessed.asSharedFlow() fun addToQueue(message: MessageWithContext, sessionId: String) { - pendingMessages.getOrPut(sessionId) { ArrayDeque() }.add(message) + val queue = pendingMessages.getOrPut(sessionId) { ArrayDeque() } + queue.add(message) + logger.debug( + "Queued agent message for session=$sessionId queueSize=${queue.size} uiVisible=${message.uiVisible} messageId=${message.id}" + ) } suspend fun getCheckpoint(sessionId: String): AgentCheckpointData? { @@ -114,7 +118,7 @@ class AgentService(private val project: Project) { } fun submitMessage(message: MessageWithContext, events: AgentEvents, sessionId: String) { - updateMcpContext(sessionId, message) + val selectedServerIds = updateMcpContext(sessionId, message) if (isSessionRunning(sessionId)) { addToQueue(message, sessionId) @@ -125,6 +129,11 @@ class AgentService(private val project: Project) { val contentManager = project.service() val session = contentManager.getSession(sessionId) ?: return if (!session.externalAgentId.isNullOrBlank()) { + if (session.shouldRecreateExternalAgentSession(selectedServerIds)) { + project.service().closeSession(sessionId) + session.externalAgentSessionId = null + session.externalAgentMcpServerIds = emptySet() + } submitExternalMessage(session, message, events, provider) return } @@ -201,7 +210,11 @@ class AgentService(private val project: Project) { } project.service().closeSession(sessionId) project.service() - .getSession(sessionId)?.externalAgentSessionId = null + .getSession(sessionId) + ?.also { + it.externalAgentSessionId = null + it.externalAgentMcpServerIds = emptySet() + } project.service().clear(sessionId) } @@ -227,22 +240,31 @@ class AgentService(private val project: Project) { events: AgentEvents, provider: ServiceType ) { - val externalAgentService = project.service() + logger.debug( + "Starting external ACP run for session=${session.sessionId} externalAgent=${session.externalAgentId} messageId=${message.id}" + ) sessionJobs[session.sessionId] = CoroutineScope(Dispatchers.IO).launch { try { - externalAgentService.runPromptLoop( - session = session, - firstMessage = message, - events = events, - pollNextQueued = { - val queue = pendingMessages[session.sessionId] ?: return@runPromptLoop null - if (queue.isEmpty()) { - null - } else { - queue.removeFirst() + project.service() + .runPromptLoop( + session = session, + firstMessage = message, + events = events, + pollNextQueued = { + val queue = + pendingMessages[session.sessionId] ?: return@runPromptLoop null + if (queue.isEmpty()) { + logger.debug("No queued ACP follow-up message for session=${session.sessionId}") + null + } else { + queue.removeFirst().also { next -> + logger.debug( + "Dequeued ACP follow-up message for session=${session.sessionId} queueRemaining=${queue.size} messageId=${next.id} uiVisible=${next.uiVisible}" + ) + } + } } - } - ) + ) } catch (_: CancellationException) { return@launch } catch (ex: Throwable) { @@ -255,7 +277,7 @@ class AgentService(private val project: Project) { } } - private fun updateMcpContext(sessionId: String, message: MessageWithContext) { + private fun updateMcpContext(sessionId: String, message: MessageWithContext): Set { val selectedServerIds = message.tags .filterIsInstance() .filter { it.selected } @@ -273,6 +295,7 @@ class AgentService(private val project: Project) { project.service() .update(sessionId, conversationId, selectedServerIds) + return selectedServerIds } suspend fun createSeedCheckpointFromHistory(history: List): CheckpointRef? = diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentUsageEvent.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentUsageEvent.kt new file mode 100644 index 00000000..f8c8b4c9 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentUsageEvent.kt @@ -0,0 +1,8 @@ +package ee.carlrobert.codegpt.agent + +data class AgentUsageEvent( + val usedTokens: Long, + val sizeTokens: Long? = null, + val costAmount: Double? = null, + val costCurrency: String? = null +) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt index 1e689577..4cfc36bd 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt @@ -90,7 +90,7 @@ object ProxyAIAgent { val modelSelection = service().getModelSelectionForFeature(FeatureType.AGENT) val skills = project.service().listSkills() - val stream = shouldStreamAgentToolLoop(project, provider) + val stream = shouldStreamAgentToolLoop(provider) val projectInstructions = loadProjectInstructions(project.basePath) val executor = AgentFactory.createExecutor(provider, events) val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() } @@ -99,7 +99,6 @@ object ProxyAIAgent { project = project, events = events, sessionId = sessionId, - provider = provider, parentModelSelection = modelSelection, hookManager = hookManager ) @@ -198,7 +197,7 @@ object ProxyAIAgent { output.forEach { msg -> (msg as? Message.Reasoning)?.let { if (it.content.isNotBlank()) { - events.onTextReceived("${it.content}") + events.onThinkingReceived(it.content) } } } @@ -213,7 +212,7 @@ object ProxyAIAgent { } (msg as? Message.Reasoning)?.let { if (it.content.isNotBlank()) { - events.onTextReceived("${it.content}") + events.onThinkingReceived(it.content) } } } @@ -297,10 +296,7 @@ object ProxyAIAgent { } } - private fun shouldStreamAgentToolLoop( - project: Project, - provider: ServiceType, - ): Boolean { + private fun shouldStreamAgentToolLoop(provider: ServiceType): Boolean { return when (provider) { ServiceType.CUSTOM_OPENAI -> { val selectedModel = @@ -320,7 +316,6 @@ object ProxyAIAgent { project: Project, events: AgentEvents, sessionId: String, - provider: ServiceType, parentModelSelection: ModelSelection, hookManager: HookManager ): ToolRegistry { @@ -334,84 +329,23 @@ object ProxyAIAgent { tool(ConfirmingEditTool(EditTool(project, sessionId, hookManager), standardApproval)) tool(ConfirmingWriteTool(WriteTool(project, sessionId, hookManager), writeApproval)) tool(TodoWriteTool(project, sessionId, hookManager)) - tool( - AskUserQuestionTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - events = events - ) - ) - createMcpTools( - sessionId = sessionId, - contextService = contextService, - approve = genericApproval - ).forEach { mcpTool -> tool(mcpTool) } + tool(AskUserQuestionTool(workingDirectory, sessionId, hookManager, events)) + createMcpTools(sessionId, contextService, genericApproval).forEach { mcpTool -> + tool(mcpTool) + } tool(ExitTool) - tool( - IntelliJSearchTool( - project = project, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - DiagnosticsTool( - project = project, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - WebSearchTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - WebFetchTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - BashOutputTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - KillShellTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - ResolveLibraryIdTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - ) - ) - tool( - GetLibraryDocsTool( - workingDirectory = workingDirectory, - sessionId = sessionId, - hookManager = hookManager, - ) - ) + tool(IntelliJSearchTool(project, sessionId, hookManager)) + tool(DiagnosticsTool(project, sessionId, hookManager)) + tool(WebSearchTool(workingDirectory, sessionId, hookManager)) + tool(WebFetchTool(workingDirectory, sessionId, hookManager)) + tool(BashOutputTool(workingDirectory, sessionId, hookManager)) + tool(KillShellTool(workingDirectory, sessionId, hookManager)) + tool(ResolveLibraryIdTool(workingDirectory, sessionId, hookManager)) + tool(GetLibraryDocsTool(workingDirectory, sessionId, hookManager)) tool( ConfirmingLoadSkillTool( - LoadSkillTool( - project = project, - sessionId = sessionId, - hookManager = hookManager - ), - project = project + LoadSkillTool(project, sessionId, hookManager), + project ) { name, details -> genericApproval(name, details) } ) tool( @@ -438,15 +372,7 @@ object ProxyAIAgent { hookManager = hookManager ) ) - tool( - TaskTool( - project, - sessionId, - parentModelSelection, - events, - hookManager - ) - ) + tool(TaskTool(project, sessionId, parentModelSelection, events, hookManager)) } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt index aa1f1929..c787e959 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAILLMClient.kt @@ -8,6 +8,7 @@ import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings import ai.koog.prompt.executor.clients.openai.OpenAILLMClient import ai.koog.prompt.executor.clients.openai.OpenAIResponsesParams import ai.koog.prompt.executor.clients.openai.base.models.* +import ai.koog.prompt.executor.clients.openai.models.OpenAIChatCompletionResponse import ai.koog.prompt.llm.LLMCapability import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel @@ -26,10 +27,10 @@ import ee.carlrobert.codegpt.settings.service.custom.CustomServicePlaceholders import ee.carlrobert.codegpt.util.JsonMapper import io.ktor.client.* import io.ktor.client.plugins.* -import io.ktor.client.plugins.api.createClientPlugin +import io.ktor.client.plugins.api.* import io.ktor.client.request.* -import io.ktor.http.ContentType -import io.ktor.http.content.TextContent +import io.ktor.http.* +import io.ktor.http.content.* import kotlinx.coroutines.flow.Flow import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer @@ -200,6 +201,10 @@ class CustomOpenAILLMClient( } } + override fun decodeResponse(data: String): OpenAIChatCompletionResponse { + return json.decodeFromString(data) + } + override suspend fun getCodeCompletion(infillRequest: InfillRequest): String { val state = requireCodeCompletionState() val url = requireNotNull(state.url) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentCatalog.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentCatalog.kt index f582e1e5..80245d51 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentCatalog.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentCatalog.kt @@ -6,6 +6,7 @@ data class ExternalAcpAgentPreset( val vendor: String, val command: String, val args: List, + val toolEventFlavor: AcpToolEventFlavor = AcpToolEventFlavor.STANDARD, val env: Map = emptyMap(), val enabledByDefault: Boolean = false, val description: String? = null, @@ -28,6 +29,7 @@ object ExternalAcpAgents { vendor = "OpenAI", command = "npx", args = listOf("-y", "@zed-industries/codex-acp"), + toolEventFlavor = AcpToolEventFlavor.ZED_ADAPTER, enabledByDefault = true, description = "OpenAI Codex via the Zed ACP adapter." ), @@ -54,6 +56,7 @@ object ExternalAcpAgents { vendor = "Anthropic", command = "npx", args = listOf("-y", "@zed-industries/claude-code-acp"), + toolEventFlavor = AcpToolEventFlavor.ZED_ADAPTER, enabledByDefault = true, description = "Anthropic Claude Code via the Zed ACP adapter." ), @@ -63,6 +66,7 @@ object ExternalAcpAgents { vendor = "Google", command = "gemini", args = listOf("--experimental-acp"), + toolEventFlavor = AcpToolEventFlavor.GEMINI_CLI, description = "Google Gemini CLI in experimental ACP mode." ), ExternalAcpAgentPreset( @@ -119,6 +123,7 @@ object ExternalAcpAgents { vendor = "Anthropic", command = "npx", args = listOf("-y", "@zed-industries/claude-agent-acp"), + toolEventFlavor = AcpToolEventFlavor.ZED_ADAPTER, description = "Anthropic Claude Agent via the Zed ACP adapter." ), ExternalAcpAgentPreset( @@ -276,7 +281,7 @@ object ExternalAcpAgents { val message = throwable.message.orEmpty() return when { message.contains("Cannot run program", ignoreCase = true) && - message.contains("No such file or directory", ignoreCase = true) -> + message.contains("No such file or directory", ignoreCase = true) -> "Command not found: $command" message.isNotBlank() -> message diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentService.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentService.kt index 3ddd0ed8..57d5d871 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpAgentService.kt @@ -1,42 +1,51 @@ package ee.carlrobert.codegpt.agent.external +import com.agentclientprotocol.agent.AgentInfo +import com.agentclientprotocol.client.ClientInfo +import com.agentclientprotocol.model.* +import com.agentclientprotocol.transport.StdioTransport import com.intellij.execution.configurations.GeneralCommandLine import com.intellij.openapi.components.Service import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.thisLogger import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.LocalFileSystem -import com.intellij.openapi.vfs.VfsUtil import ee.carlrobert.codegpt.agent.AgentEvents import ee.carlrobert.codegpt.agent.MessageWithContext -import ee.carlrobert.codegpt.agent.ToolSpecs -import ee.carlrobert.codegpt.agent.tools.BashTool -import ee.carlrobert.codegpt.agent.tools.EditTool -import ee.carlrobert.codegpt.agent.tools.WriteTool +import ee.carlrobert.codegpt.agent.external.acpcompat.AcpProtocol +import ee.carlrobert.codegpt.agent.external.acpcompat.JsonRpcException +import ee.carlrobert.codegpt.agent.external.acpcompat.ProtocolOptions +import ee.carlrobert.codegpt.agent.external.acpcompat.invoke +import ee.carlrobert.codegpt.agent.external.acpcompat.setNotificationHandler +import ee.carlrobert.codegpt.agent.external.acpcompat.vendor.AcpCompatibilityRegistry +import ee.carlrobert.codegpt.agent.external.acpcompat.vendor.AcpPeerProfile +import ee.carlrobert.codegpt.agent.external.host.AcpFileHost +import ee.carlrobert.codegpt.agent.external.host.AcpHostCapabilities +import ee.carlrobert.codegpt.agent.external.host.AcpTerminalHost +import ee.carlrobert.codegpt.agent.external.host.DefaultAcpTerminalProcessLauncher import ee.carlrobert.codegpt.conversations.Conversation +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.mcp.McpSettings +import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOption +import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOptions +import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOptionChoice import ee.carlrobert.codegpt.toolwindow.agent.AgentSession import ee.carlrobert.codegpt.toolwindow.agent.AgentToolWindowContentManager -import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.* import ee.carlrobert.codegpt.ui.textarea.header.tag.* +import ee.carlrobert.codegpt.util.CommandRuntimeHelper import kotlinx.coroutines.* import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.serialization.json.* +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement import java.io.File -import java.nio.file.Files import java.nio.file.Path import java.nio.file.Paths +import java.util.* import java.util.concurrent.ConcurrentHashMap -import java.util.UUID -import kotlin.io.path.createDirectories -import kotlin.io.path.notExists - -private data class ExternalToolCall( - val toolName: String, - val args: Any? -) @Service(Service.Level.PROJECT) class ExternalAcpAgentService(private val project: Project) { @@ -46,20 +55,17 @@ class ExternalAcpAgentService(private val project: Project) { const val FULL_ACCESS_MODE_ID = "full-access" val NO_OP_EVENTS = object : AgentEvents { - override fun onQueuedMessagesResolved() = Unit - override fun onAgentException( - provider: ee.carlrobert.codegpt.settings.service.ServiceType, - throwable: Throwable - ) = Unit + override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit + override fun onAgentException(provider: ServiceType, throwable: Throwable) = Unit } } private val logger = thisLogger() private val json = Json { ignoreUnknownKeys = true } - private val sessionConfigAdapter = AcpSessionConfigAdapter(json) private val toolCallDecoder = AcpToolCallDecoder(json) - private val sessionUpdateParser = AcpSessionUpdateParser(toolCallDecoder) private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + private val sessionRoot: Path = + Paths.get(project.basePath ?: System.getProperty("user.dir")).toAbsolutePath().normalize() private val states = ConcurrentHashMap() private val sessionSetupMutexes = ConcurrentHashMap() @@ -72,6 +78,12 @@ class ExternalAcpAgentService(private val project: Project) { val preset = ExternalAcpAgents.find(session.externalAgentId) ?: error("Unsupported external agent: ${session.externalAgentId}") val state = ensureSessionReady(session, preset, events) + val debugModeEnabled = ConfigurationSettings.getState().debugModeEnabled + if (debugModeEnabled) { + logger.info( + "[${preset.displayName}] run/start session=${session.sessionId} externalSessionId=${session.externalAgentSessionId} firstMessageId=${firstMessage.id}" + ) + } var current: MessageWithContext? = firstMessage while (current != null && scope.isActive) { @@ -80,7 +92,23 @@ class ExternalAcpAgentService(private val project: Project) { ?: error("Missing ACP session id for ${session.sessionId}") try { + if (debugModeEnabled) { + logger.info( + "[${preset.displayName}] prompt/send session=${session.sessionId} externalSessionId=$externalSessionId messageId=${promptMessage.id} preview=${promptMessage.text.logPreview()}" + ) + } + logger.debug( + "Sending ACP prompt for session=${session.sessionId} externalSessionId=$externalSessionId messageId=${promptMessage.id} uiVisible=${promptMessage.uiVisible} preview=${promptMessage.text.logPreview()}" + ) sendPrompt(state, externalSessionId, promptMessage) + if (debugModeEnabled) { + logger.info( + "[${preset.displayName}] prompt/sent session=${session.sessionId} externalSessionId=$externalSessionId messageId=${promptMessage.id}" + ) + } + logger.debug( + "ACP prompt completed for session=${session.sessionId} externalSessionId=$externalSessionId messageId=${promptMessage.id}" + ) } catch (cancelled: CancellationException) { cancelSession(state, externalSessionId) throw cancelled @@ -88,11 +116,18 @@ class ExternalAcpAgentService(private val project: Project) { val nextMessage = pollNextQueued() if (nextMessage == null) { + if (debugModeEnabled) { + logger.info("[${preset.displayName}] run/complete session=${session.sessionId}") + } + logger.debug("ACP run finished with no queued follow-up for session=${session.sessionId}") events.onAgentCompleted(preset.displayName) return } - events.onQueuedMessagesResolved() + logger.debug( + "Promoting queued ACP message for session=${session.sessionId} nextMessageId=${nextMessage.id} uiVisible=${nextMessage.uiVisible} preview=${nextMessage.text.logPreview()}" + ) + events.onQueuedMessagesResolved(nextMessage) current = nextMessage delay(50) } @@ -103,7 +138,7 @@ class ExternalAcpAgentService(private val project: Project) { sessionSetupMutexes.remove(sessionId) } - suspend fun cancelSession(sessionId: String, externalSessionId: String?) { + fun cancelSession(sessionId: String, externalSessionId: String?) { val state = states[sessionId] ?: return val activeSessionId = externalSessionId ?: return cancelSession(state, activeSessionId) @@ -142,28 +177,40 @@ class ExternalAcpAgentService(private val project: Project) { val state = ensureSessionReady(session, preset, NO_OP_EVENTS) val externalSessionId = session.externalAgentSessionId ?: error("Missing ACP session id for ${session.sessionId}") - val result = sessionConfigAdapter.updateOption( - request = AcpConfigUpdateRequest( - sessionId = externalSessionId, - optionId = optionId, - value = value - ), - support = state.configUpdateSupport, - sendRequest = state::sendRequest - ) - when (result) { - AcpConfigUpdateResult.Unsupported -> { - state.configUpdateSupport = AcpConfigUpdateSupport.Unsupported - throw IllegalStateException("${preset.displayName} does not support runtime option changes") - } + when (optionId) { + AcpConfigCategories.MODE -> state.setMode( + SessionId(externalSessionId), + SessionModeId(value) + ) - is AcpConfigUpdateResult.Applied -> { - state.configUpdateSupport = result.support - session.externalAgentConfigSelections = - session.externalAgentConfigSelections + (optionId to value) - updateSessionConfigOptions(session, result.response) + AcpConfigCategories.MODEL -> state.setModel( + SessionId(externalSessionId), + ModelId(value) + ) + + else -> { + val option = session.externalAgentConfigOptions.firstOrNull { it.id == optionId } + ?: error("Unknown ACP runtime option '$optionId'") + mergeSessionConfigOptions( + session, + state.setConfigOption( + SessionId(externalSessionId), + option, + value + ) + ) } } + if (optionId == AcpConfigCategories.MODE || optionId == AcpConfigCategories.MODEL) { + session.externalAgentConfigOptions = + session.externalAgentConfigOptions.updateCurrentValue(optionId, value) + } + session.externalAgentConfigSelections = AcpConfigOptions.normalizeSelections( + session.externalAgentConfigOptions, + session.externalAgentConfigSelections + + (optionId to value) + + session.externalAgentConfigOptions.currentSelections() + ) } private suspend fun ensureSessionReady( @@ -194,14 +241,28 @@ class ExternalAcpAgentService(private val project: Project) { } existing?.close() - val resolvedCommand = AcpProcessHelper.resolveCommand( + val resolvedCommand = CommandRuntimeHelper.resolveCommand( command = preset.command, extraEnvironment = preset.env ) ?: throw IllegalStateException( - AcpProcessHelper.getCommandNotFoundMessage(preset.command) + buildString { + append("Command '${preset.command}' not found. ") + when (preset.command) { + "npx", "node" -> { + append("Node.js/npm is required for this ACP runtime. ") + append("Ensure it is installed and available to the IDE process. ") + append("You can also point the runtime to an absolute executable path.") + } + + else -> { + append("Ensure it is installed and available to the IDE process. ") + append("You can also point the runtime to an absolute executable path.") + } + } + } ) - val enhancedEnv = AcpProcessHelper.createEnvironment( + val enhancedEnv = CommandRuntimeHelper.createEnvironment( extraEnvironment = preset.env, resolvedCommand = resolvedCommand ) @@ -220,65 +281,58 @@ class ExternalAcpAgentService(private val project: Project) { val state = AcpProcessState( proxySessionId = session.sessionId, preset = preset, + launchEnv = enhancedEnv, process = process, events = events ) states[session.sessionId] = state - state.startReader() state.startStderrLogger() initialize(state) return state } private suspend fun initialize(state: AcpProcessState) { - val response = state.sendRequest( - method = "initialize", - params = buildJsonObject { - put("protocolVersion", PROTOCOL_VERSION) - putJsonObject("clientCapabilities") { - putJsonObject("fs") { - put("readTextFile", true) - put("writeTextFile", true) - } - } - putJsonObject("clientInfo") { - put("name", "proxyai") - put("title", "ProxyAI") - put("version", "dev") - } - } + val response = state.initialize( + ClientInfo( + protocolVersion = PROTOCOL_VERSION, + capabilities = state.clientCapabilities(), + implementation = Implementation( + name = "ProxyAI", + version = "unknown", + title = "ProxyAI" + ) + ) ) - state.authMethodIds = - response["authMethods"]?.jsonArray?.mapNotNull { it.jsonObject.string("id") }.orEmpty() + state.authMethodIds = response.authMethods.map(AuthMethod::id) + currentSession(state.proxySessionId)?.externalAgentPeerProfileId = state.peerProfile.profileId } private suspend fun createSession(state: AcpProcessState, session: AgentSession): String { return try { + val selectedMcpServerIds = selectedMcpServerIds(session.sessionId) + val mcpServers = buildMcpServers(selectedMcpServerIds) val response = runCatching { - state.sendRequest( - method = "session/new", - params = buildJsonObject { - put("cwd", project.basePath ?: System.getProperty("user.dir")) - put("mcpServers", buildMcpServers(session.sessionId)) - } + state.createSession( + cwd = project.basePath ?: System.getProperty("user.dir"), + mcpServers = mcpServers, + requestMeta = session.externalAgentRequestMeta ) }.recoverCatching { ex -> if (ex.isAuthenticationRequiredError() && state.authMethodIds.isNotEmpty()) { authenticate(state, state.authMethodIds.first()) - state.sendRequest( - method = "session/new", - params = buildJsonObject { - put("cwd", project.basePath ?: System.getProperty("user.dir")) - put("mcpServers", buildMcpServers(session.sessionId)) - } + state.createSession( + cwd = project.basePath ?: System.getProperty("user.dir"), + mcpServers = mcpServers, + requestMeta = session.externalAgentRequestMeta ) } else { throw ex } }.getOrThrow() - updateSessionConfigOptions(session, response) - val externalSessionId = response["sessionId"]?.jsonPrimitive?.content - ?: error("ACP agent did not return a sessionId") + applyRuntimeState(session, response.toRuntimeState()) + session.externalAgentMcpServerIds = selectedMcpServerIds + session.externalAgentPeerProfileId = state.peerProfile.profileId + val externalSessionId = response.sessionId.value applyConfiguredSelections(state, session, externalSessionId) externalSessionId } finally { @@ -286,23 +340,52 @@ class ExternalAcpAgentService(private val project: Project) { } } - private suspend fun authenticate(state: AcpProcessState, methodId: String) { - state.sendRequest( - method = "authenticate", - params = buildJsonObject { - put("methodId", methodId) - } - ) + private suspend fun authenticate(state: AcpProcessState, methodId: AuthMethodId) { + state.authenticate(methodId) } - private fun updateSessionConfigOptions(session: AgentSession, response: JsonObject) { - session.externalAgentConfigOptions = sessionConfigAdapter.merge( - existing = session.externalAgentConfigOptions, - response = response + private fun currentSession(proxySessionId: String): AgentSession? { + return project.service().getSession(proxySessionId) + } + + private fun applyRuntimeState( + session: AgentSession, + runtimeState: AcpRuntimeState + ) { + session.externalAgentConfigOptions = buildConfigOptions( + modes = runtimeState.modes, + models = runtimeState.models, + configOptions = runtimeState.configOptions ) + session.externalAgentConfigSelections = AcpConfigOptions.normalizeSelections( + session.externalAgentConfigOptions, + session.externalAgentConfigSelections + session.externalAgentConfigOptions.currentSelections() + ) + session.externalAgentAvailableCommands = runtimeState.availableCommands + session.externalAgentVendorMeta = runtimeState.vendorMeta + runtimeState.sessionTitle + ?.takeIf(String::isNotBlank) + ?.let { session.externalAgentSessionTitle = it } session.externalAgentConfigLoading = false } + private fun mergeSessionConfigOptions( + session: AgentSession, + configOptions: List + ) { + val standard = session.externalAgentConfigOptions.filter { + it.id == AcpConfigCategories.MODE || it.id == AcpConfigCategories.MODEL + } + session.externalAgentConfigOptions = (standard + configOptions.toAcpConfigOptions()) + .associateByTo(linkedMapOf(), AcpConfigOption::id) + .values + .toList() + session.externalAgentConfigSelections = AcpConfigOptions.normalizeSelections( + session.externalAgentConfigOptions, + session.externalAgentConfigSelections + session.externalAgentConfigOptions.currentSelections() + ) + } + private suspend fun applyConfiguredSelections( state: AcpProcessState, session: AgentSession, @@ -333,8 +416,10 @@ class ExternalAcpAgentService(private val project: Project) { addAll(remaining.entries.map { it.key to it.value }) } + val sessionId = SessionId(externalSessionId) for ((optionId, value) in orderedSelections) { - val option = session.externalAgentConfigOptions.firstOrNull { it.id == optionId } ?: continue + val option = + session.externalAgentConfigOptions.firstOrNull { it.id == optionId } ?: continue if (option.currentValue == value) { continue } @@ -343,33 +428,30 @@ class ExternalAcpAgentService(private val project: Project) { continue } - when (val result = runCatching { - sessionConfigAdapter.updateOption( - request = AcpConfigUpdateRequest( - sessionId = externalSessionId, - optionId = optionId, - value = value - ), - support = state.configUpdateSupport, - sendRequest = state::sendRequest + runCatching { + when (optionId) { + AcpConfigCategories.MODE -> state.setMode(sessionId, SessionModeId(value)) + AcpConfigCategories.MODEL -> state.setModel(sessionId, ModelId(value)) + else -> mergeSessionConfigOptions( + session, + state.setConfigOption(sessionId, option, value) + ) + } + if (optionId == AcpConfigCategories.MODE || optionId == AcpConfigCategories.MODEL) { + session.externalAgentConfigOptions = + session.externalAgentConfigOptions.updateCurrentValue(optionId, value) + } + session.externalAgentConfigSelections = AcpConfigOptions.normalizeSelections( + session.externalAgentConfigOptions, + session.externalAgentConfigSelections + + (optionId to value) + + session.externalAgentConfigOptions.currentSelections() ) - }.getOrElse { error -> + }.onFailure { error -> logger.warn( "Failed to apply ACP subagent option $optionId=$value for ${session.externalAgentId}", error ) - continue - }) { - AcpConfigUpdateResult.Unsupported -> { - state.configUpdateSupport = AcpConfigUpdateSupport.Unsupported - logger.warn("ACP agent ${session.externalAgentId} does not support runtime option changes") - return - } - - is AcpConfigUpdateResult.Applied -> { - state.configUpdateSupport = result.support - updateSessionConfigOptions(session, result.response) - } } } } @@ -379,81 +461,59 @@ class ExternalAcpAgentService(private val project: Project) { externalSessionId: String, message: MessageWithContext ) { - state.sendRequest( - method = "session/prompt", - params = buildJsonObject { - put("sessionId", externalSessionId) - put("prompt", buildPromptBlocks(message)) - } + state.sendPrompt( + sessionId = SessionId(externalSessionId), + prompt = buildPromptBlocks(message) ) } - private suspend fun cancelSession(state: AcpProcessState, externalSessionId: String) { + private fun cancelSession(state: AcpProcessState, externalSessionId: String) { runCatching { - state.sendNotification( - method = "session/cancel", - params = buildJsonObject { - put("sessionId", externalSessionId) - } - ) + state.cancel(SessionId(externalSessionId)) }.onFailure { logger.debug("Failed to cancel ACP session $externalSessionId", it) } } - private fun buildMcpServers(proxySessionId: String): JsonArray { - val selectedServerIds = - project.service() - .get(proxySessionId) - ?.selectedServerIds - .orEmpty() + private fun selectedMcpServerIds(proxySessionId: String): Set { + return project.service() + .get(proxySessionId) + ?.selectedServerIds + .orEmpty() + } + + private fun buildMcpServers(selectedServerIds: Set): List { if (selectedServerIds.isEmpty()) { - return JsonArray(emptyList()) + return emptyList() } val serversById = project.service().state.servers.associateBy { it.id.toString() } - return JsonArray(selectedServerIds.mapNotNull { serverId -> + return selectedServerIds.mapNotNull { serverId -> val server = serversById[serverId] ?: return@mapNotNull null - buildJsonObject { - put("type", "stdio") - put("name", server.name ?: serverId) - put("command", server.command ?: "npx") - putJsonArray("args") { - server.arguments.forEach { add(JsonPrimitive(it)) } + McpServer.Stdio( + name = server.name ?: serverId, + command = server.command ?: "npx", + args = server.arguments, + env = server.environmentVariables.map { (key, value) -> + EnvVariable(name = key, value = value) } - putJsonArray("env") { - server.environmentVariables.forEach { (key, value) -> - add( - buildJsonObject { - put("name", key) - put("value", value) - } - ) - } - } - } - }) + ) + } } - private fun buildPromptBlocks(message: MessageWithContext): JsonArray { - val blocks = mutableListOf() + private fun buildPromptBlocks(message: MessageWithContext): List { + val blocks = mutableListOf() val selectedTags = message.tags.filter { it.selected } if (selectedTags.isNotEmpty()) { val tagSummary = buildTagSummary(selectedTags) if (tagSummary.isNotBlank()) { - blocks += buildJsonObject { - put("type", "text") - put("text", tagSummary) - } + blocks += ContentBlock.Text(tagSummary) } } - blocks += buildJsonObject { - put("type", "text") - put("text", message.text) - } + blocks += ContentBlock.Text(message.text) selectedTags.forEach { tag -> when (tag) { @@ -463,17 +523,16 @@ class ExternalAcpAgentService(private val project: Project) { } } - return JsonArray(blocks) + return blocks } - private fun resourceLinkBlock(path: String): JsonObject { + private fun resourceLinkBlock(path: String): ContentBlock.ResourceLink { val filePath = Paths.get(path) - return buildJsonObject { - put("type", "resource_link") - put("uri", Paths.get(path).toUri().toString()) - put("name", filePath.fileName?.toString() ?: path) - put("mimeType", "text/plain") - } + return ContentBlock.ResourceLink( + uri = Paths.get(path).toUri().toString(), + name = filePath.fileName?.toString() ?: path, + mimeType = "text/plain" + ) } private fun buildTagSummary(tags: List): String { @@ -502,358 +561,391 @@ class ExternalAcpAgentService(private val project: Project) { }.trim() } + private fun buildConfigOptions( + modes: SessionModeState?, + models: SessionModelState?, + configOptions: List + ): List { + val standard = buildList { + models?.let { state -> + add( + AcpConfigOption( + id = AcpConfigCategories.MODEL, + name = "Model", + category = AcpConfigCategories.MODEL, + type = "select", + currentValue = state.currentModelId.value, + options = state.availableModels.map { model -> + AcpConfigOptionChoice( + value = model.modelId.value, + name = model.name, + description = model.description + ) + } + ) + ) + } + modes?.let { state -> + add( + AcpConfigOption( + id = AcpConfigCategories.MODE, + name = "Mode", + category = AcpConfigCategories.MODE, + type = "select", + currentValue = state.currentModeId.value, + options = state.availableModes.map { mode -> + AcpConfigOptionChoice( + value = mode.id.value, + name = mode.name, + description = mode.description + ) + } + ) + ) + } + } + return (standard + configOptions.toAcpConfigOptions()) + .associateByTo(linkedMapOf(), AcpConfigOption::id) + .values + .toList() + } + private inner class AcpProcessState( val proxySessionId: String, val preset: ExternalAcpAgentPreset, + val launchEnv: Map, val process: Process, @Volatile var events: AgentEvents ) { - private val toolCallsById = ConcurrentHashMap() - private val rpcConnection = AcpJsonRpcConnection( - json = json, - process = process, - scope = scope, - logger = logger, - processName = preset.displayName, - onRequest = ::handleRequest, - onNotification = { notification -> handleNotification(notification, events) } + private val compatibilityRegistry = AcpCompatibilityRegistry() + + @Volatile + var peerProfile: AcpPeerProfile = compatibilityRegistry.initialProfile(preset) + + private val transport = StdioTransport( + parentScope = scope, + ioDispatcher = Dispatchers.IO, + input = process.inputStream.asSource().buffered(), + output = process.outputStream.asSink().buffered(), + name = preset.displayName + ) + private val protocol = AcpProtocol( + scope, + transport, + ProtocolOptions( + protocolDebugName = preset.displayName, + outboundPayloadAugmenter = { methodName, payload -> + compatibilityRegistry.augmentOutboundPayload( + profile = peerProfile, + methodName = methodName, + payload = payload, + sessionRequestMeta = when (methodName.name) { + "session/new", "session/load" -> currentSession()?.externalAgentRequestMeta + else -> null + }, + launchEnv = launchEnv + ) + }, + inboundPayloadNormalizer = { methodName, payload -> + compatibilityRegistry.normalizeInboundPayload( + profile = peerProfile, + methodName = methodName, + payload = payload + ) + }, + trace = ::acpTrace + ) + ) + private val hostCapabilities = AcpHostCapabilities( + fileHost = AcpFileHost(), + terminalHost = AcpTerminalHost(DefaultAcpTerminalProcessLauncher(scope)) + ) + private val hostBridge = AcpHostBridge( + proxySessionId = proxySessionId, + displayName = preset.displayName, + toolEventFlavor = preset.toolEventFlavor, + fullAccessModeId = FULL_ACCESS_MODE_ID, + sessionRoot = sessionRoot, + toolCallDecoder = toolCallDecoder, + hostCapabilities = hostCapabilities, + currentSession = ::currentSession, + eventsProvider = { events }, + trace = ::acpTrace + ) + private val sessionUpdateBridge = AcpSessionUpdateBridge( + proxySessionId = proxySessionId, + toolEventFlavor = preset.toolEventFlavor, + toolCallDecoder = toolCallDecoder, + updateModeSelection = ::updateCurrentMode, + updateConfigOptions = ::updateConfigOptions, + updateAvailableCommands = ::updateAvailableCommands, + updateSessionInfo = ::updateSessionInfo, + trace = ::acpTrace ) @Volatile - var authMethodIds: List = emptyList() + var authMethodIds: List = emptyList() - @Volatile - var configUpdateSupport: AcpConfigUpdateSupport = AcpConfigUpdateSupport.Unknown - - fun isAlive(): Boolean = rpcConnection.isAlive() - - fun startReader() = rpcConnection.startReader() - - fun startStderrLogger() = rpcConnection.startStderrLogger() - - suspend fun sendRequest(method: String, params: JsonObject): JsonObject = - rpcConnection.request(method, params) - - suspend fun sendNotification(method: String, params: JsonObject) = - rpcConnection.notify(method, params) - - private suspend fun handleRequest(request: AcpJsonRpcRequest): JsonElement? { - return when (request.method) { - "session/request_permission", "requestPermission" -> - handleRequestPermission(request.params) - - "fs/read_text_file", "readTextFile" -> handleReadTextFile(request.params) - "fs/write_text_file", "writeTextFile" -> handleWriteTextFile(request.params) - else -> null + init { + hostBridge.register(protocol) + protocol.setNotificationHandler(AcpMethod.ClientMethods.SessionUpdate) { notification -> + sessionUpdateBridge.handle(notification, events) } + protocol.start() } - private fun handleNotification(notification: AcpJsonRpcNotification, events: AgentEvents) { - when (val update = sessionUpdateParser.parse(notification)) { - null -> Unit - is AcpSessionUpdate.TextChunk -> events.onTextReceived(update.text) - is AcpSessionUpdate.ThoughtChunk -> events.onTextReceived("${update.text}") - is AcpSessionUpdate.ToolCall -> handleToolCall(update, events) - is AcpSessionUpdate.ToolCallUpdate -> handleToolCallUpdate(update, events) - is AcpSessionUpdate.ConfigOptionUpdate -> handleConfigOptionUpdate(update.update) - } - } + fun isAlive(): Boolean = process.isAlive - private suspend fun handleRequestPermission(params: JsonObject): JsonObject { - val requestData = toolCallDecoder.decodePermissionRequest(params) - val session = currentSession() - val mode = session.currentAcpMode() - if (mode == FULL_ACCESS_MODE_ID) { - logger.debug( - "Auto-approving ${preset.displayName} ACP permission in mode=$mode for tool=${requestData.toolName} title=${requestData.rawTitle}" - ) - return permissionResponse(selectApprovedPermissionOptionId(requestData.options)) - } + fun clientCapabilities(): ClientCapabilities = hostCapabilities.clientCapabilities() - val request = buildApprovalRequest( - requestData.rawTitle, - requestData.details, - requestData.toolName, - requestData.parsedArgs - ) - val approved = runCatching { events.approveToolCall(request) }.getOrDefault(false) - val selectedOptionId = if (approved) { - selectApprovedPermissionOptionId(requestData.options) - } else { - selectRejectedPermissionOptionId(requestData.options) - } - logger.debug( - "Resolved ${preset.displayName} ACP permission in mode=$mode for tool=${requestData.toolName} approved=$approved title=${requestData.rawTitle}" - ) - return permissionResponse(selectedOptionId) - } - - private fun handleToolCall(update: AcpSessionUpdate.ToolCall, events: AgentEvents) { - val toolCall = update.toolCall - toolCallsById[toolCall.id] = ExternalToolCall(toolCall.toolName, toolCall.args) - if (!shouldDeferToolStart(toolCall.toolName, toolCall.args, update.status)) { - events.onToolStarting(toolCall.id, toolCall.toolName, toolCall.args) - } - - if (update.status?.isTerminal == true) { - completeToolCall( - toolCallId = toolCall.id, - toolName = toolCall.toolName, - args = toolCall.args, - status = update.status, - rawOutput = update.rawOutput, - events = events - ) - } - } - - private fun handleToolCallUpdate( - update: AcpSessionUpdate.ToolCallUpdate, - events: AgentEvents - ) { - val updatedToolCall = update.toolCall - val currentToolCall = toolCallsById[update.toolCallId] - val effectiveToolName = updatedToolCall?.toolName ?: currentToolCall?.toolName ?: "Tool" - val effectiveArgs = updatedToolCall?.args ?: currentToolCall?.args - - if (updatedToolCall != null) { - toolCallsById[update.toolCallId] = ExternalToolCall(effectiveToolName, effectiveArgs) - } - - if (currentToolCall?.args == null && effectiveArgs != null) { - events.onToolStarting(update.toolCallId, effectiveToolName, effectiveArgs) - } - - if (!update.status.isTerminal) { - return - } - - completeToolCall( - toolCallId = update.toolCallId, - toolName = effectiveToolName, - args = effectiveArgs, - status = update.status, - rawOutput = update.rawOutput, - events = events - ) - } - - private fun completeToolCall( - toolCallId: String, - toolName: String, - args: Any?, - status: AcpToolCallStatus, - rawOutput: JsonElement?, - events: AgentEvents - ) { - val result = toolCallDecoder.decodeResult( - toolName = toolName, - args = args, - status = status, - rawOutput = rawOutput - ) - toolCallsById.remove(toolCallId) - events.onToolCompleted(toolCallId, toolName, result) - } - - private fun shouldDeferToolStart( - toolName: String, - args: Any?, - status: AcpToolCallStatus? - ): Boolean { - return toolName in setOf("WebSearch", "WebFetch") && - args == null && - status?.isTerminal != true - } - - private fun handleConfigOptionUpdate(update: JsonObject) { - project.service() - .getSession(proxySessionId) - ?.let { session -> - updateSessionConfigOptions(session, update) - } - } - - private fun handleReadTextFile(params: JsonObject): JsonObject { - val path = requestPath(params) - val raw = Files.readString(path) - val line = params["line"]?.jsonPrimitive?.intOrNull - val limit = params["limit"]?.jsonPrimitive?.intOrNull - val content = if (line != null && limit != null && line > 0 && limit > 0) { - raw.lineSequence() - .drop(line - 1) - .take(limit) - .joinToString("\n") - } else { - raw - } - return buildJsonObject { - put("content", content) - } - } - - private fun handleWriteTextFile(params: JsonObject): JsonElement { - val path = requestPath(params) - val content = params["content"]?.jsonPrimitive?.content ?: "" - if (path.parent != null && path.parent.notExists()) { - path.parent.createDirectories() - } - Files.writeString(path, content) - val virtualFile = - LocalFileSystem.getInstance().refreshAndFindFileByIoFile(path.toFile()) - if (virtualFile != null) { - VfsUtil.markDirtyAndRefresh(false, false, false, virtualFile) - } else { - val parent = path.parent?.toFile() - if (parent != null) { - LocalFileSystem.getInstance().refreshAndFindFileByIoFile(parent) - ?.let { parentVf -> - VfsUtil.markDirtyAndRefresh(false, false, true, parentVf) + fun startStderrLogger() { + scope.launch { + process.errorStream.bufferedReader().useLines { lines -> + lines.forEach { line -> + if (line.isNotBlank()) { + logger.info("[${preset.displayName}] $line") } + } } } - return JsonNull } - private fun selectApprovedPermissionOptionId(options: JsonArray): String { - return selectPermissionOptionId( - options = options, - preferredKinds = listOf("allow_once", "allow_always", "trust"), - fallback = "allow" + suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AcpMethod.AgentMethods.Initialize( + protocol, + InitializeRequest( + clientInfo.protocolVersion, + clientInfo.capabilities, + clientInfo.implementation, + clientInfo._meta + ) + ) + .let { + peerProfile = compatibilityRegistry.resolveProfile( + preset = preset, + agentInfo = AgentInfo( + it.protocolVersion, + it.agentCapabilities, + it.authMethods, + it.agentInfo, + it._meta + ) + ) + AgentInfo( + it.protocolVersion, + it.agentCapabilities, + it.authMethods, + it.agentInfo, + it._meta + ) + } + } + + suspend fun authenticate(methodId: AuthMethodId) { + AcpMethod.AgentMethods.Authenticate(protocol, AuthenticateRequest(methodId)) + } + + suspend fun createSession( + cwd: String, + mcpServers: List, + requestMeta: JsonElement? = null + ): NewSessionResponse { + return AcpMethod.AgentMethods.SessionNew( + protocol, + NewSessionRequest(cwd = cwd, mcpServers = mcpServers, _meta = requestMeta) ) } - private fun selectRejectedPermissionOptionId(options: JsonArray): String { - return selectPermissionOptionId( - options = options, - preferredKinds = listOf("reject_once", "reject_always", "deny", "abort", "cancel"), - fallback = "abort", - predicate = { value -> - value.contains("reject") || value.contains("deny") || - value.contains("abort") || value.contains("cancel") - } + suspend fun sendPrompt(sessionId: SessionId, prompt: List) { + AcpMethod.AgentMethods.SessionPrompt( + protocol, + PromptRequest(sessionId = sessionId, prompt = prompt) ) } - private fun selectPermissionOptionId( - options: JsonArray, - preferredKinds: List, - fallback: String, - predicate: (String) -> Boolean = { false } - ): String { - preferredKinds.forEach { preferred -> - options.firstOrNull { optionMatches(it.jsonObject, preferred) }?.let { - return optionIdOf(it.jsonObject) ?: preferred - } - } - - options.firstOrNull { option -> - optionValues(option.jsonObject).any(predicate) - }?.let { option -> - return optionIdOf(option.jsonObject) ?: fallback - } - - return options.firstOrNull()?.jsonObject?.let(::optionIdOf) ?: fallback + fun cancel(sessionId: SessionId) { + AcpMethod.AgentMethods.SessionCancel(protocol, CancelNotification(sessionId)) } - private fun optionMatches(option: JsonObject, expected: String): Boolean { - return optionValues(option).any { it == expected } - } - - private fun optionValues(option: JsonObject): List { - return listOfNotNull( - option["kind"]?.jsonPrimitive?.contentOrNull?.lowercase(), - option["optionId"]?.jsonPrimitive?.contentOrNull?.lowercase(), - option["id"]?.jsonPrimitive?.contentOrNull?.lowercase() + suspend fun setMode(sessionId: SessionId, modeId: SessionModeId) { + AcpMethod.AgentMethods.SessionSetMode( + protocol, + SetSessionModeRequest(sessionId, modeId) ) } - private fun optionIdOf(option: JsonObject): String? { - return option["optionId"]?.jsonPrimitive?.contentOrNull - ?: option["id"]?.jsonPrimitive?.contentOrNull - } - - private fun buildApprovalRequest( - rawTitle: String, - details: String, - toolName: String, - parsedArgs: Any? - ): ToolApprovalRequest { - val approvalType = when (parsedArgs) { - is WriteTool.Args -> ToolApprovalType.WRITE - is EditTool.Args -> ToolApprovalType.EDIT - is BashTool.Args -> ToolApprovalType.BASH - else -> ToolSpecs.approvalTypeFor(toolName) - } - val payload = approvalPayload(parsedArgs) - val title = rawTitle.ifBlank { - when (approvalType) { - ToolApprovalType.BASH -> "Run shell command?" - ToolApprovalType.WRITE -> "Write file?" - ToolApprovalType.EDIT -> "Edit file?" - ToolApprovalType.GENERIC -> "Allow action?" - } - } - return ToolApprovalRequest( - type = approvalType, - title = title, - details = details, - payload = payload + suspend fun setModel(sessionId: SessionId, modelId: ModelId) { + AcpMethod.AgentMethods.SessionSetModel( + protocol, + SetSessionModelRequest(sessionId, modelId) ) } - private fun approvalPayload(parsedArgs: Any?): ToolApprovalPayload? { - return when (parsedArgs) { - is WriteTool.Args -> WritePayload(parsedArgs.filePath, parsedArgs.content) - is EditTool.Args -> EditPayload( - filePath = parsedArgs.filePath, - oldString = parsedArgs.oldString, - newString = parsedArgs.newString, - replaceAll = parsedArgs.replaceAll + suspend fun setConfigOption( + sessionId: SessionId, + option: AcpConfigOption, + value: String + ): List { + val optionValue = when (option.type) { + "boolean" -> SessionConfigOptionValue.BoolValue( + value.toBooleanStrictOrNull() + ?: error("Invalid boolean ACP option value '$value' for ${option.id}") ) - is BashTool.Args -> BashPayload(parsedArgs.command, parsedArgs.description) - else -> null - } - } - - private fun permissionResponse(optionId: String): JsonObject { - return buildJsonObject { - putJsonObject("outcome") { - put("outcome", "selected") - put("optionId", optionId) - } + else -> SessionConfigOptionValue.StringValue(value) } + return AcpMethod.AgentMethods.SessionSetConfigOption( + protocol, + SetSessionConfigOptionRequest( + sessionId = sessionId, + configId = SessionConfigId(option.id), + value = optionValue + ) + ).configOptions } private fun currentSession(): AgentSession? { return project.service().getSession(proxySessionId) } - fun close() = rpcConnection.close() - } + private fun updateCurrentMode(currentModeId: String) { + currentSession()?.let { session -> + session.externalAgentConfigOptions = + session.externalAgentConfigOptions.updateCurrentValue( + AcpConfigCategories.MODE, + currentModeId + ) + session.externalAgentConfigSelections = AcpConfigOptions.normalizeSelections( + session.externalAgentConfigOptions, + session.externalAgentConfigSelections + session.externalAgentConfigOptions.currentSelections() + ) + } + } - private fun AgentSession?.currentAcpMode(): String? { - return this?.externalAgentConfigOptions - ?.firstOrNull(AcpSessionConfigId.MODE::matches) - ?.currentValue - } + private fun updateConfigOptions(configOptions: List) { + currentSession()?.let { session -> + mergeSessionConfigOptions(session, configOptions) + } + } - private fun requestPath(params: JsonObject): Path { - val rawPath = params["path"]?.jsonPrimitive?.contentOrNull - ?: params["uri"]?.jsonPrimitive?.contentOrNull - ?: error("Missing path") - return uriToPath(rawPath) - } + private fun updateAvailableCommands(availableCommands: List) { + currentSession()?.externalAgentAvailableCommands = availableCommands + } - private fun uriToPath(uri: String): Path { - return when { - uri.startsWith("file://") -> Paths.get(java.net.URI.create(uri)) - else -> Paths.get(uri) + private fun updateSessionInfo(title: String?, updatedAt: String?) { + currentSession()?.let { session -> + if (!title.isNullOrBlank()) { + session.externalAgentSessionTitle = title + } + } + } + + private fun acpTrace(message: String) { + if ( + ConfigurationSettings.getState().debugModeEnabled || + preset.toolEventFlavor == AcpToolEventFlavor.GEMINI_CLI + ) { + logger.info("[ACP TRACE][${preset.displayName}/${peerProfile.profileId}] $message") + } + } + + fun close() { + protocol.close() + process.destroy() } } private fun Throwable.isAuthenticationRequiredError(): Boolean { - return message?.contains("Authentication required", ignoreCase = true) == true + return (this as? JsonRpcException)?.message?.contains( + "Authentication required", + ignoreCase = true + ) == true || + message?.contains("Authentication required", ignoreCase = true) == true + } +} + +internal object AcpConfigCategories { + const val MODEL = "model" + const val MODE = "mode" + const val THOUGHT_LEVEL = "thought_level" +} + +internal fun List.updateCurrentValue( + optionId: String, + value: String +): List { + return map { option -> + if (option.id == optionId) { + option.copy(currentValue = value) + } else { + option + } + } +} + +internal fun List.currentSelections(): Map { + return mapNotNull { option -> + option.currentValue?.takeIf { it.isNotBlank() }?.let { option.id to it } + }.toMap(linkedMapOf()) +} + +private fun List.toAcpConfigOptions(): List { + return map { option -> + when (option) { + is SessionConfigOption.Select -> { + val flattenedOptions = when (val selectOptions = option.options) { + is SessionConfigSelectOptions.Flat -> selectOptions.options.map { choice -> + AcpConfigOptionChoice( + value = choice.value.value, + name = choice.name, + description = choice.description + ) + } + + is SessionConfigSelectOptions.Grouped -> selectOptions.groups.flatMap { group -> + group.options.map { choice -> + AcpConfigOptionChoice( + value = choice.value.value, + name = "${group.name}: ${choice.name}", + description = choice.description + ) + } + } + } + AcpConfigOption( + id = option.id.value, + name = option.name, + description = option.description, + category = option.id.value.toAcpConfigCategory(), + type = "select", + currentValue = option.currentValue.value, + options = flattenedOptions + ) + } + + is SessionConfigOption.BooleanOption -> AcpConfigOption( + id = option.id.value, + name = option.name, + description = option.description, + category = option.id.value.toAcpConfigCategory(), + type = "boolean", + currentValue = option.currentValue.toString(), + options = listOf( + AcpConfigOptionChoice("true", "Enabled"), + AcpConfigOptionChoice("false", "Disabled") + ) + ) + } + } +} + +private fun String.toAcpConfigCategory(): String? { + return when (lowercase()) { + AcpConfigCategories.MODEL -> AcpConfigCategories.MODEL + AcpConfigCategories.MODE -> AcpConfigCategories.MODE + AcpConfigCategories.THOUGHT_LEVEL, + "reasoning", + "reasoning_effort" -> AcpConfigCategories.THOUGHT_LEVEL + else -> null } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpHostBridge.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpHostBridge.kt new file mode 100644 index 00000000..a3f937a5 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpHostBridge.kt @@ -0,0 +1,340 @@ +package ee.carlrobert.codegpt.agent.external + +import com.agentclientprotocol.model.* +import ee.carlrobert.codegpt.agent.AgentEvents +import ee.carlrobert.codegpt.agent.ToolSpecs +import ee.carlrobert.codegpt.agent.external.acpcompat.AcpProtocol +import ee.carlrobert.codegpt.agent.external.acpcompat.acpFail +import ee.carlrobert.codegpt.agent.external.acpcompat.setRequestHandler +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs +import ee.carlrobert.codegpt.agent.external.host.AcpHostCapabilities +import ee.carlrobert.codegpt.agent.external.host.AcpHostSessionContext +import ee.carlrobert.codegpt.toolwindow.agent.AgentSession +import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.* +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CancellationException +import java.nio.file.Path +import kotlin.io.path.name + +internal class AcpHostBridge( + private val proxySessionId: String, + private val displayName: String, + private val toolEventFlavor: AcpToolEventFlavor, + private val fullAccessModeId: String, + private val sessionRoot: Path, + private val toolCallDecoder: AcpToolCallDecoder, + private val hostCapabilities: AcpHostCapabilities, + private val currentSession: () -> AgentSession?, + private val eventsProvider: () -> AgentEvents, + private val trace: (String) -> Unit +) { + private val logger = KotlinLogging.logger {} + + fun register(protocol: AcpProtocol) { + protocol.setRequestHandler(AcpMethod.ClientMethods.SessionRequestPermission) { request -> + handleRequestPermission(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.FsReadTextFile) { request -> + handleReadTextFile(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.FsWriteTextFile) { request -> + handleWriteTextFile(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalCreate) { request -> + handleCreateTerminal(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalOutput) { request -> + handleTerminalOutput(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalRelease) { request -> + handleTerminalRelease(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalWaitForExit) { request -> + handleTerminalWaitForExit(request) + } + protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalKill) { request -> + handleTerminalKill(request) + } + } + + private suspend fun handleRequestPermission( + request: RequestPermissionRequest + ): RequestPermissionResponse { + val permissionRequest = toolCallDecoder.decodePermissionRequest(toolEventFlavor, request) + val toolCall = permissionRequest.toolCall + val mode = currentSession().currentAcpMode() + trace( + "permission/request session=$proxySessionId mode=$mode tool=${toolCall.toolName} title=${toolCall.title.logPreview()} options=${permissionRequest.options.joinToString { it.kind.name }} rawInput=${request.toolCall.rawInput.logSummary()} details=${ + permissionRequest.details.logPreview( + 200 + ) + }" + ) + logger.debug { + "Received $displayName ACP permission request for session=$proxySessionId mode=$mode tool=${toolCall.toolName} title=${toolCall.title.logPreview()} details=${permissionRequest.details.logPreview()} options=${permissionRequest.options.joinToString { it.kind.name }}" + } + if (mode == fullAccessModeId) { + trace("permission/auto-approve session=$proxySessionId mode=$mode tool=${toolCall.toolName}") + logger.debug { + "Auto-approving $displayName ACP permission in mode=$mode for tool=${toolCall.toolName} title=${toolCall.title}" + } + return permissionResponse(selectApprovedPermissionOptionId(permissionRequest.options)) + } + + val approvalRequest = buildApprovalRequest( + rawTitle = toolCall.title, + details = permissionRequest.details, + toolName = toolCall.toolName, + parsedArgs = toolCall.args + ) + logger.debug { + "Queueing UI approval for session=$proxySessionId type=${approvalRequest.type} title=${approvalRequest.title.logPreview()} payload=${approvalRequest.payload.logSummary()}" + } + trace( + "permission/await-ui session=$proxySessionId type=${approvalRequest.type} title=${approvalRequest.title.logPreview()} payload=${approvalRequest.payload.logSummary()}" + ) + return try { + val approved = eventsProvider().approveToolCall(approvalRequest) + val selectedOptionId = if (approved) { + selectApprovedPermissionOptionId(permissionRequest.options) + } else { + selectRejectedPermissionOptionId(permissionRequest.options) + } + trace( + "permission/resolved session=$proxySessionId tool=${toolCall.toolName} approved=$approved selectedOption=${selectedOptionId.value}" + ) + logger.debug { + "Resolved $displayName ACP permission in mode=$mode for tool=${toolCall.toolName} approved=$approved title=${toolCall.title}" + } + permissionResponse(selectedOptionId) + } catch (_: CancellationException) { + trace("permission/cancelled session=$proxySessionId tool=${toolCall.toolName}") + permissionCancelledResponse() + } catch (error: Exception) { + logger.warn(error) { + "Permission approval failed for $displayName tool=${toolCall.toolName}; defaulting to reject" + } + permissionResponse(selectRejectedPermissionOptionId(permissionRequest.options)) + } + } + + private suspend fun handleReadTextFile(params: ReadTextFileRequest): ReadTextFileResponse { + trace("fs/read_text_file session=$proxySessionId path=${params.path} line=${params.line} limit=${params.limit}") + return runHostCall( + invalidRequestMessage = "Invalid read_text_file request", + failureTrace = "fs/read_text_file/failed session=$proxySessionId path=${params.path}" + ) { + hostCapabilities.readTextFile(hostSessionContext(params.sessionId), params) + } + } + + private suspend fun handleWriteTextFile(params: WriteTextFileRequest): WriteTextFileResponse { + trace("fs/write_text_file session=$proxySessionId path=${params.path}") + return runHostCall( + invalidRequestMessage = "Invalid write_text_file request", + failureTrace = "fs/write_text_file/failed session=$proxySessionId path=${params.path}" + ) { + hostCapabilities.writeTextFile(hostSessionContext(params.sessionId), params) + } + } + + private suspend fun handleCreateTerminal(params: CreateTerminalRequest): CreateTerminalResponse { + trace( + "terminal/create session=$proxySessionId requestSession=${params.sessionId.value} command=${params.command.logPreview()} cwd=${params.cwd?.logPreview()}" + ) + return runHostCall( + invalidRequestMessage = "Invalid terminal/create request", + failureTrace = "terminal/create/failed session=$proxySessionId command=${params.command.logPreview()}" + ) { + hostCapabilities.createTerminal(hostSessionContext(params.sessionId), params) + } + } + + private suspend fun handleTerminalOutput(params: TerminalOutputRequest): TerminalOutputResponse { + trace( + "terminal/output session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}" + ) + return runHostCall( + invalidRequestMessage = "Invalid terminal/output request", + failureTrace = "terminal/output/failed session=$proxySessionId terminalId=${params.terminalId}" + ) { + hostCapabilities.output(hostSessionContext(params.sessionId), params) + } + } + + private suspend fun handleTerminalRelease(params: ReleaseTerminalRequest): ReleaseTerminalResponse { + trace( + "terminal/release session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}" + ) + return runHostCall( + invalidRequestMessage = "Invalid terminal/release request", + failureTrace = "terminal/release/failed session=$proxySessionId terminalId=${params.terminalId}" + ) { + hostCapabilities.release(hostSessionContext(params.sessionId), params) + } + } + + private suspend fun handleTerminalWaitForExit(params: WaitForTerminalExitRequest): WaitForTerminalExitResponse { + trace( + "terminal/wait_for_exit session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}" + ) + return runHostCall( + invalidRequestMessage = "Invalid terminal/wait_for_exit request", + failureTrace = "terminal/wait_for_exit/failed session=$proxySessionId terminalId=${params.terminalId}" + ) { + hostCapabilities.waitForExit(hostSessionContext(params.sessionId), params) + } + } + + private suspend fun handleTerminalKill(params: KillTerminalCommandRequest): KillTerminalCommandResponse { + trace( + "terminal/kill session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}" + ) + return runHostCall( + invalidRequestMessage = "Invalid terminal/kill request", + failureTrace = "terminal/kill/failed session=$proxySessionId terminalId=${params.terminalId}" + ) { + hostCapabilities.kill(hostSessionContext(params.sessionId), params) + } + } + + private fun selectApprovedPermissionOptionId(options: List): PermissionOptionId { + return selectPermissionOptionId( + options = options, + preferredKinds = listOf( + PermissionOptionKind.ALLOW_ONCE, + PermissionOptionKind.ALLOW_ALWAYS + ) + ) + } + + private fun selectRejectedPermissionOptionId(options: List): PermissionOptionId { + return selectPermissionOptionId( + options = options, + preferredKinds = listOf( + PermissionOptionKind.REJECT_ONCE, + PermissionOptionKind.REJECT_ALWAYS + ) + ) + } + + private fun selectPermissionOptionId( + options: List, + preferredKinds: List + ): PermissionOptionId { + preferredKinds.forEach { preferred -> + options.firstOrNull { it.kind == preferred }?.let { return it.optionId } + } + return options.firstOrNull()?.optionId ?: PermissionOptionId("abort") + } + + private fun buildApprovalRequest( + rawTitle: String, + details: String, + toolName: String, + parsedArgs: AcpToolCallArgs? + ): ToolApprovalRequest { + val approvalType = when (parsedArgs) { + is AcpToolCallArgs.Write -> ToolApprovalType.WRITE + is AcpToolCallArgs.Edit -> ToolApprovalType.EDIT + is AcpToolCallArgs.Bash -> ToolApprovalType.BASH + else -> ToolSpecs.approvalTypeFor(toolName) + } + val payload = approvalPayload(parsedArgs) + val title = approvalTitle(rawTitle, approvalType, parsedArgs) + val resolvedDetails = approvalDetails(details, parsedArgs) + return ToolApprovalRequest( + type = approvalType, + title = title, + details = resolvedDetails, + payload = payload + ) + } + + private fun approvalPayload(parsedArgs: AcpToolCallArgs?): ToolApprovalPayload? { + return when (parsedArgs) { + is AcpToolCallArgs.Write -> WritePayload( + parsedArgs.value.filePath, + parsedArgs.value.content + ) + + is AcpToolCallArgs.Edit -> EditPayload( + filePath = parsedArgs.value.filePath, + oldString = parsedArgs.value.oldString, + newString = parsedArgs.value.newString, + replaceAll = parsedArgs.value.replaceAll + ) + + is AcpToolCallArgs.Bash -> BashPayload( + parsedArgs.value.command, + parsedArgs.value.description + ) + + else -> null + } + } + + private fun approvalTitle( + rawTitle: String, + approvalType: ToolApprovalType, + parsedArgs: AcpToolCallArgs? + ): String { + return when (parsedArgs) { + is AcpToolCallArgs.Write -> "Write ${Path.of(parsedArgs.value.filePath).name}?" + is AcpToolCallArgs.Edit -> "Edit ${Path.of(parsedArgs.value.filePath).name}?" + else -> rawTitle.ifBlank { + when (approvalType) { + ToolApprovalType.BASH -> "Run shell command?" + ToolApprovalType.WRITE -> "Write file?" + ToolApprovalType.EDIT -> "Edit file?" + ToolApprovalType.GENERIC -> "Allow action?" + } + } + } + } + + private fun approvalDetails(details: String, parsedArgs: AcpToolCallArgs?): String { + return when (parsedArgs) { + is AcpToolCallArgs.Write -> parsedArgs.value.filePath + is AcpToolCallArgs.Edit -> parsedArgs.value.filePath + else -> details + } + } + + private fun permissionResponse(optionId: PermissionOptionId): RequestPermissionResponse { + return RequestPermissionResponse( + outcome = RequestPermissionOutcome.Selected(optionId) + ) + } + + private fun permissionCancelledResponse(): RequestPermissionResponse { + return RequestPermissionResponse( + outcome = RequestPermissionOutcome.Cancelled + ) + } + + private fun hostSessionContext(sessionId: SessionId): AcpHostSessionContext { + return AcpHostSessionContext(sessionId = sessionId.value, cwd = sessionRoot) + } + + private suspend inline fun runHostCall( + invalidRequestMessage: String, + failureTrace: String, + crossinline block: suspend () -> T + ): T { + return try { + block() + } catch (ex: IllegalArgumentException) { + acpFail(ex.message ?: invalidRequestMessage) + } catch (ex: Exception) { + trace("$failureTrace error=${(ex.message ?: ex::class.simpleName.orEmpty()).logPreview()}") + throw ex + } + } +} + +private fun AgentSession?.currentAcpMode(): String? { + return this?.externalAgentConfigOptions + ?.firstOrNull { it.id == AcpConfigCategories.MODE || it.category == AcpConfigCategories.MODE } + ?.currentValue +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJson.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJson.kt index 93a0309c..5861589b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJson.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJson.kt @@ -1,81 +1,22 @@ package ee.carlrobert.codegpt.agent.external -import kotlinx.serialization.json.* +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.decodeFromJsonElement -internal fun JsonObject.string(vararg keys: String): String? { - return keys.firstNotNullOfOrNull { key -> - when (val element = this[key]) { - is JsonPrimitive -> element.contentOrNull?.takeIf { it.isNotBlank() } - else -> null - } - } -} - -internal fun JsonObject.commandString(): String? { - return listOf("command", "cmd").firstNotNullOfOrNull { key -> - when (val element = this[key]) { - is JsonPrimitive -> element.contentOrNull?.takeIf { it.isNotBlank() } - is JsonArray -> element.mapNotNull { item -> - (item as? JsonPrimitive)?.contentOrNull - }.takeIf { it.isNotEmpty() }?.joinToString(" ") - - else -> null - } - } -} - -internal fun JsonObject.boolean(vararg keys: String): Boolean? { - return keys.firstNotNullOfOrNull { key -> - (this[key] as? JsonPrimitive)?.booleanOrNull - } -} - -internal fun JsonObject.int(vararg keys: String): Int? { - return keys.firstNotNullOfOrNull { key -> - (this[key] as? JsonPrimitive)?.intOrNull - } -} - -internal fun JsonObject.firstLocationPath(): String? { - return (this["locations"] as? JsonArray) - ?.firstNotNullOfOrNull { (it as? JsonObject)?.string("path") } -} - -internal fun JsonObject.titlePath(): String? { - val title = string("title") ?: return null - val firstSpaceIndex = title.indexOf(' ') - if (firstSpaceIndex < 0 || firstSpaceIndex >= title.lastIndex) { - return null - } - val path = title.substring(firstSpaceIndex + 1).trim() - return path.takeIf { it.startsWith("/") } -} - -internal fun JsonObject.firstChangePath(): String? { - return (this["changes"] as? JsonObject)?.keys?.firstOrNull() -} - -internal fun JsonObject.firstChangeContent(): String? { - return (this["changes"] as? JsonObject)?.values?.firstNotNullOfOrNull { change -> - (change as? JsonObject)?.string("content") - } -} - -internal fun JsonElement?.asJsonArrayOrEmpty(): JsonArray { - return this as? JsonArray ?: JsonArray(emptyList()) -} - -internal fun JsonElement?.asJsonObjectOrNull(json: Json): JsonObject? { +internal inline fun JsonElement?.decodeOrNull(json: Json): T? { return when (this) { - is JsonObject -> this + null -> null is JsonPrimitive -> { if (!isString) { - return null + runCatching { json.decodeFromJsonElement(this) }.getOrNull() + } else { + runCatching { json.decodeFromString(content) }.getOrNull() } - runCatching { json.parseToJsonElement(content) }.getOrNull() as? JsonObject } - else -> null + else -> runCatching { json.decodeFromJsonElement(this) }.getOrNull() } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJsonRpcConnection.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJsonRpcConnection.kt deleted file mode 100644 index c5a44c90..00000000 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpJsonRpcConnection.kt +++ /dev/null @@ -1,257 +0,0 @@ -package ee.carlrobert.codegpt.agent.external - -import com.intellij.openapi.components.service -import com.intellij.openapi.diagnostic.Logger -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings -import kotlinx.coroutines.* -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock -import kotlinx.serialization.json.* -import java.nio.charset.StandardCharsets -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicLong - -internal data class AcpJsonRpcRequest( - val id: JsonElement, - val method: String, - val params: JsonObject -) - -internal data class AcpJsonRpcNotification( - val method: String, - val params: JsonObject -) - -internal data class AcpJsonRpcError( - val code: Int, - val message: String, - val data: JsonElement? = null -) { - companion object { - const val METHOD_NOT_FOUND_CODE = -32601 - } -} - -internal class AcpJsonRpcException( - val error: AcpJsonRpcError -) : IllegalStateException(error.message) - -private sealed interface AcpJsonRpcIncomingMessage { - data class Response( - val id: String, - val result: JsonElement, - val error: AcpJsonRpcError? - ) : AcpJsonRpcIncomingMessage - - data class Request(val value: AcpJsonRpcRequest) : AcpJsonRpcIncomingMessage - - data class Notification(val value: AcpJsonRpcNotification) : AcpJsonRpcIncomingMessage -} - -internal class AcpJsonRpcConnection( - private val json: Json, - private val process: Process, - private val scope: CoroutineScope, - private val logger: Logger, - private val processName: String, - private val onRequest: suspend (AcpJsonRpcRequest) -> JsonElement?, - private val onNotification: suspend (AcpJsonRpcNotification) -> Unit -) { - - private val requestCounter = AtomicLong(0) - private val pendingResponses = ConcurrentHashMap>() - private val writeMutex = Mutex() - - fun isAlive(): Boolean = process.isAlive - - fun startReader() { - scope.launch { - val reader = process.inputStream.bufferedReader(StandardCharsets.UTF_8) - try { - while (isActive && process.isAlive) { - val line = reader.readLine() ?: break - if (line.isBlank()) continue - if (service().state.debugModeEnabled) { - logger.info("[$processName] $line") - } - - when (val message = parseIncomingMessage(line)) { - null -> Unit - is AcpJsonRpcIncomingMessage.Response -> handleResponse(message) - is AcpJsonRpcIncomingMessage.Request -> reply(message.value) - is AcpJsonRpcIncomingMessage.Notification -> onNotification(message.value) - } - } - } catch (cancelled: CancellationException) { - throw cancelled - } catch (t: Throwable) { - logger.warn("ACP reader loop failed for $processName", t) - } finally { - closePendingResponses(IllegalStateException("$processName ACP process exited")) - } - } - } - - fun startStderrLogger() { - scope.launch { - process.errorStream.bufferedReader(StandardCharsets.UTF_8).useLines { lines -> - lines.forEach { line -> - if (line.isNotBlank()) { - logger.info("[$processName] $line") - } - } - } - } - } - - suspend fun request(method: String, params: JsonObject): JsonObject { - val id = requestCounter.incrementAndGet().toString() - val response = CompletableDeferred() - pendingResponses[id] = response - writePayload(requestPayload(id, method, params)) - return response.await().jsonObject - } - - suspend fun notify(method: String, params: JsonObject) { - writePayload(notificationPayload(method, params)) - } - - fun close() { - closePendingResponses(CancellationException("ACP session closed")) - process.destroy() - } - - private suspend fun reply(request: AcpJsonRpcRequest) { - val response = runCatching { onRequest(request) }.fold( - onSuccess = { body -> - if (body == null) { - errorResponse(request.id, -32601, "Method not found: ${request.method}") - } else { - successResponse(request.id, body) - } - }, - onFailure = { error -> - errorResponse(request.id, -32603, error.message ?: "Internal error") - } - ) - writePayload(response) - } - - private fun handleResponse(response: AcpJsonRpcIncomingMessage.Response) { - val pending = pendingResponses.remove(response.id) ?: return - if (response.error != null) { - pending.completeExceptionally(AcpJsonRpcException(response.error)) - return - } - pending.complete(response.result) - } - - private fun parseIncomingMessage(line: String): AcpJsonRpcIncomingMessage? { - val element = runCatching { json.parseToJsonElement(line) } - .onFailure { logger.warn("Ignoring non-JSON ACP output from $processName: $line") } - .getOrNull() ?: return null - val obj = element.jsonObject - return when { - obj["result"] != null || obj["error"] != null -> { - val responseId = obj["id"]?.jsonPrimitive?.content ?: return null - AcpJsonRpcIncomingMessage.Response( - id = responseId, - result = obj["result"] ?: JsonObject(emptyMap()), - error = parseError(obj["error"]) - ) - } - - obj["method"] != null && obj["id"] != null -> { - val method = obj["method"]?.jsonPrimitive?.content ?: return null - AcpJsonRpcIncomingMessage.Request( - AcpJsonRpcRequest( - id = obj["id"] ?: return null, - method = method, - params = obj["params"]?.jsonObject ?: JsonObject(emptyMap()) - ) - ) - } - - obj["method"] != null -> { - val method = obj["method"]?.jsonPrimitive?.content ?: return null - AcpJsonRpcIncomingMessage.Notification( - AcpJsonRpcNotification( - method = method, - params = obj["params"]?.jsonObject ?: JsonObject(emptyMap()) - ) - ) - } - - else -> null - } - } - - private fun parseError(element: JsonElement?): AcpJsonRpcError? { - val error = element as? JsonObject ?: return null - return AcpJsonRpcError( - code = error.int("code") ?: 0, - message = error.string("message") ?: "Unknown JSON-RPC error", - data = error["data"] - ) - } - - private suspend fun writePayload(payload: JsonObject) { - writeMutex.withLock { - val serializedPayload = json.encodeToString(JsonObject.serializer(), payload) - if (service().state.debugModeEnabled) { - logger.info("[$processName] $serializedPayload") - } - - process.outputStream.write( - (serializedPayload + "\n").toByteArray(StandardCharsets.UTF_8) - ) - process.outputStream.flush() - } - } - - private fun closePendingResponses(error: Throwable) { - pendingResponses.values.forEach { pending -> - pending.completeExceptionally(error) - } - pendingResponses.clear() - } - - private fun requestPayload(id: String, method: String, params: JsonObject): JsonObject { - return buildJsonObject { - put("jsonrpc", JsonPrimitive("2.0")) - put("id", JsonPrimitive(id)) - put("method", JsonPrimitive(method)) - put("params", params) - } - } - - private fun notificationPayload(method: String, params: JsonObject): JsonObject { - return buildJsonObject { - put("jsonrpc", JsonPrimitive("2.0")) - put("method", JsonPrimitive(method)) - put("params", params) - } - } - - private fun successResponse(id: JsonElement, result: JsonElement): JsonObject { - return buildJsonObject { - put("jsonrpc", JsonPrimitive("2.0")) - put("id", id) - put("result", result) - } - } - - private fun errorResponse(id: JsonElement, code: Int, message: String): JsonObject { - return buildJsonObject { - put("jsonrpc", JsonPrimitive("2.0")) - put("id", id) - put( - "error", - buildJsonObject { - put("code", JsonPrimitive(code)) - put("message", JsonPrimitive(message)) - } - ) - } - } -} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpLogging.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpLogging.kt new file mode 100644 index 00000000..b5cc83e3 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpLogging.kt @@ -0,0 +1,59 @@ +package ee.carlrobert.codegpt.agent.external + +import com.agentclientprotocol.model.ContentBlock +import com.agentclientprotocol.model.SessionUpdate +import com.agentclientprotocol.model.ToolCallStatus +import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload +import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.EditPayload +import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalPayload +import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.WritePayload +import kotlinx.serialization.json.JsonElement + +internal val ToolCallStatus.wireValue: String + get() = when (this) { + ToolCallStatus.PENDING -> "pending" + ToolCallStatus.IN_PROGRESS -> "in_progress" + ToolCallStatus.COMPLETED -> "completed" + ToolCallStatus.FAILED -> "failed" + } + +internal fun String.logPreview(limit: Int = 120): String { + return replace('\n', ' ') + .replace(Regex("\\s+"), " ") + .trim() + .take(limit) +} + +internal fun ToolApprovalPayload?.logSummary(): String { + return when (this) { + is WritePayload -> "write:${filePath.logPreview(80)}" + is EditPayload -> "edit:${filePath.logPreview(80)}" + is BashPayload -> "bash:${command.logPreview(80)}" + null -> "none" + } +} + +internal fun JsonElement?.logSummary(limit: Int = 400): String { + return this?.toString()?.replace(Regex("\\s+"), " ")?.take(limit) ?: "null" +} + +internal fun Any?.logSummary(limit: Int = 400): String { + return this?.toString()?.replace(Regex("\\s+"), " ")?.take(limit) ?: "null" +} + +internal fun SessionUpdate.logSummary(): String { + return when (this) { + is SessionUpdate.ToolCall -> "toolCallId=${toolCallId.value} title=${title.logPreview()} kind=${kind?.name} status=${status?.wireValue} rawInput=${rawInput.logSummary()}" + is SessionUpdate.ToolCallUpdate -> "toolCallId=${toolCallId.value} title=${title?.logPreview()} kind=${kind?.name} status=${status?.wireValue} rawInput=${rawInput.logSummary()} rawOutput=${rawOutput.logSummary()}" + is SessionUpdate.AgentMessageChunk -> "message=${(content as? ContentBlock.Text)?.text?.logPreview()}" + is SessionUpdate.AgentThoughtChunk -> "thought=${(content as? ContentBlock.Text)?.text?.logPreview()}" + is SessionUpdate.CurrentModeUpdate -> "currentMode=${currentModeId.value}" + is SessionUpdate.PlanUpdate -> "planEntries=${entries.size}" + is SessionUpdate.ConfigOptionUpdate -> "configOptions=${configOptions.size}" + is SessionUpdate.SessionInfoUpdate -> "title=${title?.logPreview().orEmpty()} updatedAt=${updatedAt.orEmpty()}" + is SessionUpdate.UsageUpdate -> "used=$used size=$size cost=${cost.logSummary()}" + is SessionUpdate.UnknownSessionUpdate -> "type=$sessionUpdateType raw=${rawJson.logSummary()}" + is SessionUpdate.AvailableCommandsUpdate -> "availableCommands=${availableCommands.size}" + else -> this::class.simpleName.orEmpty() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpProcessHelper.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpProcessHelper.kt deleted file mode 100644 index 3d0d0a75..00000000 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpProcessHelper.kt +++ /dev/null @@ -1,41 +0,0 @@ -package ee.carlrobert.codegpt.agent.external - -import ee.carlrobert.codegpt.util.CommandRuntimeHelper - -object AcpProcessHelper { - - fun resolveCommand( - command: String, - extraEnvironment: Map = emptyMap() - ): String? { - return CommandRuntimeHelper.resolveCommand(command, extraEnvironment) - } - - fun createEnvironment( - extraEnvironment: Map, - resolvedCommand: String - ): MutableMap { - return CommandRuntimeHelper.createEnvironment( - extraEnvironment = extraEnvironment, - resolvedCommand = resolvedCommand - ) - } - - fun getCommandNotFoundMessage(command: String): String { - return buildString { - append("Command '$command' not found. ") - when (command) { - "npx", "node" -> { - append("Node.js/npm is required for this ACP runtime. ") - append("Ensure it is installed and available to the IDE process. ") - append("You can also point the runtime to an absolute executable path.") - } - - else -> { - append("Ensure it is installed and available to the IDE process. ") - append("You can also point the runtime to an absolute executable path.") - } - } - } - } -} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpRuntimeState.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpRuntimeState.kt new file mode 100644 index 00000000..9fa1458a --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpRuntimeState.kt @@ -0,0 +1,27 @@ +package ee.carlrobert.codegpt.agent.external + +import com.agentclientprotocol.model.AcpCreatedSessionResponse +import com.agentclientprotocol.model.AvailableCommand +import com.agentclientprotocol.model.SessionConfigOption +import com.agentclientprotocol.model.SessionModelState +import com.agentclientprotocol.model.SessionModeState +import kotlinx.serialization.json.JsonElement + +internal data class AcpRuntimeState( + val modes: SessionModeState? = null, + val models: SessionModelState? = null, + val configOptions: List = emptyList(), + val availableCommands: List = emptyList(), + val sessionTitle: String? = null, + val sessionUpdatedAt: String? = null, + val vendorMeta: JsonElement? = null +) + +internal fun AcpCreatedSessionResponse.toRuntimeState(): AcpRuntimeState { + return AcpRuntimeState( + modes = modes, + models = models, + configOptions = configOptions.orEmpty(), + vendorMeta = _meta + ) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionConfigAdapter.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionConfigAdapter.kt deleted file mode 100644 index 0f4b45b0..00000000 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionConfigAdapter.kt +++ /dev/null @@ -1,266 +0,0 @@ -package ee.carlrobert.codegpt.agent.external - -import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOption -import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOptionChoice -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.* - -internal enum class AcpSessionConfigId( - val value: String, - val displayName: String -) { - MODEL("model", "Model"), - MODE("mode", "Mode"); - - fun matches(option: AcpConfigOption): Boolean { - return option.id == value || option.category == value - } -} - -internal data class AcpConfigUpdateRequest( - val sessionId: String, - val optionId: String, - val value: String -) { - fun toJsonObject(): JsonObject { - return buildJsonObject { - put("sessionId", sessionId) - put("configId", optionId) - put("value", value) - } - } -} - -internal enum class AcpConfigUpdateMethod(val wireName: String) { - SNAKE_CASE("session/set_config_option"), - CAMEL_CASE("session/setConfigOption") -} - -internal sealed interface AcpConfigUpdateSupport { - fun candidateMethods(): List - - data object Unknown : AcpConfigUpdateSupport { - override fun candidateMethods(): List = - AcpConfigUpdateMethod.entries - } - - data object Unsupported : AcpConfigUpdateSupport { - override fun candidateMethods(): List = emptyList() - } - - data class Supported(val method: AcpConfigUpdateMethod) : AcpConfigUpdateSupport { - override fun candidateMethods(): List = listOf(method) - } -} - -internal sealed interface AcpConfigUpdateResult { - data class Applied( - val response: JsonObject, - val support: AcpConfigUpdateSupport.Supported - ) : AcpConfigUpdateResult - - data object Unsupported : AcpConfigUpdateResult -} - -internal class AcpSessionConfigAdapter( - private val json: Json -) { - - fun merge( - existing: List, - response: JsonObject - ): List { - val updates = decode(response) - if (updates.isEmpty()) { - return existing - } - - val merged = existing.associateByTo(linkedMapOf()) { it.id } - updates.forEach { option -> - merged[option.id] = option - } - return merged.values.toList() - } - - suspend fun updateOption( - request: AcpConfigUpdateRequest, - support: AcpConfigUpdateSupport, - sendRequest: suspend (String, JsonObject) -> JsonObject - ): AcpConfigUpdateResult { - val candidateMethods = support.candidateMethods() - if (candidateMethods.isEmpty()) { - return AcpConfigUpdateResult.Unsupported - } - - val params = request.toJsonObject() - candidateMethods.forEach { method -> - try { - return AcpConfigUpdateResult.Applied( - response = sendRequest(method.wireName, params), - support = AcpConfigUpdateSupport.Supported(method) - ) - } catch (error: Throwable) { - if (!error.isMethodNotFoundJsonRpcError()) { - throw error - } - } - } - - return AcpConfigUpdateResult.Unsupported - } - - private fun decode(response: JsonObject): List { - val directOptions = buildList { - addAll( - response.decodeField>("configOptions") - .orEmpty() - .mapNotNull(AcpStandardConfigOptionPayload::toConfigOption) - ) - response.decodeField("configOption") - ?.toConfigOption() - ?.let(::add) - } - - val hasDirectModelOption = directOptions.any(AcpSessionConfigId.MODEL::matches) - val hasDirectModeOption = directOptions.any(AcpSessionConfigId.MODE::matches) - - return buildList { - addAll(directOptions) - if (!hasDirectModelOption) { - response.decodeField("models") - ?.toConfigOption() - ?.let(::add) - } - if (!hasDirectModeOption) { - response.decodeField("modes") - ?.toConfigOption() - ?.let(::add) - } - } - } - - private inline fun JsonObject.decodeField(key: String): T? { - val element = this[key] ?: return null - return runCatching { json.decodeFromJsonElement(element) }.getOrNull() - } -} - -@Serializable -private data class AcpStandardConfigOptionPayload( - val id: String? = null, - val name: String? = null, - val description: String? = null, - val category: String? = null, - val type: String? = null, - val currentValue: String? = null, - val current_value: String? = null, - val value: String? = null, - val options: List = emptyList() -) { - fun toConfigOption(): AcpConfigOption? { - val resolvedId = id.nullIfBlank() ?: return null - return AcpConfigOption( - id = resolvedId, - name = name.nullIfBlank() ?: resolvedId, - description = description.nullIfBlank(), - category = category.nullIfBlank(), - type = type.nullIfBlank(), - currentValue = firstNotBlank(currentValue, current_value, value), - options = options.mapNotNull(AcpConfigChoicePayload::toChoice) - ) - } -} - -@Serializable -private data class AcpConfigChoicePayload( - val value: String? = null, - val id: String? = null, - val name: String? = null, - val description: String? = null -) { - fun toChoice(): AcpConfigOptionChoice? { - val resolvedValue = firstNotBlank(value, id) ?: return null - return AcpConfigOptionChoice( - value = resolvedValue, - name = name.nullIfBlank() ?: resolvedValue, - description = description.nullIfBlank() - ) - } -} - -@Serializable -private data class AcpModelsPayload( - val currentModelId: String? = null, - val availableModels: List = emptyList() -) { - fun toConfigOption(): AcpConfigOption? { - return toAlternativeConfigOption( - id = AcpSessionConfigId.MODEL, - currentValue = currentModelId, - entries = availableModels - ) - } -} - -@Serializable -private data class AcpModesPayload( - val currentModeId: String? = null, - val availableModes: List = emptyList() -) { - fun toConfigOption(): AcpConfigOption? { - return toAlternativeConfigOption( - id = AcpSessionConfigId.MODE, - currentValue = currentModeId, - entries = availableModes - ) - } -} - -@Serializable -private data class AcpAlternativeChoicePayload( - val modelId: String? = null, - val modeId: String? = null, - val value: String? = null, - val id: String? = null, - val name: String? = null, - val description: String? = null -) { - fun toChoice(): AcpConfigOptionChoice? { - val resolvedValue = firstNotBlank(modelId, modeId, value, id) ?: return null - return AcpConfigOptionChoice( - value = resolvedValue, - name = name.nullIfBlank() ?: resolvedValue, - description = description.nullIfBlank() - ) - } -} - -private fun toAlternativeConfigOption( - id: AcpSessionConfigId, - currentValue: String?, - entries: List -): AcpConfigOption? { - val options = entries.mapNotNull(AcpAlternativeChoicePayload::toChoice) - val resolvedCurrentValue = currentValue.nullIfBlank() - if (options.isEmpty() && resolvedCurrentValue == null) { - return null - } - return AcpConfigOption( - id = id.value, - name = id.displayName, - category = id.value, - type = "select", - currentValue = resolvedCurrentValue, - options = options - ) -} - -private fun firstNotBlank(vararg values: String?): String? { - return values.firstNotNullOfOrNull(String?::nullIfBlank) -} - -private fun String?.nullIfBlank(): String? = this?.takeIf { it.isNotBlank() } - -private fun Throwable.isMethodNotFoundJsonRpcError(): Boolean { - return (this as? AcpJsonRpcException)?.error?.code == AcpJsonRpcError.METHOD_NOT_FOUND_CODE -} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateBridge.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateBridge.kt new file mode 100644 index 00000000..613e8840 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateBridge.kt @@ -0,0 +1,175 @@ +package ee.carlrobert.codegpt.agent.external + +import com.agentclientprotocol.model.SessionNotification +import com.agentclientprotocol.model.AvailableCommand +import com.agentclientprotocol.model.SessionConfigOption +import com.agentclientprotocol.model.ToolKind +import ee.carlrobert.codegpt.agent.AgentEvents +import ee.carlrobert.codegpt.agent.AgentUsageEvent +import ee.carlrobert.codegpt.agent.external.events.AcpExternalEvent +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallSnapshot +import java.util.concurrent.ConcurrentHashMap + +internal class AcpSessionUpdateBridge( + private val proxySessionId: String, + private val toolEventFlavor: AcpToolEventFlavor, + private val toolCallDecoder: AcpToolCallDecoder, + private val updateModeSelection: (String) -> Unit, + private val updateConfigOptions: (List) -> Unit, + private val updateAvailableCommands: (List) -> Unit, + private val updateSessionInfo: (String?, String?) -> Unit, + private val trace: (String) -> Unit +) { + private val toolCallsById = ConcurrentHashMap() + + fun handle(notification: SessionNotification, events: AgentEvents) { + trace( + "session/update session=$proxySessionId update=${notification.update::class.simpleName} payload=${notification.update.logSummary()}" + ) + when (val event = + toolCallDecoder.decodeExternalEvent(toolEventFlavor, notification.update)) { + is AcpExternalEvent.TextChunk -> events.onTextReceived(event.text) + is AcpExternalEvent.ThinkingChunk -> events.onThinkingReceived(event.text) + is AcpExternalEvent.PlanUpdate -> events.onPlanUpdated(event.entries) + is AcpExternalEvent.UsageUpdate -> events.onUsageAvailable( + AgentUsageEvent( + usedTokens = event.used, + sizeTokens = event.size.takeIf { it > 0L }, + costAmount = event.cost?.amount, + costCurrency = event.cost?.currency + ) + ) + is AcpExternalEvent.ConfigOptionUpdate -> { + updateConfigOptions(event.configOptions) + events.onRuntimeOptionsUpdated() + } + is AcpExternalEvent.SessionInfoUpdate -> { + updateSessionInfo(event.title, event.updatedAt) + events.onSessionInfoUpdated(event.title, event.updatedAt) + } + is AcpExternalEvent.UnknownSessionUpdate -> trace( + "session/update/unknown session=$proxySessionId type=${event.type} payload=${event.rawJson.logSummary()}" + ) + is AcpExternalEvent.AvailableCommandsUpdate -> updateAvailableCommands(event.availableCommands) + is AcpExternalEvent.CurrentModeUpdate -> { + updateModeSelection(event.currentModeId) + events.onRuntimeOptionsUpdated() + } + is AcpExternalEvent.ToolCallStarted -> handleToolCallStarted(event.toolCall, events) + is AcpExternalEvent.ToolCallUpdated -> handleToolCallUpdated(event.toolCall, events) + else -> Unit + } + } + + private fun handleToolCallStarted( + toolCall: AcpToolCallSnapshot, + events: AgentEvents + ) { + trace( + "tool-call/start session=$proxySessionId id=${toolCall.id} title=${toolCall.title.logPreview()} kind=${toolCall.kind?.name} status=${toolCall.status?.wireValue} rawInput=${toolCall.rawInput.logSummary()}" + ) + if (toolCall.kind == ToolKind.THINK && toolCall.title.isNotBlank()) { + events.onThinkingReceived(toolCall.title) + return + } + toolCallsById[toolCall.id] = toolCall + if (!shouldDeferToolStart(toolCall)) { + events.onToolStarting(toolCall.id, toolCall.toolName, toolCall.args.toUiArgs()) + } + + if (toolCall.status?.isTerminal == true) { + completeToolCall(toolCall, events) + } + } + + private fun handleToolCallUpdated( + toolCall: AcpToolCallSnapshot, + events: AgentEvents + ) { + trace( + "tool-call/update session=$proxySessionId id=${toolCall.id} title=${toolCall.title.logPreview()} kind=${toolCall.kind?.name} status=${toolCall.status?.wireValue} rawInput=${toolCall.rawInput.logSummary()} rawOutput=${toolCall.rawOutput.logSummary()}" + ) + val currentToolCall = toolCallsById[toolCall.id] + val effectiveToolCall = mergeToolCallSnapshots(currentToolCall, toolCall) + val status = effectiveToolCall.status ?: return + + toolCallsById[effectiveToolCall.id] = effectiveToolCall + + if ((currentToolCall?.args == null && effectiveToolCall.args != null) || + (currentToolCall == null && !shouldDeferToolStart(effectiveToolCall)) + ) { + events.onToolStarting( + effectiveToolCall.id, + effectiveToolCall.toolName, + effectiveToolCall.args.toUiArgs() + ) + } + + if (!status.isTerminal) { + return + } + + completeToolCall(effectiveToolCall, events) + } + + private fun completeToolCall( + toolCall: AcpToolCallSnapshot, + events: AgentEvents + ) { + val result = toolCallDecoder.decodeResult(toolCall, toolCall.rawOutput) + trace( + "tool-call/complete session=$proxySessionId id=${toolCall.id} tool=${toolCall.toolName} status=${toolCall.status?.wireValue} result=${result.logSummary()}" + ) + toolCallsById.remove(toolCall.id) + events.onToolCompleted(toolCall.id, toolCall.toolName, result) + } + + private fun shouldDeferToolStart(toolCall: AcpToolCallSnapshot): Boolean { + return toolCall.toolName in setOf("WebSearch", "WebFetch", "Write", "Edit") && + toolCall.args == null && + toolCall.status?.isTerminal != true + } + + private fun AcpToolCallArgs?.toUiArgs(): Any? { + return when (this) { + is AcpToolCallArgs.SearchPreview -> value + is AcpToolCallArgs.BashPreview -> value + is AcpToolCallArgs.Mcp -> value + is AcpToolCallArgs.Read -> value + is AcpToolCallArgs.Write -> value + is AcpToolCallArgs.Edit -> value + is AcpToolCallArgs.Bash -> value + is AcpToolCallArgs.WebSearch -> value + is AcpToolCallArgs.WebFetch -> value + is AcpToolCallArgs.IntelliJSearch -> value + is AcpToolCallArgs.Unknown -> value + null -> null + } + } + + private fun mergeToolCallSnapshots( + currentToolCall: AcpToolCallSnapshot?, + updatedToolCall: AcpToolCallSnapshot + ): AcpToolCallSnapshot { + val effectiveTitle = updatedToolCall.title.ifBlank { currentToolCall?.title.orEmpty() } + .ifBlank { "Tool" } + val effectiveToolName = when { + updatedToolCall.toolName.isNotBlank() && updatedToolCall.toolName != "Tool" -> updatedToolCall.toolName + currentToolCall != null -> currentToolCall.toolName + else -> "Tool" + } + return updatedToolCall.copy( + title = effectiveTitle, + toolName = effectiveToolName, + kind = updatedToolCall.kind ?: currentToolCall?.kind, + status = updatedToolCall.status ?: currentToolCall?.status, + args = updatedToolCall.args ?: currentToolCall?.args, + locations = updatedToolCall.locations.ifEmpty { currentToolCall?.locations.orEmpty() }, + content = updatedToolCall.content.ifEmpty { currentToolCall?.content.orEmpty() }, + meta = updatedToolCall.meta ?: currentToolCall?.meta, + rawInput = updatedToolCall.rawInput ?: currentToolCall?.rawInput, + rawOutput = updatedToolCall.rawOutput ?: currentToolCall?.rawOutput + ) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateParser.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateParser.kt deleted file mode 100644 index d2059994..00000000 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateParser.kt +++ /dev/null @@ -1,122 +0,0 @@ -package ee.carlrobert.codegpt.agent.external - -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonObject - -internal sealed interface AcpSessionUpdate { - data class TextChunk(val text: String) : AcpSessionUpdate - data class ThoughtChunk(val text: String) : AcpSessionUpdate - data class ToolCall( - val toolCall: AcpDecodedToolCall, - val status: AcpToolCallStatus?, - val rawOutput: JsonElement? - ) : AcpSessionUpdate - - data class ToolCallUpdate( - val toolCallId: String, - val toolCall: AcpDecodedToolCall?, - val status: AcpToolCallStatus, - val rawOutput: JsonElement? - ) : AcpSessionUpdate - - data class ConfigOptionUpdate(val update: JsonObject) : AcpSessionUpdate -} - -internal enum class AcpToolCallStatus(val wireValue: String) { - IN_PROGRESS("in_progress"), - COMPLETED("completed"), - FAILED("failed"), - CANCELLED("cancelled"); - - val isTerminal: Boolean - get() = this != IN_PROGRESS - - companion object { - fun fromWireValue(value: String?): AcpToolCallStatus? { - return entries.firstOrNull { it.wireValue == value } - } - } -} - -internal class AcpSessionUpdateParser( - private val toolCallDecoder: AcpToolCallDecoder -) { - - fun parse(notification: AcpJsonRpcNotification): AcpSessionUpdate? { - if (notification.method != SESSION_UPDATE_METHOD) { - return null - } - - val update = notification.params["update"] as? JsonObject ?: return null - return when (SessionUpdateKind.fromWireValue(update.string("sessionUpdate"))) { - SessionUpdateKind.AGENT_MESSAGE_CHUNK -> parseAgentMessageChunk(update) - SessionUpdateKind.TOOL_CALL -> parseToolCall(update) - SessionUpdateKind.TOOL_CALL_UPDATE -> parseToolCallUpdate(update) - SessionUpdateKind.CONFIG_OPTION_UPDATE -> AcpSessionUpdate.ConfigOptionUpdate(update) - null -> null - } - } - - private fun parseAgentMessageChunk(update: JsonObject): AcpSessionUpdate? { - val content = update["content"] as? JsonObject ?: return null - return when (MessageChunkType.fromWireValue(content.string("type"))) { - MessageChunkType.TEXT -> - AcpSessionUpdate.TextChunk(content.string("text").orEmpty()) - - MessageChunkType.THOUGHT -> { - val text = content.string("thought").orEmpty() - text.takeIf { it.isNotBlank() }?.let(AcpSessionUpdate::ThoughtChunk) - } - - null -> null - } - } - - private fun parseToolCall(update: JsonObject): AcpSessionUpdate? { - val toolCall = toolCallDecoder.decodeToolCall(update) ?: return null - return AcpSessionUpdate.ToolCall( - toolCall = toolCall, - status = AcpToolCallStatus.fromWireValue(update.string("status")), - rawOutput = update["rawOutput"] ?: update["content"] - ) - } - - private fun parseToolCallUpdate(update: JsonObject): AcpSessionUpdate? { - val toolCallId = update.string("toolCallId") ?: return null - val status = AcpToolCallStatus.fromWireValue(update.string("status")) ?: return null - return AcpSessionUpdate.ToolCallUpdate( - toolCallId = toolCallId, - toolCall = toolCallDecoder.decodeToolCall(update), - status = status, - rawOutput = update["rawOutput"] ?: update["content"] - ) - } - - private enum class SessionUpdateKind(val wireValue: String) { - AGENT_MESSAGE_CHUNK("agent_message_chunk"), - TOOL_CALL("tool_call"), - TOOL_CALL_UPDATE("tool_call_update"), - CONFIG_OPTION_UPDATE("config_option_update"); - - companion object { - fun fromWireValue(value: String?): SessionUpdateKind? { - return entries.firstOrNull { it.wireValue == value } - } - } - } - - private enum class MessageChunkType(val wireValue: String) { - TEXT("text"), - THOUGHT("thought"); - - companion object { - fun fromWireValue(value: String?): MessageChunkType? { - return entries.firstOrNull { it.wireValue == value } - } - } - } - - private companion object { - const val SESSION_UPDATE_METHOD = "session/update" - } -} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecoder.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecoder.kt index 7a8bab2f..6c9f49d9 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecoder.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecoder.kt @@ -1,68 +1,179 @@ package ee.carlrobert.codegpt.agent.external +import com.agentclientprotocol.model.ContentBlock +import com.agentclientprotocol.model.RequestPermissionRequest +import com.agentclientprotocol.model.SessionUpdate +import com.agentclientprotocol.model.ToolCallContent +import com.agentclientprotocol.model.ToolCallStatus +import com.agentclientprotocol.model.ToolCallLocation +import com.agentclientprotocol.model.ToolKind import ee.carlrobert.codegpt.agent.ToolSpecs -import ee.carlrobert.codegpt.agent.tools.* +import ee.carlrobert.codegpt.agent.tools.EditTool +import ee.carlrobert.codegpt.agent.tools.McpTool +import ee.carlrobert.codegpt.agent.tools.WriteTool +import ee.carlrobert.codegpt.agent.external.events.AcpExternalEvent +import ee.carlrobert.codegpt.agent.external.events.AcpPermissionRequestSnapshot +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallSnapshot +import ee.carlrobert.codegpt.agent.external.events.toAcpToolCallContent import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonObject import java.nio.charset.StandardCharsets -internal data class AcpDecodedToolCall( - val id: String, - val toolName: String, - val args: Any? -) - -internal data class AcpPermissionRequestData( - val rawTitle: String, - val toolName: String, - val parsedArgs: Any?, - val details: String, - val options: JsonArray -) - -private data class DiffContent( - val path: String, - val oldText: String?, - val newText: String -) - -private data class ResolvedToolCall( - val rawTitle: String, - val toolName: String, - val args: Any? -) - internal class AcpToolCallDecoder( private val json: Json ) { + private val support = AcpToolCallDecodingSupport(json) + private val standardNormalizer = StandardSemanticToolCallNormalizer() + private val zedNormalizer = ZedAdapterToolCallNormalizer() + private val geminiNormalizer = GeminiCliToolCallNormalizer() + private val fallbackNormalizer = FallbackToolCallNormalizer() - fun decodeToolCall(metadata: JsonObject): AcpDecodedToolCall? { - val toolCallId = metadata.string("toolCallId") ?: return null - val tool = resolveToolCall(metadata) - return AcpDecodedToolCall( - id = toolCallId, - toolName = tool.toolName, - args = tool.args + fun decodeExternalEvent( + flavor: AcpToolEventFlavor, + update: SessionUpdate + ): AcpExternalEvent? { + return when (update) { + is SessionUpdate.UserMessageChunk -> null + is SessionUpdate.AgentMessageChunk -> { + (update.content as? ContentBlock.Text)?.text?.let { + AcpExternalEvent.TextChunk(it) + } + } + + is SessionUpdate.AgentThoughtChunk -> { + (update.content as? ContentBlock.Text)?.text?.let { + AcpExternalEvent.ThinkingChunk(it) + } + } + + is SessionUpdate.PlanUpdate -> AcpExternalEvent.PlanUpdate(update.entries) + is SessionUpdate.UsageUpdate -> AcpExternalEvent.UsageUpdate( + used = update.used, + size = update.size, + cost = update.cost + ) + is SessionUpdate.ConfigOptionUpdate -> AcpExternalEvent.ConfigOptionUpdate(update.configOptions) + is SessionUpdate.SessionInfoUpdate -> AcpExternalEvent.SessionInfoUpdate( + title = update.title, + updatedAt = update.updatedAt + ) + is SessionUpdate.UnknownSessionUpdate -> AcpExternalEvent.UnknownSessionUpdate( + type = update.sessionUpdateType, + rawJson = update.rawJson + ) + is SessionUpdate.AvailableCommandsUpdate -> AcpExternalEvent.AvailableCommandsUpdate(update.availableCommands) + is SessionUpdate.CurrentModeUpdate -> AcpExternalEvent.CurrentModeUpdate(update.currentModeId.value) + is SessionUpdate.ToolCall -> AcpExternalEvent.ToolCallStarted(decodeToolCallSnapshot(flavor, update)) + is SessionUpdate.ToolCallUpdate -> AcpExternalEvent.ToolCallUpdated(decodeToolCallSnapshot(flavor, update)) + } + } + + fun decodePermissionRequest( + flavor: AcpToolEventFlavor, + request: RequestPermissionRequest + ): AcpPermissionRequestSnapshot { + return AcpPermissionRequestSnapshot( + toolCall = decodeToolCallSnapshot( + flavor = flavor, + toolCallId = request.toolCall.toolCallId.value, + rawTitle = request.toolCall.title ?: "Allow action?", + rawKind = request.toolCall.kind?.wireValue, + rawInput = request.toolCall.rawInput, + locations = request.toolCall.locations.orEmpty(), + content = request.toolCall.content.orEmpty(), + rawMeta = request.toolCall._meta, + defaultTitle = "Allow action?" + ), + details = permissionDetails( + locations = request.toolCall.locations.orEmpty(), + rawInput = request.toolCall.rawInput, + fallback = request.toolCall.toString() + ), + options = request.options ) } - fun decodePermissionRequest(params: JsonObject): AcpPermissionRequestData { - val toolCall = params["toolCall"] as? JsonObject ?: JsonObject(emptyMap()) - val tool = resolveToolCall(toolCall, defaultTitle = "Allow action?") - return AcpPermissionRequestData( - rawTitle = tool.rawTitle, - toolName = tool.toolName, - parsedArgs = tool.args, - details = permissionDetails(toolCall), - options = params["options"].asJsonArrayOrEmpty() + private fun decodeToolCallSnapshot( + flavor: AcpToolEventFlavor, + update: SessionUpdate.ToolCall + ): AcpToolCallSnapshot { + return decodeToolCallSnapshot( + flavor = flavor, + toolCallId = update.toolCallId.value, + rawTitle = update.title, + rawKind = update.kind?.wireValue, + rawInput = update.rawInput, + locations = update.locations, + content = update.content, + rawMeta = update._meta, + rawOutput = update.rawOutput, + status = update.status?.toAcpToolCallStatus() + ) + } + + private fun decodeToolCallSnapshot( + flavor: AcpToolEventFlavor, + update: SessionUpdate.ToolCallUpdate + ): AcpToolCallSnapshot { + return decodeToolCallSnapshot( + flavor = flavor, + toolCallId = update.toolCallId.value, + rawTitle = update.title.orEmpty(), + rawKind = update.kind?.wireValue, + rawInput = update.rawInput, + locations = update.locations.orEmpty(), + content = update.content.orEmpty(), + rawMeta = update._meta, + rawOutput = update.rawOutput, + status = update.status?.toAcpToolCallStatus(), + defaultTitle = "" + ) + } + + private fun decodeToolCallSnapshot( + flavor: AcpToolEventFlavor, + toolCallId: String, + rawTitle: String, + rawKind: String?, + rawInput: JsonElement?, + locations: List, + content: List, + rawMeta: JsonElement? = null, + rawOutput: JsonElement? = null, + status: AcpToolCallStatus? = null, + defaultTitle: String = "Tool" + ): AcpToolCallSnapshot { + val resolved = resolveToolCall( + flavor = flavor, + context = AcpToolCallContext( + toolCallId = toolCallId, + rawTitle = rawTitle, + rawKind = rawKind, + rawInput = rawInput, + locations = locations, + content = content, + defaultTitle = defaultTitle + ) + ) + return AcpToolCallSnapshot( + id = toolCallId, + title = resolved.rawTitle, + toolName = resolved.toolName, + kind = rawKind?.toAcpToolKind(), + status = status, + args = resolved.typedArgs, + locations = locations, + content = content.toAcpToolCallContent(), + meta = rawMeta, + rawInput = rawInput, + rawOutput = rawOutput ) } fun decodeResult( toolName: String, - args: Any?, + args: AcpToolCallArgs?, status: AcpToolCallStatus, rawOutput: JsonElement? ): Any? { @@ -71,325 +182,131 @@ internal class AcpToolCallDecoder( if (status == AcpToolCallStatus.COMPLETED) { when (args) { - is McpTool.Args -> return McpTool.Result( - serverId = args.serverId, - serverName = args.serverName, - toolName = args.toolName, + is AcpToolCallArgs.Mcp -> return McpTool.Result( + serverId = args.value.serverId, + serverName = args.value.serverName, + toolName = args.value.toolName, success = true, output = payload.ifBlank { "MCP tool completed" } ) - is EditTool.Args -> return EditTool.Result.Success( - filePath = args.filePath, + is AcpToolCallArgs.Edit -> return EditTool.Result.Success( + filePath = args.value.filePath, replacementsMade = 1, message = "Edit completed" ) - is WriteTool.Args -> return WriteTool.Result.Success( - filePath = args.filePath, - bytesWritten = args.content.toByteArray(StandardCharsets.UTF_8).size, + is AcpToolCallArgs.Write -> return WriteTool.Result.Success( + filePath = args.value.filePath, + bytesWritten = args.value.content.toByteArray(StandardCharsets.UTF_8).size, isNewFile = false, message = "Write completed" ) + + else -> Unit } } if (status == AcpToolCallStatus.FAILED || status == AcpToolCallStatus.CANCELLED) { val message = payload.ifBlank { "Tool ${status.wireValue}" } when (args) { - is McpTool.Args -> return McpTool.Result.error( - toolName = args.toolName, + is AcpToolCallArgs.Mcp -> return McpTool.Result.error( + toolName = args.value.toolName, output = message, - serverId = args.serverId, - serverName = args.serverName + serverId = args.value.serverId, + serverName = args.value.serverName ) - is EditTool.Args -> return EditTool.Result.Error(args.filePath, message) - is WriteTool.Args -> return WriteTool.Result.Error(args.filePath, message) + is AcpToolCallArgs.Edit -> return EditTool.Result.Error(args.value.filePath, message) + is AcpToolCallArgs.Write -> return WriteTool.Result.Error(args.value.filePath, message) + else -> Unit } } return payload.ifBlank { null } } - private fun resolveToolCall( - metadata: JsonObject, - defaultTitle: String = "Tool" - ): ResolvedToolCall { - val rawTitle = metadata.string("title") ?: defaultTitle - val rawKind = metadata.string("kind") - val rawInput = metadata["rawInput"] - val initialToolName = normalizeToolName(rawTitle, rawKind, rawInput) - val parsedArgs = if (initialToolName == "MCP") { - decodeMcpArgs(rawTitle, rawInput) - } else { - decodeToolArgs(initialToolName, rawInput, metadata) - } - return ResolvedToolCall( - rawTitle = rawTitle, - toolName = resolveToolName(initialToolName, parsedArgs), - args = parsedArgs + fun decodeResult( + toolCall: AcpToolCallSnapshot, + rawOutput: JsonElement? + ): Any? { + return decodeResult( + toolName = toolCall.toolName, + args = toolCall.args, + status = toolCall.status ?: AcpToolCallStatus.COMPLETED, + rawOutput = rawOutput ?: toolCall.rawOutput ) } - private fun permissionDetails(toolCall: JsonObject): String { + private fun resolveToolCall( + flavor: AcpToolEventFlavor, + context: AcpToolCallContext + ): AcpResolvedToolCall { + val vendorNormalizers = when (flavor) { + AcpToolEventFlavor.ZED_ADAPTER -> listOf(zedNormalizer, standardNormalizer) + AcpToolEventFlavor.GEMINI_CLI -> listOf(geminiNormalizer, standardNormalizer) + AcpToolEventFlavor.STANDARD -> listOf(standardNormalizer) + } + return vendorNormalizers + .firstNotNullOfOrNull { it.normalize(context, support) } + ?: fallbackNormalizer.normalize(context, support) + } + + private fun permissionDetails( + locations: List, + rawInput: JsonElement?, + fallback: String + ): String { return buildString { - (toolCall["locations"] as? JsonArray) - ?.mapNotNull { (it as? JsonObject)?.string("path") } - ?.takeIf { it.isNotEmpty() } + locations.map(ToolCallLocation::path) + .takeIf { it.isNotEmpty() } ?.let { paths -> appendLine("Locations:") - paths.forEach { path -> appendLine(path) } + paths.forEach(::appendLine) } - - toolCall["rawInput"]?.let { rawInput -> + rawInput?.let { input -> if (isNotBlank()) { appendLine() } appendLine("Input:") - append(rawInput.toString()) + append(input.toString()) } - }.ifBlank { toolCall.toString() } - } - - private fun normalizeToolName( - rawTitle: String, - kind: String?, - rawInput: JsonElement? - ): String { - val normalizedKind = kind?.lowercase().orEmpty() - val rawInputObject = rawInput.asJsonObjectOrNull(json) - val actionType = (rawInputObject?.get("action") as? JsonObject)?.string("type")?.lowercase() - val titleLower = rawTitle.lowercase() - - return when { - looksLikeMcpToolName(rawTitle) -> "MCP" - rawInputObject?.get("command") != null || rawInputObject?.get("cmd") != null -> "Bash" - normalizedKind == "execute" || normalizedKind == "terminal" || normalizedKind == "bash" -> "Bash" - actionType == "search" -> "WebSearch" - actionType == "open_page" || actionType == "fetch" || actionType == "open" -> "WebFetch" - normalizedKind == "edit" -> "Edit" - normalizedKind == "read" -> "Read" - normalizedKind == "search" -> "IntelliJSearch" - normalizedKind == "fetch" && ( - titleLower == "searching the web" || titleLower.startsWith("searching for:") - ) -> "WebSearch" - normalizedKind == "fetch" || titleLower.startsWith("opening:") -> "WebFetch" - else -> rawTitle.ifBlank { kind ?: "Tool" } - } - } - - private fun resolveToolName(initialToolName: String, args: Any?): String { - return when (args) { - is McpTool.Args -> "MCP" - is WriteTool.Args -> "Write" - is EditTool.Args -> "Edit" - is ReadTool.Args -> "Read" - is IntelliJSearchTool.Args -> "IntelliJSearch" - is BashTool.Args -> "Bash" - is WebSearchTool.Args -> "WebSearch" - is WebFetchTool.Args -> "WebFetch" - else -> initialToolName - } - } - - private fun decodeToolArgs( - toolName: String, - rawInput: JsonElement?, - metadata: JsonObject? = null - ): Any? { - val payload = rawInput.toPayloadString() - ToolSpecs.decodeArgsOrNull(json, toolName, payload)?.let { return it } - - val obj = rawInput.asJsonObjectOrNull(json) ?: JsonObject(emptyMap()) - return when (toolName) { - "Edit" -> decodeEditOrWriteArgs(obj, metadata) ?: payload.ifBlank { null } - "Write" -> decodeWriteArgs(obj, metadata) ?: payload.ifBlank { null } - "Read" -> decodeReadArgs(obj, metadata) ?: payload.ifBlank { null } - "IntelliJSearch" -> decodeSearchArgs(obj) ?: payload.ifBlank { null } - "Bash" -> decodeBashArgs(obj) ?: payload.ifBlank { null } - "WebSearch" -> decodeWebSearchArgs(obj, metadata) ?: payload.ifBlank { null } - "WebFetch" -> decodeWebFetchArgs(obj, rawInput, metadata) ?: payload.ifBlank { null } - else -> payload.ifBlank { null } - } - } - - private fun decodeEditOrWriteArgs(obj: JsonObject, metadata: JsonObject?): Any? { - decodeDiffContent(metadata)?.let { diff -> - return if (diff.oldText == null) { - WriteTool.Args( - filePath = diff.path, - content = diff.newText - ) - } else { - EditTool.Args( - filePath = diff.path, - oldString = diff.oldText, - newString = diff.newText, - shortDescription = metadata?.string("title") ?: "ACP edit", - replaceAll = false - ) - } - } - - decodeWriteArgs(obj, metadata)?.let { return it } - return decodeEditArgs(obj, metadata) - } - - private fun decodeEditArgs(obj: JsonObject, metadata: JsonObject? = null): EditTool.Args? { - val filePath = obj.string("file_path", "filePath", "path") ?: return null - val oldString = obj.string("old_string", "oldString", "old_text", "oldText") ?: return null - val newString = obj.string("new_string", "newString", "new_text", "newText") ?: return null - val shortDescription = obj.string("short_description", "shortDescription", "description") - ?: metadata?.string("title") - ?: "ACP edit" - val replaceAll = obj.boolean("replace_all", "replaceAll") ?: false - return EditTool.Args(filePath, oldString, newString, shortDescription, replaceAll) - } - - private fun decodeWriteArgs( - obj: JsonObject, - metadata: JsonObject? = null - ): WriteTool.Args? { - val filePath = obj.string("file_path", "filePath", "path") - ?: metadata?.firstLocationPath() - ?: metadata?.titlePath() - ?: obj.firstChangePath() - ?: return null - val content = obj.string("content", "text") - ?: obj.firstChangeContent() - ?: decodeDiffContent(metadata)?.takeIf { it.oldText == null }?.newText - ?: return null - return WriteTool.Args(filePath, content) - } - - private fun decodeReadArgs(obj: JsonObject, metadata: JsonObject? = null): ReadTool.Args? { - val filePath = obj.string("file_path", "filePath", "path") - ?: metadata?.firstLocationPath() - ?: metadata?.titlePath() - ?: return null - return ReadTool.Args( - filePath = filePath, - offset = obj.int("offset", "line") ?: metadata?.int("line"), - limit = obj.int("limit", "maxLinesCount") - ) - } - - private fun decodeSearchArgs(obj: JsonObject): IntelliJSearchTool.Args? { - val pattern = obj.string("pattern", "searchText", "query", "nameKeyword") ?: return null - return IntelliJSearchTool.Args( - pattern = pattern, - scope = obj.string("scope"), - path = obj.string("path", "directoryToSearch"), - fileType = obj.string("fileType", "fileMask"), - context = null, - caseSensitive = obj.boolean("caseSensitive"), - regex = obj.boolean("regex"), - wholeWords = null, - outputMode = null, - limit = obj.int("limit", "maxUsageCount", "fileCountLimit") - ) - } - - private fun decodeBashArgs(obj: JsonObject): BashTool.Args? { - val command = obj.commandString() ?: return null - return BashTool.Args( - command = command, - workingDirectory = obj.string("workingDirectory", "workdir", "cwd"), - timeout = obj.int("timeout") ?: 60_000, - description = obj.string("description", "title"), - runInBackground = obj.boolean("run_in_background", "runInBackground") - ) - } - - private fun decodeWebSearchArgs( - obj: JsonObject, - metadata: JsonObject? = null - ): WebSearchTool.Args? { - val action = obj["action"] as? JsonObject - val query = obj.string("query", "q") - ?: action?.string("query") - ?: metadata?.string("query", "q") - ?: return null - return WebSearchTool.Args(query = query) - } - - private fun decodeWebFetchArgs( - obj: JsonObject, - rawInput: JsonElement?, - metadata: JsonObject? = null - ): WebFetchTool.Args? { - val action = obj["action"] as? JsonObject - val payload = rawInput.toPayloadString() - val url = obj.string("url", "uri", "href", "link") - ?: action?.string("url", "uri", "href", "link") - ?: metadata?.string("url", "uri") - ?: extractFirstUrl(payload) - ?: metadata?.string("title")?.let(::extractFirstUrl) - ?: return null - - return WebFetchTool.Args( - url = url, - selector = obj.string("selector", "css_selector", "cssSelector") - ?: action?.string("selector", "css_selector", "cssSelector"), - timeoutMs = obj.int("timeout_ms", "timeoutMs", "timeout") - ?: action?.int("timeout_ms", "timeoutMs", "timeout") - ?: 10_000, - offset = obj.int("offset", "start_line", "startLine") - ?: action?.int("offset", "start_line", "startLine"), - limit = obj.int("limit", "max_lines", "maxLines", "count") - ?: action?.int("limit", "max_lines", "maxLines", "count") - ) - } - - private fun decodeMcpArgs(rawTitle: String, rawInput: JsonElement?): McpTool.Args { - val obj = rawInput.asJsonObjectOrNull(json) ?: JsonObject(emptyMap()) - val callName = rawTitle.ifBlank { obj.string("tool_name", "toolName") ?: "unknown" } - val slashIndex = callName.indexOf('/') - val serverName = if (slashIndex > 0) callName.substring(0, slashIndex) else obj.string( - "server_name", - "serverName" - ) - val toolName = if (slashIndex > 0 && slashIndex < callName.length - 1) { - callName.substring(slashIndex + 1) - } else { - obj.string("tool_name", "toolName") ?: callName.ifBlank { "unknown" } - } - val arguments = (obj["arguments"] as? JsonObject)?.toMap() ?: obj.toMap() - return McpTool.Args( - toolName = toolName, - serverId = obj.string("server_id", "serverId"), - serverName = serverName, - arguments = arguments - ) - } - - private fun decodeDiffContent(metadata: JsonObject?): DiffContent? { - val diff = (metadata?.get("content") as? JsonArray) - ?.firstOrNull { entry -> - (entry as? JsonObject)?.string("type") == "diff" - } as? JsonObject - ?: return null - val path = diff.string("path") ?: return null - val newText = diff.string("newText", "new_text") ?: return null - val oldText = diff.string("oldText", "old_text") - return DiffContent(path, oldText, newText) - } - - private fun looksLikeMcpToolName(rawTitle: String): Boolean { - val candidate = rawTitle.trim() - if (candidate.isBlank() || ' ' in candidate || candidate.startsWith("/")) { - return false - } - val slashIndex = candidate.indexOf('/') - return slashIndex > 0 && slashIndex < candidate.length - 1 - } - - private fun extractFirstUrl(text: String): String? { - return URL_REGEX.find(text)?.value - } - - private companion object { - val URL_REGEX = Regex("""https?://[^\s"'<>]+""") + }.ifBlank { fallback } + } +} + +private val ToolKind.wireValue: String + get() = when (this) { + ToolKind.READ -> "read" + ToolKind.EDIT -> "edit" + ToolKind.EXECUTE -> "execute" + ToolKind.SEARCH -> "search" + ToolKind.FETCH -> "fetch" + else -> name.lowercase() + } + +private fun ToolCallStatus.toAcpToolCallStatus(): AcpToolCallStatus { + return when (this) { + ToolCallStatus.PENDING, + ToolCallStatus.IN_PROGRESS -> AcpToolCallStatus.IN_PROGRESS + + ToolCallStatus.COMPLETED -> AcpToolCallStatus.COMPLETED + ToolCallStatus.FAILED -> AcpToolCallStatus.FAILED + } +} + +private fun String.toAcpToolKind(): ToolKind? { + return when (this.lowercase()) { + "read" -> ToolKind.READ + "edit" -> ToolKind.EDIT + "delete" -> ToolKind.DELETE + "move" -> ToolKind.MOVE + "search" -> ToolKind.SEARCH + "execute", "terminal", "bash" -> ToolKind.EXECUTE + "think" -> ToolKind.THINK + "fetch" -> ToolKind.FETCH + "switch_mode" -> ToolKind.SWITCH_MODE + "other" -> ToolKind.OTHER + else -> null } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecodingSupport.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecodingSupport.kt new file mode 100644 index 00000000..97139284 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecodingSupport.kt @@ -0,0 +1,606 @@ +package ee.carlrobert.codegpt.agent.external + +import com.agentclientprotocol.model.ToolCallContent +import com.agentclientprotocol.model.ToolCallLocation +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs +import ee.carlrobert.codegpt.agent.external.events.AcpBashPreviewArgs +import ee.carlrobert.codegpt.agent.external.events.AcpSearchPreviewArgs +import ee.carlrobert.codegpt.agent.tools.BashTool +import ee.carlrobert.codegpt.agent.tools.EditTool +import ee.carlrobert.codegpt.agent.tools.IntelliJSearchTool +import ee.carlrobert.codegpt.agent.tools.McpTool +import ee.carlrobert.codegpt.agent.tools.ReadTool +import ee.carlrobert.codegpt.agent.tools.WebFetchTool +import ee.carlrobert.codegpt.agent.tools.WebSearchTool +import ee.carlrobert.codegpt.agent.tools.WriteTool +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement + +internal data class AcpToolCallContext( + val toolCallId: String, + val rawTitle: String, + val rawKind: String?, + val rawInput: JsonElement?, + val locations: List, + val content: List, + val defaultTitle: String = "Tool" +) + +internal data class AcpResolvedToolCall( + val rawTitle: String, + val toolName: String, + val typedArgs: AcpToolCallArgs? = null +) + +internal interface AcpToolCallNormalizer { + fun normalize(context: AcpToolCallContext, support: AcpToolCallDecodingSupport): AcpResolvedToolCall? +} + +internal data class DiffContent( + val path: String, + val oldText: String?, + val newText: String +) + +@Serializable +internal data class AcpMcpToolCallPayload( + val server: String, + val tool: String, + val arguments: Map = emptyMap(), + @SerialName("server_id") val serverIdSnake: String? = null, + val serverId: String? = null, + @SerialName("server_name") val serverNameSnake: String? = null, + val serverName: String? = null +) { + val resolvedServerId: String? + get() = serverId ?: serverIdSnake + + val resolvedServerName: String + get() = serverName ?: serverNameSnake ?: server +} + +@Serializable +internal data class AcpQueryPayload( + val query: String? = null, + val q: String? = null +) { + val resolvedQuery: String? + get() = query ?: q +} + +@Serializable +internal data class AcpActionEnvelope( + val action: AcpActionPayload +) + +@Serializable +internal data class AcpActionPayload( + val type: String, + val query: String? = null, + val q: String? = null, + val url: String? = null, + val uri: String? = null, + val href: String? = null, + val link: String? = null, + val selector: String? = null, + @SerialName("css_selector") val cssSelectorSnake: String? = null, + val cssSelector: String? = null, + @SerialName("timeout_ms") val timeoutMsSnake: Int? = null, + val timeoutMs: Int? = null, + val timeout: Int? = null, + val offset: Int? = null, + @SerialName("start_line") val startLineSnake: Int? = null, + val startLine: Int? = null, + val limit: Int? = null, + @SerialName("max_lines") val maxLinesSnake: Int? = null, + val maxLines: Int? = null, + val count: Int? = null +) { + val resolvedQuery: String? + get() = query ?: q + + val resolvedUrl: String? + get() = url ?: uri ?: href ?: link + + val resolvedSelector: String? + get() = selector ?: cssSelector ?: cssSelectorSnake + + val resolvedTimeoutMs: Int + get() = timeoutMs ?: timeoutMsSnake ?: timeout ?: 10_000 + + val resolvedOffset: Int? + get() = offset ?: startLine ?: startLineSnake + + val resolvedLimit: Int? + get() = limit ?: maxLines ?: maxLinesSnake ?: count +} + +@Serializable +internal data class AcpReadPayload( + @SerialName("file_path") val filePathSnake: String? = null, + val filePath: String? = null, + val path: String? = null, + val offset: Int? = null, + val line: Int? = null, + val limit: Int? = null, + @SerialName("maxLinesCount") val maxLinesCount: Int? = null +) { + val resolvedPath: String? + get() = filePath ?: filePathSnake ?: path + + val resolvedOffset: Int? + get() = offset ?: line + + val resolvedLimit: Int? + get() = limit ?: maxLinesCount +} + +@Serializable +internal data class AcpWritePayload( + @SerialName("file_path") val filePathSnake: String? = null, + val filePath: String? = null, + val path: String? = null, + val content: String? = null, + val text: String? = null +) { + val resolvedPath: String? + get() = filePath ?: filePathSnake ?: path + + val resolvedContent: String? + get() = content ?: text +} + +@Serializable +internal data class AcpEditPayload( + @SerialName("file_path") val filePathSnake: String? = null, + val filePath: String? = null, + val path: String? = null, + @SerialName("old_string") val oldStringSnake: String? = null, + val oldString: String? = null, + @SerialName("new_string") val newStringSnake: String? = null, + val newString: String? = null, + @SerialName("short_description") val shortDescriptionSnake: String? = null, + val shortDescription: String? = null, + @SerialName("replace_all") val replaceAllSnake: Boolean? = null, + val replaceAll: Boolean? = null +) { + val resolvedPath: String? + get() = filePath ?: filePathSnake ?: path + + val resolvedOldString: String? + get() = oldString ?: oldStringSnake + + val resolvedNewString: String? + get() = newString ?: newStringSnake + + val resolvedDescription: String + get() = shortDescription ?: shortDescriptionSnake ?: "ACP edit" + + val resolvedReplaceAll: Boolean + get() = replaceAll ?: replaceAllSnake ?: false +} + +@Serializable +internal data class AcpChangeSetPayload( + val changes: Map = emptyMap() +) { + fun firstWriteArgs(): AcpToolCallArgs? { + val entry = changes.entries.firstOrNull() ?: return null + val content = entry.value.content ?: return null + return AcpToolCallArgs.Write(WriteTool.Args(entry.key, content)) + } +} + +@Serializable +internal data class AcpChangePayload( + val content: String? = null +) + +@Serializable +internal data class AcpSearchPayload( + val pattern: String? = null, + @SerialName("searchText") val searchText: String? = null, + val query: String? = null, + @SerialName("nameKeyword") val nameKeyword: String? = null, + val scope: String? = null, + val path: String? = null, + @SerialName("directoryToSearch") val directoryToSearch: String? = null, + @SerialName("fileType") val fileType: String? = null, + @SerialName("fileMask") val fileMask: String? = null, + val caseSensitive: Boolean? = null, + val regex: Boolean? = null, + val limit: Int? = null, + @SerialName("maxUsageCount") val maxUsageCount: Int? = null, + @SerialName("fileCountLimit") val fileCountLimit: Int? = null +) { + val resolvedPattern: String? + get() = pattern ?: searchText ?: query ?: nameKeyword + + val resolvedPath: String? + get() = path ?: directoryToSearch + + val resolvedFileType: String? + get() = fileType ?: fileMask + + val resolvedLimit: Int? + get() = limit ?: maxUsageCount ?: fileCountLimit +} + +@Serializable +internal data class AcpCommandStringPayload( + val command: String, + val description: String? = null, + val workingDirectory: String? = null, + val workdir: String? = null, + val cwd: String? = null, + val timeout: Int? = null, + @SerialName("run_in_background") val runInBackgroundSnake: Boolean? = null, + val runInBackground: Boolean? = null +) { + val resolvedWorkingDirectory: String? + get() = workingDirectory ?: workdir ?: cwd + + val resolvedRunInBackground: Boolean + get() = runInBackground ?: runInBackgroundSnake ?: false +} + +@Serializable +internal data class AcpCommandArrayPayload( + val command: List, + val description: String? = null, + val workingDirectory: String? = null, + val workdir: String? = null, + val cwd: String? = null, + val timeout: Int? = null, + @SerialName("run_in_background") val runInBackgroundSnake: Boolean? = null, + val runInBackground: Boolean? = null +) { + val resolvedCommand: String + get() = command.joinToString(" ") + + val resolvedWorkingDirectory: String? + get() = workingDirectory ?: workdir ?: cwd + + val resolvedRunInBackground: Boolean + get() = runInBackground ?: runInBackgroundSnake ?: false +} + +@Serializable +internal data class AcpParsedCommandEnvelope( + @SerialName("parsed_cmd") val parsedCommands: List = emptyList() +) { + val firstParsedCommand: AcpParsedCommandPayload? + get() = parsedCommands.firstOrNull() +} + +@Serializable +internal data class AcpParsedCommandPayload( + val type: String, + val cmd: String? = null, + val name: String? = null, + val path: String? = null, + val query: String? = null, + val offset: Int? = null, + val line: Int? = null +) { + val resolvedOffset: Int? + get() = offset ?: line +} + +internal class AcpToolCallDecodingSupport( + private val json: Json +) { + inline fun decode(rawInput: JsonElement?): T? = rawInput.decodeOrNull(json) + + fun decodeMcpArgs(payload: AcpMcpToolCallPayload): McpTool.Args { + return McpTool.Args( + toolName = payload.tool, + serverId = payload.resolvedServerId, + serverName = payload.resolvedServerName, + arguments = payload.arguments + ) + } + + fun decodeMcpPayload(rawInput: JsonElement?): AcpMcpToolCallPayload? { + return decode(rawInput) + } + + fun decodeParsedCommand(rawInput: JsonElement?): AcpParsedCommandPayload? { + return decode(rawInput)?.firstParsedCommand + } + + fun decodeDiffContent(content: List): DiffContent? { + val diff = content.filterIsInstance().firstOrNull() ?: return null + return DiffContent(diff.path, diff.oldText, diff.newText) + } + + fun decodeEditOrWriteArgs( + rawInput: JsonElement?, + locations: List, + content: List + ): AcpToolCallArgs? { + decodeDiffContent(content)?.let { diff -> + return if (diff.oldText == null) { + AcpToolCallArgs.Write(WriteTool.Args(diff.path, diff.newText)) + } else { + AcpToolCallArgs.Edit( + EditTool.Args( + filePath = diff.path, + oldString = diff.oldText, + newString = diff.newText, + shortDescription = "ACP edit", + replaceAll = false + ) + ) + } + } + + decode(rawInput)?.firstWriteArgs()?.let { return it } + + decode(rawInput)?.let { payload -> + val path = payload.resolvedPath ?: locations.firstOrNull()?.path ?: return@let + val text = payload.resolvedContent ?: return@let + return AcpToolCallArgs.Write(WriteTool.Args(path, text)) + } + + decode(rawInput)?.let { payload -> + val path = payload.resolvedPath ?: return@let + val oldString = payload.resolvedOldString ?: return@let + val newString = payload.resolvedNewString ?: return@let + return AcpToolCallArgs.Edit( + EditTool.Args( + filePath = path, + oldString = oldString, + newString = newString, + shortDescription = payload.resolvedDescription, + replaceAll = payload.resolvedReplaceAll + ) + ) + } + + return null + } + + fun decodeReadArgs( + rawInput: JsonElement?, + locations: List + ): AcpToolCallArgs? { + decode(rawInput)?.let { payload -> + val path = payload.resolvedPath ?: locations.firstOrNull()?.path ?: return@let + return AcpToolCallArgs.Read( + ReadTool.Args( + filePath = path, + offset = payload.resolvedOffset ?: locations.firstOrNull()?.line?.toInt(), + limit = payload.resolvedLimit + ) + ) + } + + return locations.firstOrNull()?.path?.let { path -> + AcpToolCallArgs.Read( + ReadTool.Args( + filePath = path, + offset = locations.firstOrNull()?.line?.toInt(), + limit = null + ) + ) + } + } + + fun decodeReadArgsFromParsedCommand( + rawInput: JsonElement?, + locations: List + ): AcpToolCallArgs? { + val parsed = decode(rawInput)?.firstParsedCommand ?: return null + if (!parsed.type.equals("read", ignoreCase = true)) { + return null + } + val path = parsed.path ?: locations.firstOrNull()?.path ?: return null + return AcpToolCallArgs.Read( + ReadTool.Args( + filePath = path, + offset = parsed.resolvedOffset ?: locations.firstOrNull()?.line?.toInt(), + limit = null + ) + ) + } + + fun decodeSearchArgs(rawInput: JsonElement?): AcpToolCallArgs? { + val payload = decode(rawInput) ?: return null + val pattern = payload.resolvedPattern ?: return null + return AcpToolCallArgs.IntelliJSearch( + IntelliJSearchTool.Args( + pattern = pattern, + scope = payload.scope, + path = payload.resolvedPath, + fileType = payload.resolvedFileType, + context = null, + caseSensitive = payload.caseSensitive, + regex = payload.regex, + wholeWords = null, + outputMode = null, + limit = payload.resolvedLimit + ) + ) + } + + fun decodeSearchArgsOrPreview( + rawTitle: String, + rawInput: JsonElement?, + locations: List + ): AcpToolCallArgs { + decodeSearchArgsFromParsedCommand(rawInput)?.let { return it } + decodeSearchArgs(rawInput)?.let { return it } + + val normalizedTitle = rawTitle.trim().ifBlank { "Search" } + val path = locations.firstOrNull()?.path + val pattern = extractGeminiSearchPattern(normalizedTitle) + + return AcpToolCallArgs.SearchPreview( + AcpSearchPreviewArgs( + title = normalizedTitle, + path = path, + pattern = pattern + ) + ) + } + + fun decodeReadArgsFromParsedCommandOrFallback( + rawInput: JsonElement?, + locations: List + ): AcpToolCallArgs? { + return decodeReadArgsFromParsedCommand(rawInput, locations) + ?: decodeReadArgs(rawInput, locations) + } + + fun decodeSearchArgsFromParsedCommand(rawInput: JsonElement?): AcpToolCallArgs? { + val parsed = decode(rawInput)?.firstParsedCommand ?: return null + if (!parsed.type.equals("search", ignoreCase = true)) { + return null + } + val pattern = parsed.query ?: return null + return AcpToolCallArgs.IntelliJSearch( + IntelliJSearchTool.Args( + pattern = pattern, + scope = null, + path = parsed.path, + fileType = null, + context = null, + caseSensitive = null, + regex = null, + wholeWords = null, + outputMode = null, + limit = null + ) + ) + } + + fun decodeBashArgs(rawInput: JsonElement?): AcpToolCallArgs? { + decode(rawInput)?.let { payload -> + return AcpToolCallArgs.Bash( + BashTool.Args( + command = payload.command, + workingDirectory = payload.resolvedWorkingDirectory, + timeout = payload.timeout ?: 60_000, + description = payload.description, + runInBackground = payload.resolvedRunInBackground + ) + ) + } + + decode(rawInput)?.let { payload -> + return AcpToolCallArgs.Bash( + BashTool.Args( + command = payload.resolvedCommand, + workingDirectory = payload.resolvedWorkingDirectory, + timeout = payload.timeout ?: 60_000, + description = payload.description, + runInBackground = payload.resolvedRunInBackground + ) + ) + } + + return null + } + + fun decodeBashArgsOrPreview( + rawTitle: String, + rawInput: JsonElement?, + defaultTitle: String = "Run shell command" + ): AcpToolCallArgs { + decodeBashArgs(rawInput)?.let { return it } + + val normalizedTitle = rawTitle.trim().ifBlank { defaultTitle } + return AcpToolCallArgs.BashPreview( + AcpBashPreviewArgs( + title = normalizedTitle, + command = extractGeminiCommand(normalizedTitle) + ) + ) + } + + fun decodeWebSearchArgs(rawInput: JsonElement?): AcpToolCallArgs? { + decode(rawInput)?.action?.let { action -> + if (action.type.equals("search", ignoreCase = true)) { + action.resolvedQuery?.let { return AcpToolCallArgs.WebSearch(WebSearchTool.Args(it)) } + } + } + + decode(rawInput)?.resolvedQuery?.let { query -> + return AcpToolCallArgs.WebSearch(WebSearchTool.Args(query)) + } + + return null + } + + fun decodeWebFetchArgs(rawInput: JsonElement?): AcpToolCallArgs? { + decode(rawInput)?.action?.let { action -> + if (action.type.lowercase() in setOf("open_page", "fetch", "open")) { + val url = action.resolvedUrl ?: return@let + return AcpToolCallArgs.WebFetch( + WebFetchTool.Args( + url = url, + selector = action.resolvedSelector, + timeoutMs = action.resolvedTimeoutMs, + offset = action.resolvedOffset, + limit = action.resolvedLimit + ) + ) + } + } + + return null + } + + fun resolveWebSearchCall(rawTitle: String, rawInput: JsonElement?): AcpResolvedToolCall? { + return decodeWebSearchArgs(rawInput)?.let { args -> + fixedCall(rawTitle, "WebSearch", "WebSearch", args) + } + } + + fun resolveWebFetchCall(rawTitle: String, rawInput: JsonElement?): AcpResolvedToolCall? { + return decodeWebFetchArgs(rawInput)?.let { args -> + fixedCall(rawTitle, "WebFetch", "WebFetch", args) + } + } + + fun resolveWebCall(rawTitle: String, rawInput: JsonElement?): AcpResolvedToolCall? { + return resolveWebSearchCall(rawTitle, rawInput) + ?: resolveWebFetchCall(rawTitle, rawInput) + } +} + +internal fun fixedCall( + rawTitle: String, + defaultTitle: String, + toolName: String, + typedArgs: AcpToolCallArgs? = null +): AcpResolvedToolCall { + return AcpResolvedToolCall( + rawTitle = rawTitle.ifBlank { defaultTitle }, + toolName = toolName, + typedArgs = typedArgs + ) +} + +internal fun editOrWriteCall(rawTitle: String, args: AcpToolCallArgs?): AcpResolvedToolCall? { + return when (args) { + is AcpToolCallArgs.Write -> fixedCall(rawTitle, "Write", "Write", args) + is AcpToolCallArgs.Edit -> fixedCall(rawTitle, "Edit", "Edit", args) + else -> null + } +} + +private fun extractGeminiSearchPattern(title: String): String? { + val match = Regex("(?i)^(?:search|grep)(?:\\s+for)?\\s*:?\\s+(.+)$").find(title) ?: return null + return match.groupValues.getOrNull(1)?.trim()?.takeIf { it.isNotBlank() } +} + +private fun extractGeminiCommand(title: String): String? { + val match = Regex( + "(?i)^(?:run(?:\\s+shell)?\\s+command|execute(?:\\s+shell)?\\s+command)\\s*:?\\s+(.+)$" + ).find(title) ?: return null + return match.groupValues.getOrNull(1)?.trim()?.takeIf { it.isNotBlank() } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallNormalizers.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallNormalizers.kt new file mode 100644 index 00000000..9a08078e --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallNormalizers.kt @@ -0,0 +1,273 @@ +package ee.carlrobert.codegpt.agent.external + +import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs +import ee.carlrobert.codegpt.agent.external.events.AcpSearchPreviewArgs + +internal class ZedAdapterToolCallNormalizer : AcpToolCallNormalizer { + override fun normalize( + context: AcpToolCallContext, + support: AcpToolCallDecodingSupport + ): AcpResolvedToolCall? { + support.decodeMcpPayload(context.rawInput)?.let { payload -> + return fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "MCP", + toolName = "MCP", + typedArgs = AcpToolCallArgs.Mcp(support.decodeMcpArgs(payload)) + ) + } + + if (context.rawKind.equals("fetch", ignoreCase = true)) { + support.resolveWebSearchCall(context.rawTitle, context.rawInput)?.let { return it } + } + + support.decodeParsedCommand(context.rawInput)?.let { parsed -> + return when { + parsed.type.equals("read", ignoreCase = true) -> + fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Read", + toolName = "Read", + typedArgs = support.decodeReadArgsFromParsedCommandOrFallback( + context.rawInput, + context.locations + ) + ) + + parsed.type.equals("search", ignoreCase = true) -> + fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Search", + toolName = "IntelliJSearch", + typedArgs = support.decodeSearchArgsOrPreview( + context.rawTitle, + context.rawInput, + context.locations + ) + ) + + parsed.type.equals("list_files", ignoreCase = true) -> + fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Bash", + toolName = "Bash", + typedArgs = support.decodeBashArgsOrPreview( + context.rawTitle, + context.rawInput, + defaultTitle = "List files" + ) + ) + + else -> null + } + } + + support.decodeBashArgs(context.rawInput)?.let { args -> + return fixedCall(context.rawTitle, "Bash", "Bash", args) + } + + if (context.rawKind.equals("execute", ignoreCase = true)) { + return fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Bash", + toolName = "Bash", + typedArgs = support.decodeBashArgsOrPreview( + context.rawTitle, + context.rawInput, + defaultTitle = "Bash" + ) + ) + } + + support.resolveWebCall(context.rawTitle, context.rawInput)?.let { return it } + + return editOrWriteCall( + rawTitle = context.rawTitle, + args = support.decodeEditOrWriteArgs(context.rawInput, context.locations, context.content) + ) + } +} + +internal class GeminiCliToolCallNormalizer : AcpToolCallNormalizer { + override fun normalize( + context: AcpToolCallContext, + support: AcpToolCallDecodingSupport + ): AcpResolvedToolCall? { + return when { + context.toolCallId.startsWith("google_web_search-") -> + fixedCall( + context.rawTitle, + "Google web search", + "WebSearch", + support.decodeWebSearchArgs(context.rawInput) + ) + + context.toolCallId.startsWith("grep_search-") -> + fixedCall( + context.rawTitle, + "Search", + "IntelliJSearch", + support.decodeSearchArgsOrPreview( + context.rawTitle.ifBlank { "Search" }, + context.rawInput, + context.locations + ) + ) + + context.toolCallId.startsWith("read_file-") -> + fixedCall( + context.rawTitle, + "Read file", + "Read", + support.decodeReadArgs(context.rawInput, context.locations) + ) + + context.toolCallId.startsWith("list_directory-") -> + fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "List directory", + toolName = "ListDirectory", + typedArgs = AcpToolCallArgs.SearchPreview( + AcpSearchPreviewArgs( + title = "List directory", + path = context.rawTitle.trim().ifBlank { null } + ) + ) + ) + + context.toolCallId.startsWith("glob-") -> + fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Find files", + toolName = "Glob", + typedArgs = AcpToolCallArgs.SearchPreview( + AcpSearchPreviewArgs( + title = "Find files", + pattern = context.rawTitle.trim().trim('\'').ifBlank { null } + ) + ) + ) + + context.toolCallId.startsWith("write_file-") -> + editOrWriteCall( + rawTitle = context.rawTitle, + args = support.decodeEditOrWriteArgs( + context.rawInput, + context.locations, + context.content + ) + ) ?: fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Write file", + toolName = "Write" + ) + + context.toolCallId.startsWith("replace-") -> + editOrWriteCall( + rawTitle = context.rawTitle, + args = support.decodeEditOrWriteArgs( + context.rawInput, + context.locations, + context.content + ) + ) ?: fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Replace text", + toolName = "Edit" + ) + + context.toolCallId.startsWith("run_shell_command-") -> + fixedCall( + context.rawTitle, + "Run shell command", + "Bash", + support.decodeBashArgsOrPreview( + context.rawTitle, + context.rawInput, + defaultTitle = "Run shell command" + ) + ) + + else -> null + } + } +} + +internal class StandardSemanticToolCallNormalizer : AcpToolCallNormalizer { + override fun normalize( + context: AcpToolCallContext, + support: AcpToolCallDecodingSupport + ): AcpResolvedToolCall? { + support.decodeMcpPayload(context.rawInput)?.let { payload -> + return fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "MCP", + toolName = "MCP", + typedArgs = AcpToolCallArgs.Mcp(support.decodeMcpArgs(payload)) + ) + } + + support.resolveWebCall(context.rawTitle, context.rawInput)?.let { return it } + + return when (context.rawKind?.lowercase().orEmpty()) { + "read" -> fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Read", + toolName = "Read", + typedArgs = support.decodeReadArgs(context.rawInput, context.locations) + ) + + "edit" -> editOrWriteCall( + rawTitle = context.rawTitle, + args = support.decodeEditOrWriteArgs(context.rawInput, context.locations, context.content) + ) + + "execute", "terminal", "bash" -> fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Bash", + toolName = "Bash", + typedArgs = support.decodeBashArgsOrPreview( + context.rawTitle, + context.rawInput, + defaultTitle = "Bash" + ) + ) + + "search" -> fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "Search", + toolName = "IntelliJSearch", + typedArgs = support.decodeSearchArgsOrPreview( + context.rawTitle, + context.rawInput, + context.locations + ) + ) + + "fetch" -> fixedCall( + rawTitle = context.rawTitle, + defaultTitle = "WebFetch", + toolName = "WebFetch", + typedArgs = support.decodeWebFetchArgs(context.rawInput) + ) + + else -> null + } + } +} + +internal class FallbackToolCallNormalizer : AcpToolCallNormalizer { + override fun normalize( + context: AcpToolCallContext, + support: AcpToolCallDecodingSupport + ): AcpResolvedToolCall { + val fallbackName = context.rawTitle.ifBlank { + context.rawKind?.replaceFirstChar { it.uppercase() } ?: context.defaultTitle + } + return fixedCall( + rawTitle = context.rawTitle, + defaultTitle = fallbackName, + toolName = fallbackName + ) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallStatus.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallStatus.kt new file mode 100644 index 00000000..6fd91312 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallStatus.kt @@ -0,0 +1,11 @@ +package ee.carlrobert.codegpt.agent.external + +internal enum class AcpToolCallStatus(val wireValue: String) { + IN_PROGRESS("in_progress"), + COMPLETED("completed"), + FAILED("failed"), + CANCELLED("cancelled"); + + val isTerminal: Boolean + get() = this != IN_PROGRESS +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolEventFlavor.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolEventFlavor.kt new file mode 100644 index 00000000..ebe99117 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolEventFlavor.kt @@ -0,0 +1,7 @@ +package ee.carlrobert.codegpt.agent.external + +enum class AcpToolEventFlavor { + STANDARD, + ZED_ADAPTER, + GEMINI_CLI +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocol.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocol.kt new file mode 100644 index 00000000..deebf701 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocol.kt @@ -0,0 +1,133 @@ +package ee.carlrobert.codegpt.agent.external.acpcompat + +import com.agentclientprotocol.model.AcpMethod +import com.agentclientprotocol.model.AcpNotification +import com.agentclientprotocol.model.AcpRequest +import com.agentclientprotocol.model.AcpResponse +import com.agentclientprotocol.rpc.ACPJson +import com.agentclientprotocol.rpc.MethodName +import com.agentclientprotocol.transport.Transport +import ee.carlrobert.codegpt.agent.external.runtime.AcpProtocolCore +import kotlinx.coroutines.CoroutineScope +import kotlinx.serialization.json.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +class AcpExpectedError(message: String) : Exception(message) + +fun acpFail(message: String): Nothing = throw AcpExpectedError(message) + +class JsonRpcException( + val code: Int, + message: String, + val data: JsonElement? = null +) : Exception(message) + +open class ProtocolOptions( + val gracefulRequestCancellationTimeout: Duration = 1.seconds, + val protocolDebugName: String = AcpProtocol::class.simpleName!!, + val outboundPayloadAugmenter: (MethodName, JsonElement?) -> JsonElement? = { _, payload -> payload }, + val inboundPayloadNormalizer: (MethodName, JsonElement?) -> JsonElement? = { _, payload -> payload }, + val trace: ((String) -> Unit)? = null +) + +class AcpProtocol( + parentScope: CoroutineScope, + transport: Transport, + val options: ProtocolOptions = ProtocolOptions() +) { + internal val json: Json = ACPJson + + internal val core = AcpProtocolCore(parentScope, transport, json, options) + + fun start() = core.start() + + fun close() = core.close() + + override fun toString(): String = "Protocol(${options.protocolDebugName})" +} + +internal suspend inline fun AcpProtocol.sendRequest( + method: AcpMethod.AcpRequestResponseMethod, + request: TRequest? +): TResponse { + val params = options.outboundPayloadAugmenter( + method.methodName, + request?.let { json.encodeToJsonElement(request) } + ) + val responseJson = core.sendRequestRaw(method.methodName, params) + return json.decodeFromJsonElement( + options.inboundPayloadNormalizer(method.methodName, responseJson) ?: JsonNull + ) +} + +internal inline fun AcpProtocol.sendNotification( + method: AcpMethod.AcpNotificationMethod, + notification: TNotification? = null +) { + val params = notification?.let { json.encodeToJsonElement(notification) } + core.sendNotificationRaw(method, params) +} + +internal inline fun AcpProtocol.setRequestHandler( + method: AcpMethod.AcpRequestResponseMethod, + additionalContext: CoroutineContext = EmptyCoroutineContext, + noinline handler: suspend (TRequest) -> TResponse +) { + core.setRequestHandlerRaw(method, additionalContext) { request -> + val requestParams = decodeAcpPayload( + json, + options.inboundPayloadNormalizer(method.methodName, request.params) + ) + val responseObject = handler(requestParams) + json.encodeToJsonElement(responseObject) + } +} + +internal inline fun AcpProtocol.setNotificationHandler( + method: AcpMethod.AcpNotificationMethod, + additionalContext: CoroutineContext = EmptyCoroutineContext, + noinline handler: suspend (TNotification) -> Unit +) { + core.setNotificationHandlerRaw(method, additionalContext) { notification -> + val notificationParams = decodeAcpPayload( + json, + options.inboundPayloadNormalizer(method.methodName, notification.params) + ) + handler(notificationParams) + } +} + +internal suspend inline operator fun AcpMethod.AcpRequestResponseMethod.invoke( + protocol: AcpProtocol, + request: TRequest +): TResponse { + return protocol.sendRequest(this, request) +} + +internal inline operator fun AcpMethod.AcpNotificationMethod.invoke( + protocol: AcpProtocol, + notification: TNotification +) { + return protocol.sendNotification(this, notification) +} + +internal inline fun decodeAcpPayload( + json: Json, + payload: JsonElement? +): T { + return when (payload) { + null -> json.decodeFromJsonElement(JsonNull) + is JsonPrimitive -> { + if (payload.isString) { + json.decodeFromString(payload.content) + } else { + json.decodeFromJsonElement(payload) + } + } + + else -> json.decodeFromJsonElement(payload) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpCompatibilityRegistry.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpCompatibilityRegistry.kt new file mode 100644 index 00000000..ac2e954e --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpCompatibilityRegistry.kt @@ -0,0 +1,175 @@ +package ee.carlrobert.codegpt.agent.external.acpcompat.vendor + +import com.agentclientprotocol.agent.AgentInfo +import com.agentclientprotocol.rpc.MethodName +import ee.carlrobert.codegpt.agent.external.ExternalAcpAgentPreset +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +internal class AcpCompatibilityRegistry { + + fun initialProfile(preset: ExternalAcpAgentPreset): AcpPeerProfile { + return profileForPresetId(preset.id) + } + + fun resolveProfile( + preset: ExternalAcpAgentPreset, + agentInfo: AgentInfo? + ): AcpPeerProfile { + val presetProfile = initialProfile(preset) + val implementationName = buildString { + append(agentInfo?.implementation?.name.orEmpty()) + append(' ') + append(agentInfo?.implementation?.title.orEmpty()) + }.trim().lowercase() + + return when { + "codex" in implementationName -> AcpPeerProfile.CODEX + "gemini" in implementationName -> AcpPeerProfile.GEMINI + "opencode" in implementationName -> AcpPeerProfile.OPENCODE + "claude" in implementationName -> AcpPeerProfile.CLAUDE_CODE + else -> presetProfile + } + } + + fun augmentOutboundPayload( + profile: AcpPeerProfile, + methodName: MethodName, + payload: JsonElement?, + sessionRequestMeta: JsonElement? = null, + launchEnv: Map = emptyMap() + ): JsonElement? { + return when (methodName.name) { + "initialize" -> augmentInitializeRequest(profile, payload) + "authenticate" -> augmentAuthenticateRequest(profile, payload, launchEnv) + "session/new", "session/load" -> mergeTopLevelMeta(payload, sessionRequestMeta) + else -> payload + } + } + + fun normalizeInboundPayload( + profile: AcpPeerProfile, + methodName: MethodName, + payload: JsonElement? + ): JsonElement? { + return when (methodName.name) { + "initialize" -> normalizeInitializeResponse(profile, payload) + "session/new", "session/load", "session/update" -> payload + else -> payload + } + } + + private fun profileForPresetId(presetId: String): AcpPeerProfile { + return when (presetId) { + "codex" -> AcpPeerProfile.CODEX + "gemini-cli" -> AcpPeerProfile.GEMINI + "opencode" -> AcpPeerProfile.OPENCODE + "claude-code" -> AcpPeerProfile.CLAUDE_CODE + else -> AcpPeerProfile.STANDARD + } + } + + private fun augmentInitializeRequest( + profile: AcpPeerProfile, + payload: JsonElement? + ): JsonElement? { + val root = payload as? JsonObject ?: return payload + if (!profile.supportsTerminalAuthMeta) { + return payload + } + + val capabilities = root["clientCapabilities"] as? JsonObject ?: return payload + val capabilityMeta = mergeMeta( + capabilities["_meta"], + JsonObject(mapOf("terminal-auth" to JsonPrimitive(true))) + ) ?: return payload + + return JsonObject( + root + ("clientCapabilities" to JsonObject(capabilities + ("_meta" to capabilityMeta))) + ) + } + + private fun augmentAuthenticateRequest( + profile: AcpPeerProfile, + payload: JsonElement?, + launchEnv: Map + ): JsonElement? { + if (profile != AcpPeerProfile.GEMINI) { + return payload + } + + val metaEntries = linkedMapOf() + val apiKey = launchEnv["GEMINI_API_KEY"] ?: launchEnv["GOOGLE_API_KEY"] + val gateway = launchEnv["GEMINI_GATEWAY"] + if (!apiKey.isNullOrBlank()) { + metaEntries["api-key"] = JsonPrimitive(apiKey) + } + if (!gateway.isNullOrBlank()) { + metaEntries["gateway"] = JsonPrimitive(gateway) + } + + return mergeTopLevelMeta( + payload, + metaEntries.takeIf { it.isNotEmpty() }?.let(::JsonObject) + ) + } + + private fun normalizeInitializeResponse( + profile: AcpPeerProfile, + payload: JsonElement? + ): JsonElement? { + if (profile != AcpPeerProfile.CODEX) { + return payload + } + + val root = payload as? JsonObject ?: return payload + val authMethods = root["authMethods"]?.jsonArray ?: return payload + val normalizedAuthMethods = authMethods.map(::normalizeAuthMethod) + return JsonObject(root + ("authMethods" to kotlinx.serialization.json.JsonArray(normalizedAuthMethods))) + } + + private fun normalizeAuthMethod(payload: JsonElement): JsonElement { + val authMethod = payload as? JsonObject ?: return payload + val type = authMethod["type"]?.jsonPrimitive?.contentOrNull + if (type != "env_var" || authMethod["varName"] != null) { + return payload + } + + val firstVarName = authMethod["vars"] + ?.jsonArray + ?.firstOrNull() + ?.jsonObject + ?.get("name") + ?.jsonPrimitive + ?.contentOrNull + ?: return payload + + return JsonObject(authMethod + ("varName" to JsonPrimitive(firstVarName))) + } + + private fun mergeTopLevelMeta( + payload: JsonElement?, + extraMeta: JsonElement? + ): JsonElement? { + val root = payload as? JsonObject ?: return payload + val mergedMeta = mergeMeta(root["_meta"], extraMeta) ?: return payload + return JsonObject(root + ("_meta" to mergedMeta)) + } + + private fun mergeMeta( + base: JsonElement?, + extra: JsonElement? + ): JsonElement? { + return when { + base == null -> extra + extra == null -> base + base is JsonObject && extra is JsonObject -> JsonObject(base + extra) + else -> extra + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpPeerProfile.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpPeerProfile.kt new file mode 100644 index 00000000..7bc3feb2 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpPeerProfile.kt @@ -0,0 +1,29 @@ +package ee.carlrobert.codegpt.agent.external.acpcompat.vendor + +internal enum class AcpPeerProfile( + val profileId: String, + val displayName: String, + val supportsTerminalAuthMeta: Boolean = false +) { + STANDARD( + profileId = "standard", + displayName = "Standard ACP" + ), + CODEX( + profileId = "codex", + displayName = "Codex ACP" + ), + GEMINI( + profileId = "gemini", + displayName = "Gemini CLI ACP" + ), + OPENCODE( + profileId = "opencode", + displayName = "OpenCode ACP" + ), + CLAUDE_CODE( + profileId = "claude-code", + displayName = "Claude Code ACP", + supportsTerminalAuthMeta = true + ) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpExternalEvent.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpExternalEvent.kt new file mode 100644 index 00000000..3e056b84 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpExternalEvent.kt @@ -0,0 +1,88 @@ +package ee.carlrobert.codegpt.agent.external.events + +import com.agentclientprotocol.model.* +import ee.carlrobert.codegpt.agent.external.AcpToolCallStatus +import ee.carlrobert.codegpt.agent.tools.* +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject + +internal sealed interface AcpExternalEvent { + data class TextChunk(val text: String) : AcpExternalEvent + data class ThinkingChunk(val text: String) : AcpExternalEvent + data class PlanUpdate(val entries: List) : AcpExternalEvent + data class UsageUpdate(val used: Long, val size: Long, val cost: Cost?) : AcpExternalEvent + data class ConfigOptionUpdate(val configOptions: List) : AcpExternalEvent + data class SessionInfoUpdate(val title: String?, val updatedAt: String?) : AcpExternalEvent + data class UnknownSessionUpdate(val type: String, val rawJson: JsonObject) : AcpExternalEvent + data class AvailableCommandsUpdate(val availableCommands: List) : + AcpExternalEvent + + data class CurrentModeUpdate(val currentModeId: String) : AcpExternalEvent + data class ToolCallStarted(val toolCall: AcpToolCallSnapshot) : AcpExternalEvent + data class ToolCallUpdated(val toolCall: AcpToolCallSnapshot) : AcpExternalEvent +} + +internal sealed interface AcpToolCallArgs { + data class SearchPreview(val value: AcpSearchPreviewArgs) : AcpToolCallArgs + data class BashPreview(val value: AcpBashPreviewArgs) : AcpToolCallArgs + data class Mcp(val value: McpTool.Args) : AcpToolCallArgs + data class Read(val value: ReadTool.Args) : AcpToolCallArgs + data class Write(val value: WriteTool.Args) : AcpToolCallArgs + data class Edit(val value: EditTool.Args) : AcpToolCallArgs + data class Bash(val value: BashTool.Args) : AcpToolCallArgs + data class WebSearch(val value: WebSearchTool.Args) : AcpToolCallArgs + data class WebFetch(val value: WebFetchTool.Args) : AcpToolCallArgs + data class IntelliJSearch(val value: IntelliJSearchTool.Args) : AcpToolCallArgs + data class Unknown(val value: JsonElement?) : AcpToolCallArgs +} + +internal data class AcpSearchPreviewArgs( + val title: String, + val path: String? = null, + val pattern: String? = null +) + +internal data class AcpBashPreviewArgs( + val title: String, + val command: String? = null +) + +internal sealed interface AcpToolCallContent { + data class Block(val content: ContentBlock) : AcpToolCallContent + data class Diff(val path: String, val oldText: String?, val newText: String) : + AcpToolCallContent + + data class Terminal(val terminalId: String) : AcpToolCallContent +} + +internal data class AcpToolCallSnapshot( + val id: String, + val title: String, + val toolName: String, + val kind: ToolKind? = null, + val status: AcpToolCallStatus? = null, + val args: AcpToolCallArgs? = null, + val locations: List = emptyList(), + val content: List = emptyList(), + val meta: JsonElement? = null, + val rawInput: JsonElement? = null, + val rawOutput: JsonElement? = null +) + +internal data class AcpPermissionRequestSnapshot( + val toolCall: AcpToolCallSnapshot, + val details: String, + val options: List +) + +internal fun ToolCallContent.toAcpToolCallContent(): AcpToolCallContent { + return when (this) { + is ToolCallContent.Content -> AcpToolCallContent.Block(content) + is ToolCallContent.Diff -> AcpToolCallContent.Diff(path, oldText, newText) + is ToolCallContent.Terminal -> AcpToolCallContent.Terminal(terminalId) + } +} + +internal fun List.toAcpToolCallContent(): List { + return map(ToolCallContent::toAcpToolCallContent) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHost.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHost.kt new file mode 100644 index 00000000..fd45f103 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHost.kt @@ -0,0 +1,114 @@ +package ee.carlrobert.codegpt.agent.external.host + +import com.agentclientprotocol.model.* +import com.intellij.openapi.application.runReadAction +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.openapi.vfs.VfsUtil +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.Path +import kotlin.io.path.createDirectories +import kotlin.io.path.notExists + +class AcpFileHost( + private val pathPolicy: AcpPathPolicy = AcpPathPolicy(), + private val openDocumentReader: AcpOpenDocumentReader = IntelliJOpenDocumentReader(), + private val writer: AcpTextFileWriter = IntelliJTextFileWriter() +) { + + fun clientCapabilities(includeTerminal: Boolean = false): ClientCapabilities { + return ClientCapabilities( + fs = FileSystemCapability( + readTextFile = true, + writeTextFile = true + ), + terminal = includeTerminal + ) + } + + fun readTextFile( + session: AcpHostSessionContext, + request: ReadTextFileRequest + ): ReadTextFileResponse { + require(request.sessionId.value == session.sessionId) { + "Read request session does not match host session" + } + val resolvedPath = pathPolicy.resolveWithinCwd(request.path, session.cwd) + val content = readContent(resolvedPath) + val trimmed = applyLineWindow( + content = content.content, + line = request.line, + limit = request.limit + ) + return ReadTextFileResponse(trimmed) + } + + fun writeTextFile( + session: AcpHostSessionContext, + request: WriteTextFileRequest + ): WriteTextFileResponse { + require(request.sessionId.value == session.sessionId) { + "Write request session does not match host session" + } + val resolvedPath = pathPolicy.resolveWithinCwd(request.path, session.cwd) + writer.write(resolvedPath, request.content) + return WriteTextFileResponse() + } + + private fun readContent(path: Path): AcpTextFileReadResult { + openDocumentReader.read(path)?.let { editorText -> + return AcpTextFileReadResult(editorText, fromEditor = true) + } + return AcpTextFileReadResult(Files.readString(path), fromEditor = false) + } + + private fun applyLineWindow( + content: String, + line: UInt?, + limit: UInt? + ): String { + val startLine = line?.toInt() + val maxLines = limit?.toInt() + if (startLine == null || maxLines == null || startLine <= 0 || maxLines <= 0) { + return content + } + + return content.lineSequence() + .drop(startLine - 1) + .take(maxLines) + .joinToString("\n") + } +} + +private class IntelliJOpenDocumentReader : AcpOpenDocumentReader { + override fun read(path: Path): String? { + return runReadAction { + val virtualFile = + LocalFileSystem.getInstance().refreshAndFindFileByIoFile(path.toFile()) + ?: return@runReadAction null + FileDocumentManager.getInstance().getDocument(virtualFile)?.text + } + } +} + +private class IntelliJTextFileWriter : AcpTextFileWriter { + override fun write(path: Path, content: String) { + if (path.parent != null && path.parent.notExists()) { + path.parent.createDirectories() + } + Files.writeString(path, content, StandardCharsets.UTF_8) + + val virtualFile = LocalFileSystem.getInstance().refreshAndFindFileByIoFile(path.toFile()) + if (virtualFile != null) { + VfsUtil.markDirtyAndRefresh(false, false, false, virtualFile) + } else { + path.parent?.toFile()?.let { parent -> + LocalFileSystem.getInstance().refreshAndFindFileByIoFile(parent) + ?.let { parentVf -> + VfsUtil.markDirtyAndRefresh(false, false, true, parentVf) + } + } + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpHostModels.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpHostModels.kt new file mode 100644 index 00000000..003a33e4 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpHostModels.kt @@ -0,0 +1,52 @@ +package ee.carlrobert.codegpt.agent.external.host + +import java.nio.file.Path + +data class AcpHostSessionContext( + val sessionId: String, + val cwd: Path +) + +data class AcpTextFileReadResult( + val content: String, + val fromEditor: Boolean +) + +data class AcpTerminalExitStatus( + val exitCode: UInt? = null, + val signal: String? = null +) + +data class AcpTerminalOutputSnapshot( + val output: String, + val truncated: Boolean, + val exitStatus: AcpTerminalExitStatus? = null +) + +fun interface AcpOpenDocumentReader { + fun read(path: Path): String? +} + +fun interface AcpTextFileWriter { + fun write(path: Path, content: String) +} + +interface AcpTerminalProcess { + val terminalId: String + + fun output(): AcpTerminalOutputSnapshot + suspend fun waitForExit(): AcpTerminalExitStatus + fun release() + fun kill() +} + +fun interface AcpTerminalProcessLauncher { + fun launch( + command: String, + args: List, + cwd: Path, + env: Map, + outputByteLimit: ULong? + ): AcpTerminalProcess +} + diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicy.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicy.kt new file mode 100644 index 00000000..21420851 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicy.kt @@ -0,0 +1,57 @@ +package ee.carlrobert.codegpt.agent.external.host + +import java.net.URI +import java.nio.file.Path + +class AcpHostPathBoundaryException(message: String) : IllegalArgumentException(message) + +class AcpPathPolicy { + + fun resolveWithinCwd(rawPath: String, cwd: Path): Path { + val normalizedCwd = cwd.toAbsolutePath().normalize() + val checkedCwd = canonicalizeForBoundaryCheck(normalizedCwd) + val requestedPath = parsePath(rawPath) + val candidate = if (requestedPath.isAbsolute) { + requestedPath + } else { + normalizedCwd.resolve(requestedPath) + }.toAbsolutePath().normalize() + val checkedCandidate = canonicalizeForBoundaryCheck(candidate) + + if (!checkedCandidate.startsWith(checkedCwd)) { + throw AcpHostPathBoundaryException( + "Path escapes the session cwd: $rawPath" + ) + } + + return candidate + } + + private fun parsePath(rawPath: String): Path { + return if (rawPath.startsWith("file://")) { + Path.of(URI.create(rawPath)) + } else { + Path.of(rawPath) + } + } + + private fun canonicalizeForBoundaryCheck(path: Path): Path { + val absolute = path.toAbsolutePath().normalize() + val deferredSegments = mutableListOf() + var current: Path? = absolute + + while (current != null) { + val realPath = runCatching { current.toRealPath() }.getOrNull() + if (realPath != null) { + return deferredSegments.foldRight(realPath) { segment, acc -> + acc.resolve(segment) + }.normalize() + } + + current.fileName?.toString()?.let(deferredSegments::add) + current = current.parent + } + + return absolute + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHost.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHost.kt new file mode 100644 index 00000000..ad71bfd7 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHost.kt @@ -0,0 +1,149 @@ +package ee.carlrobert.codegpt.agent.external.host + +import com.agentclientprotocol.model.CreateTerminalRequest +import com.agentclientprotocol.model.CreateTerminalResponse +import com.agentclientprotocol.model.KillTerminalCommandRequest +import com.agentclientprotocol.model.KillTerminalCommandResponse +import com.agentclientprotocol.model.ReadTextFileRequest +import com.agentclientprotocol.model.ReleaseTerminalRequest +import com.agentclientprotocol.model.ReleaseTerminalResponse +import com.agentclientprotocol.model.TerminalExitStatus +import com.agentclientprotocol.model.TerminalOutputRequest +import com.agentclientprotocol.model.TerminalOutputResponse +import com.agentclientprotocol.model.WaitForTerminalExitRequest +import com.agentclientprotocol.model.WaitForTerminalExitResponse +import com.agentclientprotocol.model.WriteTextFileRequest +import java.nio.file.Path +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +class AcpHostTerminalNotFoundException(message: String) : IllegalArgumentException(message) + +class AcpTerminalHost( + private val launcher: AcpTerminalProcessLauncher, + private val pathPolicy: AcpPathPolicy = AcpPathPolicy() +) { + private val sessions = ConcurrentHashMap() + + fun createTerminal( + session: AcpHostSessionContext, + request: CreateTerminalRequest + ): CreateTerminalResponse { + require(request.sessionId.value == session.sessionId) { + "Terminal request session does not match host session" + } + + val terminalId = "terminal-${UUID.randomUUID()}" + val process = launcher.launch( + command = request.command, + args = request.args, + cwd = resolveWorkingDirectory(session.cwd, request.cwd), + env = request.env.associate { it.name to it.value }, + outputByteLimit = request.outputByteLimit + ) + sessions[terminalId] = process + return CreateTerminalResponse(terminalId = terminalId) + } + + fun output( + session: AcpHostSessionContext, + request: TerminalOutputRequest + ): TerminalOutputResponse { + require(request.sessionId.value == session.sessionId) { + "Terminal request session does not match host session" + } + val process = processFor(session, request.terminalId) + val snapshot = process.output() + return TerminalOutputResponse( + output = snapshot.output, + truncated = snapshot.truncated, + exitStatus = snapshot.exitStatus?.let { + TerminalExitStatus( + exitCode = it.exitCode, + signal = it.signal + ) + } + ) + } + + fun release( + session: AcpHostSessionContext, + request: ReleaseTerminalRequest + ): ReleaseTerminalResponse { + require(request.sessionId.value == session.sessionId) { + "Terminal request session does not match host session" + } + val process = processFor(session, request.terminalId) + process.release() + sessions.remove(request.terminalId) + return ReleaseTerminalResponse() + } + + suspend fun waitForExit( + session: AcpHostSessionContext, + request: WaitForTerminalExitRequest + ): WaitForTerminalExitResponse { + require(request.sessionId.value == session.sessionId) { + "Terminal request session does not match host session" + } + val process = processFor(session, request.terminalId) + val exit = process.waitForExit() + return WaitForTerminalExitResponse( + exitCode = exit.exitCode, + signal = exit.signal + ) + } + + fun kill( + session: AcpHostSessionContext, + request: KillTerminalCommandRequest + ): KillTerminalCommandResponse { + require(request.sessionId.value == session.sessionId) { + "Terminal request session does not match host session" + } + val process = processFor(session, request.terminalId) + process.kill() + return KillTerminalCommandResponse() + } + + private fun processFor(session: AcpHostSessionContext, terminalId: String): AcpTerminalProcess { + return sessions[terminalId] + ?: throw AcpHostTerminalNotFoundException( + "Unknown terminal '$terminalId' for session ${session.sessionId}" + ) + } + + private fun resolveWorkingDirectory(cwd: Path, overrideCwd: String?): Path { + return overrideCwd?.takeIf { it.isNotBlank() } + ?.let { pathPolicy.resolveWithinCwd(it, cwd) } + ?: cwd.toAbsolutePath().normalize() + } +} + +class AcpHostCapabilities( + private val fileHost: AcpFileHost, + private val terminalHost: AcpTerminalHost +) { + fun clientCapabilities(includeTerminal: Boolean = true) = fileHost.clientCapabilities(includeTerminal) + + fun readTextFile(session: AcpHostSessionContext, request: ReadTextFileRequest) = + fileHost.readTextFile(session, request) + + fun writeTextFile(session: AcpHostSessionContext, request: WriteTextFileRequest) = + fileHost.writeTextFile(session, request) + + fun createTerminal(session: AcpHostSessionContext, request: CreateTerminalRequest) = + terminalHost.createTerminal(session, request) + + fun output(session: AcpHostSessionContext, request: TerminalOutputRequest) = + terminalHost.output(session, request) + + fun release(session: AcpHostSessionContext, request: ReleaseTerminalRequest) = + terminalHost.release(session, request) + + suspend fun waitForExit(session: AcpHostSessionContext, request: WaitForTerminalExitRequest) = + terminalHost.waitForExit(session, request) + + fun kill(session: AcpHostSessionContext, request: KillTerminalCommandRequest) = + terminalHost.kill(session, request) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/DefaultAcpTerminalProcessLauncher.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/DefaultAcpTerminalProcessLauncher.kt new file mode 100644 index 00000000..a08b890d --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/DefaultAcpTerminalProcessLauncher.kt @@ -0,0 +1,142 @@ +package ee.carlrobert.codegpt.agent.external.host + +import com.intellij.execution.configurations.GeneralCommandLine +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.InputStream +import java.nio.charset.StandardCharsets +import java.nio.file.Path +import java.util.UUID +import java.util.concurrent.atomic.AtomicReference + +class DefaultAcpTerminalProcessLauncher( + private val scope: CoroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) +) : AcpTerminalProcessLauncher { + override fun launch( + command: String, + args: List, + cwd: Path, + env: Map, + outputByteLimit: ULong? + ): AcpTerminalProcess { + val process = GeneralCommandLine().apply { + withExePath(command) + withParameters(args) + withWorkDirectory(cwd.toFile()) + withEnvironment(env) + withParentEnvironmentType(GeneralCommandLine.ParentEnvironmentType.NONE) + withRedirectErrorStream(true) + }.createProcess() + + return DefaultAcpTerminalProcess( + terminalId = UUID.randomUUID().toString(), + process = process, + scope = scope, + outputByteLimit = outputByteLimit + ) + } +} + +private class DefaultAcpTerminalProcess( + override val terminalId: String, + private val process: Process, + scope: CoroutineScope, + outputByteLimit: ULong? +) : AcpTerminalProcess { + + private val outputBuffer = BoundedOutputBuffer(outputByteLimit) + private val exitStatus = AtomicReference(null) + private val stdoutJob = scope.launch { + collect(process.inputStream) + } + private val stderrJob = scope.launch { + collect(process.errorStream) + } + + override fun output(): AcpTerminalOutputSnapshot { + if (exitStatus.get() == null && !process.isAlive) { + exitStatus.compareAndSet( + null, + AcpTerminalExitStatus(exitCode = process.exitValue().toUInt()) + ) + } + return outputBuffer.snapshot(exitStatus.get()) + } + + override suspend fun waitForExit(): AcpTerminalExitStatus { + val exitCode = withContext(Dispatchers.IO) { + process.waitFor() + } + joinAll(stdoutJob, stderrJob) + val status = AcpTerminalExitStatus(exitCode = exitCode.toUInt()) + exitStatus.set(status) + return status + } + + override fun release() { + } + + override fun kill() { + process.destroyForcibly() + } + + private suspend fun collect(stream: InputStream) { + withContext(Dispatchers.IO) { + stream.bufferedReader(StandardCharsets.UTF_8).useLines { lines -> + lines.forEach { line -> + outputBuffer.appendLine(line) + } + } + } + } +} + +private class BoundedOutputBuffer( + private val byteLimit: ULong? +) { + private val text = StringBuilder() + private var bytesUsed: ULong = 0u + private var truncated = false + + @Synchronized + fun appendLine(line: String) { + if (truncated) return + + val encoded = (line + '\n').toByteArray(StandardCharsets.UTF_8) + val limit = byteLimit + if (limit == null) { + text.appendLine(line) + return + } + + val remaining = limit - bytesUsed + if (remaining == 0uL) { + truncated = true + return + } + + if (encoded.size.toUInt().toULong() <= remaining) { + text.appendLine(line) + bytesUsed += encoded.size.toUInt().toULong() + return + } + + val allowed = remaining.toInt().coerceAtMost(encoded.size) + text.append(String(encoded, 0, allowed, StandardCharsets.UTF_8)) + bytesUsed = limit + truncated = true + } + + @Synchronized + fun snapshot(exitStatus: AcpTerminalExitStatus? = null): AcpTerminalOutputSnapshot { + return AcpTerminalOutputSnapshot( + output = text.toString(), + truncated = truncated, + exitStatus = exitStatus + ) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpIncomingRequestManager.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpIncomingRequestManager.kt new file mode 100644 index 00000000..a16e25dc --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpIncomingRequestManager.kt @@ -0,0 +1,175 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.model.AcpMethod +import com.agentclientprotocol.model.CancelRequestNotification +import com.agentclientprotocol.rpc.* +import com.agentclientprotocol.transport.Transport +import ee.carlrobert.codegpt.agent.external.acpcompat.AcpExpectedError +import ee.carlrobert.codegpt.agent.external.decodeOrNull +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlin.coroutines.CoroutineContext + +private val incomingLogger = KotlinLogging.logger {} + +internal class AcpIncomingRequestManager( + private val state: AcpProtocolState, + private val transport: Transport, + private val json: Json, + private val requestsScope: CoroutineScope, + private val trace: ((String) -> Unit)? +) { + + fun installMetaHandlers() { + state.notificationHandlers[AcpMethod.MetaMethods.CancelRequest.methodName] = + { notification -> + val request = notification.params.decodeOrNull(json) + when { + request == null -> incomingLogger.warn { "Received CancelRequest with invalid payload" } + else -> { + val requestJob = state.pendingIncomingRequests.remove(request.requestId) + if (requestJob == null) { + incomingLogger.warn { "Received CancelRequest for unknown request: ${request.requestId}" } + } else { + requestJob.cancel( + JsonRpcIncomingRequestCanceledException( + request.message ?: "Cancelled by the counterpart", + ) + ) + } + } + } + } + } + + fun setRequestHandlerRaw( + method: AcpMethod.AcpRequestResponseMethod<*, *>, + additionalContext: CoroutineContext, + handler: suspend (JsonRpcRequest) -> JsonElement? + ) { + state.requestHandlers[method.methodName] = { request -> + withContext(additionalContext) { + handler(request) + } + } + } + + fun setNotificationHandlerRaw( + method: AcpMethod.AcpNotificationMethod<*>, + additionalContext: CoroutineContext, + handler: suspend (JsonRpcNotification) -> Unit + ) { + state.notificationHandlers[method.methodName] = { notification -> + withContext(additionalContext) { + handler(notification) + } + } + } + + fun handleRequest(request: JsonRpcRequest) { + val requestId = request.id + requestsScope.launch { + processRequest(request) + }.also { job -> + state.pendingIncomingRequests[requestId] = job + }.invokeOnCompletion { + state.pendingIncomingRequests.remove(requestId) + } + } + + suspend fun handleNotification(notification: JsonRpcNotification) { + val handler = state.notificationHandlers[notification.method] + if (handler != null) { + try { + handler(notification) + } catch (e: Exception) { + incomingLogger.error(e) { "Error handling notification ${notification.method}" } + } + } else { + incomingLogger.debug { "No handler for notification: ${notification.method}" } + } + } + + fun cancelPendingIncomingRequests(ce: CancellationException? = null) { + val requests = state.pendingIncomingRequests.toMap() + state.pendingIncomingRequests.clear() + for ((requestId, job) in requests) { + incomingLogger.trace { "Canceling pending incoming request: $requestId" } + job.cancel(ce) + } + } + + private suspend fun processRequest(request: JsonRpcRequest) { + val handler = state.requestHandlers[request.method] + val response = if (handler == null) { + JsonRpcResponse( + request.id, + null, + JsonRpcError( + JsonRpcErrorCode.METHOD_NOT_FOUND.code, + "Method not supported: ${request.method}" + ) + ) + } else { + buildResponse(request, handler) + } + if (response != null) { + trace?.invoke("-> ${response.traceSummary()}") + transport.send(response) + } + } + + private suspend fun buildResponse( + request: JsonRpcRequest, + handler: suspend (JsonRpcRequest) -> JsonElement? + ): JsonRpcResponse? { + return try { + val result = withContext(JsonRpcRequestContextElement()) { + handler(request) + } + JsonRpcResponse(request.id, result, null) + } catch (e: AcpExpectedError) { + incomingLogger.trace(e) { "Expected error on '${request.method}'" } + errorResponse( + request.id, + JsonRpcErrorCode.INVALID_PARAMS, + e.message ?: "Invalid params" + ) + } catch (e: SerializationException) { + incomingLogger.trace(e) { "Serialization error on ${request.method}" } + errorResponse( + request.id, + JsonRpcErrorCode.PARSE_ERROR, + e.message ?: "Serialization error" + ) + } catch (ce: CancellationException) { + incomingLogger.trace(ce) { "Incoming request cancelled: ${request.method}" } + if (ce is JsonRpcIncomingRequestCanceledException) { + null + } else { + errorResponse(request.id, JsonRpcErrorCode.CANCELLED, ce.message ?: "Cancelled") + } + } catch (e: Exception) { + incomingLogger.error(e) { "Exception on ${request.method}" } + errorResponse( + request.id, + JsonRpcErrorCode.INTERNAL_ERROR, + e.message ?: "Internal error" + ) + } + } + + private fun errorResponse( + requestId: RequestId, + code: JsonRpcErrorCode, + message: String + ): JsonRpcResponse { + return JsonRpcResponse(requestId, null, JsonRpcError(code.code, message)) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpMessagePump.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpMessagePump.kt new file mode 100644 index 00000000..c6dfa24d --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpMessagePump.kt @@ -0,0 +1,36 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.rpc.JsonRpcMessage +import com.agentclientprotocol.transport.Transport +import com.agentclientprotocol.transport.asMessageChannel +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.launch + +private val pumpLogger = KotlinLogging.logger {} + +internal class AcpMessagePump( + private val scope: CoroutineScope, + private val transport: Transport, + private val protocolDebugName: String, + private val onMessage: suspend (JsonRpcMessage) -> Unit +) { + + fun start() { + scope.launch(CoroutineName("$protocolDebugName.read-messages")) { + try { + for (message in transport.asMessageChannel()) { + try { + onMessage(message) + } catch (e: Exception) { + pumpLogger.error(e) { "Error processing incoming message: $message" } + } + } + } catch (e: Exception) { + pumpLogger.error(e) { "Error processing incoming messages" } + } + } + transport.start() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpOutgoingRequestManager.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpOutgoingRequestManager.kt new file mode 100644 index 00000000..943597f1 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpOutgoingRequestManager.kt @@ -0,0 +1,129 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.model.AcpMethod +import com.agentclientprotocol.model.CancelRequestNotification +import com.agentclientprotocol.rpc.* +import com.agentclientprotocol.transport.Transport +import ee.carlrobert.codegpt.agent.external.acpcompat.JsonRpcException +import ee.carlrobert.codegpt.agent.external.acpcompat.ProtocolOptions +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.* +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.encodeToJsonElement + +private val outgoingLogger = KotlinLogging.logger {} + +internal class AcpOutgoingRequestManager( + private val state: AcpProtocolState, + private val transport: Transport, + private val json: Json, + private val options: ProtocolOptions, + private val trace: ((String) -> Unit)? +) { + + suspend fun sendRequestRaw( + method: MethodName, + params: JsonElement? = null + ): JsonElement { + val requestId = state.nextRequestId() + val deferred = CompletableDeferred() + state.pendingOutgoingRequests[requestId] = deferred + + try { + val request = JsonRpcRequest(requestId, method, params) + trace?.invoke("-> ${request.traceSummary()}") + transport.send(request) + return deferred.await() + } catch (jsonRpcException: JsonRpcException) { + throw convertJsonRpcExceptionIfPossible(jsonRpcException) + } catch (ce: CancellationException) { + outgoingLogger.trace(ce) { + "Request cancelled on this side. Sending CancelRequest notification." + } + withContext(NonCancellable) { + sendCancellationNotification(requestId, ce.message) + + if (!deferred.isCancelled) { + waitForGracefulCancellation(requestId, deferred) + deferred.cancel() + } + } + throw ce + } finally { + state.pendingOutgoingRequests.remove(requestId) + } + } + + fun handleResponse(response: JsonRpcResponse) { + val deferred = state.pendingOutgoingRequests.remove(response.id) + if (deferred != null) { + val responseError = response.error + if (responseError != null) { + deferred.completeExceptionally( + JsonRpcException(responseError.code, responseError.message, responseError.data) + ) + } else { + deferred.complete(response.result ?: JsonNull) + } + } else { + outgoingLogger.warn { "Received response for unknown request ID: ${response.id}" } + } + } + + fun cancelPendingOutgoingRequests(ce: CancellationException? = null) { + val requests = state.pendingOutgoingRequests.toMap() + state.pendingOutgoingRequests.clear() + for ((requestId, deferred) in requests) { + outgoingLogger.trace { "Canceling pending outgoing request: $requestId" } + deferred.cancel(ce) + } + } + + private fun sendCancellationNotification( + requestId: RequestId, + message: String? + ) { + val notification = JsonRpcNotification( + method = AcpMethod.MetaMethods.CancelRequest.methodName, + params = json.encodeToJsonElement(CancelRequestNotification(requestId, message)) + ) + trace?.invoke("-> ${notification.traceSummary()}") + transport.send(notification) + } + + private suspend fun waitForGracefulCancellation( + requestId: RequestId, + deferred: CompletableDeferred + ) { + try { + withTimeout(options.gracefulRequestCancellationTimeout) { + deferred.await() + } + } catch (e: TimeoutCancellationException) { + outgoingLogger.trace(e) { + "Timed out waiting for graceful cancellation response for request: $requestId" + } + } catch (ce: CancellationException) { + outgoingLogger.trace(ce) { + "Graceful cancellation response received for request: $requestId" + } + } catch (e: JsonRpcException) { + val convertedException = convertJsonRpcExceptionIfPossible(e) + if (convertedException is CancellationException) { + outgoingLogger.trace(convertedException) { + "Graceful cancellation response received for request: $requestId" + } + } else { + outgoingLogger.warn(convertedException) { + "Unexpected error while waiting for graceful cancellation response for request: $requestId" + } + } + } catch (e: Exception) { + outgoingLogger.warn(e) { + "Unexpected error while waiting for graceful cancellation response for request: $requestId" + } + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolCore.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolCore.kt new file mode 100644 index 00000000..d6a9c302 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolCore.kt @@ -0,0 +1,102 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.model.AcpMethod +import com.agentclientprotocol.rpc.* +import com.agentclientprotocol.transport.Transport +import ee.carlrobert.codegpt.agent.external.acpcompat.ProtocolOptions +import kotlinx.coroutines.* +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import java.util.concurrent.Executors +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext + +internal class AcpProtocolCore( + parentScope: CoroutineScope, + private val transport: Transport, + json: Json, + private val options: ProtocolOptions +) { + private val trace = options.trace + private val scope = CoroutineScope( + parentScope.coroutineContext + + SupervisorJob(parentScope.coroutineContext[Job]) + + CoroutineName(options.protocolDebugName) + ) + private val requestsExecutor = Executors.newSingleThreadExecutor { runnable -> + Thread(runnable, "${options.protocolDebugName}.requests").apply { + isDaemon = true + } + } + private val requestsDispatcher = requestsExecutor.asCoroutineDispatcher() + private val requestsScope = CoroutineScope( + scope.coroutineContext + + SupervisorJob(scope.coroutineContext[Job]) + + requestsDispatcher + + CoroutineName("${options.protocolDebugName}.requests") + ) + private val state = AcpProtocolState() + private val incomingRequests = AcpIncomingRequestManager(state, transport, json, requestsScope, trace) + private val outgoingRequests = AcpOutgoingRequestManager(state, transport, json, options, trace) + private val messagePump = AcpMessagePump( + scope, + transport, + options.protocolDebugName, + onMessage = ::handleIncomingMessage + ) + + fun start() { + incomingRequests.installMetaHandlers() + messagePump.start() + } + + suspend fun sendRequestRaw( + method: MethodName, + params: JsonElement? = null + ): JsonElement { + return outgoingRequests.sendRequestRaw(method, params) + } + + fun sendNotificationRaw( + method: AcpMethod.AcpNotificationMethod<*>, + params: JsonElement? = null + ) { + val notification = JsonRpcNotification(method = method.methodName, params = params) + trace?.invoke("-> ${notification.traceSummary()}") + transport.send(notification) + } + + fun setRequestHandlerRaw( + method: AcpMethod.AcpRequestResponseMethod<*, *>, + additionalContext: CoroutineContext = EmptyCoroutineContext, + handler: suspend (JsonRpcRequest) -> JsonElement? + ) { + incomingRequests.setRequestHandlerRaw(method, additionalContext, handler) + } + + fun setNotificationHandlerRaw( + method: AcpMethod.AcpNotificationMethod<*>, + additionalContext: CoroutineContext = EmptyCoroutineContext, + handler: suspend (JsonRpcNotification) -> Unit + ) { + incomingRequests.setNotificationHandlerRaw(method, additionalContext, handler) + } + + fun close() { + transport.close() + val message = "Protocol closed" + incomingRequests.cancelPendingIncomingRequests(CancellationException(message)) + outgoingRequests.cancelPendingOutgoingRequests(CancellationException(message)) + scope.cancel(message) + requestsDispatcher.close() + } + + private suspend fun handleIncomingMessage(message: JsonRpcMessage) { + trace?.invoke("<- ${message.traceSummary()}") + when (message) { + is JsonRpcNotification -> incomingRequests.handleNotification(message) + is JsonRpcRequest -> incomingRequests.handleRequest(message) + is JsonRpcResponse -> outgoingRequests.handleResponse(message) + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolState.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolState.kt new file mode 100644 index 00000000..12fa8fd2 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolState.kt @@ -0,0 +1,23 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.rpc.JsonRpcNotification +import com.agentclientprotocol.rpc.JsonRpcRequest +import com.agentclientprotocol.rpc.MethodName +import com.agentclientprotocol.rpc.RequestId +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Job +import kotlinx.serialization.json.JsonElement +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +internal class AcpProtocolState { + private val requestIdCounter = AtomicInteger(0) + + val pendingOutgoingRequests = ConcurrentHashMap>() + val pendingIncomingRequests = ConcurrentHashMap() + val requestHandlers = ConcurrentHashMap JsonElement?>() + val notificationHandlers = + ConcurrentHashMap Unit>() + + fun nextRequestId(): RequestId = RequestId.Companion.create(requestIdCounter.incrementAndGet()) +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolSupport.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolSupport.kt new file mode 100644 index 00000000..fdcaa302 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolSupport.kt @@ -0,0 +1,40 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.rpc.JsonRpcErrorCode +import ee.carlrobert.codegpt.agent.external.acpcompat.AcpExpectedError +import ee.carlrobert.codegpt.agent.external.acpcompat.JsonRpcException +import kotlinx.coroutines.CancellationException +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonElement +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext + +internal class JsonRpcRequestContextElement : + AbstractCoroutineContextElement(Key) { + object Key : CoroutineContext.Key +} + +internal class JsonRpcIncomingRequestCanceledException( + message: String, + val data: JsonElement? = null +) : CancellationException(message) + +internal fun convertJsonRpcExceptionIfPossible(jsonRpcException: JsonRpcException): Exception { + return when (jsonRpcException.code) { + JsonRpcErrorCode.PARSE_ERROR.code -> SerializationException( + jsonRpcException.message, + jsonRpcException + ) + + JsonRpcErrorCode.INVALID_PARAMS.code -> AcpExpectedError( + jsonRpcException.message ?: "Invalid params" + ) + + JsonRpcErrorCode.CANCELLED.code -> CancellationException( + jsonRpcException.message ?: "Cancelled on the counterpart side", + jsonRpcException + ) + + else -> jsonRpcException + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolTrace.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolTrace.kt new file mode 100644 index 00000000..fa2d0286 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolTrace.kt @@ -0,0 +1,32 @@ +package ee.carlrobert.codegpt.agent.external.runtime + +import com.agentclientprotocol.rpc.JsonRpcError +import com.agentclientprotocol.rpc.JsonRpcMessage +import com.agentclientprotocol.rpc.JsonRpcNotification +import com.agentclientprotocol.rpc.JsonRpcRequest +import com.agentclientprotocol.rpc.JsonRpcResponse +import kotlinx.serialization.json.JsonElement + +internal fun JsonRpcMessage.traceSummary(): String { + return when (this) { + is JsonRpcRequest -> "request id=$id method=$method params=${params.traceSummary()}" + is JsonRpcNotification -> "notification method=$method params=${params.traceSummary()}" + is JsonRpcResponse -> "response id=$id ${error.traceSummary(result)}" + } +} + +private fun JsonRpcError?.traceSummary(result: JsonElement?): String { + return if (this == null) { + "result=${result.traceSummary()}" + } else { + "errorCode=$code errorMessage=${message.orEmpty().singleLine()} errorData=${data.traceSummary()}" + } +} + +private fun JsonElement?.traceSummary(limit: Int = 400): String { + return this?.toString()?.singleLine()?.take(limit) ?: "null" +} + +private fun String.singleLine(): String { + return replace(Regex("\\s+"), " ").trim() +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashOutputTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashOutputTool.kt index 0bed29fc..1528b048 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashOutputTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashOutputTool.kt @@ -13,8 +13,8 @@ import kotlinx.serialization.Serializable class BashOutputTool( workingDirectory: String, - hookManager: HookManager, private val sessionId: String, + hookManager: HookManager, ) : BaseTool( workingDirectory = workingDirectory, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/IntelliJSearchTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/IntelliJSearchTool.kt index 2887dc33..560600f0 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/IntelliJSearchTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/IntelliJSearchTool.kt @@ -32,9 +32,9 @@ import java.nio.file.Paths * Enhanced search tool using IntelliJ's native SearchService and FindModel. */ class IntelliJSearchTool( - hookManager: HookManager, - sessionId: String, private val project: Project, + sessionId: String, + hookManager: HookManager, ) : BaseTool( workingDirectory = project.basePath ?: System.getProperty("user.dir"), argsSerializer = Args.serializer(), diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/TaskTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/TaskTool.kt index 03a301fd..900ea6a9 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/TaskTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/TaskTool.kt @@ -527,7 +527,7 @@ private class ExternalSubagentEventsAdapter( delegate.onSubAgentToolCompleted(parentId, childId, toolName, result) } - override fun onQueuedMessagesResolved() = Unit + override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit override suspend fun approveToolCall(request: ToolApprovalRequest): Boolean { return when (request.type) { diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebFetchTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebFetchTool.kt index 11a3474e..70e9f2b6 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebFetchTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebFetchTool.kt @@ -13,9 +13,9 @@ import java.net.URI class WebFetchTool( workingDirectory: String, - private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)", sessionId: String, hookManager: HookManager, + private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)", ) : BaseTool( workingDirectory = workingDirectory, argsSerializer = Args.serializer(), diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebSearchTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebSearchTool.kt index fd6d8037..c1f61c22 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebSearchTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WebSearchTool.kt @@ -17,9 +17,9 @@ import java.time.format.DateTimeFormatter class WebSearchTool( workingDirectory: String, - private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)", sessionId: String, hookManager: HookManager, + private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)", ) : BaseTool( workingDirectory = workingDirectory, argsSerializer = Args.serializer(), diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt index 2e8ce1bf..3139567a 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/AgentCompletionRunner.kt @@ -68,7 +68,7 @@ internal object AgentCompletionRunner : CompletionRunner { } } - override fun onQueuedMessagesResolved() = Unit + override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit } val cancellableRequest = CancellableRequest { @@ -137,7 +137,7 @@ internal object AgentCompletionRunner : CompletionRunner { } (msg as? ai.koog.prompt.message.Message.Reasoning)?.let { if (it.content.isNotBlank()) { - events.onTextReceived("${it.content}") + events.onThinkingReceived(it.content) } } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/ProxyAISettingsService.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/ProxyAISettingsService.kt index 086511d9..f20c356c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/ProxyAISettingsService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/ProxyAISettingsService.kt @@ -6,6 +6,7 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.diagnostic.thisLogger import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile +import ee.carlrobert.codegpt.agent.external.ExternalAcpAgents import ee.carlrobert.codegpt.settings.agents.SubagentDefaults import ee.carlrobert.codegpt.settings.hooks.HookConfiguration import ee.carlrobert.codegpt.settings.service.ServiceType diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt index 878a0499..f54017a2 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt @@ -2,6 +2,8 @@ package ee.carlrobert.codegpt.toolwindow.agent import ai.koog.agents.core.agent.exception.AIAgentStuckInTheNodeException import ai.koog.http.client.KoogHttpClientException +import com.agentclientprotocol.model.PlanEntry +import com.agentclientprotocol.model.PlanEntryStatus import com.intellij.openapi.Disposable import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.application.runInEdt @@ -18,6 +20,7 @@ import ee.carlrobert.codegpt.agent.history.CheckpointRef import ee.carlrobert.codegpt.agent.rollback.RollbackService import ee.carlrobert.codegpt.agent.tools.* import ee.carlrobert.codegpt.settings.agents.SubagentRuntimeResolver +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTApiException import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.codegpt.toolwindow.agent.ui.* @@ -297,6 +300,12 @@ class AgentEventHandler( } else { request } + logger.debug( + "Enqueueing agent approval for session=$sessionId type=${resolvedRequest.type} title=${resolvedRequest.title.logPreview()} payload=${resolvedRequest.payload.logSummary()}" + ) + approvalTrace( + "approval/enqueue session=$sessionId type=${resolvedRequest.type} title=${resolvedRequest.title.logPreview()} payload=${resolvedRequest.payload.logSummary()}" + ) runInEdt { approvalQueue.addLast(ApprovalRequest(resolvedRequest, deferred)) maybeShowNextApproval() @@ -318,6 +327,12 @@ class AgentEventHandler( } val decision = CompletableDeferred() + logger.debug( + "Enqueueing agent approval for session=$sessionId type=${request.type} title=${request.title.logPreview()} payload=${request.payload.logSummary()}" + ) + approvalTrace( + "approval/enqueue session=$sessionId type=${request.type} title=${request.title.logPreview()} payload=${request.payload.logSummary()}" + ) runInEdt { approvalQueue.addLast(ApprovalRequest(request, decision)) maybeShowNextApproval() @@ -344,6 +359,23 @@ class AgentEventHandler( } } + override fun onThinkingReceived(text: String) { + runInEdt { + currentResponseBody?.appendThinking(text) + scrollablePanel.update() + scrollablePanel.scrollToBottom() + } + } + + override fun onPlanUpdated(entries: List) { + runInEdt { + todoListPanel.updateTodos(entries.toTodoItems()) + todoListPanel.isVisible = entries.isNotEmpty() + scrollablePanel.update() + scrollablePanel.scrollToBottom() + } + } + override fun onToolStarting(id: String, toolName: String, args: Any?) { when (args) { is TodoWriteTool.Args -> { @@ -511,10 +543,29 @@ class AgentEventHandler( } } - override fun onQueuedMessagesResolved() { - val pendingMessage = project.service() - .getPendingMessages(sessionId) - .firstOrNull { it.uiVisible } ?: return + private fun List.toTodoItems(): List { + return map { entry -> + TodoWriteTool.TodoItem( + content = entry.content, + status = when (entry.status) { + PlanEntryStatus.PENDING -> TodoWriteTool.TodoStatus.PENDING + PlanEntryStatus.IN_PROGRESS -> TodoWriteTool.TodoStatus.IN_PROGRESS + PlanEntryStatus.COMPLETED -> TodoWriteTool.TodoStatus.COMPLETED + }, + activeForm = entry.content + ) + } + } + + override fun onQueuedMessagesResolved(message: MessageWithContext?) { + val pendingMessage = message + ?: project.service() + .getPendingMessages(sessionId) + .firstOrNull { it.uiVisible } + ?: return + logger.debug( + "Resolving queued message in UI for session=$sessionId messageId=${pendingMessage.id} uiVisible=${pendingMessage.uiVisible} preview=${pendingMessage.text.logPreview()}" + ) onQueuedMessagesResolved(pendingMessage) } @@ -525,11 +576,48 @@ class AgentEventHandler( } override fun onTokenUsageAvailable(tokenUsage: Long) { - lastReportedPromptTokens = tokenUsage + onUsageAvailable(AgentUsageEvent(usedTokens = tokenUsage)) + } - val event = TokenUsageEvent(sessionId, tokenUsage) + override fun onUsageAvailable(event: AgentUsageEvent) { + lastReportedPromptTokens = event.usedTokens project.messageBus.syncPublisher(TokenUsageListener.TOKEN_USAGE_TOPIC) - .onTokenUsageChanged(event) + .onTokenUsageChanged( + TokenUsageEvent( + sessionId = sessionId, + totalTokens = event.usedTokens, + sizeTokens = event.sizeTokens, + costAmount = event.costAmount, + costCurrency = event.costCurrency + ) + ) + } + + override fun onRuntimeOptionsUpdated() { + runInEdt { + userInputPanel.refreshModelDependentState() + } + } + + override fun onSessionInfoUpdated(title: String?, updatedAt: String?) { + val normalizedTitle = title?.trim().orEmpty() + if (normalizedTitle.isEmpty()) { + return + } + + val contentManager = project.service() + val session = contentManager.getSession(sessionId) ?: return + val previousExternalTitle = session.externalAgentSessionTitle + session.externalAgentSessionTitle = normalizedTitle + + val shouldRename = session.displayName.isBlank() || + session.displayName == previousExternalTitle || + session.displayName.matches(Regex("""Agent \d+( \(\d+\))?""")) + + if (shouldRename) { + project.messageBus.syncPublisher(AgentTabTitleNotifier.AGENT_TAB_TITLE_TOPIC) + .updateTabTitle(sessionId, normalizedTitle) + } } override fun onCreditsAvailable(event: AgentCreditsEvent) { @@ -587,6 +675,12 @@ class AgentEventHandler( val contentManager = project.service() if (contentManager.isSessionAutoApproved(sessionId)) { + approvalTrace( + "approval/auto-approve session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}" + ) + logger.debug( + "Auto-approving queued agent approval for session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}" + ) next.deferred.complete(true) currentApproval = null maybeShowNextApproval() @@ -615,6 +709,13 @@ class AgentEventHandler( updateEditToolCardPreview(next.model) } + logger.debug( + "Showing agent approval for session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()} payload=${next.model.payload.logSummary()} queueRemaining=${approvalQueue.size}" + ) + approvalTrace( + "approval/show session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()} payload=${next.model.payload.logSummary()} queueRemaining=${approvalQueue.size}" + ) + runCatching { project.service() .setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.APPROVAL) @@ -629,6 +730,12 @@ class AgentEventHandler( .markSessionAsAutoApproved(sessionId) } + logger.debug( + "Approved agent approval for session=$sessionId type=${next.model.type} auto=$auto title=${next.model.title.logPreview()}" + ) + approvalTrace( + "approval/approved session=$sessionId type=${next.model.type} auto=$auto title=${next.model.title.logPreview()}" + ) next.deferred.complete(true) currentApproval = null clearApprovalContainer() @@ -639,6 +746,12 @@ class AgentEventHandler( maybeShowNextApproval() }, onReject = { + logger.debug( + "Rejected agent approval for session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}" + ) + approvalTrace( + "approval/rejected session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}" + ) next.deferred.complete(false) currentApproval = null clearApprovalContainer() @@ -758,6 +871,28 @@ class AgentEventHandler( return viewHolder } + private fun String.logPreview(limit: Int = 120): String { + return replace('\n', ' ') + .replace(Regex("\\s+"), " ") + .trim() + .take(limit) + } + + private fun ToolApprovalPayload?.logSummary(): String { + return when (this) { + is WritePayload -> "write:${filePath.logPreview(80)}" + is EditPayload -> "edit:${filePath.logPreview(80)}" + is BashPayload -> "bash:${command.logPreview(80)}" + null -> "none" + } + } + + private fun approvalTrace(message: String) { + if (ConfigurationSettings.getState().debugModeEnabled) { + logger.info("[APPROVAL TRACE] $message") + } + } + class RunViewHolder( private val vm: AgentRunViewModel, private val view: AgentRunDslPanel, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSession.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSession.kt index d91df3d5..ffe7492d 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSession.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSession.kt @@ -1,9 +1,11 @@ package ee.carlrobert.codegpt.toolwindow.agent +import com.agentclientprotocol.model.AvailableCommand import com.intellij.openapi.vfs.VirtualFile import ee.carlrobert.codegpt.agent.history.CheckpointRef import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.settings.service.ServiceType +import kotlinx.serialization.json.JsonElement data class AcpConfigOptionChoice( val value: String, @@ -25,7 +27,7 @@ object AcpConfigOptions { fun selectable(options: List): List { return options .asSequence() - .filter { it.type == "select" && it.options.isNotEmpty() } + .filter { it.type in setOf("select", "boolean") && it.options.isNotEmpty() } .sortedBy { categoryOrder(it.category) } .toList() } @@ -40,7 +42,9 @@ object AcpConfigOptions { } fun selectedValueName(options: List, category: String): String? { - val option = selectable(options).firstOrNull { it.category == category } ?: return null + val option = selectable(options).firstOrNull { + it.category == category || it.id == category + } ?: return null return option.options .firstOrNull { it.value == option.currentValue } ?.name @@ -73,6 +77,24 @@ object AcpConfigOptions { .toMap(linkedMapOf()) } + fun summaryParts( + options: List, + maxEntries: Int = 3 + ): List { + return selectable(options) + .mapNotNull { option -> + selectedValueName(options, option.category ?: option.id)?.let { value -> + if (option.category in setOf("model", "mode", "thought_level")) { + value + } else { + "${label(option)}: $value" + } + } + } + .distinct() + .take(maxEntries) + } + private fun categoryOrder(category: String?): Int { return when (category.orEmpty()) { "model" -> 0 @@ -95,6 +117,7 @@ data class AgentSession( var modelCode: String? = null, var externalAgentId: String? = null, var externalAgentSessionId: String? = null, + var externalAgentMcpServerIds: Set = emptySet(), var externalAgentConfigOptions: List = emptyList(), var externalAgentConfigSelections: Map = emptyMap(), var externalAgentConfigLoading: Boolean = false, @@ -105,4 +128,14 @@ data class AgentSession( var lastActiveAt: Long = System.currentTimeMillis() ) { var externalAgentErrorMessage: String? = null + var externalAgentPeerProfileId: String? = null + var externalAgentAvailableCommands: List = emptyList() + var externalAgentVendorMeta: JsonElement? = null + var externalAgentRequestMeta: JsonElement? = null + var externalAgentSessionTitle: String? = null + + fun shouldRecreateExternalAgentSession(selectedMcpServerIds: Set): Boolean { + return !externalAgentSessionId.isNullOrBlank() && + externalAgentMcpServerIds != selectedMcpServerIds + } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowPanel.kt index 7e002129..3bfdc955 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowPanel.kt @@ -28,6 +28,7 @@ class AgentToolWindowPanel( private val tabbedPane = contentManager.initializeTabbedPane() private val centerLayout = CardLayout() private val centerPanel = JPanel(centerLayout) + private val creditsLabel = AgentCreditsToolbarLabel(project, ::currentSession) private var landingPanel: AgentToolWindowTabPanel? = null init { @@ -75,7 +76,6 @@ class AgentToolWindowPanel( val toolbar = ActionManager.getInstance() .createActionToolbar("AgentToolWindow", actionGroup, true) toolbar.targetComponent = tabbedPane - val creditsLabel = AgentCreditsToolbarLabel(project) Disposer.register(this, creditsLabel) return BorderLayoutPanel() .addToLeft(toolbar.component) @@ -87,6 +87,7 @@ class AgentToolWindowPanel( private fun showTabsView() { disposeLandingPanel() centerLayout.show(centerPanel, TABS_CARD) + creditsLabel.refresh() } private fun showLandingView() { @@ -97,6 +98,7 @@ class AgentToolWindowPanel( landingPanel?.requestFocusForTextArea() centerPanel.revalidate() centerPanel.repaint() + creditsLabel.refresh() } private fun createLandingPanel(): AgentToolWindowTabPanel { @@ -122,6 +124,11 @@ class AgentToolWindowPanel( landingPanel = null } + private fun currentSession(): AgentSession? { + return landingPanel?.getAgentSession() + ?: contentManager.getActiveTabPanel()?.getAgentSession() + } + override fun dispose() { tabbedPane.setTabLifecycleCallbacks(onTabsOpened = {}, onAllTabsClosed = {}) disposeLandingPanel() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt index 23da22bf..3ce065d3 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt @@ -113,6 +113,7 @@ class AgentToolWindowTabPanel( ).apply { isVisible = false } + private val tokenUsageCounterPanel = TokenUsageCounterPanel(project, sessionId) private val userInputPanel = UserInputPanel( project, @@ -128,8 +129,8 @@ class AgentToolWindowTabPanel( onSubmit = ::handleSubmit, onStop = ::handleCancel, withRemovableSelectedEditorTag = true, - agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId), - agentTokenCounterVisibilityProvider = { agentSession.externalAgentId.isNullOrBlank() }, + agentTokenCounterPanel = tokenUsageCounterPanel, + agentTokenCounterVisibilityProvider = { tokenUsageCounterPanel.hasReportedUsage() }, sessionIdProvider = { sessionId }, conversationIdProvider = { conversation.id }, onStartSessionTimeline = ::showSessionStartTimelineDialog, @@ -300,8 +301,7 @@ class AgentToolWindowTabPanel( if (message.text.isBlank()) return disposeLandingPanelIfPresent() scrollablePanel.clearLandingViewIfVisible() - val modelSettings = ModelSettings.getInstance() - val agentModelSelection = modelSettings.getModelSelectionForFeature(FeatureType.AGENT) + val agentModelSelection = service().getModelSelectionForFeature(FeatureType.AGENT) agentSession.serviceType = agentModelSelection.provider agentSession.modelCode = agentModelSelection.selectionId @@ -409,10 +409,13 @@ class AgentToolWindowTabPanel( project.service().closeSession(sessionId) agentSession.externalAgentId = externalAgentId agentSession.externalAgentSessionId = null + agentSession.externalAgentMcpServerIds = emptySet() agentSession.externalAgentConfigOptions = emptyList() agentSession.externalAgentConfigSelections = emptyMap() agentSession.externalAgentErrorMessage = null agentSession.externalAgentConfigLoading = !externalAgentId.isNullOrBlank() + project.messageBus.syncPublisher(AgentUiStateNotifier.AGENT_UI_STATE_TOPIC) + .sessionRuntimeChanged(sessionId) if (!externalAgentId.isNullOrBlank()) { backgroundScope.launch { runCatching { @@ -474,15 +477,18 @@ class AgentToolWindowTabPanel( withContext(Dispatchers.EDT) { inputPanel.refreshModelDependentState() } - runCatching { - project.service() - .setSessionConfigOption(agentSession, optionId, value) - }.onFailure { ex -> + try { + runCatching { + project.service() + .setSessionConfigOption(agentSession, optionId, value) + }.onFailure { ex -> + OverlayUtil.showNotification( + "${displayExternalAgentName(agentSession.externalAgentId ?: "agent")} option update failed. ${buildExternalAgentConfigFailureMessage(ex)}", + NotificationType.ERROR + ) + } + } finally { agentSession.externalAgentConfigLoading = false - OverlayUtil.showNotification( - "${displayExternalAgentName(agentSession.externalAgentId ?: "agent")} option update failed. ${buildExternalAgentConfigFailureMessage(ex)}", - NotificationType.ERROR - ) } withContext(Dispatchers.EDT) { inputPanel.refreshModelDependentState() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabbedPane.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabbedPane.kt index 42497498..7e606708 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabbedPane.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabbedPane.kt @@ -48,7 +48,11 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(), init { tabComponentInsets = null setComponentPopupMenu(TabPopupMenu()) - addChangeListener { refreshTabState() } + addChangeListener { + refreshTabState() + project.messageBus.syncPublisher(AgentUiStateNotifier.AGENT_UI_STATE_TOPIC) + .activeSessionChanged() + } } enum class TabStatus { diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentUiContextTopic.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentUiContextTopic.kt new file mode 100644 index 00000000..05e3c6bc --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentUiContextTopic.kt @@ -0,0 +1,11 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import com.intellij.util.messages.Topic + +interface AgentUiContextListener { + fun onUiContextChanged() + + companion object { + val AGENT_UI_CONTEXT_TOPIC = Topic("Agent UI Context Changes", AgentUiContextListener::class.java) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentUiStateNotifier.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentUiStateNotifier.kt new file mode 100644 index 00000000..4c987f76 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentUiStateNotifier.kt @@ -0,0 +1,16 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import com.intellij.util.messages.Topic + +interface AgentUiStateNotifier { + + fun activeSessionChanged() + + fun sessionRuntimeChanged(sessionId: String) + + companion object { + @JvmField + val AGENT_UI_STATE_TOPIC: Topic = + Topic.create("agentUiState", AgentUiStateNotifier::class.java) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/TokenUsageTopic.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/TokenUsageTopic.kt index 2eb3458f..f3915258 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/TokenUsageTopic.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/TokenUsageTopic.kt @@ -5,6 +5,9 @@ import com.intellij.util.messages.Topic data class TokenUsageEvent( val sessionId: String, val totalTokens: Long, + val sizeTokens: Long? = null, + val costAmount: Double? = null, + val costCurrency: String? = null, ) interface TokenUsageListener { @@ -13,4 +16,4 @@ interface TokenUsageListener { companion object { val TOKEN_USAGE_TOPIC = Topic("Token Usage Changes", TokenUsageListener::class.java) } -} \ No newline at end of file +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentCreditsToolbarLabel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentCreditsToolbarLabel.kt index 7720cb08..6fea069f 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentCreditsToolbarLabel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentCreditsToolbarLabel.kt @@ -10,17 +10,23 @@ import com.intellij.util.ui.JBUI import ee.carlrobert.codegpt.CodeGPTBundle import ee.carlrobert.codegpt.CodeGPTKeys import ee.carlrobert.codegpt.settings.models.ModelSettings -import ee.carlrobert.codegpt.settings.service.* +import ee.carlrobert.codegpt.settings.service.FeatureType +import ee.carlrobert.codegpt.settings.service.ModelChangeNotifier +import ee.carlrobert.codegpt.settings.service.ModelChangeNotifierAdapter +import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTService import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTUserDetails import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTUserDetailsNotifier +import ee.carlrobert.codegpt.toolwindow.agent.AgentSession import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsListener +import ee.carlrobert.codegpt.toolwindow.agent.AgentUiStateNotifier import java.text.NumberFormat import java.util.* class AgentCreditsToolbarLabel( - private val project: Project + private val project: Project, + private val sessionProvider: () -> AgentSession? ) : JBLabel(), Disposable { private val numberFormat = NumberFormat.getNumberInstance(Locale.US).apply { @@ -67,13 +73,31 @@ class AgentCreditsToolbarLabel( } } ) + messageBusConnection.subscribe( + AgentUiStateNotifier.AGENT_UI_STATE_TOPIC, + object : AgentUiStateNotifier { + override fun activeSessionChanged() { + updateDisplay() + } + + override fun sessionRuntimeChanged(sessionId: String) { + updateDisplay() + } + } + ) + } + + fun refresh() { + updateDisplay() } private fun updateDisplay() { ApplicationManager.getApplication().invokeLater { + val activeSession = sessionProvider() val provider = ModelSettings.getInstance() .getServiceForFeature(FeatureType.AGENT) - if (provider != ServiceType.PROXYAI) { + val isExternalAgentSelected = !activeSession?.externalAgentId.isNullOrBlank() + if (provider != ServiceType.PROXYAI || isExternalAgentSelected) { text = null toolTipText = null isVisible = false diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentModelComboBoxAction.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentModelComboBoxAction.kt index c7a78834..db232e0c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentModelComboBoxAction.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentModelComboBoxAction.kt @@ -316,6 +316,7 @@ class AgentModelComboBoxAction( override fun actionPerformed(event: AnActionEvent) { agentSession.externalAgentId = preset.id agentSession.externalAgentSessionId = null + agentSession.externalAgentMcpServerIds = emptySet() onAgentRuntimeChanged(preset.id) updateExternalAgentPresentation(preset.id) onModelChange(modelSettings.getServiceForFeature(FeatureType.AGENT)) @@ -484,6 +485,7 @@ class AgentModelComboBoxAction( private fun clearExternalAgentSelection() { agentSession.externalAgentId = null agentSession.externalAgentSessionId = null + agentSession.externalAgentMcpServerIds = emptySet() onAgentRuntimeChanged(null) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentRuntimeOptionsComboBoxAction.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentRuntimeOptionsComboBoxAction.kt index d82e4e38..69ca8374 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentRuntimeOptionsComboBoxAction.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentRuntimeOptionsComboBoxAction.kt @@ -144,13 +144,7 @@ class AgentRuntimeOptionsComboBoxAction( ?.takeIf(String::isNotBlank) ?.let { return it } - val parts = listOfNotNull( - AcpConfigOptions.selectedValueName(agentSession.externalAgentConfigOptions, "model"), - AcpConfigOptions.selectedValueName(agentSession.externalAgentConfigOptions, "thought_level") - ).ifEmpty { - listOfNotNull(AcpConfigOptions.selectedValueName(agentSession.externalAgentConfigOptions, "mode")) - } - + val parts = AcpConfigOptions.summaryParts(agentSession.externalAgentConfigOptions) return parts.takeIf { it.isNotEmpty() }?.joinToString(" ยท ") ?: "Options" } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/descriptor/ToolCallDescriptorFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/descriptor/ToolCallDescriptorFactory.kt index e43e710c..4c58b11b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/descriptor/ToolCallDescriptorFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/descriptor/ToolCallDescriptorFactory.kt @@ -6,6 +6,8 @@ import com.intellij.ui.JBColor import com.intellij.util.ui.JBUI import com.intellij.util.ui.components.BorderLayoutPanel import ee.carlrobert.codegpt.Icons +import ee.carlrobert.codegpt.agent.external.events.AcpBashPreviewArgs +import ee.carlrobert.codegpt.agent.external.events.AcpSearchPreviewArgs import ee.carlrobert.codegpt.agent.tools.* import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter import ee.carlrobert.codegpt.toolwindow.agent.ui.AgentUiConfig @@ -66,11 +68,11 @@ object ToolCallDescriptorFactory { private fun detectToolKind(toolName: String, args: Any, result: Any?): ToolKind { return when { - toolName == "IntelliJSearch" || args is IntelliJSearchTool.Args -> ToolKind.SEARCH + toolName == "IntelliJSearch" || args is IntelliJSearchTool.Args || args is AcpSearchPreviewArgs -> ToolKind.SEARCH toolName == "Read" || args is ReadTool.Args -> ToolKind.READ toolName == "Write" || args is WriteTool.Args -> ToolKind.WRITE toolName == "Edit" || args is EditTool.Args -> ToolKind.EDIT - toolName == "Bash" || args is BashTool.Args -> ToolKind.BASH + toolName == "Bash" || args is BashTool.Args || args is AcpBashPreviewArgs -> ToolKind.BASH toolName == "BashOutput" || args is BashOutputTool.Args -> ToolKind.BASH_OUTPUT toolName == "KillShell" || args is KillShellTool.Args -> ToolKind.KILL_SHELL toolName == "WebSearch" || args is WebSearchTool.Args || result is WebSearchTool.Result -> ToolKind.WEB @@ -465,19 +467,24 @@ object ToolCallDescriptorFactory { ): ToolCallDescriptor { val command = when (args) { is BashTool.Args -> args.command + is AcpBashPreviewArgs -> args.command ?: args.title is BashOutputTool.Args -> args.bashId is KillShellTool.Args -> "kill_shell" else -> "Unknown" } + val isGenericBashPreview = args is AcpBashPreviewArgs && + args.command == null && + args.title.equals("Run shell command", ignoreCase = true) + val titleMain = if (isGenericBashPreview) "Pending command" else truncateCommand(command) + val tooltip = if (isGenericBashPreview) "Command pending approval" else "Command: $command" - val truncatedCommand = truncateCommand(command) return ToolCallDescriptor( kind = ToolKind.BASH, icon = AllIcons.Nodes.Console, titlePrefix = "Bash:", - titleMain = truncatedCommand, - tooltip = "Command: $command", + titleMain = titleMain, + tooltip = tooltip, supportsStreaming = true, args = args, result = result, @@ -504,16 +511,29 @@ object ToolCallDescriptorFactory { projectId: String? ): ToolCallDescriptor { val searchArgs = args as? IntelliJSearchTool.Args - val pattern = searchArgs?.pattern ?: "" - val scopeOrPath = searchArgs?.path?.substringAfterLast('/') ?: (searchArgs?.scope ?: "") - val titleMain = buildSearchDisplay(truncatePattern(pattern), scopeOrPath) + val searchPreviewArgs = args as? AcpSearchPreviewArgs + val pattern = searchArgs?.pattern ?: searchPreviewArgs?.pattern.orEmpty() + val scopeOrPath = searchArgs?.path?.substringAfterLast('/') + ?: searchPreviewArgs?.path?.substringAfterLast('/') + ?: (searchArgs?.scope ?: "") + val titleMain = if (pattern.isBlank()) { + scopeOrPath.ifBlank { + searchPreviewArgs?.title?.takeIf { + it.isNotBlank() && !it.equals("search", ignoreCase = true) + } ?: "Search" + } + } else { + buildSearchDisplay(truncatePattern(pattern), scopeOrPath) + } return ToolCallDescriptor( kind = ToolKind.SEARCH, icon = AllIcons.Actions.Search, titlePrefix = "Search:", titleMain = titleMain, - tooltip = if (scopeOrPath.isBlank()) { + tooltip = if (pattern.isBlank()) { + searchPreviewArgs?.path?.let { "Search in $it" } ?: "Search" + } else if (scopeOrPath.isBlank()) { "Search: \"$pattern\"" } else { "Search: \"$pattern\" in $scopeOrPath" diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/components/TokenUsageCounterPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/components/TokenUsageCounterPanel.kt index b480ed5a..ed9034be 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/components/TokenUsageCounterPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/components/TokenUsageCounterPanel.kt @@ -40,6 +40,10 @@ class TokenUsageCounterPanel( private var messageBusConnection: MessageBusConnection? = null private var lastEstimatedPromptTokens: Long = 0L private var calibrationOffset: Long = 0L + private var lastReportedSizeTokens: Long? = null + private var lastReportedCostAmount: Double? = null + private var lastReportedCostCurrency: String? = null + private var hasReportedUsage: Boolean = false init { Disposer.register(ApplicationManager.getApplication(), scope) @@ -47,7 +51,7 @@ class TokenUsageCounterPanel( isOpaque = false font = JBFont.small() border = JBUI.Borders.empty(5, 7) - text = "100% context left" + isVisible = false setupMessageBusConnection() } @@ -61,7 +65,11 @@ class TokenUsageCounterPanel( if (event.sessionId == sessionId) { val model = getAgentModelForSession(event.sessionId) ?: getSelectedAgentModel() - ?: return + hasReportedUsage = true + isVisible = true + lastReportedSizeTokens = event.sizeTokens?.takeIf { it > 0L } + lastReportedCostAmount = event.costAmount + lastReportedCostCurrency = event.costCurrency?.takeIf { it.isNotBlank() } calibrationOffset = event.totalTokens - lastEstimatedPromptTokens updateDisplay(event.totalTokens, model) } @@ -71,13 +79,18 @@ class TokenUsageCounterPanel( } fun updateFromTotalTokens(totalTokens: Long) { + if (!hasReportedUsage) { + return + } lastEstimatedPromptTokens = totalTokens - val model = getSelectedAgentModel() ?: return + val model = getSelectedAgentModel() val calibratedTotal = (totalTokens + calibrationOffset).coerceAtLeast(0L) updateDisplay(calibratedTotal, model) } - private fun updateDisplay(totalTokens: Long, model: LLModel) { + fun hasReportedUsage(): Boolean = hasReportedUsage + + private fun updateDisplay(totalTokens: Long, model: LLModel?) { scope.launch { withEdt { updateColorAndText(totalTokens, model) @@ -88,10 +101,14 @@ class TokenUsageCounterPanel( } } - private fun updateColorAndText(usedPromptTokens: Long, model: LLModel) { - val budget = computeBudget(model) + private fun updateColorAndText(usedPromptTokens: Long, model: LLModel?) { + val effectiveCapacity = ( + lastReportedSizeTokens + ?: model?.let { computeBudget(it).inputBudget } + ?: return + ).coerceAtLeast(1L) val percentageUsed = - ((usedPromptTokens.coerceAtLeast(0).toDouble() / budget.inputBudget.toDouble()) * 100.0) + ((usedPromptTokens.coerceAtLeast(0).toDouble() / effectiveCapacity.toDouble()) * 100.0) .coerceIn(0.0, 100.0) val percentageLeft = (100.0 - percentageUsed).coerceIn(0.0, 100.0) @@ -105,14 +122,33 @@ class TokenUsageCounterPanel( text = "${percentageLeft.toInt()}% context left" } - private fun updateTooltipText(totalTokens: Long, model: LLModel) { - val budget = computeBudget(model) + private fun updateTooltipText(totalTokens: Long, model: LLModel?) { toolTipText = buildString { append("") append("Usage Details
") - append("Input size: ${numberFormat.format(totalTokens)} tokens
") - append("Max output size: ${numberFormat.format(budget.reservedOutput)} tokens
") - append("Max context size: ${numberFormat.format(budget.contextLength)} tokens
") + append("Current usage: ${numberFormat.format(totalTokens)} tokens
") + lastReportedSizeTokens?.let { + append("Reported context size: ${numberFormat.format(it)} tokens
") + append( + "Context remaining: ${ + numberFormat.format((it - totalTokens).coerceAtLeast(0L)) + } tokens
" + ) + } + if (lastReportedSizeTokens == null) { + model?.let { + val budget = computeBudget(it) + append("Max output size: ${numberFormat.format(budget.reservedOutput)} tokens
") + append("Max context size: ${numberFormat.format(budget.contextLength)} tokens
") + } + } + if (lastReportedCostAmount != null && lastReportedCostCurrency != null) { + append( + "Reported cost: ${ + numberFormat.format(lastReportedCostAmount) + } ${lastReportedCostCurrency}
" + ) + } append("") } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/hover/PsiLinkHoverPreview.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/hover/PsiLinkHoverPreview.kt index e6ef95e1..c660b76c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/hover/PsiLinkHoverPreview.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/hover/PsiLinkHoverPreview.kt @@ -1,8 +1,6 @@ package ee.carlrobert.codegpt.ui.hover -import com.intellij.codeInsight.documentation.DocumentationHintEditorPane import com.intellij.codeInsight.documentation.DocumentationManager -import com.intellij.lang.documentation.DocumentationImageResolver import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.application.ModalityState import com.intellij.openapi.application.ReadAction @@ -21,7 +19,6 @@ import com.intellij.util.ui.UIUtil import ee.carlrobert.codegpt.util.NavigationResolverFactory import ee.carlrobert.codegpt.util.EditorUtil import java.awt.Dimension -import java.awt.Image import java.awt.MouseInfo import java.awt.Point import java.awt.event.MouseAdapter @@ -30,6 +27,7 @@ import java.net.URLDecoder import java.nio.charset.StandardCharsets import java.util.concurrent.ScheduledFuture import java.util.concurrent.TimeUnit +import javax.swing.JEditorPane import javax.swing.JComponent import javax.swing.JTextPane import javax.swing.SwingUtilities @@ -73,6 +71,10 @@ object PsiLinkHoverPreview { private val executor = AppExecutorUtil.getAppExecutorService() override fun hyperlinkUpdate(e: HyperlinkEvent) { + if (isAnchorLink(e.description)) { + handleAnchorEvent(e) + return + } when (e.eventType) { HyperlinkEvent.EventType.ENTERED -> onEntered(e) HyperlinkEvent.EventType.EXITED -> onExited() @@ -80,6 +82,29 @@ object PsiLinkHoverPreview { } } + private fun handleAnchorEvent(e: HyperlinkEvent) { + when (e.eventType) { + HyperlinkEvent.EventType.ENTERED -> { + scheduledClose?.cancel(true) + scheduledClose = null + pending?.cancel(true) + pending = null + lastDesc = null + } + + HyperlinkEvent.EventType.EXITED -> onExited() + HyperlinkEvent.EventType.ACTIVATED -> { + scheduledClose?.cancel(true) + scheduledClose = null + pending?.cancel(true) + pending = null + scrollToReference(e.source, e.description) + } + + else -> Unit + } + } + private fun onExited() { scheduledClose?.cancel(true) scheduledClose = scheduledExecutor.schedule({ @@ -89,28 +114,24 @@ object PsiLinkHoverPreview { }, EXIT_CLOSE_DELAY_MS, TimeUnit.MILLISECONDS) } - fun cancelAndHide() { + fun cancelAndHide(force: Boolean = false) { scheduledClose?.cancel(true) scheduledClose = null val comp = popupComponent - if (comp != null && isMouseOverComponent(comp)) return + if (!force && comp != null && isMouseOverComponent(comp)) return pending?.cancel(true) pending = null + val popupToCancel = popup ?: return + popup = null + popupComponent = null + try { - if (SwingUtilities.isEventDispatchThread()) { - popup?.cancel() - popup = null - popupComponent = null - } else { - ApplicationManager.getApplication().invokeLater({ - popup?.cancel() - popup = null - popupComponent = null - }, ModalityState.any()) - } + ApplicationManager.getApplication().invokeLater({ + popupToCancel.cancel() + }, ModalityState.any()) } catch (t: Throwable) { logger.warn("Failed while cancelling popup", t) } @@ -170,11 +191,7 @@ object PsiLinkHoverPreview { private fun showDocHint(where: RelativePoint, html: String) { scheduledClose?.cancel(true) scheduledClose = null - cancelAndHide() - - val imageResolver = object : DocumentationImageResolver { - override fun resolveImage(p0: String): Image? = null - } + cancelAndHide(force = true) val safeHtml = if (html.contains("$html" } - val editor = DocumentationHintEditorPane(project, emptyMap(), imageResolver).apply { + val editor = JEditorPane().apply { isEditable = false contentType = "text/html" text = safeHtml @@ -196,6 +213,11 @@ object PsiLinkHoverPreview { isOpaque = true margin = JBUI.insets(8) } + editor.addHyperlinkListener { event -> + if (event.eventType == HyperlinkEvent.EventType.ACTIVATED && isAnchorLink(event.description)) { + scrollToReference(editor, event.description) + } + } val scroll = ScrollPaneFactory.createScrollPane(editor).apply { border = JBUI.Borders.empty() @@ -218,8 +240,6 @@ object PsiLinkHoverPreview { .setCancelKeyEnabled(true) .createPopup() - editor.setHint(newPopup) - popup = newPopup popupComponent = scroll @@ -252,6 +272,25 @@ object PsiLinkHoverPreview { } } + private fun isAnchorLink(description: String?): Boolean { + return !description.isNullOrBlank() && description.startsWith("#") + } + + private fun scrollToReference(source: Any?, description: String?) { + val anchor = description?.removePrefix("#")?.takeIf { it.isNotBlank() } ?: return + val editor = source as? JEditorPane ?: return + ApplicationManager.getApplication().invokeLater({ + if (!editor.isDisplayable) { + return@invokeLater + } + runCatching { + editor.scrollToReference(anchor) + }.onFailure { error -> + logger.debug("Failed to scroll to anchor #$anchor", error) + } + }, ModalityState.any()) + } + private fun decode(value: String): String = try { URLDecoder.decode(value, StandardCharsets.UTF_8) } catch (_: Throwable) { @@ -303,7 +342,7 @@ object PsiLinkHoverPreview { } private fun computeHoverSize( - editor: DocumentationHintEditorPane, + editor: JEditorPane, maxW: Int, maxH: Int ): Dimension { diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt index 17f3ee60..9b036aab 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentProviderIntegrationTest.kt @@ -981,6 +981,6 @@ class AgentProviderIntegrationTest : IntegrationTest() { this.text.append(text) } - override fun onQueuedMessagesResolved() = Unit + override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt index 13becf68..66a38481 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt @@ -43,7 +43,7 @@ class AgentServiceIntegrationTest : IntegrationTest() { private lateinit var originalRuntimeFactory: AgentRuntimeFactory private val createdSessions = mutableListOf() private val noopEvents = object : AgentEvents { - override fun onQueuedMessagesResolved() = Unit + override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit } private val runTimeoutMillis = 5_000L @@ -119,7 +119,7 @@ class AgentServiceIntegrationTest : IntegrationTest() { var callbackMessageId: UUID? = null var callbackRef: CheckpointRef? = null val events = object : AgentEvents { - override fun onQueuedMessagesResolved() = Unit + override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) { callbackMessageId = runMessageId diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesSerializationIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesSerializationIntegrationTest.kt index 6045395c..6249694f 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesSerializationIntegrationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/clients/CustomOpenAIResponsesSerializationIntegrationTest.kt @@ -110,4 +110,5 @@ class CustomOpenAIResponsesSerializationIntegrationTest : IntegrationTest() { ) ) } + } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocolJsonTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocolJsonTest.kt new file mode 100644 index 00000000..e6bcd92f --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocolJsonTest.kt @@ -0,0 +1,102 @@ +package ee.carlrobert.codegpt.agent.external.acpcompat + +import com.agentclientprotocol.model.AcpMethod +import com.agentclientprotocol.model.McpServer +import com.agentclientprotocol.model.NewSessionRequest +import com.agentclientprotocol.rpc.JsonRpcMessage +import com.agentclientprotocol.transport.Transport +import ee.carlrobert.codegpt.agent.external.acpcompat.vendor.AcpCompatibilityRegistry +import ee.carlrobert.codegpt.agent.external.acpcompat.vendor.AcpPeerProfile +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import org.junit.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals + +class AcpProtocolJsonTest { + private val parser = Json { ignoreUnknownKeys = true } + private val compatibilityRegistry = AcpCompatibilityRegistry() + + @Test + fun serializesMcpServerTypeInformation() { + val protocol = AcpProtocol( + parentScope = CoroutineScope(SupervisorJob() + Dispatchers.Unconfined), + transport = NoOpTransport() + ) + + val encoded = protocol.json.encodeToString( + NewSessionRequest( + cwd = "/tmp", + mcpServers = listOf( + McpServer.Stdio( + name = "local", + command = "npx", + args = listOf("server"), + env = emptyList() + ) + ) + ) + ) + + assertContains(encoded, "\"type\":\"stdio\"") + assertContains(encoded, "\"mcpServers\"") + } + + @Test + fun normalizesEnvVarAuthMethodFromVarsArray() { + val payload = parser.parseToJsonElement( + """ + { + "protocolVersion": 1, + "agentCapabilities": {}, + "authMethods": [ + { + "type": "env_var", + "id": "openai-api-key", + "name": "Use OPENAI_API_KEY", + "vars": [ + { "name": "OPENAI_API_KEY" } + ] + } + ] + } + """.trimIndent() + ) + + val normalized = compatibilityRegistry.normalizeInboundPayload( + profile = AcpPeerProfile.CODEX, + methodName = AcpMethod.AgentMethods.Initialize.methodName, + payload = payload + ).toString() + + assertContains(normalized, "\"varName\":\"OPENAI_API_KEY\"") + assertEquals(1, Regex("\"varName\"").findAll(normalized).count()) + } +} + +private class NoOpTransport : Transport { + private val currentState = MutableStateFlow(Transport.State.CREATED) + + override val state: StateFlow = currentState + + override fun start() { + currentState.value = Transport.State.STARTED + } + + override fun send(message: JsonRpcMessage) = Unit + + override fun onMessage(handler: (JsonRpcMessage) -> Unit) = Unit + + override fun onError(handler: (Throwable) -> Unit) = Unit + + override fun onClose(handler: () -> Unit) = Unit + + override fun close() { + currentState.value = Transport.State.CLOSED + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpToolCallDecoderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpToolCallDecoderTest.kt new file mode 100644 index 00000000..a1a117b3 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpToolCallDecoderTest.kt @@ -0,0 +1,331 @@ +package ee.carlrobert.codegpt.agent.external.events + +import com.agentclientprotocol.model.* +import ee.carlrobert.codegpt.agent.external.AcpToolCallDecoder +import ee.carlrobert.codegpt.agent.external.AcpToolCallStatus +import ee.carlrobert.codegpt.agent.external.AcpToolEventFlavor +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonObject +import kotlin.test.* + +class AcpToolCallDecoderTest { + + private val json = Json { ignoreUnknownKeys = true } + private val decoder = AcpToolCallDecoder(json) + + @Test + fun decodeToolCallStartedProducesTypedTerminalContent() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("tool-1"), + title = "Run shell command", + kind = ToolKind.EXECUTE, + status = ToolCallStatus.IN_PROGRESS, + content = listOf(ToolCallContent.Terminal("terminal-1")), + locations = listOf(ToolCallLocation("src/Main.kt")), + rawInput = json.parseToJsonElement("""{"command":"npm test"}""") + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update) + + val started = assertIs(event) + assertEquals("tool-1", started.toolCall.id) + assertEquals("Bash", started.toolCall.toolName) + val bashArgs = assertIs(started.toolCall.args) + assertEquals("npm test", bashArgs.value.command) + assertEquals(ToolKind.EXECUTE, started.toolCall.kind) + assertEquals(AcpToolCallStatus.IN_PROGRESS, started.toolCall.status) + val terminal = started.toolCall.content.single() + assertIs(terminal) + assertEquals("terminal-1", terminal.terminalId) + } + + @Test + fun decodeToolCallUpdateProducesTypedEditArgs() { + val update = SessionUpdate.ToolCallUpdate( + toolCallId = ToolCallId("tool-2"), + title = "Edit file", + kind = ToolKind.EDIT, + status = ToolCallStatus.COMPLETED, + content = listOf(ToolCallContent.Diff("src/Main.kt", "old", "new")), + rawInput = json.parseToJsonElement( + """{"file_path":"src/Main.kt","old_string":"old","new_string":"new"}""" + ) + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update) + + val updated = assertIs(event) + assertEquals("Edit", updated.toolCall.toolName) + val editArgs = assertIs(updated.toolCall.args) + assertEquals("src/Main.kt", editArgs.value.filePath) + assertEquals(AcpToolCallStatus.COMPLETED, updated.toolCall.status) + val diff = updated.toolCall.content.single() + assertIs(diff) + assertEquals("src/Main.kt", diff.path) + } + + @Test + fun decodePermissionRequestPreservesTypedToolCallSnapshot() { + val request = RequestPermissionRequest( + sessionId = SessionId("session-1"), + toolCall = SessionUpdate.ToolCallUpdate( + toolCallId = ToolCallId("tool-3"), + title = "Allow write", + kind = ToolKind.EDIT, + status = ToolCallStatus.IN_PROGRESS, + content = listOf(ToolCallContent.Diff("src/Main.kt", "old", "new")), + locations = listOf(ToolCallLocation("src/Main.kt")), + rawInput = json.parseToJsonElement( + """{"file_path":"src/Main.kt","old_string":"old","new_string":"new"}""" + ) + ), + options = listOf( + PermissionOption( + optionId = PermissionOptionId("allow_once"), + name = "Allow once", + kind = PermissionOptionKind.ALLOW_ONCE + ) + ) + ) + + val permission = decoder.decodePermissionRequest(AcpToolEventFlavor.STANDARD, request) + + assertEquals("Allow write", permission.toolCall.title) + assertEquals("Edit", permission.toolCall.toolName) + val editArgs = assertIs(permission.toolCall.args) + assertEquals("src/Main.kt", editArgs.value.filePath) + assertTrue(permission.details.contains("src/Main.kt")) + assertEquals(1, permission.options.size) + } + + @Test + fun decodeUsageUpdatePreservesUsageAndCostFields() { + val update = SessionUpdate.UsageUpdate( + used = 1_024, + size = 128_000, + cost = Cost(amount = 0.02, currency = "USD") + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update) + + val usage = assertIs(event) + assertEquals(1_024, usage.used) + assertEquals(128_000, usage.size) + assertEquals(0.02, usage.cost?.amount) + assertEquals("USD", usage.cost?.currency) + } + + @Test + fun decodeConfigOptionUpdatePreservesSelectAndBooleanOptions() { + val update = SessionUpdate.ConfigOptionUpdate( + configOptions = listOf( + SessionConfigOption.select( + id = "reasoning_effort", + name = "Reasoning Effort", + currentValue = "high", + options = SessionConfigSelectOptions.Flat( + listOf( + SessionConfigSelectOption(SessionConfigValueId("medium"), "Medium"), + SessionConfigSelectOption(SessionConfigValueId("high"), "High") + ) + ) + ), + SessionConfigOption.boolean( + id = "sandbox", + name = "Sandbox", + currentValue = true + ) + ) + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update) + + val configUpdate = assertIs(event) + assertEquals(2, configUpdate.configOptions.size) + assertIs(configUpdate.configOptions[0]) + assertIs(configUpdate.configOptions[1]) + } + + @Test + fun decodeSessionInfoUpdatePreservesTitleAndTimestamp() { + val update = SessionUpdate.SessionInfoUpdate( + title = "Fix ACP runtime", + updatedAt = "2026-03-22T18:30:00Z" + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update) + + val sessionInfo = assertIs(event) + assertEquals("Fix ACP runtime", sessionInfo.title) + assertEquals("2026-03-22T18:30:00Z", sessionInfo.updatedAt) + } + + @Test + fun decodeUnknownSessionUpdatePreservesRawPayload() { + val rawJson = json.parseToJsonElement( + """{"sessionUpdate":"future_update","flag":true,"count":3}""" + ).jsonObject + val update = SessionUpdate.UnknownSessionUpdate( + sessionUpdateType = "future_update", + rawJson = rawJson + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update) + + val unknown = assertIs(event) + assertEquals("future_update", unknown.type) + assertEquals(rawJson, unknown.rawJson) + } + + @Test + fun decodeGeminiSearchWithoutRawInputFallsBackToPreviewArgs() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("grep_search-1"), + title = "Search", + kind = ToolKind.SEARCH, + status = ToolCallStatus.IN_PROGRESS, + locations = listOf(ToolCallLocation("/tmp/README.md")) + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update) + + val started = assertIs(event) + assertEquals("IntelliJSearch", started.toolCall.toolName) + val preview = assertIs(started.toolCall.args).value + assertEquals("Search", preview.title) + assertEquals("/tmp/README.md", preview.path) + assertEquals(null, preview.pattern) + } + + @Test + fun decodeGeminiShellWithoutRawInputKeepsBashToolIdentity() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("run_shell_command-1"), + title = "", + kind = ToolKind.EXECUTE, + status = ToolCallStatus.IN_PROGRESS + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update) + + val started = assertIs(event) + assertEquals("Bash", started.toolCall.toolName) + val preview = assertIs(started.toolCall.args).value + assertEquals("Run shell command", preview.title) + assertEquals(null, preview.command) + } + + @Test + fun decodeGeminiWriteUpdateFromDiffMapsToWriteTool() { + val update = SessionUpdate.ToolCallUpdate( + toolCallId = ToolCallId("write_file-1"), + title = "", + kind = null, + status = ToolCallStatus.COMPLETED, + content = listOf(ToolCallContent.Diff("src/Main.kt", "new content")) + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update) + + val updated = assertIs(event) + assertEquals("Write", updated.toolCall.toolName) + val writeArgs = assertIs(updated.toolCall.args).value + assertEquals("src/Main.kt", writeArgs.filePath) + assertEquals("new content", writeArgs.content) + } + + @Test + fun decodeGeminiGlobMapsToGlobTool() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("glob-1"), + title = "'**/*.java'", + kind = ToolKind.SEARCH, + status = ToolCallStatus.IN_PROGRESS + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update) + + val started = assertIs(event) + assertEquals("Glob", started.toolCall.toolName) + val preview = assertIs(started.toolCall.args).value + assertEquals("Find files", preview.title) + assertEquals("**/*.java", preview.pattern) + } + + @Test + fun decodeGeminiListDirectoryMapsToListDirectoryTool() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("list_directory-1"), + title = ".", + kind = ToolKind.SEARCH, + status = ToolCallStatus.IN_PROGRESS + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update) + + val started = assertIs(event) + assertEquals("ListDirectory", started.toolCall.toolName) + val preview = assertIs(started.toolCall.args).value + assertEquals("List directory", preview.title) + assertEquals(".", preview.path) + } + + @Test + fun decodeGeminiReplaceUpdateFromDiffMapsToEditTool() { + val update = SessionUpdate.ToolCallUpdate( + toolCallId = ToolCallId("replace-1"), + title = "", + kind = null, + status = ToolCallStatus.COMPLETED, + content = listOf(ToolCallContent.Diff("src/Main.kt", "new", "old")) + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update) + + val updated = assertIs(event) + assertEquals("Edit", updated.toolCall.toolName) + val editArgs = assertIs(updated.toolCall.args).value + assertEquals("src/Main.kt", editArgs.filePath) + assertEquals("old", editArgs.oldString) + assertEquals("new", editArgs.newString) + } + + @Test + fun decodeZedSearchWithoutQueryablePayloadFallsBackToPreviewArgs() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("zed-search-1"), + title = "Search", + kind = ToolKind.SEARCH, + status = ToolCallStatus.IN_PROGRESS, + locations = listOf(ToolCallLocation("/tmp/src")), + rawInput = json.parseToJsonElement("""{"parsed_cmd":[{"type":"search","path":"/tmp/src"}]}""") + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.ZED_ADAPTER, update) + + val started = assertIs(event) + assertEquals("IntelliJSearch", started.toolCall.toolName) + val preview = assertIs(started.toolCall.args).value + assertEquals("Search", preview.title) + assertEquals("/tmp/src", preview.path) + assertEquals(null, preview.pattern) + } + + @Test + fun decodeZedExecuteWithoutPayloadKeepsBashToolIdentity() { + val update = SessionUpdate.ToolCall( + toolCallId = ToolCallId("zed-bash-1"), + title = "Run shell command", + kind = ToolKind.EXECUTE, + status = ToolCallStatus.IN_PROGRESS + ) + + val event = decoder.decodeExternalEvent(AcpToolEventFlavor.ZED_ADAPTER, update) + + val started = assertIs(event) + assertEquals("Bash", started.toolCall.toolName) + val preview = assertIs(started.toolCall.args).value + assertEquals("Run shell command", preview.title) + assertEquals(null, preview.command) + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHostTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHostTest.kt new file mode 100644 index 00000000..8c3cf1cb --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHostTest.kt @@ -0,0 +1,58 @@ +package ee.carlrobert.codegpt.agent.external.host + +import com.agentclientprotocol.model.ReadTextFileRequest +import com.agentclientprotocol.model.SessionId +import java.nio.file.Files +import kotlin.test.Test +import kotlin.test.assertEquals + +class AcpFileHostTest { + + @Test + fun `readTextFile prefers editor content over disk content`() { + val cwd = Files.createTempDirectory("acp-host-file") + val file = cwd.resolve("Example.txt") + Files.writeString(file, "disk content") + val host = AcpFileHost( + openDocumentReader = AcpOpenDocumentReader { path -> + if (path == file) "editor content" else null + }, + writer = AcpTextFileWriter { _, _ -> error("unexpected write") } + ) + + val response = host.readTextFile( + AcpHostSessionContext(sessionId = "session-1", cwd = cwd), + ReadTextFileRequest( + sessionId = SessionId("session-1"), + path = file.toString() + ) + ) + + assertEquals("editor content", response.content) + } + + @Test + fun `readTextFile applies line window after resolving editor content`() { + val cwd = Files.createTempDirectory("acp-host-file") + val file = cwd.resolve("Example.txt") + val host = AcpFileHost( + openDocumentReader = AcpOpenDocumentReader { path -> + if (path == file) "line1\nline2\nline3\nline4" else null + }, + writer = AcpTextFileWriter { _, _ -> error("unexpected write") } + ) + + val response = host.readTextFile( + AcpHostSessionContext(sessionId = "session-1", cwd = cwd), + ReadTextFileRequest( + sessionId = SessionId("session-1"), + path = file.toString(), + line = 2u, + limit = 2u + ) + ) + + assertEquals("line2\nline3", response.content) + } + +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicyTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicyTest.kt new file mode 100644 index 00000000..6d1ebdf8 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicyTest.kt @@ -0,0 +1,39 @@ +package ee.carlrobert.codegpt.agent.external.host + +import java.nio.file.Files +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class AcpPathPolicyTest { + + private val policy = AcpPathPolicy() + + @Test + fun `resolveWithinCwd keeps relative paths inside cwd`() { + val cwd = Files.createTempDirectory("acp-host-policy") + + val resolved = policy.resolveWithinCwd("src/Main.kt", cwd) + + assertEquals(cwd.resolve("src/Main.kt").toAbsolutePath().normalize(), resolved) + } + + @Test + fun `resolveWithinCwd allows absolute path inside cwd`() { + val cwd = Files.createTempDirectory("acp-host-policy") + val inside = cwd.resolve("nested/file.txt") + + val resolved = policy.resolveWithinCwd(inside.toString(), cwd) + + assertEquals(inside.toAbsolutePath().normalize(), resolved) + } + + @Test + fun `resolveWithinCwd rejects path traversal`() { + val cwd = Files.createTempDirectory("acp-host-policy") + + assertFailsWith { + policy.resolveWithinCwd("../secret.txt", cwd) + } + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHostTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHostTest.kt new file mode 100644 index 00000000..1a09e003 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHostTest.kt @@ -0,0 +1,101 @@ +package ee.carlrobert.codegpt.agent.external.host + +import com.agentclientprotocol.model.CreateTerminalRequest +import com.agentclientprotocol.model.EnvVariable +import com.agentclientprotocol.model.SessionId +import java.nio.file.Files +import java.nio.file.Path +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class AcpTerminalHostTest { + + @Test + fun `createTerminal resolves cwd and stores terminal process`() { + val cwd = Files.createTempDirectory("acp-host-terminal") + val launcher = FakeTerminalProcessLauncher() + val host = AcpTerminalHost(launcher) + val session = AcpHostSessionContext(sessionId = "session-1", cwd = cwd) + + val response = host.createTerminal( + session, + CreateTerminalRequest( + sessionId = SessionId("session-1"), + command = "git", + args = listOf("status"), + cwd = "worktree", + env = listOf(EnvVariable(name = "FOO", value = "bar")), + outputByteLimit = 128uL + ) + ) + + assertTrue(response.terminalId.isNotBlank()) + assertEquals(cwd.resolve("worktree").toAbsolutePath().normalize(), launcher.lastCwd) + assertEquals(mapOf("FOO" to "bar"), launcher.nextEnv) + assertEquals(128uL, launcher.process.outputByteLimit) + } + + @Test + fun `createTerminal rejects session mismatch`() { + val cwd = Files.createTempDirectory("acp-host-terminal") + val host = AcpTerminalHost(FakeTerminalProcessLauncher()) + + assertFailsWith { + host.createTerminal( + AcpHostSessionContext(sessionId = "session-1", cwd = cwd), + CreateTerminalRequest( + sessionId = SessionId("different-session"), + command = "echo" + ) + ) + } + } +} + +private class FakeTerminalProcessLauncher : AcpTerminalProcessLauncher { + val process = FakeTerminalProcess() + var lastCwd: Path? = null + var nextEnv: Map = emptyMap() + + override fun launch( + command: String, + args: List, + cwd: Path, + env: Map, + outputByteLimit: ULong? + ): AcpTerminalProcess { + lastCwd = cwd + nextEnv = env + process.command = command + process.args = args + process.cwd = cwd + process.outputByteLimit = outputByteLimit + return process + } +} + +private class FakeTerminalProcess : AcpTerminalProcess { + override val terminalId: String = "fake-terminal" + var command: String = "" + var args: List = emptyList() + var cwd: Path? = null + var outputByteLimit: ULong? = null + var snapshot: AcpTerminalOutputSnapshot = AcpTerminalOutputSnapshot("", truncated = false) + var waitResult: AcpTerminalExitStatus = AcpTerminalExitStatus(exitCode = 0u) + var killed: Boolean = false + var released: Boolean = false + + override fun output(): AcpTerminalOutputSnapshot = snapshot + + override suspend fun waitForExit(): AcpTerminalExitStatus = waitResult + + override fun release() { + released = true + } + + override fun kill() { + killed = true + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AcpConfigOptionsTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AcpConfigOptionsTest.kt new file mode 100644 index 00000000..01af0a5b --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AcpConfigOptionsTest.kt @@ -0,0 +1,46 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import kotlin.test.Test +import kotlin.test.assertEquals + +class AcpConfigOptionsTest { + + @Test + fun summaryPartsIncludesBooleanAndCustomOptions() { + val options = listOf( + AcpConfigOption( + id = "model", + name = "Model", + category = "model", + type = "select", + currentValue = "gpt-5", + options = listOf(AcpConfigOptionChoice("gpt-5", "GPT-5")) + ), + AcpConfigOption( + id = "sandbox", + name = "Sandbox", + type = "boolean", + currentValue = "true", + options = listOf( + AcpConfigOptionChoice("true", "Enabled"), + AcpConfigOptionChoice("false", "Disabled") + ) + ), + AcpConfigOption( + id = "approval_mode", + name = "Approval", + type = "select", + currentValue = "manual", + options = listOf( + AcpConfigOptionChoice("manual", "Manual"), + AcpConfigOptionChoice("auto", "Automatic") + ) + ) + ) + + assertEquals( + listOf("GPT-5", "Sandbox: Enabled", "Approval: Manual"), + AcpConfigOptions.summaryParts(options) + ) + } +}