mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 07:54:46 +00:00
refactor: replace custom JSON-RPC layer with official ACP SDK runtime
This commit is contained in:
parent
f6c22b0313
commit
5dca4b5d17
70 changed files with 5009 additions and 1765 deletions
|
|
@ -85,6 +85,7 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
private final ResponseBodyProgressPanel progressPanel = new ResponseBodyProgressPanel();
|
||||
private final JPanel contentPanel =
|
||||
new JPanel(new VerticalFlowLayout(VerticalFlowLayout.TOP, 0, 4, true, false));
|
||||
private final StringBuilder accumulatedThinking = new StringBuilder();
|
||||
|
||||
private ResponseEditorPanel currentlyProcessedEditorPanel;
|
||||
private MermaidResponsePanel currentlyProcessedMermaidPanel;
|
||||
|
|
@ -170,6 +171,20 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
}
|
||||
}
|
||||
|
||||
public void appendThinking(String partialThinking) {
|
||||
if (partialThinking.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (accumulatedThinking.length() > 0) {
|
||||
accumulatedThinking.append('\n');
|
||||
}
|
||||
accumulatedThinking.append(partialThinking);
|
||||
processThinkingOutput(accumulatedThinking.toString());
|
||||
revalidate();
|
||||
repaint();
|
||||
}
|
||||
|
||||
public void displayMissingCredential() {
|
||||
String message = "API key not provided. Open <a href=\"#\">Settings</a> to set one.";
|
||||
displayErrorMessage(message, e -> {
|
||||
|
|
@ -254,6 +269,7 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
public void clear() {
|
||||
contentPanel.removeAll();
|
||||
streamOutputParser.clear();
|
||||
accumulatedThinking.setLength(0);
|
||||
|
||||
// Reset for the next incoming message
|
||||
prepareProcessingText(true);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import com.agentclientprotocol.model.PlanEntry
|
||||
import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
||||
import ee.carlrobert.codegpt.agent.tools.AskUserQuestionTool
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
|
|
@ -9,6 +10,8 @@ import java.util.UUID
|
|||
|
||||
interface AgentEvents {
|
||||
fun onTextReceived(text: String) {}
|
||||
fun onThinkingReceived(text: String) {}
|
||||
fun onPlanUpdated(entries: List<PlanEntry>) {}
|
||||
fun onAgentCompleted(agentId: String) {}
|
||||
fun onToolStarting(id: String, toolName: String, args: Any?) {}
|
||||
fun onToolCompleted(id: String?, toolName: String, result: Any?) {}
|
||||
|
|
@ -23,8 +26,13 @@ interface AgentEvents {
|
|||
|
||||
fun onRetry(attempt: Int, maxAttempts: Int) {}
|
||||
fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {}
|
||||
fun onQueuedMessagesResolved()
|
||||
fun onQueuedMessagesResolved(message: MessageWithContext? = null)
|
||||
fun onTokenUsageAvailable(tokenUsage: Long) {}
|
||||
fun onUsageAvailable(event: AgentUsageEvent) {
|
||||
onTokenUsageAvailable(event.usedTokens)
|
||||
}
|
||||
fun onRuntimeOptionsUpdated() {}
|
||||
fun onSessionInfoUpdated(title: String?, updatedAt: String? = null) {}
|
||||
fun onCreditsAvailable(event: AgentCreditsEvent) {}
|
||||
fun onAgentException(provider: ServiceType, throwable: Throwable) {}
|
||||
fun onHistoryCompressionStateChanged(isCompressing: Boolean) {}
|
||||
|
|
|
|||
|
|
@ -85,7 +85,11 @@ class AgentService(private val project: Project) {
|
|||
val queuedMessageProcessed = _queuedMessageProcessed.asSharedFlow()
|
||||
|
||||
fun addToQueue(message: MessageWithContext, sessionId: String) {
|
||||
pendingMessages.getOrPut(sessionId) { ArrayDeque() }.add(message)
|
||||
val queue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
|
||||
queue.add(message)
|
||||
logger.debug(
|
||||
"Queued agent message for session=$sessionId queueSize=${queue.size} uiVisible=${message.uiVisible} messageId=${message.id}"
|
||||
)
|
||||
}
|
||||
|
||||
suspend fun getCheckpoint(sessionId: String): AgentCheckpointData? {
|
||||
|
|
@ -114,7 +118,7 @@ class AgentService(private val project: Project) {
|
|||
}
|
||||
|
||||
fun submitMessage(message: MessageWithContext, events: AgentEvents, sessionId: String) {
|
||||
updateMcpContext(sessionId, message)
|
||||
val selectedServerIds = updateMcpContext(sessionId, message)
|
||||
|
||||
if (isSessionRunning(sessionId)) {
|
||||
addToQueue(message, sessionId)
|
||||
|
|
@ -125,6 +129,11 @@ class AgentService(private val project: Project) {
|
|||
val contentManager = project.service<AgentToolWindowContentManager>()
|
||||
val session = contentManager.getSession(sessionId) ?: return
|
||||
if (!session.externalAgentId.isNullOrBlank()) {
|
||||
if (session.shouldRecreateExternalAgentSession(selectedServerIds)) {
|
||||
project.service<ExternalAcpAgentService>().closeSession(sessionId)
|
||||
session.externalAgentSessionId = null
|
||||
session.externalAgentMcpServerIds = emptySet()
|
||||
}
|
||||
submitExternalMessage(session, message, events, provider)
|
||||
return
|
||||
}
|
||||
|
|
@ -201,7 +210,11 @@ class AgentService(private val project: Project) {
|
|||
}
|
||||
project.service<ExternalAcpAgentService>().closeSession(sessionId)
|
||||
project.service<AgentToolWindowContentManager>()
|
||||
.getSession(sessionId)?.externalAgentSessionId = null
|
||||
.getSession(sessionId)
|
||||
?.also {
|
||||
it.externalAgentSessionId = null
|
||||
it.externalAgentMcpServerIds = emptySet()
|
||||
}
|
||||
project.service<AgentMcpContextService>().clear(sessionId)
|
||||
}
|
||||
|
||||
|
|
@ -227,22 +240,31 @@ class AgentService(private val project: Project) {
|
|||
events: AgentEvents,
|
||||
provider: ServiceType
|
||||
) {
|
||||
val externalAgentService = project.service<ExternalAcpAgentService>()
|
||||
logger.debug(
|
||||
"Starting external ACP run for session=${session.sessionId} externalAgent=${session.externalAgentId} messageId=${message.id}"
|
||||
)
|
||||
sessionJobs[session.sessionId] = CoroutineScope(Dispatchers.IO).launch {
|
||||
try {
|
||||
externalAgentService.runPromptLoop(
|
||||
session = session,
|
||||
firstMessage = message,
|
||||
events = events,
|
||||
pollNextQueued = {
|
||||
val queue = pendingMessages[session.sessionId] ?: return@runPromptLoop null
|
||||
if (queue.isEmpty()) {
|
||||
null
|
||||
} else {
|
||||
queue.removeFirst()
|
||||
project.service<ExternalAcpAgentService>()
|
||||
.runPromptLoop(
|
||||
session = session,
|
||||
firstMessage = message,
|
||||
events = events,
|
||||
pollNextQueued = {
|
||||
val queue =
|
||||
pendingMessages[session.sessionId] ?: return@runPromptLoop null
|
||||
if (queue.isEmpty()) {
|
||||
logger.debug("No queued ACP follow-up message for session=${session.sessionId}")
|
||||
null
|
||||
} else {
|
||||
queue.removeFirst().also { next ->
|
||||
logger.debug(
|
||||
"Dequeued ACP follow-up message for session=${session.sessionId} queueRemaining=${queue.size} messageId=${next.id} uiVisible=${next.uiVisible}"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
} catch (_: CancellationException) {
|
||||
return@launch
|
||||
} catch (ex: Throwable) {
|
||||
|
|
@ -255,7 +277,7 @@ class AgentService(private val project: Project) {
|
|||
}
|
||||
}
|
||||
|
||||
private fun updateMcpContext(sessionId: String, message: MessageWithContext) {
|
||||
private fun updateMcpContext(sessionId: String, message: MessageWithContext): Set<String> {
|
||||
val selectedServerIds = message.tags
|
||||
.filterIsInstance<McpTagDetails>()
|
||||
.filter { it.selected }
|
||||
|
|
@ -273,6 +295,7 @@ class AgentService(private val project: Project) {
|
|||
|
||||
project.service<AgentMcpContextService>()
|
||||
.update(sessionId, conversationId, selectedServerIds)
|
||||
return selectedServerIds
|
||||
}
|
||||
|
||||
suspend fun createSeedCheckpointFromHistory(history: List<PromptMessage>): CheckpointRef? =
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
data class AgentUsageEvent(
|
||||
val usedTokens: Long,
|
||||
val sizeTokens: Long? = null,
|
||||
val costAmount: Double? = null,
|
||||
val costCurrency: String? = null
|
||||
)
|
||||
|
|
@ -90,7 +90,7 @@ object ProxyAIAgent {
|
|||
val modelSelection =
|
||||
service<ModelSettings>().getModelSelectionForFeature(FeatureType.AGENT)
|
||||
val skills = project.service<SkillDiscoveryService>().listSkills()
|
||||
val stream = shouldStreamAgentToolLoop(project, provider)
|
||||
val stream = shouldStreamAgentToolLoop(provider)
|
||||
val projectInstructions = loadProjectInstructions(project.basePath)
|
||||
val executor = AgentFactory.createExecutor(provider, events)
|
||||
val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
|
||||
|
|
@ -99,7 +99,6 @@ object ProxyAIAgent {
|
|||
project = project,
|
||||
events = events,
|
||||
sessionId = sessionId,
|
||||
provider = provider,
|
||||
parentModelSelection = modelSelection,
|
||||
hookManager = hookManager
|
||||
)
|
||||
|
|
@ -198,7 +197,7 @@ object ProxyAIAgent {
|
|||
output.forEach { msg ->
|
||||
(msg as? Message.Reasoning)?.let {
|
||||
if (it.content.isNotBlank()) {
|
||||
events.onTextReceived("<think>${it.content}</think>")
|
||||
events.onThinkingReceived(it.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -213,7 +212,7 @@ object ProxyAIAgent {
|
|||
}
|
||||
(msg as? Message.Reasoning)?.let {
|
||||
if (it.content.isNotBlank()) {
|
||||
events.onTextReceived("<think>${it.content}</think>")
|
||||
events.onThinkingReceived(it.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -297,10 +296,7 @@ object ProxyAIAgent {
|
|||
}
|
||||
}
|
||||
|
||||
private fun shouldStreamAgentToolLoop(
|
||||
project: Project,
|
||||
provider: ServiceType,
|
||||
): Boolean {
|
||||
private fun shouldStreamAgentToolLoop(provider: ServiceType): Boolean {
|
||||
return when (provider) {
|
||||
ServiceType.CUSTOM_OPENAI -> {
|
||||
val selectedModel =
|
||||
|
|
@ -320,7 +316,6 @@ object ProxyAIAgent {
|
|||
project: Project,
|
||||
events: AgentEvents,
|
||||
sessionId: String,
|
||||
provider: ServiceType,
|
||||
parentModelSelection: ModelSelection,
|
||||
hookManager: HookManager
|
||||
): ToolRegistry {
|
||||
|
|
@ -334,84 +329,23 @@ object ProxyAIAgent {
|
|||
tool(ConfirmingEditTool(EditTool(project, sessionId, hookManager), standardApproval))
|
||||
tool(ConfirmingWriteTool(WriteTool(project, sessionId, hookManager), writeApproval))
|
||||
tool(TodoWriteTool(project, sessionId, hookManager))
|
||||
tool(
|
||||
AskUserQuestionTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
events = events
|
||||
)
|
||||
)
|
||||
createMcpTools(
|
||||
sessionId = sessionId,
|
||||
contextService = contextService,
|
||||
approve = genericApproval
|
||||
).forEach { mcpTool -> tool(mcpTool) }
|
||||
tool(AskUserQuestionTool(workingDirectory, sessionId, hookManager, events))
|
||||
createMcpTools(sessionId, contextService, genericApproval).forEach { mcpTool ->
|
||||
tool(mcpTool)
|
||||
}
|
||||
tool(ExitTool)
|
||||
tool(
|
||||
IntelliJSearchTool(
|
||||
project = project,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
DiagnosticsTool(
|
||||
project = project,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
WebSearchTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
WebFetchTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
BashOutputTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
KillShellTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
ResolveLibraryIdTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
GetLibraryDocsTool(
|
||||
workingDirectory = workingDirectory,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(IntelliJSearchTool(project, sessionId, hookManager))
|
||||
tool(DiagnosticsTool(project, sessionId, hookManager))
|
||||
tool(WebSearchTool(workingDirectory, sessionId, hookManager))
|
||||
tool(WebFetchTool(workingDirectory, sessionId, hookManager))
|
||||
tool(BashOutputTool(workingDirectory, sessionId, hookManager))
|
||||
tool(KillShellTool(workingDirectory, sessionId, hookManager))
|
||||
tool(ResolveLibraryIdTool(workingDirectory, sessionId, hookManager))
|
||||
tool(GetLibraryDocsTool(workingDirectory, sessionId, hookManager))
|
||||
tool(
|
||||
ConfirmingLoadSkillTool(
|
||||
LoadSkillTool(
|
||||
project = project,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager
|
||||
),
|
||||
project = project
|
||||
LoadSkillTool(project, sessionId, hookManager),
|
||||
project
|
||||
) { name, details -> genericApproval(name, details) }
|
||||
)
|
||||
tool(
|
||||
|
|
@ -438,15 +372,7 @@ object ProxyAIAgent {
|
|||
hookManager = hookManager
|
||||
)
|
||||
)
|
||||
tool(
|
||||
TaskTool(
|
||||
project,
|
||||
sessionId,
|
||||
parentModelSelection,
|
||||
events,
|
||||
hookManager
|
||||
)
|
||||
)
|
||||
tool(TaskTool(project, sessionId, parentModelSelection, events, hookManager))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings
|
|||
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
|
||||
import ai.koog.prompt.executor.clients.openai.OpenAIResponsesParams
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.*
|
||||
import ai.koog.prompt.executor.clients.openai.models.OpenAIChatCompletionResponse
|
||||
import ai.koog.prompt.llm.LLMCapability
|
||||
import ai.koog.prompt.llm.LLMProvider
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
|
|
@ -26,10 +27,10 @@ import ee.carlrobert.codegpt.settings.service.custom.CustomServicePlaceholders
|
|||
import ee.carlrobert.codegpt.util.JsonMapper
|
||||
import io.ktor.client.*
|
||||
import io.ktor.client.plugins.*
|
||||
import io.ktor.client.plugins.api.createClientPlugin
|
||||
import io.ktor.client.plugins.api.*
|
||||
import io.ktor.client.request.*
|
||||
import io.ktor.http.ContentType
|
||||
import io.ktor.http.content.TextContent
|
||||
import io.ktor.http.*
|
||||
import io.ktor.http.content.*
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.KSerializer
|
||||
|
|
@ -200,6 +201,10 @@ class CustomOpenAILLMClient(
|
|||
}
|
||||
}
|
||||
|
||||
override fun decodeResponse(data: String): OpenAIChatCompletionResponse {
|
||||
return json.decodeFromString(data)
|
||||
}
|
||||
|
||||
override suspend fun getCodeCompletion(infillRequest: InfillRequest): String {
|
||||
val state = requireCodeCompletionState()
|
||||
val url = requireNotNull(state.url)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ data class ExternalAcpAgentPreset(
|
|||
val vendor: String,
|
||||
val command: String,
|
||||
val args: List<String>,
|
||||
val toolEventFlavor: AcpToolEventFlavor = AcpToolEventFlavor.STANDARD,
|
||||
val env: Map<String, String> = emptyMap(),
|
||||
val enabledByDefault: Boolean = false,
|
||||
val description: String? = null,
|
||||
|
|
@ -28,6 +29,7 @@ object ExternalAcpAgents {
|
|||
vendor = "OpenAI",
|
||||
command = "npx",
|
||||
args = listOf("-y", "@zed-industries/codex-acp"),
|
||||
toolEventFlavor = AcpToolEventFlavor.ZED_ADAPTER,
|
||||
enabledByDefault = true,
|
||||
description = "OpenAI Codex via the Zed ACP adapter."
|
||||
),
|
||||
|
|
@ -54,6 +56,7 @@ object ExternalAcpAgents {
|
|||
vendor = "Anthropic",
|
||||
command = "npx",
|
||||
args = listOf("-y", "@zed-industries/claude-code-acp"),
|
||||
toolEventFlavor = AcpToolEventFlavor.ZED_ADAPTER,
|
||||
enabledByDefault = true,
|
||||
description = "Anthropic Claude Code via the Zed ACP adapter."
|
||||
),
|
||||
|
|
@ -63,6 +66,7 @@ object ExternalAcpAgents {
|
|||
vendor = "Google",
|
||||
command = "gemini",
|
||||
args = listOf("--experimental-acp"),
|
||||
toolEventFlavor = AcpToolEventFlavor.GEMINI_CLI,
|
||||
description = "Google Gemini CLI in experimental ACP mode."
|
||||
),
|
||||
ExternalAcpAgentPreset(
|
||||
|
|
@ -119,6 +123,7 @@ object ExternalAcpAgents {
|
|||
vendor = "Anthropic",
|
||||
command = "npx",
|
||||
args = listOf("-y", "@zed-industries/claude-agent-acp"),
|
||||
toolEventFlavor = AcpToolEventFlavor.ZED_ADAPTER,
|
||||
description = "Anthropic Claude Agent via the Zed ACP adapter."
|
||||
),
|
||||
ExternalAcpAgentPreset(
|
||||
|
|
@ -276,7 +281,7 @@ object ExternalAcpAgents {
|
|||
val message = throwable.message.orEmpty()
|
||||
return when {
|
||||
message.contains("Cannot run program", ignoreCase = true) &&
|
||||
message.contains("No such file or directory", ignoreCase = true) ->
|
||||
message.contains("No such file or directory", ignoreCase = true) ->
|
||||
"Command not found: $command"
|
||||
|
||||
message.isNotBlank() -> message
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
340
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpHostBridge.kt
vendored
Normal file
340
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpHostBridge.kt
vendored
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.agentclientprotocol.model.*
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import ee.carlrobert.codegpt.agent.ToolSpecs
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.AcpProtocol
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.acpFail
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.setRequestHandler
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs
|
||||
import ee.carlrobert.codegpt.agent.external.host.AcpHostCapabilities
|
||||
import ee.carlrobert.codegpt.agent.external.host.AcpHostSessionContext
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentSession
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.*
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import java.nio.file.Path
|
||||
import kotlin.io.path.name
|
||||
|
||||
internal class AcpHostBridge(
|
||||
private val proxySessionId: String,
|
||||
private val displayName: String,
|
||||
private val toolEventFlavor: AcpToolEventFlavor,
|
||||
private val fullAccessModeId: String,
|
||||
private val sessionRoot: Path,
|
||||
private val toolCallDecoder: AcpToolCallDecoder,
|
||||
private val hostCapabilities: AcpHostCapabilities,
|
||||
private val currentSession: () -> AgentSession?,
|
||||
private val eventsProvider: () -> AgentEvents,
|
||||
private val trace: (String) -> Unit
|
||||
) {
|
||||
private val logger = KotlinLogging.logger {}
|
||||
|
||||
fun register(protocol: AcpProtocol) {
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.SessionRequestPermission) { request ->
|
||||
handleRequestPermission(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.FsReadTextFile) { request ->
|
||||
handleReadTextFile(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.FsWriteTextFile) { request ->
|
||||
handleWriteTextFile(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalCreate) { request ->
|
||||
handleCreateTerminal(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalOutput) { request ->
|
||||
handleTerminalOutput(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalRelease) { request ->
|
||||
handleTerminalRelease(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalWaitForExit) { request ->
|
||||
handleTerminalWaitForExit(request)
|
||||
}
|
||||
protocol.setRequestHandler(AcpMethod.ClientMethods.TerminalKill) { request ->
|
||||
handleTerminalKill(request)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleRequestPermission(
|
||||
request: RequestPermissionRequest
|
||||
): RequestPermissionResponse {
|
||||
val permissionRequest = toolCallDecoder.decodePermissionRequest(toolEventFlavor, request)
|
||||
val toolCall = permissionRequest.toolCall
|
||||
val mode = currentSession().currentAcpMode()
|
||||
trace(
|
||||
"permission/request session=$proxySessionId mode=$mode tool=${toolCall.toolName} title=${toolCall.title.logPreview()} options=${permissionRequest.options.joinToString { it.kind.name }} rawInput=${request.toolCall.rawInput.logSummary()} details=${
|
||||
permissionRequest.details.logPreview(
|
||||
200
|
||||
)
|
||||
}"
|
||||
)
|
||||
logger.debug {
|
||||
"Received $displayName ACP permission request for session=$proxySessionId mode=$mode tool=${toolCall.toolName} title=${toolCall.title.logPreview()} details=${permissionRequest.details.logPreview()} options=${permissionRequest.options.joinToString { it.kind.name }}"
|
||||
}
|
||||
if (mode == fullAccessModeId) {
|
||||
trace("permission/auto-approve session=$proxySessionId mode=$mode tool=${toolCall.toolName}")
|
||||
logger.debug {
|
||||
"Auto-approving $displayName ACP permission in mode=$mode for tool=${toolCall.toolName} title=${toolCall.title}"
|
||||
}
|
||||
return permissionResponse(selectApprovedPermissionOptionId(permissionRequest.options))
|
||||
}
|
||||
|
||||
val approvalRequest = buildApprovalRequest(
|
||||
rawTitle = toolCall.title,
|
||||
details = permissionRequest.details,
|
||||
toolName = toolCall.toolName,
|
||||
parsedArgs = toolCall.args
|
||||
)
|
||||
logger.debug {
|
||||
"Queueing UI approval for session=$proxySessionId type=${approvalRequest.type} title=${approvalRequest.title.logPreview()} payload=${approvalRequest.payload.logSummary()}"
|
||||
}
|
||||
trace(
|
||||
"permission/await-ui session=$proxySessionId type=${approvalRequest.type} title=${approvalRequest.title.logPreview()} payload=${approvalRequest.payload.logSummary()}"
|
||||
)
|
||||
return try {
|
||||
val approved = eventsProvider().approveToolCall(approvalRequest)
|
||||
val selectedOptionId = if (approved) {
|
||||
selectApprovedPermissionOptionId(permissionRequest.options)
|
||||
} else {
|
||||
selectRejectedPermissionOptionId(permissionRequest.options)
|
||||
}
|
||||
trace(
|
||||
"permission/resolved session=$proxySessionId tool=${toolCall.toolName} approved=$approved selectedOption=${selectedOptionId.value}"
|
||||
)
|
||||
logger.debug {
|
||||
"Resolved $displayName ACP permission in mode=$mode for tool=${toolCall.toolName} approved=$approved title=${toolCall.title}"
|
||||
}
|
||||
permissionResponse(selectedOptionId)
|
||||
} catch (_: CancellationException) {
|
||||
trace("permission/cancelled session=$proxySessionId tool=${toolCall.toolName}")
|
||||
permissionCancelledResponse()
|
||||
} catch (error: Exception) {
|
||||
logger.warn(error) {
|
||||
"Permission approval failed for $displayName tool=${toolCall.toolName}; defaulting to reject"
|
||||
}
|
||||
permissionResponse(selectRejectedPermissionOptionId(permissionRequest.options))
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleReadTextFile(params: ReadTextFileRequest): ReadTextFileResponse {
|
||||
trace("fs/read_text_file session=$proxySessionId path=${params.path} line=${params.line} limit=${params.limit}")
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid read_text_file request",
|
||||
failureTrace = "fs/read_text_file/failed session=$proxySessionId path=${params.path}"
|
||||
) {
|
||||
hostCapabilities.readTextFile(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleWriteTextFile(params: WriteTextFileRequest): WriteTextFileResponse {
|
||||
trace("fs/write_text_file session=$proxySessionId path=${params.path}")
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid write_text_file request",
|
||||
failureTrace = "fs/write_text_file/failed session=$proxySessionId path=${params.path}"
|
||||
) {
|
||||
hostCapabilities.writeTextFile(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleCreateTerminal(params: CreateTerminalRequest): CreateTerminalResponse {
|
||||
trace(
|
||||
"terminal/create session=$proxySessionId requestSession=${params.sessionId.value} command=${params.command.logPreview()} cwd=${params.cwd?.logPreview()}"
|
||||
)
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid terminal/create request",
|
||||
failureTrace = "terminal/create/failed session=$proxySessionId command=${params.command.logPreview()}"
|
||||
) {
|
||||
hostCapabilities.createTerminal(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleTerminalOutput(params: TerminalOutputRequest): TerminalOutputResponse {
|
||||
trace(
|
||||
"terminal/output session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}"
|
||||
)
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid terminal/output request",
|
||||
failureTrace = "terminal/output/failed session=$proxySessionId terminalId=${params.terminalId}"
|
||||
) {
|
||||
hostCapabilities.output(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleTerminalRelease(params: ReleaseTerminalRequest): ReleaseTerminalResponse {
|
||||
trace(
|
||||
"terminal/release session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}"
|
||||
)
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid terminal/release request",
|
||||
failureTrace = "terminal/release/failed session=$proxySessionId terminalId=${params.terminalId}"
|
||||
) {
|
||||
hostCapabilities.release(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleTerminalWaitForExit(params: WaitForTerminalExitRequest): WaitForTerminalExitResponse {
|
||||
trace(
|
||||
"terminal/wait_for_exit session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}"
|
||||
)
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid terminal/wait_for_exit request",
|
||||
failureTrace = "terminal/wait_for_exit/failed session=$proxySessionId terminalId=${params.terminalId}"
|
||||
) {
|
||||
hostCapabilities.waitForExit(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun handleTerminalKill(params: KillTerminalCommandRequest): KillTerminalCommandResponse {
|
||||
trace(
|
||||
"terminal/kill session=$proxySessionId requestSession=${params.sessionId.value} terminalId=${params.terminalId}"
|
||||
)
|
||||
return runHostCall(
|
||||
invalidRequestMessage = "Invalid terminal/kill request",
|
||||
failureTrace = "terminal/kill/failed session=$proxySessionId terminalId=${params.terminalId}"
|
||||
) {
|
||||
hostCapabilities.kill(hostSessionContext(params.sessionId), params)
|
||||
}
|
||||
}
|
||||
|
||||
private fun selectApprovedPermissionOptionId(options: List<PermissionOption>): PermissionOptionId {
|
||||
return selectPermissionOptionId(
|
||||
options = options,
|
||||
preferredKinds = listOf(
|
||||
PermissionOptionKind.ALLOW_ONCE,
|
||||
PermissionOptionKind.ALLOW_ALWAYS
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun selectRejectedPermissionOptionId(options: List<PermissionOption>): PermissionOptionId {
|
||||
return selectPermissionOptionId(
|
||||
options = options,
|
||||
preferredKinds = listOf(
|
||||
PermissionOptionKind.REJECT_ONCE,
|
||||
PermissionOptionKind.REJECT_ALWAYS
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun selectPermissionOptionId(
|
||||
options: List<PermissionOption>,
|
||||
preferredKinds: List<PermissionOptionKind>
|
||||
): PermissionOptionId {
|
||||
preferredKinds.forEach { preferred ->
|
||||
options.firstOrNull { it.kind == preferred }?.let { return it.optionId }
|
||||
}
|
||||
return options.firstOrNull()?.optionId ?: PermissionOptionId("abort")
|
||||
}
|
||||
|
||||
private fun buildApprovalRequest(
|
||||
rawTitle: String,
|
||||
details: String,
|
||||
toolName: String,
|
||||
parsedArgs: AcpToolCallArgs?
|
||||
): ToolApprovalRequest {
|
||||
val approvalType = when (parsedArgs) {
|
||||
is AcpToolCallArgs.Write -> ToolApprovalType.WRITE
|
||||
is AcpToolCallArgs.Edit -> ToolApprovalType.EDIT
|
||||
is AcpToolCallArgs.Bash -> ToolApprovalType.BASH
|
||||
else -> ToolSpecs.approvalTypeFor(toolName)
|
||||
}
|
||||
val payload = approvalPayload(parsedArgs)
|
||||
val title = approvalTitle(rawTitle, approvalType, parsedArgs)
|
||||
val resolvedDetails = approvalDetails(details, parsedArgs)
|
||||
return ToolApprovalRequest(
|
||||
type = approvalType,
|
||||
title = title,
|
||||
details = resolvedDetails,
|
||||
payload = payload
|
||||
)
|
||||
}
|
||||
|
||||
private fun approvalPayload(parsedArgs: AcpToolCallArgs?): ToolApprovalPayload? {
|
||||
return when (parsedArgs) {
|
||||
is AcpToolCallArgs.Write -> WritePayload(
|
||||
parsedArgs.value.filePath,
|
||||
parsedArgs.value.content
|
||||
)
|
||||
|
||||
is AcpToolCallArgs.Edit -> EditPayload(
|
||||
filePath = parsedArgs.value.filePath,
|
||||
oldString = parsedArgs.value.oldString,
|
||||
newString = parsedArgs.value.newString,
|
||||
replaceAll = parsedArgs.value.replaceAll
|
||||
)
|
||||
|
||||
is AcpToolCallArgs.Bash -> BashPayload(
|
||||
parsedArgs.value.command,
|
||||
parsedArgs.value.description
|
||||
)
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun approvalTitle(
|
||||
rawTitle: String,
|
||||
approvalType: ToolApprovalType,
|
||||
parsedArgs: AcpToolCallArgs?
|
||||
): String {
|
||||
return when (parsedArgs) {
|
||||
is AcpToolCallArgs.Write -> "Write ${Path.of(parsedArgs.value.filePath).name}?"
|
||||
is AcpToolCallArgs.Edit -> "Edit ${Path.of(parsedArgs.value.filePath).name}?"
|
||||
else -> rawTitle.ifBlank {
|
||||
when (approvalType) {
|
||||
ToolApprovalType.BASH -> "Run shell command?"
|
||||
ToolApprovalType.WRITE -> "Write file?"
|
||||
ToolApprovalType.EDIT -> "Edit file?"
|
||||
ToolApprovalType.GENERIC -> "Allow action?"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun approvalDetails(details: String, parsedArgs: AcpToolCallArgs?): String {
|
||||
return when (parsedArgs) {
|
||||
is AcpToolCallArgs.Write -> parsedArgs.value.filePath
|
||||
is AcpToolCallArgs.Edit -> parsedArgs.value.filePath
|
||||
else -> details
|
||||
}
|
||||
}
|
||||
|
||||
private fun permissionResponse(optionId: PermissionOptionId): RequestPermissionResponse {
|
||||
return RequestPermissionResponse(
|
||||
outcome = RequestPermissionOutcome.Selected(optionId)
|
||||
)
|
||||
}
|
||||
|
||||
private fun permissionCancelledResponse(): RequestPermissionResponse {
|
||||
return RequestPermissionResponse(
|
||||
outcome = RequestPermissionOutcome.Cancelled
|
||||
)
|
||||
}
|
||||
|
||||
private fun hostSessionContext(sessionId: SessionId): AcpHostSessionContext {
|
||||
return AcpHostSessionContext(sessionId = sessionId.value, cwd = sessionRoot)
|
||||
}
|
||||
|
||||
private suspend inline fun <T> runHostCall(
|
||||
invalidRequestMessage: String,
|
||||
failureTrace: String,
|
||||
crossinline block: suspend () -> T
|
||||
): T {
|
||||
return try {
|
||||
block()
|
||||
} catch (ex: IllegalArgumentException) {
|
||||
acpFail(ex.message ?: invalidRequestMessage)
|
||||
} catch (ex: Exception) {
|
||||
trace("$failureTrace error=${(ex.message ?: ex::class.simpleName.orEmpty()).logPreview()}")
|
||||
throw ex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun AgentSession?.currentAcpMode(): String? {
|
||||
return this?.externalAgentConfigOptions
|
||||
?.firstOrNull { it.id == AcpConfigCategories.MODE || it.category == AcpConfigCategories.MODE }
|
||||
?.currentValue
|
||||
}
|
||||
|
|
@ -1,81 +1,22 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import kotlinx.serialization.json.*
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import kotlinx.serialization.json.decodeFromJsonElement
|
||||
|
||||
internal fun JsonObject.string(vararg keys: String): String? {
|
||||
return keys.firstNotNullOfOrNull { key ->
|
||||
when (val element = this[key]) {
|
||||
is JsonPrimitive -> element.contentOrNull?.takeIf { it.isNotBlank() }
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun JsonObject.commandString(): String? {
|
||||
return listOf("command", "cmd").firstNotNullOfOrNull { key ->
|
||||
when (val element = this[key]) {
|
||||
is JsonPrimitive -> element.contentOrNull?.takeIf { it.isNotBlank() }
|
||||
is JsonArray -> element.mapNotNull { item ->
|
||||
(item as? JsonPrimitive)?.contentOrNull
|
||||
}.takeIf { it.isNotEmpty() }?.joinToString(" ")
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun JsonObject.boolean(vararg keys: String): Boolean? {
|
||||
return keys.firstNotNullOfOrNull { key ->
|
||||
(this[key] as? JsonPrimitive)?.booleanOrNull
|
||||
}
|
||||
}
|
||||
|
||||
internal fun JsonObject.int(vararg keys: String): Int? {
|
||||
return keys.firstNotNullOfOrNull { key ->
|
||||
(this[key] as? JsonPrimitive)?.intOrNull
|
||||
}
|
||||
}
|
||||
|
||||
internal fun JsonObject.firstLocationPath(): String? {
|
||||
return (this["locations"] as? JsonArray)
|
||||
?.firstNotNullOfOrNull { (it as? JsonObject)?.string("path") }
|
||||
}
|
||||
|
||||
internal fun JsonObject.titlePath(): String? {
|
||||
val title = string("title") ?: return null
|
||||
val firstSpaceIndex = title.indexOf(' ')
|
||||
if (firstSpaceIndex < 0 || firstSpaceIndex >= title.lastIndex) {
|
||||
return null
|
||||
}
|
||||
val path = title.substring(firstSpaceIndex + 1).trim()
|
||||
return path.takeIf { it.startsWith("/") }
|
||||
}
|
||||
|
||||
internal fun JsonObject.firstChangePath(): String? {
|
||||
return (this["changes"] as? JsonObject)?.keys?.firstOrNull()
|
||||
}
|
||||
|
||||
internal fun JsonObject.firstChangeContent(): String? {
|
||||
return (this["changes"] as? JsonObject)?.values?.firstNotNullOfOrNull { change ->
|
||||
(change as? JsonObject)?.string("content")
|
||||
}
|
||||
}
|
||||
|
||||
internal fun JsonElement?.asJsonArrayOrEmpty(): JsonArray {
|
||||
return this as? JsonArray ?: JsonArray(emptyList())
|
||||
}
|
||||
|
||||
internal fun JsonElement?.asJsonObjectOrNull(json: Json): JsonObject? {
|
||||
internal inline fun <reified T> JsonElement?.decodeOrNull(json: Json): T? {
|
||||
return when (this) {
|
||||
is JsonObject -> this
|
||||
null -> null
|
||||
is JsonPrimitive -> {
|
||||
if (!isString) {
|
||||
return null
|
||||
runCatching { json.decodeFromJsonElement<T>(this) }.getOrNull()
|
||||
} else {
|
||||
runCatching { json.decodeFromString<T>(content) }.getOrNull()
|
||||
}
|
||||
runCatching { json.parseToJsonElement(content) }.getOrNull() as? JsonObject
|
||||
}
|
||||
|
||||
else -> null
|
||||
else -> runCatching { json.decodeFromJsonElement<T>(this) }.getOrNull()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,257 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.diagnostic.Logger
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import kotlinx.serialization.json.*
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
internal data class AcpJsonRpcRequest(
|
||||
val id: JsonElement,
|
||||
val method: String,
|
||||
val params: JsonObject
|
||||
)
|
||||
|
||||
internal data class AcpJsonRpcNotification(
|
||||
val method: String,
|
||||
val params: JsonObject
|
||||
)
|
||||
|
||||
internal data class AcpJsonRpcError(
|
||||
val code: Int,
|
||||
val message: String,
|
||||
val data: JsonElement? = null
|
||||
) {
|
||||
companion object {
|
||||
const val METHOD_NOT_FOUND_CODE = -32601
|
||||
}
|
||||
}
|
||||
|
||||
internal class AcpJsonRpcException(
|
||||
val error: AcpJsonRpcError
|
||||
) : IllegalStateException(error.message)
|
||||
|
||||
private sealed interface AcpJsonRpcIncomingMessage {
|
||||
data class Response(
|
||||
val id: String,
|
||||
val result: JsonElement,
|
||||
val error: AcpJsonRpcError?
|
||||
) : AcpJsonRpcIncomingMessage
|
||||
|
||||
data class Request(val value: AcpJsonRpcRequest) : AcpJsonRpcIncomingMessage
|
||||
|
||||
data class Notification(val value: AcpJsonRpcNotification) : AcpJsonRpcIncomingMessage
|
||||
}
|
||||
|
||||
internal class AcpJsonRpcConnection(
|
||||
private val json: Json,
|
||||
private val process: Process,
|
||||
private val scope: CoroutineScope,
|
||||
private val logger: Logger,
|
||||
private val processName: String,
|
||||
private val onRequest: suspend (AcpJsonRpcRequest) -> JsonElement?,
|
||||
private val onNotification: suspend (AcpJsonRpcNotification) -> Unit
|
||||
) {
|
||||
|
||||
private val requestCounter = AtomicLong(0)
|
||||
private val pendingResponses = ConcurrentHashMap<String, CompletableDeferred<JsonElement>>()
|
||||
private val writeMutex = Mutex()
|
||||
|
||||
fun isAlive(): Boolean = process.isAlive
|
||||
|
||||
fun startReader() {
|
||||
scope.launch {
|
||||
val reader = process.inputStream.bufferedReader(StandardCharsets.UTF_8)
|
||||
try {
|
||||
while (isActive && process.isAlive) {
|
||||
val line = reader.readLine() ?: break
|
||||
if (line.isBlank()) continue
|
||||
if (service<ConfigurationSettings>().state.debugModeEnabled) {
|
||||
logger.info("[$processName] $line")
|
||||
}
|
||||
|
||||
when (val message = parseIncomingMessage(line)) {
|
||||
null -> Unit
|
||||
is AcpJsonRpcIncomingMessage.Response -> handleResponse(message)
|
||||
is AcpJsonRpcIncomingMessage.Request -> reply(message.value)
|
||||
is AcpJsonRpcIncomingMessage.Notification -> onNotification(message.value)
|
||||
}
|
||||
}
|
||||
} catch (cancelled: CancellationException) {
|
||||
throw cancelled
|
||||
} catch (t: Throwable) {
|
||||
logger.warn("ACP reader loop failed for $processName", t)
|
||||
} finally {
|
||||
closePendingResponses(IllegalStateException("$processName ACP process exited"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun startStderrLogger() {
|
||||
scope.launch {
|
||||
process.errorStream.bufferedReader(StandardCharsets.UTF_8).useLines { lines ->
|
||||
lines.forEach { line ->
|
||||
if (line.isNotBlank()) {
|
||||
logger.info("[$processName] $line")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun request(method: String, params: JsonObject): JsonObject {
|
||||
val id = requestCounter.incrementAndGet().toString()
|
||||
val response = CompletableDeferred<JsonElement>()
|
||||
pendingResponses[id] = response
|
||||
writePayload(requestPayload(id, method, params))
|
||||
return response.await().jsonObject
|
||||
}
|
||||
|
||||
suspend fun notify(method: String, params: JsonObject) {
|
||||
writePayload(notificationPayload(method, params))
|
||||
}
|
||||
|
||||
fun close() {
|
||||
closePendingResponses(CancellationException("ACP session closed"))
|
||||
process.destroy()
|
||||
}
|
||||
|
||||
private suspend fun reply(request: AcpJsonRpcRequest) {
|
||||
val response = runCatching { onRequest(request) }.fold(
|
||||
onSuccess = { body ->
|
||||
if (body == null) {
|
||||
errorResponse(request.id, -32601, "Method not found: ${request.method}")
|
||||
} else {
|
||||
successResponse(request.id, body)
|
||||
}
|
||||
},
|
||||
onFailure = { error ->
|
||||
errorResponse(request.id, -32603, error.message ?: "Internal error")
|
||||
}
|
||||
)
|
||||
writePayload(response)
|
||||
}
|
||||
|
||||
private fun handleResponse(response: AcpJsonRpcIncomingMessage.Response) {
|
||||
val pending = pendingResponses.remove(response.id) ?: return
|
||||
if (response.error != null) {
|
||||
pending.completeExceptionally(AcpJsonRpcException(response.error))
|
||||
return
|
||||
}
|
||||
pending.complete(response.result)
|
||||
}
|
||||
|
||||
private fun parseIncomingMessage(line: String): AcpJsonRpcIncomingMessage? {
|
||||
val element = runCatching { json.parseToJsonElement(line) }
|
||||
.onFailure { logger.warn("Ignoring non-JSON ACP output from $processName: $line") }
|
||||
.getOrNull() ?: return null
|
||||
val obj = element.jsonObject
|
||||
return when {
|
||||
obj["result"] != null || obj["error"] != null -> {
|
||||
val responseId = obj["id"]?.jsonPrimitive?.content ?: return null
|
||||
AcpJsonRpcIncomingMessage.Response(
|
||||
id = responseId,
|
||||
result = obj["result"] ?: JsonObject(emptyMap()),
|
||||
error = parseError(obj["error"])
|
||||
)
|
||||
}
|
||||
|
||||
obj["method"] != null && obj["id"] != null -> {
|
||||
val method = obj["method"]?.jsonPrimitive?.content ?: return null
|
||||
AcpJsonRpcIncomingMessage.Request(
|
||||
AcpJsonRpcRequest(
|
||||
id = obj["id"] ?: return null,
|
||||
method = method,
|
||||
params = obj["params"]?.jsonObject ?: JsonObject(emptyMap())
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
obj["method"] != null -> {
|
||||
val method = obj["method"]?.jsonPrimitive?.content ?: return null
|
||||
AcpJsonRpcIncomingMessage.Notification(
|
||||
AcpJsonRpcNotification(
|
||||
method = method,
|
||||
params = obj["params"]?.jsonObject ?: JsonObject(emptyMap())
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun parseError(element: JsonElement?): AcpJsonRpcError? {
|
||||
val error = element as? JsonObject ?: return null
|
||||
return AcpJsonRpcError(
|
||||
code = error.int("code") ?: 0,
|
||||
message = error.string("message") ?: "Unknown JSON-RPC error",
|
||||
data = error["data"]
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun writePayload(payload: JsonObject) {
|
||||
writeMutex.withLock {
|
||||
val serializedPayload = json.encodeToString(JsonObject.serializer(), payload)
|
||||
if (service<ConfigurationSettings>().state.debugModeEnabled) {
|
||||
logger.info("[$processName] $serializedPayload")
|
||||
}
|
||||
|
||||
process.outputStream.write(
|
||||
(serializedPayload + "\n").toByteArray(StandardCharsets.UTF_8)
|
||||
)
|
||||
process.outputStream.flush()
|
||||
}
|
||||
}
|
||||
|
||||
private fun closePendingResponses(error: Throwable) {
|
||||
pendingResponses.values.forEach { pending ->
|
||||
pending.completeExceptionally(error)
|
||||
}
|
||||
pendingResponses.clear()
|
||||
}
|
||||
|
||||
private fun requestPayload(id: String, method: String, params: JsonObject): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("jsonrpc", JsonPrimitive("2.0"))
|
||||
put("id", JsonPrimitive(id))
|
||||
put("method", JsonPrimitive(method))
|
||||
put("params", params)
|
||||
}
|
||||
}
|
||||
|
||||
private fun notificationPayload(method: String, params: JsonObject): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("jsonrpc", JsonPrimitive("2.0"))
|
||||
put("method", JsonPrimitive(method))
|
||||
put("params", params)
|
||||
}
|
||||
}
|
||||
|
||||
private fun successResponse(id: JsonElement, result: JsonElement): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("jsonrpc", JsonPrimitive("2.0"))
|
||||
put("id", id)
|
||||
put("result", result)
|
||||
}
|
||||
}
|
||||
|
||||
private fun errorResponse(id: JsonElement, code: Int, message: String): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("jsonrpc", JsonPrimitive("2.0"))
|
||||
put("id", id)
|
||||
put(
|
||||
"error",
|
||||
buildJsonObject {
|
||||
put("code", JsonPrimitive(code))
|
||||
put("message", JsonPrimitive(message))
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
59
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpLogging.kt
vendored
Normal file
59
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpLogging.kt
vendored
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.agentclientprotocol.model.ContentBlock
|
||||
import com.agentclientprotocol.model.SessionUpdate
|
||||
import com.agentclientprotocol.model.ToolCallStatus
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.EditPayload
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalPayload
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.WritePayload
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
|
||||
internal val ToolCallStatus.wireValue: String
|
||||
get() = when (this) {
|
||||
ToolCallStatus.PENDING -> "pending"
|
||||
ToolCallStatus.IN_PROGRESS -> "in_progress"
|
||||
ToolCallStatus.COMPLETED -> "completed"
|
||||
ToolCallStatus.FAILED -> "failed"
|
||||
}
|
||||
|
||||
internal fun String.logPreview(limit: Int = 120): String {
|
||||
return replace('\n', ' ')
|
||||
.replace(Regex("\\s+"), " ")
|
||||
.trim()
|
||||
.take(limit)
|
||||
}
|
||||
|
||||
internal fun ToolApprovalPayload?.logSummary(): String {
|
||||
return when (this) {
|
||||
is WritePayload -> "write:${filePath.logPreview(80)}"
|
||||
is EditPayload -> "edit:${filePath.logPreview(80)}"
|
||||
is BashPayload -> "bash:${command.logPreview(80)}"
|
||||
null -> "none"
|
||||
}
|
||||
}
|
||||
|
||||
internal fun JsonElement?.logSummary(limit: Int = 400): String {
|
||||
return this?.toString()?.replace(Regex("\\s+"), " ")?.take(limit) ?: "null"
|
||||
}
|
||||
|
||||
internal fun Any?.logSummary(limit: Int = 400): String {
|
||||
return this?.toString()?.replace(Regex("\\s+"), " ")?.take(limit) ?: "null"
|
||||
}
|
||||
|
||||
internal fun SessionUpdate.logSummary(): String {
|
||||
return when (this) {
|
||||
is SessionUpdate.ToolCall -> "toolCallId=${toolCallId.value} title=${title.logPreview()} kind=${kind?.name} status=${status?.wireValue} rawInput=${rawInput.logSummary()}"
|
||||
is SessionUpdate.ToolCallUpdate -> "toolCallId=${toolCallId.value} title=${title?.logPreview()} kind=${kind?.name} status=${status?.wireValue} rawInput=${rawInput.logSummary()} rawOutput=${rawOutput.logSummary()}"
|
||||
is SessionUpdate.AgentMessageChunk -> "message=${(content as? ContentBlock.Text)?.text?.logPreview()}"
|
||||
is SessionUpdate.AgentThoughtChunk -> "thought=${(content as? ContentBlock.Text)?.text?.logPreview()}"
|
||||
is SessionUpdate.CurrentModeUpdate -> "currentMode=${currentModeId.value}"
|
||||
is SessionUpdate.PlanUpdate -> "planEntries=${entries.size}"
|
||||
is SessionUpdate.ConfigOptionUpdate -> "configOptions=${configOptions.size}"
|
||||
is SessionUpdate.SessionInfoUpdate -> "title=${title?.logPreview().orEmpty()} updatedAt=${updatedAt.orEmpty()}"
|
||||
is SessionUpdate.UsageUpdate -> "used=$used size=$size cost=${cost.logSummary()}"
|
||||
is SessionUpdate.UnknownSessionUpdate -> "type=$sessionUpdateType raw=${rawJson.logSummary()}"
|
||||
is SessionUpdate.AvailableCommandsUpdate -> "availableCommands=${availableCommands.size}"
|
||||
else -> this::class.simpleName.orEmpty()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import ee.carlrobert.codegpt.util.CommandRuntimeHelper
|
||||
|
||||
object AcpProcessHelper {
|
||||
|
||||
fun resolveCommand(
|
||||
command: String,
|
||||
extraEnvironment: Map<String, String> = emptyMap()
|
||||
): String? {
|
||||
return CommandRuntimeHelper.resolveCommand(command, extraEnvironment)
|
||||
}
|
||||
|
||||
fun createEnvironment(
|
||||
extraEnvironment: Map<String, String>,
|
||||
resolvedCommand: String
|
||||
): MutableMap<String, String> {
|
||||
return CommandRuntimeHelper.createEnvironment(
|
||||
extraEnvironment = extraEnvironment,
|
||||
resolvedCommand = resolvedCommand
|
||||
)
|
||||
}
|
||||
|
||||
fun getCommandNotFoundMessage(command: String): String {
|
||||
return buildString {
|
||||
append("Command '$command' not found. ")
|
||||
when (command) {
|
||||
"npx", "node" -> {
|
||||
append("Node.js/npm is required for this ACP runtime. ")
|
||||
append("Ensure it is installed and available to the IDE process. ")
|
||||
append("You can also point the runtime to an absolute executable path.")
|
||||
}
|
||||
|
||||
else -> {
|
||||
append("Ensure it is installed and available to the IDE process. ")
|
||||
append("You can also point the runtime to an absolute executable path.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
27
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpRuntimeState.kt
vendored
Normal file
27
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpRuntimeState.kt
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.agentclientprotocol.model.AcpCreatedSessionResponse
|
||||
import com.agentclientprotocol.model.AvailableCommand
|
||||
import com.agentclientprotocol.model.SessionConfigOption
|
||||
import com.agentclientprotocol.model.SessionModelState
|
||||
import com.agentclientprotocol.model.SessionModeState
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
|
||||
internal data class AcpRuntimeState(
|
||||
val modes: SessionModeState? = null,
|
||||
val models: SessionModelState? = null,
|
||||
val configOptions: List<SessionConfigOption> = emptyList(),
|
||||
val availableCommands: List<AvailableCommand> = emptyList(),
|
||||
val sessionTitle: String? = null,
|
||||
val sessionUpdatedAt: String? = null,
|
||||
val vendorMeta: JsonElement? = null
|
||||
)
|
||||
|
||||
internal fun AcpCreatedSessionResponse.toRuntimeState(): AcpRuntimeState {
|
||||
return AcpRuntimeState(
|
||||
modes = modes,
|
||||
models = models,
|
||||
configOptions = configOptions.orEmpty(),
|
||||
vendorMeta = _meta
|
||||
)
|
||||
}
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOption
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AcpConfigOptionChoice
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.json.*
|
||||
|
||||
internal enum class AcpSessionConfigId(
|
||||
val value: String,
|
||||
val displayName: String
|
||||
) {
|
||||
MODEL("model", "Model"),
|
||||
MODE("mode", "Mode");
|
||||
|
||||
fun matches(option: AcpConfigOption): Boolean {
|
||||
return option.id == value || option.category == value
|
||||
}
|
||||
}
|
||||
|
||||
internal data class AcpConfigUpdateRequest(
|
||||
val sessionId: String,
|
||||
val optionId: String,
|
||||
val value: String
|
||||
) {
|
||||
fun toJsonObject(): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("sessionId", sessionId)
|
||||
put("configId", optionId)
|
||||
put("value", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal enum class AcpConfigUpdateMethod(val wireName: String) {
|
||||
SNAKE_CASE("session/set_config_option"),
|
||||
CAMEL_CASE("session/setConfigOption")
|
||||
}
|
||||
|
||||
internal sealed interface AcpConfigUpdateSupport {
|
||||
fun candidateMethods(): List<AcpConfigUpdateMethod>
|
||||
|
||||
data object Unknown : AcpConfigUpdateSupport {
|
||||
override fun candidateMethods(): List<AcpConfigUpdateMethod> =
|
||||
AcpConfigUpdateMethod.entries
|
||||
}
|
||||
|
||||
data object Unsupported : AcpConfigUpdateSupport {
|
||||
override fun candidateMethods(): List<AcpConfigUpdateMethod> = emptyList()
|
||||
}
|
||||
|
||||
data class Supported(val method: AcpConfigUpdateMethod) : AcpConfigUpdateSupport {
|
||||
override fun candidateMethods(): List<AcpConfigUpdateMethod> = listOf(method)
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed interface AcpConfigUpdateResult {
|
||||
data class Applied(
|
||||
val response: JsonObject,
|
||||
val support: AcpConfigUpdateSupport.Supported
|
||||
) : AcpConfigUpdateResult
|
||||
|
||||
data object Unsupported : AcpConfigUpdateResult
|
||||
}
|
||||
|
||||
internal class AcpSessionConfigAdapter(
|
||||
private val json: Json
|
||||
) {
|
||||
|
||||
fun merge(
|
||||
existing: List<AcpConfigOption>,
|
||||
response: JsonObject
|
||||
): List<AcpConfigOption> {
|
||||
val updates = decode(response)
|
||||
if (updates.isEmpty()) {
|
||||
return existing
|
||||
}
|
||||
|
||||
val merged = existing.associateByTo(linkedMapOf()) { it.id }
|
||||
updates.forEach { option ->
|
||||
merged[option.id] = option
|
||||
}
|
||||
return merged.values.toList()
|
||||
}
|
||||
|
||||
suspend fun updateOption(
|
||||
request: AcpConfigUpdateRequest,
|
||||
support: AcpConfigUpdateSupport,
|
||||
sendRequest: suspend (String, JsonObject) -> JsonObject
|
||||
): AcpConfigUpdateResult {
|
||||
val candidateMethods = support.candidateMethods()
|
||||
if (candidateMethods.isEmpty()) {
|
||||
return AcpConfigUpdateResult.Unsupported
|
||||
}
|
||||
|
||||
val params = request.toJsonObject()
|
||||
candidateMethods.forEach { method ->
|
||||
try {
|
||||
return AcpConfigUpdateResult.Applied(
|
||||
response = sendRequest(method.wireName, params),
|
||||
support = AcpConfigUpdateSupport.Supported(method)
|
||||
)
|
||||
} catch (error: Throwable) {
|
||||
if (!error.isMethodNotFoundJsonRpcError()) {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AcpConfigUpdateResult.Unsupported
|
||||
}
|
||||
|
||||
private fun decode(response: JsonObject): List<AcpConfigOption> {
|
||||
val directOptions = buildList {
|
||||
addAll(
|
||||
response.decodeField<List<AcpStandardConfigOptionPayload>>("configOptions")
|
||||
.orEmpty()
|
||||
.mapNotNull(AcpStandardConfigOptionPayload::toConfigOption)
|
||||
)
|
||||
response.decodeField<AcpStandardConfigOptionPayload>("configOption")
|
||||
?.toConfigOption()
|
||||
?.let(::add)
|
||||
}
|
||||
|
||||
val hasDirectModelOption = directOptions.any(AcpSessionConfigId.MODEL::matches)
|
||||
val hasDirectModeOption = directOptions.any(AcpSessionConfigId.MODE::matches)
|
||||
|
||||
return buildList {
|
||||
addAll(directOptions)
|
||||
if (!hasDirectModelOption) {
|
||||
response.decodeField<AcpModelsPayload>("models")
|
||||
?.toConfigOption()
|
||||
?.let(::add)
|
||||
}
|
||||
if (!hasDirectModeOption) {
|
||||
response.decodeField<AcpModesPayload>("modes")
|
||||
?.toConfigOption()
|
||||
?.let(::add)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private inline fun <reified T> JsonObject.decodeField(key: String): T? {
|
||||
val element = this[key] ?: return null
|
||||
return runCatching { json.decodeFromJsonElement<T>(element) }.getOrNull()
|
||||
}
|
||||
}
|
||||
|
||||
@Serializable
|
||||
private data class AcpStandardConfigOptionPayload(
|
||||
val id: String? = null,
|
||||
val name: String? = null,
|
||||
val description: String? = null,
|
||||
val category: String? = null,
|
||||
val type: String? = null,
|
||||
val currentValue: String? = null,
|
||||
val current_value: String? = null,
|
||||
val value: String? = null,
|
||||
val options: List<AcpConfigChoicePayload> = emptyList()
|
||||
) {
|
||||
fun toConfigOption(): AcpConfigOption? {
|
||||
val resolvedId = id.nullIfBlank() ?: return null
|
||||
return AcpConfigOption(
|
||||
id = resolvedId,
|
||||
name = name.nullIfBlank() ?: resolvedId,
|
||||
description = description.nullIfBlank(),
|
||||
category = category.nullIfBlank(),
|
||||
type = type.nullIfBlank(),
|
||||
currentValue = firstNotBlank(currentValue, current_value, value),
|
||||
options = options.mapNotNull(AcpConfigChoicePayload::toChoice)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Serializable
|
||||
private data class AcpConfigChoicePayload(
|
||||
val value: String? = null,
|
||||
val id: String? = null,
|
||||
val name: String? = null,
|
||||
val description: String? = null
|
||||
) {
|
||||
fun toChoice(): AcpConfigOptionChoice? {
|
||||
val resolvedValue = firstNotBlank(value, id) ?: return null
|
||||
return AcpConfigOptionChoice(
|
||||
value = resolvedValue,
|
||||
name = name.nullIfBlank() ?: resolvedValue,
|
||||
description = description.nullIfBlank()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Serializable
|
||||
private data class AcpModelsPayload(
|
||||
val currentModelId: String? = null,
|
||||
val availableModels: List<AcpAlternativeChoicePayload> = emptyList()
|
||||
) {
|
||||
fun toConfigOption(): AcpConfigOption? {
|
||||
return toAlternativeConfigOption(
|
||||
id = AcpSessionConfigId.MODEL,
|
||||
currentValue = currentModelId,
|
||||
entries = availableModels
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Serializable
|
||||
private data class AcpModesPayload(
|
||||
val currentModeId: String? = null,
|
||||
val availableModes: List<AcpAlternativeChoicePayload> = emptyList()
|
||||
) {
|
||||
fun toConfigOption(): AcpConfigOption? {
|
||||
return toAlternativeConfigOption(
|
||||
id = AcpSessionConfigId.MODE,
|
||||
currentValue = currentModeId,
|
||||
entries = availableModes
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Serializable
|
||||
private data class AcpAlternativeChoicePayload(
|
||||
val modelId: String? = null,
|
||||
val modeId: String? = null,
|
||||
val value: String? = null,
|
||||
val id: String? = null,
|
||||
val name: String? = null,
|
||||
val description: String? = null
|
||||
) {
|
||||
fun toChoice(): AcpConfigOptionChoice? {
|
||||
val resolvedValue = firstNotBlank(modelId, modeId, value, id) ?: return null
|
||||
return AcpConfigOptionChoice(
|
||||
value = resolvedValue,
|
||||
name = name.nullIfBlank() ?: resolvedValue,
|
||||
description = description.nullIfBlank()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun toAlternativeConfigOption(
|
||||
id: AcpSessionConfigId,
|
||||
currentValue: String?,
|
||||
entries: List<AcpAlternativeChoicePayload>
|
||||
): AcpConfigOption? {
|
||||
val options = entries.mapNotNull(AcpAlternativeChoicePayload::toChoice)
|
||||
val resolvedCurrentValue = currentValue.nullIfBlank()
|
||||
if (options.isEmpty() && resolvedCurrentValue == null) {
|
||||
return null
|
||||
}
|
||||
return AcpConfigOption(
|
||||
id = id.value,
|
||||
name = id.displayName,
|
||||
category = id.value,
|
||||
type = "select",
|
||||
currentValue = resolvedCurrentValue,
|
||||
options = options
|
||||
)
|
||||
}
|
||||
|
||||
private fun firstNotBlank(vararg values: String?): String? {
|
||||
return values.firstNotNullOfOrNull(String?::nullIfBlank)
|
||||
}
|
||||
|
||||
private fun String?.nullIfBlank(): String? = this?.takeIf { it.isNotBlank() }
|
||||
|
||||
private fun Throwable.isMethodNotFoundJsonRpcError(): Boolean {
|
||||
return (this as? AcpJsonRpcException)?.error?.code == AcpJsonRpcError.METHOD_NOT_FOUND_CODE
|
||||
}
|
||||
175
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateBridge.kt
vendored
Normal file
175
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpSessionUpdateBridge.kt
vendored
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.agentclientprotocol.model.SessionNotification
|
||||
import com.agentclientprotocol.model.AvailableCommand
|
||||
import com.agentclientprotocol.model.SessionConfigOption
|
||||
import com.agentclientprotocol.model.ToolKind
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import ee.carlrobert.codegpt.agent.AgentUsageEvent
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpExternalEvent
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallSnapshot
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
internal class AcpSessionUpdateBridge(
|
||||
private val proxySessionId: String,
|
||||
private val toolEventFlavor: AcpToolEventFlavor,
|
||||
private val toolCallDecoder: AcpToolCallDecoder,
|
||||
private val updateModeSelection: (String) -> Unit,
|
||||
private val updateConfigOptions: (List<SessionConfigOption>) -> Unit,
|
||||
private val updateAvailableCommands: (List<AvailableCommand>) -> Unit,
|
||||
private val updateSessionInfo: (String?, String?) -> Unit,
|
||||
private val trace: (String) -> Unit
|
||||
) {
|
||||
private val toolCallsById = ConcurrentHashMap<String, AcpToolCallSnapshot>()
|
||||
|
||||
fun handle(notification: SessionNotification, events: AgentEvents) {
|
||||
trace(
|
||||
"session/update session=$proxySessionId update=${notification.update::class.simpleName} payload=${notification.update.logSummary()}"
|
||||
)
|
||||
when (val event =
|
||||
toolCallDecoder.decodeExternalEvent(toolEventFlavor, notification.update)) {
|
||||
is AcpExternalEvent.TextChunk -> events.onTextReceived(event.text)
|
||||
is AcpExternalEvent.ThinkingChunk -> events.onThinkingReceived(event.text)
|
||||
is AcpExternalEvent.PlanUpdate -> events.onPlanUpdated(event.entries)
|
||||
is AcpExternalEvent.UsageUpdate -> events.onUsageAvailable(
|
||||
AgentUsageEvent(
|
||||
usedTokens = event.used,
|
||||
sizeTokens = event.size.takeIf { it > 0L },
|
||||
costAmount = event.cost?.amount,
|
||||
costCurrency = event.cost?.currency
|
||||
)
|
||||
)
|
||||
is AcpExternalEvent.ConfigOptionUpdate -> {
|
||||
updateConfigOptions(event.configOptions)
|
||||
events.onRuntimeOptionsUpdated()
|
||||
}
|
||||
is AcpExternalEvent.SessionInfoUpdate -> {
|
||||
updateSessionInfo(event.title, event.updatedAt)
|
||||
events.onSessionInfoUpdated(event.title, event.updatedAt)
|
||||
}
|
||||
is AcpExternalEvent.UnknownSessionUpdate -> trace(
|
||||
"session/update/unknown session=$proxySessionId type=${event.type} payload=${event.rawJson.logSummary()}"
|
||||
)
|
||||
is AcpExternalEvent.AvailableCommandsUpdate -> updateAvailableCommands(event.availableCommands)
|
||||
is AcpExternalEvent.CurrentModeUpdate -> {
|
||||
updateModeSelection(event.currentModeId)
|
||||
events.onRuntimeOptionsUpdated()
|
||||
}
|
||||
is AcpExternalEvent.ToolCallStarted -> handleToolCallStarted(event.toolCall, events)
|
||||
is AcpExternalEvent.ToolCallUpdated -> handleToolCallUpdated(event.toolCall, events)
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleToolCallStarted(
|
||||
toolCall: AcpToolCallSnapshot,
|
||||
events: AgentEvents
|
||||
) {
|
||||
trace(
|
||||
"tool-call/start session=$proxySessionId id=${toolCall.id} title=${toolCall.title.logPreview()} kind=${toolCall.kind?.name} status=${toolCall.status?.wireValue} rawInput=${toolCall.rawInput.logSummary()}"
|
||||
)
|
||||
if (toolCall.kind == ToolKind.THINK && toolCall.title.isNotBlank()) {
|
||||
events.onThinkingReceived(toolCall.title)
|
||||
return
|
||||
}
|
||||
toolCallsById[toolCall.id] = toolCall
|
||||
if (!shouldDeferToolStart(toolCall)) {
|
||||
events.onToolStarting(toolCall.id, toolCall.toolName, toolCall.args.toUiArgs())
|
||||
}
|
||||
|
||||
if (toolCall.status?.isTerminal == true) {
|
||||
completeToolCall(toolCall, events)
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleToolCallUpdated(
|
||||
toolCall: AcpToolCallSnapshot,
|
||||
events: AgentEvents
|
||||
) {
|
||||
trace(
|
||||
"tool-call/update session=$proxySessionId id=${toolCall.id} title=${toolCall.title.logPreview()} kind=${toolCall.kind?.name} status=${toolCall.status?.wireValue} rawInput=${toolCall.rawInput.logSummary()} rawOutput=${toolCall.rawOutput.logSummary()}"
|
||||
)
|
||||
val currentToolCall = toolCallsById[toolCall.id]
|
||||
val effectiveToolCall = mergeToolCallSnapshots(currentToolCall, toolCall)
|
||||
val status = effectiveToolCall.status ?: return
|
||||
|
||||
toolCallsById[effectiveToolCall.id] = effectiveToolCall
|
||||
|
||||
if ((currentToolCall?.args == null && effectiveToolCall.args != null) ||
|
||||
(currentToolCall == null && !shouldDeferToolStart(effectiveToolCall))
|
||||
) {
|
||||
events.onToolStarting(
|
||||
effectiveToolCall.id,
|
||||
effectiveToolCall.toolName,
|
||||
effectiveToolCall.args.toUiArgs()
|
||||
)
|
||||
}
|
||||
|
||||
if (!status.isTerminal) {
|
||||
return
|
||||
}
|
||||
|
||||
completeToolCall(effectiveToolCall, events)
|
||||
}
|
||||
|
||||
private fun completeToolCall(
|
||||
toolCall: AcpToolCallSnapshot,
|
||||
events: AgentEvents
|
||||
) {
|
||||
val result = toolCallDecoder.decodeResult(toolCall, toolCall.rawOutput)
|
||||
trace(
|
||||
"tool-call/complete session=$proxySessionId id=${toolCall.id} tool=${toolCall.toolName} status=${toolCall.status?.wireValue} result=${result.logSummary()}"
|
||||
)
|
||||
toolCallsById.remove(toolCall.id)
|
||||
events.onToolCompleted(toolCall.id, toolCall.toolName, result)
|
||||
}
|
||||
|
||||
private fun shouldDeferToolStart(toolCall: AcpToolCallSnapshot): Boolean {
|
||||
return toolCall.toolName in setOf("WebSearch", "WebFetch", "Write", "Edit") &&
|
||||
toolCall.args == null &&
|
||||
toolCall.status?.isTerminal != true
|
||||
}
|
||||
|
||||
private fun AcpToolCallArgs?.toUiArgs(): Any? {
|
||||
return when (this) {
|
||||
is AcpToolCallArgs.SearchPreview -> value
|
||||
is AcpToolCallArgs.BashPreview -> value
|
||||
is AcpToolCallArgs.Mcp -> value
|
||||
is AcpToolCallArgs.Read -> value
|
||||
is AcpToolCallArgs.Write -> value
|
||||
is AcpToolCallArgs.Edit -> value
|
||||
is AcpToolCallArgs.Bash -> value
|
||||
is AcpToolCallArgs.WebSearch -> value
|
||||
is AcpToolCallArgs.WebFetch -> value
|
||||
is AcpToolCallArgs.IntelliJSearch -> value
|
||||
is AcpToolCallArgs.Unknown -> value
|
||||
null -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun mergeToolCallSnapshots(
|
||||
currentToolCall: AcpToolCallSnapshot?,
|
||||
updatedToolCall: AcpToolCallSnapshot
|
||||
): AcpToolCallSnapshot {
|
||||
val effectiveTitle = updatedToolCall.title.ifBlank { currentToolCall?.title.orEmpty() }
|
||||
.ifBlank { "Tool" }
|
||||
val effectiveToolName = when {
|
||||
updatedToolCall.toolName.isNotBlank() && updatedToolCall.toolName != "Tool" -> updatedToolCall.toolName
|
||||
currentToolCall != null -> currentToolCall.toolName
|
||||
else -> "Tool"
|
||||
}
|
||||
return updatedToolCall.copy(
|
||||
title = effectiveTitle,
|
||||
toolName = effectiveToolName,
|
||||
kind = updatedToolCall.kind ?: currentToolCall?.kind,
|
||||
status = updatedToolCall.status ?: currentToolCall?.status,
|
||||
args = updatedToolCall.args ?: currentToolCall?.args,
|
||||
locations = updatedToolCall.locations.ifEmpty { currentToolCall?.locations.orEmpty() },
|
||||
content = updatedToolCall.content.ifEmpty { currentToolCall?.content.orEmpty() },
|
||||
meta = updatedToolCall.meta ?: currentToolCall?.meta,
|
||||
rawInput = updatedToolCall.rawInput ?: currentToolCall?.rawInput,
|
||||
rawOutput = updatedToolCall.rawOutput ?: currentToolCall?.rawOutput
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
|
||||
internal sealed interface AcpSessionUpdate {
|
||||
data class TextChunk(val text: String) : AcpSessionUpdate
|
||||
data class ThoughtChunk(val text: String) : AcpSessionUpdate
|
||||
data class ToolCall(
|
||||
val toolCall: AcpDecodedToolCall,
|
||||
val status: AcpToolCallStatus?,
|
||||
val rawOutput: JsonElement?
|
||||
) : AcpSessionUpdate
|
||||
|
||||
data class ToolCallUpdate(
|
||||
val toolCallId: String,
|
||||
val toolCall: AcpDecodedToolCall?,
|
||||
val status: AcpToolCallStatus,
|
||||
val rawOutput: JsonElement?
|
||||
) : AcpSessionUpdate
|
||||
|
||||
data class ConfigOptionUpdate(val update: JsonObject) : AcpSessionUpdate
|
||||
}
|
||||
|
||||
internal enum class AcpToolCallStatus(val wireValue: String) {
|
||||
IN_PROGRESS("in_progress"),
|
||||
COMPLETED("completed"),
|
||||
FAILED("failed"),
|
||||
CANCELLED("cancelled");
|
||||
|
||||
val isTerminal: Boolean
|
||||
get() = this != IN_PROGRESS
|
||||
|
||||
companion object {
|
||||
fun fromWireValue(value: String?): AcpToolCallStatus? {
|
||||
return entries.firstOrNull { it.wireValue == value }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class AcpSessionUpdateParser(
|
||||
private val toolCallDecoder: AcpToolCallDecoder
|
||||
) {
|
||||
|
||||
fun parse(notification: AcpJsonRpcNotification): AcpSessionUpdate? {
|
||||
if (notification.method != SESSION_UPDATE_METHOD) {
|
||||
return null
|
||||
}
|
||||
|
||||
val update = notification.params["update"] as? JsonObject ?: return null
|
||||
return when (SessionUpdateKind.fromWireValue(update.string("sessionUpdate"))) {
|
||||
SessionUpdateKind.AGENT_MESSAGE_CHUNK -> parseAgentMessageChunk(update)
|
||||
SessionUpdateKind.TOOL_CALL -> parseToolCall(update)
|
||||
SessionUpdateKind.TOOL_CALL_UPDATE -> parseToolCallUpdate(update)
|
||||
SessionUpdateKind.CONFIG_OPTION_UPDATE -> AcpSessionUpdate.ConfigOptionUpdate(update)
|
||||
null -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun parseAgentMessageChunk(update: JsonObject): AcpSessionUpdate? {
|
||||
val content = update["content"] as? JsonObject ?: return null
|
||||
return when (MessageChunkType.fromWireValue(content.string("type"))) {
|
||||
MessageChunkType.TEXT ->
|
||||
AcpSessionUpdate.TextChunk(content.string("text").orEmpty())
|
||||
|
||||
MessageChunkType.THOUGHT -> {
|
||||
val text = content.string("thought").orEmpty()
|
||||
text.takeIf { it.isNotBlank() }?.let(AcpSessionUpdate::ThoughtChunk)
|
||||
}
|
||||
|
||||
null -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun parseToolCall(update: JsonObject): AcpSessionUpdate? {
|
||||
val toolCall = toolCallDecoder.decodeToolCall(update) ?: return null
|
||||
return AcpSessionUpdate.ToolCall(
|
||||
toolCall = toolCall,
|
||||
status = AcpToolCallStatus.fromWireValue(update.string("status")),
|
||||
rawOutput = update["rawOutput"] ?: update["content"]
|
||||
)
|
||||
}
|
||||
|
||||
private fun parseToolCallUpdate(update: JsonObject): AcpSessionUpdate? {
|
||||
val toolCallId = update.string("toolCallId") ?: return null
|
||||
val status = AcpToolCallStatus.fromWireValue(update.string("status")) ?: return null
|
||||
return AcpSessionUpdate.ToolCallUpdate(
|
||||
toolCallId = toolCallId,
|
||||
toolCall = toolCallDecoder.decodeToolCall(update),
|
||||
status = status,
|
||||
rawOutput = update["rawOutput"] ?: update["content"]
|
||||
)
|
||||
}
|
||||
|
||||
private enum class SessionUpdateKind(val wireValue: String) {
|
||||
AGENT_MESSAGE_CHUNK("agent_message_chunk"),
|
||||
TOOL_CALL("tool_call"),
|
||||
TOOL_CALL_UPDATE("tool_call_update"),
|
||||
CONFIG_OPTION_UPDATE("config_option_update");
|
||||
|
||||
companion object {
|
||||
fun fromWireValue(value: String?): SessionUpdateKind? {
|
||||
return entries.firstOrNull { it.wireValue == value }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private enum class MessageChunkType(val wireValue: String) {
|
||||
TEXT("text"),
|
||||
THOUGHT("thought");
|
||||
|
||||
companion object {
|
||||
fun fromWireValue(value: String?): MessageChunkType? {
|
||||
return entries.firstOrNull { it.wireValue == value }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private companion object {
|
||||
const val SESSION_UPDATE_METHOD = "session/update"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,68 +1,179 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.agentclientprotocol.model.ContentBlock
|
||||
import com.agentclientprotocol.model.RequestPermissionRequest
|
||||
import com.agentclientprotocol.model.SessionUpdate
|
||||
import com.agentclientprotocol.model.ToolCallContent
|
||||
import com.agentclientprotocol.model.ToolCallStatus
|
||||
import com.agentclientprotocol.model.ToolCallLocation
|
||||
import com.agentclientprotocol.model.ToolKind
|
||||
import ee.carlrobert.codegpt.agent.ToolSpecs
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import ee.carlrobert.codegpt.agent.tools.EditTool
|
||||
import ee.carlrobert.codegpt.agent.tools.McpTool
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpExternalEvent
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpPermissionRequestSnapshot
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallSnapshot
|
||||
import ee.carlrobert.codegpt.agent.external.events.toAcpToolCallContent
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonArray
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
internal data class AcpDecodedToolCall(
|
||||
val id: String,
|
||||
val toolName: String,
|
||||
val args: Any?
|
||||
)
|
||||
|
||||
internal data class AcpPermissionRequestData(
|
||||
val rawTitle: String,
|
||||
val toolName: String,
|
||||
val parsedArgs: Any?,
|
||||
val details: String,
|
||||
val options: JsonArray
|
||||
)
|
||||
|
||||
private data class DiffContent(
|
||||
val path: String,
|
||||
val oldText: String?,
|
||||
val newText: String
|
||||
)
|
||||
|
||||
private data class ResolvedToolCall(
|
||||
val rawTitle: String,
|
||||
val toolName: String,
|
||||
val args: Any?
|
||||
)
|
||||
|
||||
internal class AcpToolCallDecoder(
|
||||
private val json: Json
|
||||
) {
|
||||
private val support = AcpToolCallDecodingSupport(json)
|
||||
private val standardNormalizer = StandardSemanticToolCallNormalizer()
|
||||
private val zedNormalizer = ZedAdapterToolCallNormalizer()
|
||||
private val geminiNormalizer = GeminiCliToolCallNormalizer()
|
||||
private val fallbackNormalizer = FallbackToolCallNormalizer()
|
||||
|
||||
fun decodeToolCall(metadata: JsonObject): AcpDecodedToolCall? {
|
||||
val toolCallId = metadata.string("toolCallId") ?: return null
|
||||
val tool = resolveToolCall(metadata)
|
||||
return AcpDecodedToolCall(
|
||||
id = toolCallId,
|
||||
toolName = tool.toolName,
|
||||
args = tool.args
|
||||
fun decodeExternalEvent(
|
||||
flavor: AcpToolEventFlavor,
|
||||
update: SessionUpdate
|
||||
): AcpExternalEvent? {
|
||||
return when (update) {
|
||||
is SessionUpdate.UserMessageChunk -> null
|
||||
is SessionUpdate.AgentMessageChunk -> {
|
||||
(update.content as? ContentBlock.Text)?.text?.let {
|
||||
AcpExternalEvent.TextChunk(it)
|
||||
}
|
||||
}
|
||||
|
||||
is SessionUpdate.AgentThoughtChunk -> {
|
||||
(update.content as? ContentBlock.Text)?.text?.let {
|
||||
AcpExternalEvent.ThinkingChunk(it)
|
||||
}
|
||||
}
|
||||
|
||||
is SessionUpdate.PlanUpdate -> AcpExternalEvent.PlanUpdate(update.entries)
|
||||
is SessionUpdate.UsageUpdate -> AcpExternalEvent.UsageUpdate(
|
||||
used = update.used,
|
||||
size = update.size,
|
||||
cost = update.cost
|
||||
)
|
||||
is SessionUpdate.ConfigOptionUpdate -> AcpExternalEvent.ConfigOptionUpdate(update.configOptions)
|
||||
is SessionUpdate.SessionInfoUpdate -> AcpExternalEvent.SessionInfoUpdate(
|
||||
title = update.title,
|
||||
updatedAt = update.updatedAt
|
||||
)
|
||||
is SessionUpdate.UnknownSessionUpdate -> AcpExternalEvent.UnknownSessionUpdate(
|
||||
type = update.sessionUpdateType,
|
||||
rawJson = update.rawJson
|
||||
)
|
||||
is SessionUpdate.AvailableCommandsUpdate -> AcpExternalEvent.AvailableCommandsUpdate(update.availableCommands)
|
||||
is SessionUpdate.CurrentModeUpdate -> AcpExternalEvent.CurrentModeUpdate(update.currentModeId.value)
|
||||
is SessionUpdate.ToolCall -> AcpExternalEvent.ToolCallStarted(decodeToolCallSnapshot(flavor, update))
|
||||
is SessionUpdate.ToolCallUpdate -> AcpExternalEvent.ToolCallUpdated(decodeToolCallSnapshot(flavor, update))
|
||||
}
|
||||
}
|
||||
|
||||
fun decodePermissionRequest(
|
||||
flavor: AcpToolEventFlavor,
|
||||
request: RequestPermissionRequest
|
||||
): AcpPermissionRequestSnapshot {
|
||||
return AcpPermissionRequestSnapshot(
|
||||
toolCall = decodeToolCallSnapshot(
|
||||
flavor = flavor,
|
||||
toolCallId = request.toolCall.toolCallId.value,
|
||||
rawTitle = request.toolCall.title ?: "Allow action?",
|
||||
rawKind = request.toolCall.kind?.wireValue,
|
||||
rawInput = request.toolCall.rawInput,
|
||||
locations = request.toolCall.locations.orEmpty(),
|
||||
content = request.toolCall.content.orEmpty(),
|
||||
rawMeta = request.toolCall._meta,
|
||||
defaultTitle = "Allow action?"
|
||||
),
|
||||
details = permissionDetails(
|
||||
locations = request.toolCall.locations.orEmpty(),
|
||||
rawInput = request.toolCall.rawInput,
|
||||
fallback = request.toolCall.toString()
|
||||
),
|
||||
options = request.options
|
||||
)
|
||||
}
|
||||
|
||||
fun decodePermissionRequest(params: JsonObject): AcpPermissionRequestData {
|
||||
val toolCall = params["toolCall"] as? JsonObject ?: JsonObject(emptyMap())
|
||||
val tool = resolveToolCall(toolCall, defaultTitle = "Allow action?")
|
||||
return AcpPermissionRequestData(
|
||||
rawTitle = tool.rawTitle,
|
||||
toolName = tool.toolName,
|
||||
parsedArgs = tool.args,
|
||||
details = permissionDetails(toolCall),
|
||||
options = params["options"].asJsonArrayOrEmpty()
|
||||
private fun decodeToolCallSnapshot(
|
||||
flavor: AcpToolEventFlavor,
|
||||
update: SessionUpdate.ToolCall
|
||||
): AcpToolCallSnapshot {
|
||||
return decodeToolCallSnapshot(
|
||||
flavor = flavor,
|
||||
toolCallId = update.toolCallId.value,
|
||||
rawTitle = update.title,
|
||||
rawKind = update.kind?.wireValue,
|
||||
rawInput = update.rawInput,
|
||||
locations = update.locations,
|
||||
content = update.content,
|
||||
rawMeta = update._meta,
|
||||
rawOutput = update.rawOutput,
|
||||
status = update.status?.toAcpToolCallStatus()
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeToolCallSnapshot(
|
||||
flavor: AcpToolEventFlavor,
|
||||
update: SessionUpdate.ToolCallUpdate
|
||||
): AcpToolCallSnapshot {
|
||||
return decodeToolCallSnapshot(
|
||||
flavor = flavor,
|
||||
toolCallId = update.toolCallId.value,
|
||||
rawTitle = update.title.orEmpty(),
|
||||
rawKind = update.kind?.wireValue,
|
||||
rawInput = update.rawInput,
|
||||
locations = update.locations.orEmpty(),
|
||||
content = update.content.orEmpty(),
|
||||
rawMeta = update._meta,
|
||||
rawOutput = update.rawOutput,
|
||||
status = update.status?.toAcpToolCallStatus(),
|
||||
defaultTitle = ""
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeToolCallSnapshot(
|
||||
flavor: AcpToolEventFlavor,
|
||||
toolCallId: String,
|
||||
rawTitle: String,
|
||||
rawKind: String?,
|
||||
rawInput: JsonElement?,
|
||||
locations: List<ToolCallLocation>,
|
||||
content: List<ToolCallContent>,
|
||||
rawMeta: JsonElement? = null,
|
||||
rawOutput: JsonElement? = null,
|
||||
status: AcpToolCallStatus? = null,
|
||||
defaultTitle: String = "Tool"
|
||||
): AcpToolCallSnapshot {
|
||||
val resolved = resolveToolCall(
|
||||
flavor = flavor,
|
||||
context = AcpToolCallContext(
|
||||
toolCallId = toolCallId,
|
||||
rawTitle = rawTitle,
|
||||
rawKind = rawKind,
|
||||
rawInput = rawInput,
|
||||
locations = locations,
|
||||
content = content,
|
||||
defaultTitle = defaultTitle
|
||||
)
|
||||
)
|
||||
return AcpToolCallSnapshot(
|
||||
id = toolCallId,
|
||||
title = resolved.rawTitle,
|
||||
toolName = resolved.toolName,
|
||||
kind = rawKind?.toAcpToolKind(),
|
||||
status = status,
|
||||
args = resolved.typedArgs,
|
||||
locations = locations,
|
||||
content = content.toAcpToolCallContent(),
|
||||
meta = rawMeta,
|
||||
rawInput = rawInput,
|
||||
rawOutput = rawOutput
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeResult(
|
||||
toolName: String,
|
||||
args: Any?,
|
||||
args: AcpToolCallArgs?,
|
||||
status: AcpToolCallStatus,
|
||||
rawOutput: JsonElement?
|
||||
): Any? {
|
||||
|
|
@ -71,325 +182,131 @@ internal class AcpToolCallDecoder(
|
|||
|
||||
if (status == AcpToolCallStatus.COMPLETED) {
|
||||
when (args) {
|
||||
is McpTool.Args -> return McpTool.Result(
|
||||
serverId = args.serverId,
|
||||
serverName = args.serverName,
|
||||
toolName = args.toolName,
|
||||
is AcpToolCallArgs.Mcp -> return McpTool.Result(
|
||||
serverId = args.value.serverId,
|
||||
serverName = args.value.serverName,
|
||||
toolName = args.value.toolName,
|
||||
success = true,
|
||||
output = payload.ifBlank { "MCP tool completed" }
|
||||
)
|
||||
|
||||
is EditTool.Args -> return EditTool.Result.Success(
|
||||
filePath = args.filePath,
|
||||
is AcpToolCallArgs.Edit -> return EditTool.Result.Success(
|
||||
filePath = args.value.filePath,
|
||||
replacementsMade = 1,
|
||||
message = "Edit completed"
|
||||
)
|
||||
|
||||
is WriteTool.Args -> return WriteTool.Result.Success(
|
||||
filePath = args.filePath,
|
||||
bytesWritten = args.content.toByteArray(StandardCharsets.UTF_8).size,
|
||||
is AcpToolCallArgs.Write -> return WriteTool.Result.Success(
|
||||
filePath = args.value.filePath,
|
||||
bytesWritten = args.value.content.toByteArray(StandardCharsets.UTF_8).size,
|
||||
isNewFile = false,
|
||||
message = "Write completed"
|
||||
)
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
if (status == AcpToolCallStatus.FAILED || status == AcpToolCallStatus.CANCELLED) {
|
||||
val message = payload.ifBlank { "Tool ${status.wireValue}" }
|
||||
when (args) {
|
||||
is McpTool.Args -> return McpTool.Result.error(
|
||||
toolName = args.toolName,
|
||||
is AcpToolCallArgs.Mcp -> return McpTool.Result.error(
|
||||
toolName = args.value.toolName,
|
||||
output = message,
|
||||
serverId = args.serverId,
|
||||
serverName = args.serverName
|
||||
serverId = args.value.serverId,
|
||||
serverName = args.value.serverName
|
||||
)
|
||||
|
||||
is EditTool.Args -> return EditTool.Result.Error(args.filePath, message)
|
||||
is WriteTool.Args -> return WriteTool.Result.Error(args.filePath, message)
|
||||
is AcpToolCallArgs.Edit -> return EditTool.Result.Error(args.value.filePath, message)
|
||||
is AcpToolCallArgs.Write -> return WriteTool.Result.Error(args.value.filePath, message)
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
return payload.ifBlank { null }
|
||||
}
|
||||
|
||||
private fun resolveToolCall(
|
||||
metadata: JsonObject,
|
||||
defaultTitle: String = "Tool"
|
||||
): ResolvedToolCall {
|
||||
val rawTitle = metadata.string("title") ?: defaultTitle
|
||||
val rawKind = metadata.string("kind")
|
||||
val rawInput = metadata["rawInput"]
|
||||
val initialToolName = normalizeToolName(rawTitle, rawKind, rawInput)
|
||||
val parsedArgs = if (initialToolName == "MCP") {
|
||||
decodeMcpArgs(rawTitle, rawInput)
|
||||
} else {
|
||||
decodeToolArgs(initialToolName, rawInput, metadata)
|
||||
}
|
||||
return ResolvedToolCall(
|
||||
rawTitle = rawTitle,
|
||||
toolName = resolveToolName(initialToolName, parsedArgs),
|
||||
args = parsedArgs
|
||||
fun decodeResult(
|
||||
toolCall: AcpToolCallSnapshot,
|
||||
rawOutput: JsonElement?
|
||||
): Any? {
|
||||
return decodeResult(
|
||||
toolName = toolCall.toolName,
|
||||
args = toolCall.args,
|
||||
status = toolCall.status ?: AcpToolCallStatus.COMPLETED,
|
||||
rawOutput = rawOutput ?: toolCall.rawOutput
|
||||
)
|
||||
}
|
||||
|
||||
private fun permissionDetails(toolCall: JsonObject): String {
|
||||
private fun resolveToolCall(
|
||||
flavor: AcpToolEventFlavor,
|
||||
context: AcpToolCallContext
|
||||
): AcpResolvedToolCall {
|
||||
val vendorNormalizers = when (flavor) {
|
||||
AcpToolEventFlavor.ZED_ADAPTER -> listOf(zedNormalizer, standardNormalizer)
|
||||
AcpToolEventFlavor.GEMINI_CLI -> listOf(geminiNormalizer, standardNormalizer)
|
||||
AcpToolEventFlavor.STANDARD -> listOf(standardNormalizer)
|
||||
}
|
||||
return vendorNormalizers
|
||||
.firstNotNullOfOrNull { it.normalize(context, support) }
|
||||
?: fallbackNormalizer.normalize(context, support)
|
||||
}
|
||||
|
||||
private fun permissionDetails(
|
||||
locations: List<ToolCallLocation>,
|
||||
rawInput: JsonElement?,
|
||||
fallback: String
|
||||
): String {
|
||||
return buildString {
|
||||
(toolCall["locations"] as? JsonArray)
|
||||
?.mapNotNull { (it as? JsonObject)?.string("path") }
|
||||
?.takeIf { it.isNotEmpty() }
|
||||
locations.map(ToolCallLocation::path)
|
||||
.takeIf { it.isNotEmpty() }
|
||||
?.let { paths ->
|
||||
appendLine("Locations:")
|
||||
paths.forEach { path -> appendLine(path) }
|
||||
paths.forEach(::appendLine)
|
||||
}
|
||||
|
||||
toolCall["rawInput"]?.let { rawInput ->
|
||||
rawInput?.let { input ->
|
||||
if (isNotBlank()) {
|
||||
appendLine()
|
||||
}
|
||||
appendLine("Input:")
|
||||
append(rawInput.toString())
|
||||
append(input.toString())
|
||||
}
|
||||
}.ifBlank { toolCall.toString() }
|
||||
}
|
||||
|
||||
private fun normalizeToolName(
|
||||
rawTitle: String,
|
||||
kind: String?,
|
||||
rawInput: JsonElement?
|
||||
): String {
|
||||
val normalizedKind = kind?.lowercase().orEmpty()
|
||||
val rawInputObject = rawInput.asJsonObjectOrNull(json)
|
||||
val actionType = (rawInputObject?.get("action") as? JsonObject)?.string("type")?.lowercase()
|
||||
val titleLower = rawTitle.lowercase()
|
||||
|
||||
return when {
|
||||
looksLikeMcpToolName(rawTitle) -> "MCP"
|
||||
rawInputObject?.get("command") != null || rawInputObject?.get("cmd") != null -> "Bash"
|
||||
normalizedKind == "execute" || normalizedKind == "terminal" || normalizedKind == "bash" -> "Bash"
|
||||
actionType == "search" -> "WebSearch"
|
||||
actionType == "open_page" || actionType == "fetch" || actionType == "open" -> "WebFetch"
|
||||
normalizedKind == "edit" -> "Edit"
|
||||
normalizedKind == "read" -> "Read"
|
||||
normalizedKind == "search" -> "IntelliJSearch"
|
||||
normalizedKind == "fetch" && (
|
||||
titleLower == "searching the web" || titleLower.startsWith("searching for:")
|
||||
) -> "WebSearch"
|
||||
normalizedKind == "fetch" || titleLower.startsWith("opening:") -> "WebFetch"
|
||||
else -> rawTitle.ifBlank { kind ?: "Tool" }
|
||||
}
|
||||
}
|
||||
|
||||
private fun resolveToolName(initialToolName: String, args: Any?): String {
|
||||
return when (args) {
|
||||
is McpTool.Args -> "MCP"
|
||||
is WriteTool.Args -> "Write"
|
||||
is EditTool.Args -> "Edit"
|
||||
is ReadTool.Args -> "Read"
|
||||
is IntelliJSearchTool.Args -> "IntelliJSearch"
|
||||
is BashTool.Args -> "Bash"
|
||||
is WebSearchTool.Args -> "WebSearch"
|
||||
is WebFetchTool.Args -> "WebFetch"
|
||||
else -> initialToolName
|
||||
}
|
||||
}
|
||||
|
||||
private fun decodeToolArgs(
|
||||
toolName: String,
|
||||
rawInput: JsonElement?,
|
||||
metadata: JsonObject? = null
|
||||
): Any? {
|
||||
val payload = rawInput.toPayloadString()
|
||||
ToolSpecs.decodeArgsOrNull(json, toolName, payload)?.let { return it }
|
||||
|
||||
val obj = rawInput.asJsonObjectOrNull(json) ?: JsonObject(emptyMap())
|
||||
return when (toolName) {
|
||||
"Edit" -> decodeEditOrWriteArgs(obj, metadata) ?: payload.ifBlank { null }
|
||||
"Write" -> decodeWriteArgs(obj, metadata) ?: payload.ifBlank { null }
|
||||
"Read" -> decodeReadArgs(obj, metadata) ?: payload.ifBlank { null }
|
||||
"IntelliJSearch" -> decodeSearchArgs(obj) ?: payload.ifBlank { null }
|
||||
"Bash" -> decodeBashArgs(obj) ?: payload.ifBlank { null }
|
||||
"WebSearch" -> decodeWebSearchArgs(obj, metadata) ?: payload.ifBlank { null }
|
||||
"WebFetch" -> decodeWebFetchArgs(obj, rawInput, metadata) ?: payload.ifBlank { null }
|
||||
else -> payload.ifBlank { null }
|
||||
}
|
||||
}
|
||||
|
||||
private fun decodeEditOrWriteArgs(obj: JsonObject, metadata: JsonObject?): Any? {
|
||||
decodeDiffContent(metadata)?.let { diff ->
|
||||
return if (diff.oldText == null) {
|
||||
WriteTool.Args(
|
||||
filePath = diff.path,
|
||||
content = diff.newText
|
||||
)
|
||||
} else {
|
||||
EditTool.Args(
|
||||
filePath = diff.path,
|
||||
oldString = diff.oldText,
|
||||
newString = diff.newText,
|
||||
shortDescription = metadata?.string("title") ?: "ACP edit",
|
||||
replaceAll = false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
decodeWriteArgs(obj, metadata)?.let { return it }
|
||||
return decodeEditArgs(obj, metadata)
|
||||
}
|
||||
|
||||
private fun decodeEditArgs(obj: JsonObject, metadata: JsonObject? = null): EditTool.Args? {
|
||||
val filePath = obj.string("file_path", "filePath", "path") ?: return null
|
||||
val oldString = obj.string("old_string", "oldString", "old_text", "oldText") ?: return null
|
||||
val newString = obj.string("new_string", "newString", "new_text", "newText") ?: return null
|
||||
val shortDescription = obj.string("short_description", "shortDescription", "description")
|
||||
?: metadata?.string("title")
|
||||
?: "ACP edit"
|
||||
val replaceAll = obj.boolean("replace_all", "replaceAll") ?: false
|
||||
return EditTool.Args(filePath, oldString, newString, shortDescription, replaceAll)
|
||||
}
|
||||
|
||||
private fun decodeWriteArgs(
|
||||
obj: JsonObject,
|
||||
metadata: JsonObject? = null
|
||||
): WriteTool.Args? {
|
||||
val filePath = obj.string("file_path", "filePath", "path")
|
||||
?: metadata?.firstLocationPath()
|
||||
?: metadata?.titlePath()
|
||||
?: obj.firstChangePath()
|
||||
?: return null
|
||||
val content = obj.string("content", "text")
|
||||
?: obj.firstChangeContent()
|
||||
?: decodeDiffContent(metadata)?.takeIf { it.oldText == null }?.newText
|
||||
?: return null
|
||||
return WriteTool.Args(filePath, content)
|
||||
}
|
||||
|
||||
private fun decodeReadArgs(obj: JsonObject, metadata: JsonObject? = null): ReadTool.Args? {
|
||||
val filePath = obj.string("file_path", "filePath", "path")
|
||||
?: metadata?.firstLocationPath()
|
||||
?: metadata?.titlePath()
|
||||
?: return null
|
||||
return ReadTool.Args(
|
||||
filePath = filePath,
|
||||
offset = obj.int("offset", "line") ?: metadata?.int("line"),
|
||||
limit = obj.int("limit", "maxLinesCount")
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeSearchArgs(obj: JsonObject): IntelliJSearchTool.Args? {
|
||||
val pattern = obj.string("pattern", "searchText", "query", "nameKeyword") ?: return null
|
||||
return IntelliJSearchTool.Args(
|
||||
pattern = pattern,
|
||||
scope = obj.string("scope"),
|
||||
path = obj.string("path", "directoryToSearch"),
|
||||
fileType = obj.string("fileType", "fileMask"),
|
||||
context = null,
|
||||
caseSensitive = obj.boolean("caseSensitive"),
|
||||
regex = obj.boolean("regex"),
|
||||
wholeWords = null,
|
||||
outputMode = null,
|
||||
limit = obj.int("limit", "maxUsageCount", "fileCountLimit")
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeBashArgs(obj: JsonObject): BashTool.Args? {
|
||||
val command = obj.commandString() ?: return null
|
||||
return BashTool.Args(
|
||||
command = command,
|
||||
workingDirectory = obj.string("workingDirectory", "workdir", "cwd"),
|
||||
timeout = obj.int("timeout") ?: 60_000,
|
||||
description = obj.string("description", "title"),
|
||||
runInBackground = obj.boolean("run_in_background", "runInBackground")
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeWebSearchArgs(
|
||||
obj: JsonObject,
|
||||
metadata: JsonObject? = null
|
||||
): WebSearchTool.Args? {
|
||||
val action = obj["action"] as? JsonObject
|
||||
val query = obj.string("query", "q")
|
||||
?: action?.string("query")
|
||||
?: metadata?.string("query", "q")
|
||||
?: return null
|
||||
return WebSearchTool.Args(query = query)
|
||||
}
|
||||
|
||||
private fun decodeWebFetchArgs(
|
||||
obj: JsonObject,
|
||||
rawInput: JsonElement?,
|
||||
metadata: JsonObject? = null
|
||||
): WebFetchTool.Args? {
|
||||
val action = obj["action"] as? JsonObject
|
||||
val payload = rawInput.toPayloadString()
|
||||
val url = obj.string("url", "uri", "href", "link")
|
||||
?: action?.string("url", "uri", "href", "link")
|
||||
?: metadata?.string("url", "uri")
|
||||
?: extractFirstUrl(payload)
|
||||
?: metadata?.string("title")?.let(::extractFirstUrl)
|
||||
?: return null
|
||||
|
||||
return WebFetchTool.Args(
|
||||
url = url,
|
||||
selector = obj.string("selector", "css_selector", "cssSelector")
|
||||
?: action?.string("selector", "css_selector", "cssSelector"),
|
||||
timeoutMs = obj.int("timeout_ms", "timeoutMs", "timeout")
|
||||
?: action?.int("timeout_ms", "timeoutMs", "timeout")
|
||||
?: 10_000,
|
||||
offset = obj.int("offset", "start_line", "startLine")
|
||||
?: action?.int("offset", "start_line", "startLine"),
|
||||
limit = obj.int("limit", "max_lines", "maxLines", "count")
|
||||
?: action?.int("limit", "max_lines", "maxLines", "count")
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeMcpArgs(rawTitle: String, rawInput: JsonElement?): McpTool.Args {
|
||||
val obj = rawInput.asJsonObjectOrNull(json) ?: JsonObject(emptyMap())
|
||||
val callName = rawTitle.ifBlank { obj.string("tool_name", "toolName") ?: "unknown" }
|
||||
val slashIndex = callName.indexOf('/')
|
||||
val serverName = if (slashIndex > 0) callName.substring(0, slashIndex) else obj.string(
|
||||
"server_name",
|
||||
"serverName"
|
||||
)
|
||||
val toolName = if (slashIndex > 0 && slashIndex < callName.length - 1) {
|
||||
callName.substring(slashIndex + 1)
|
||||
} else {
|
||||
obj.string("tool_name", "toolName") ?: callName.ifBlank { "unknown" }
|
||||
}
|
||||
val arguments = (obj["arguments"] as? JsonObject)?.toMap() ?: obj.toMap()
|
||||
return McpTool.Args(
|
||||
toolName = toolName,
|
||||
serverId = obj.string("server_id", "serverId"),
|
||||
serverName = serverName,
|
||||
arguments = arguments
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeDiffContent(metadata: JsonObject?): DiffContent? {
|
||||
val diff = (metadata?.get("content") as? JsonArray)
|
||||
?.firstOrNull { entry ->
|
||||
(entry as? JsonObject)?.string("type") == "diff"
|
||||
} as? JsonObject
|
||||
?: return null
|
||||
val path = diff.string("path") ?: return null
|
||||
val newText = diff.string("newText", "new_text") ?: return null
|
||||
val oldText = diff.string("oldText", "old_text")
|
||||
return DiffContent(path, oldText, newText)
|
||||
}
|
||||
|
||||
private fun looksLikeMcpToolName(rawTitle: String): Boolean {
|
||||
val candidate = rawTitle.trim()
|
||||
if (candidate.isBlank() || ' ' in candidate || candidate.startsWith("/")) {
|
||||
return false
|
||||
}
|
||||
val slashIndex = candidate.indexOf('/')
|
||||
return slashIndex > 0 && slashIndex < candidate.length - 1
|
||||
}
|
||||
|
||||
private fun extractFirstUrl(text: String): String? {
|
||||
return URL_REGEX.find(text)?.value
|
||||
}
|
||||
|
||||
private companion object {
|
||||
val URL_REGEX = Regex("""https?://[^\s"'<>]+""")
|
||||
}.ifBlank { fallback }
|
||||
}
|
||||
}
|
||||
|
||||
private val ToolKind.wireValue: String
|
||||
get() = when (this) {
|
||||
ToolKind.READ -> "read"
|
||||
ToolKind.EDIT -> "edit"
|
||||
ToolKind.EXECUTE -> "execute"
|
||||
ToolKind.SEARCH -> "search"
|
||||
ToolKind.FETCH -> "fetch"
|
||||
else -> name.lowercase()
|
||||
}
|
||||
|
||||
private fun ToolCallStatus.toAcpToolCallStatus(): AcpToolCallStatus {
|
||||
return when (this) {
|
||||
ToolCallStatus.PENDING,
|
||||
ToolCallStatus.IN_PROGRESS -> AcpToolCallStatus.IN_PROGRESS
|
||||
|
||||
ToolCallStatus.COMPLETED -> AcpToolCallStatus.COMPLETED
|
||||
ToolCallStatus.FAILED -> AcpToolCallStatus.FAILED
|
||||
}
|
||||
}
|
||||
|
||||
private fun String.toAcpToolKind(): ToolKind? {
|
||||
return when (this.lowercase()) {
|
||||
"read" -> ToolKind.READ
|
||||
"edit" -> ToolKind.EDIT
|
||||
"delete" -> ToolKind.DELETE
|
||||
"move" -> ToolKind.MOVE
|
||||
"search" -> ToolKind.SEARCH
|
||||
"execute", "terminal", "bash" -> ToolKind.EXECUTE
|
||||
"think" -> ToolKind.THINK
|
||||
"fetch" -> ToolKind.FETCH
|
||||
"switch_mode" -> ToolKind.SWITCH_MODE
|
||||
"other" -> ToolKind.OTHER
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
|
|
|||
606
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecodingSupport.kt
vendored
Normal file
606
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallDecodingSupport.kt
vendored
Normal file
|
|
@ -0,0 +1,606 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import com.agentclientprotocol.model.ToolCallContent
|
||||
import com.agentclientprotocol.model.ToolCallLocation
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpBashPreviewArgs
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpSearchPreviewArgs
|
||||
import ee.carlrobert.codegpt.agent.tools.BashTool
|
||||
import ee.carlrobert.codegpt.agent.tools.EditTool
|
||||
import ee.carlrobert.codegpt.agent.tools.IntelliJSearchTool
|
||||
import ee.carlrobert.codegpt.agent.tools.McpTool
|
||||
import ee.carlrobert.codegpt.agent.tools.ReadTool
|
||||
import ee.carlrobert.codegpt.agent.tools.WebFetchTool
|
||||
import ee.carlrobert.codegpt.agent.tools.WebSearchTool
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
|
||||
internal data class AcpToolCallContext(
|
||||
val toolCallId: String,
|
||||
val rawTitle: String,
|
||||
val rawKind: String?,
|
||||
val rawInput: JsonElement?,
|
||||
val locations: List<ToolCallLocation>,
|
||||
val content: List<ToolCallContent>,
|
||||
val defaultTitle: String = "Tool"
|
||||
)
|
||||
|
||||
internal data class AcpResolvedToolCall(
|
||||
val rawTitle: String,
|
||||
val toolName: String,
|
||||
val typedArgs: AcpToolCallArgs? = null
|
||||
)
|
||||
|
||||
internal interface AcpToolCallNormalizer {
|
||||
fun normalize(context: AcpToolCallContext, support: AcpToolCallDecodingSupport): AcpResolvedToolCall?
|
||||
}
|
||||
|
||||
internal data class DiffContent(
|
||||
val path: String,
|
||||
val oldText: String?,
|
||||
val newText: String
|
||||
)
|
||||
|
||||
@Serializable
|
||||
internal data class AcpMcpToolCallPayload(
|
||||
val server: String,
|
||||
val tool: String,
|
||||
val arguments: Map<String, JsonElement> = emptyMap(),
|
||||
@SerialName("server_id") val serverIdSnake: String? = null,
|
||||
val serverId: String? = null,
|
||||
@SerialName("server_name") val serverNameSnake: String? = null,
|
||||
val serverName: String? = null
|
||||
) {
|
||||
val resolvedServerId: String?
|
||||
get() = serverId ?: serverIdSnake
|
||||
|
||||
val resolvedServerName: String
|
||||
get() = serverName ?: serverNameSnake ?: server
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpQueryPayload(
|
||||
val query: String? = null,
|
||||
val q: String? = null
|
||||
) {
|
||||
val resolvedQuery: String?
|
||||
get() = query ?: q
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpActionEnvelope(
|
||||
val action: AcpActionPayload
|
||||
)
|
||||
|
||||
@Serializable
|
||||
internal data class AcpActionPayload(
|
||||
val type: String,
|
||||
val query: String? = null,
|
||||
val q: String? = null,
|
||||
val url: String? = null,
|
||||
val uri: String? = null,
|
||||
val href: String? = null,
|
||||
val link: String? = null,
|
||||
val selector: String? = null,
|
||||
@SerialName("css_selector") val cssSelectorSnake: String? = null,
|
||||
val cssSelector: String? = null,
|
||||
@SerialName("timeout_ms") val timeoutMsSnake: Int? = null,
|
||||
val timeoutMs: Int? = null,
|
||||
val timeout: Int? = null,
|
||||
val offset: Int? = null,
|
||||
@SerialName("start_line") val startLineSnake: Int? = null,
|
||||
val startLine: Int? = null,
|
||||
val limit: Int? = null,
|
||||
@SerialName("max_lines") val maxLinesSnake: Int? = null,
|
||||
val maxLines: Int? = null,
|
||||
val count: Int? = null
|
||||
) {
|
||||
val resolvedQuery: String?
|
||||
get() = query ?: q
|
||||
|
||||
val resolvedUrl: String?
|
||||
get() = url ?: uri ?: href ?: link
|
||||
|
||||
val resolvedSelector: String?
|
||||
get() = selector ?: cssSelector ?: cssSelectorSnake
|
||||
|
||||
val resolvedTimeoutMs: Int
|
||||
get() = timeoutMs ?: timeoutMsSnake ?: timeout ?: 10_000
|
||||
|
||||
val resolvedOffset: Int?
|
||||
get() = offset ?: startLine ?: startLineSnake
|
||||
|
||||
val resolvedLimit: Int?
|
||||
get() = limit ?: maxLines ?: maxLinesSnake ?: count
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpReadPayload(
|
||||
@SerialName("file_path") val filePathSnake: String? = null,
|
||||
val filePath: String? = null,
|
||||
val path: String? = null,
|
||||
val offset: Int? = null,
|
||||
val line: Int? = null,
|
||||
val limit: Int? = null,
|
||||
@SerialName("maxLinesCount") val maxLinesCount: Int? = null
|
||||
) {
|
||||
val resolvedPath: String?
|
||||
get() = filePath ?: filePathSnake ?: path
|
||||
|
||||
val resolvedOffset: Int?
|
||||
get() = offset ?: line
|
||||
|
||||
val resolvedLimit: Int?
|
||||
get() = limit ?: maxLinesCount
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpWritePayload(
|
||||
@SerialName("file_path") val filePathSnake: String? = null,
|
||||
val filePath: String? = null,
|
||||
val path: String? = null,
|
||||
val content: String? = null,
|
||||
val text: String? = null
|
||||
) {
|
||||
val resolvedPath: String?
|
||||
get() = filePath ?: filePathSnake ?: path
|
||||
|
||||
val resolvedContent: String?
|
||||
get() = content ?: text
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpEditPayload(
|
||||
@SerialName("file_path") val filePathSnake: String? = null,
|
||||
val filePath: String? = null,
|
||||
val path: String? = null,
|
||||
@SerialName("old_string") val oldStringSnake: String? = null,
|
||||
val oldString: String? = null,
|
||||
@SerialName("new_string") val newStringSnake: String? = null,
|
||||
val newString: String? = null,
|
||||
@SerialName("short_description") val shortDescriptionSnake: String? = null,
|
||||
val shortDescription: String? = null,
|
||||
@SerialName("replace_all") val replaceAllSnake: Boolean? = null,
|
||||
val replaceAll: Boolean? = null
|
||||
) {
|
||||
val resolvedPath: String?
|
||||
get() = filePath ?: filePathSnake ?: path
|
||||
|
||||
val resolvedOldString: String?
|
||||
get() = oldString ?: oldStringSnake
|
||||
|
||||
val resolvedNewString: String?
|
||||
get() = newString ?: newStringSnake
|
||||
|
||||
val resolvedDescription: String
|
||||
get() = shortDescription ?: shortDescriptionSnake ?: "ACP edit"
|
||||
|
||||
val resolvedReplaceAll: Boolean
|
||||
get() = replaceAll ?: replaceAllSnake ?: false
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpChangeSetPayload(
|
||||
val changes: Map<String, AcpChangePayload> = emptyMap()
|
||||
) {
|
||||
fun firstWriteArgs(): AcpToolCallArgs? {
|
||||
val entry = changes.entries.firstOrNull() ?: return null
|
||||
val content = entry.value.content ?: return null
|
||||
return AcpToolCallArgs.Write(WriteTool.Args(entry.key, content))
|
||||
}
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpChangePayload(
|
||||
val content: String? = null
|
||||
)
|
||||
|
||||
@Serializable
|
||||
internal data class AcpSearchPayload(
|
||||
val pattern: String? = null,
|
||||
@SerialName("searchText") val searchText: String? = null,
|
||||
val query: String? = null,
|
||||
@SerialName("nameKeyword") val nameKeyword: String? = null,
|
||||
val scope: String? = null,
|
||||
val path: String? = null,
|
||||
@SerialName("directoryToSearch") val directoryToSearch: String? = null,
|
||||
@SerialName("fileType") val fileType: String? = null,
|
||||
@SerialName("fileMask") val fileMask: String? = null,
|
||||
val caseSensitive: Boolean? = null,
|
||||
val regex: Boolean? = null,
|
||||
val limit: Int? = null,
|
||||
@SerialName("maxUsageCount") val maxUsageCount: Int? = null,
|
||||
@SerialName("fileCountLimit") val fileCountLimit: Int? = null
|
||||
) {
|
||||
val resolvedPattern: String?
|
||||
get() = pattern ?: searchText ?: query ?: nameKeyword
|
||||
|
||||
val resolvedPath: String?
|
||||
get() = path ?: directoryToSearch
|
||||
|
||||
val resolvedFileType: String?
|
||||
get() = fileType ?: fileMask
|
||||
|
||||
val resolvedLimit: Int?
|
||||
get() = limit ?: maxUsageCount ?: fileCountLimit
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpCommandStringPayload(
|
||||
val command: String,
|
||||
val description: String? = null,
|
||||
val workingDirectory: String? = null,
|
||||
val workdir: String? = null,
|
||||
val cwd: String? = null,
|
||||
val timeout: Int? = null,
|
||||
@SerialName("run_in_background") val runInBackgroundSnake: Boolean? = null,
|
||||
val runInBackground: Boolean? = null
|
||||
) {
|
||||
val resolvedWorkingDirectory: String?
|
||||
get() = workingDirectory ?: workdir ?: cwd
|
||||
|
||||
val resolvedRunInBackground: Boolean
|
||||
get() = runInBackground ?: runInBackgroundSnake ?: false
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpCommandArrayPayload(
|
||||
val command: List<String>,
|
||||
val description: String? = null,
|
||||
val workingDirectory: String? = null,
|
||||
val workdir: String? = null,
|
||||
val cwd: String? = null,
|
||||
val timeout: Int? = null,
|
||||
@SerialName("run_in_background") val runInBackgroundSnake: Boolean? = null,
|
||||
val runInBackground: Boolean? = null
|
||||
) {
|
||||
val resolvedCommand: String
|
||||
get() = command.joinToString(" ")
|
||||
|
||||
val resolvedWorkingDirectory: String?
|
||||
get() = workingDirectory ?: workdir ?: cwd
|
||||
|
||||
val resolvedRunInBackground: Boolean
|
||||
get() = runInBackground ?: runInBackgroundSnake ?: false
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpParsedCommandEnvelope(
|
||||
@SerialName("parsed_cmd") val parsedCommands: List<AcpParsedCommandPayload> = emptyList()
|
||||
) {
|
||||
val firstParsedCommand: AcpParsedCommandPayload?
|
||||
get() = parsedCommands.firstOrNull()
|
||||
}
|
||||
|
||||
@Serializable
|
||||
internal data class AcpParsedCommandPayload(
|
||||
val type: String,
|
||||
val cmd: String? = null,
|
||||
val name: String? = null,
|
||||
val path: String? = null,
|
||||
val query: String? = null,
|
||||
val offset: Int? = null,
|
||||
val line: Int? = null
|
||||
) {
|
||||
val resolvedOffset: Int?
|
||||
get() = offset ?: line
|
||||
}
|
||||
|
||||
internal class AcpToolCallDecodingSupport(
|
||||
private val json: Json
|
||||
) {
|
||||
inline fun <reified T> decode(rawInput: JsonElement?): T? = rawInput.decodeOrNull(json)
|
||||
|
||||
fun decodeMcpArgs(payload: AcpMcpToolCallPayload): McpTool.Args {
|
||||
return McpTool.Args(
|
||||
toolName = payload.tool,
|
||||
serverId = payload.resolvedServerId,
|
||||
serverName = payload.resolvedServerName,
|
||||
arguments = payload.arguments
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeMcpPayload(rawInput: JsonElement?): AcpMcpToolCallPayload? {
|
||||
return decode(rawInput)
|
||||
}
|
||||
|
||||
fun decodeParsedCommand(rawInput: JsonElement?): AcpParsedCommandPayload? {
|
||||
return decode<AcpParsedCommandEnvelope>(rawInput)?.firstParsedCommand
|
||||
}
|
||||
|
||||
fun decodeDiffContent(content: List<ToolCallContent>): DiffContent? {
|
||||
val diff = content.filterIsInstance<ToolCallContent.Diff>().firstOrNull() ?: return null
|
||||
return DiffContent(diff.path, diff.oldText, diff.newText)
|
||||
}
|
||||
|
||||
fun decodeEditOrWriteArgs(
|
||||
rawInput: JsonElement?,
|
||||
locations: List<ToolCallLocation>,
|
||||
content: List<ToolCallContent>
|
||||
): AcpToolCallArgs? {
|
||||
decodeDiffContent(content)?.let { diff ->
|
||||
return if (diff.oldText == null) {
|
||||
AcpToolCallArgs.Write(WriteTool.Args(diff.path, diff.newText))
|
||||
} else {
|
||||
AcpToolCallArgs.Edit(
|
||||
EditTool.Args(
|
||||
filePath = diff.path,
|
||||
oldString = diff.oldText,
|
||||
newString = diff.newText,
|
||||
shortDescription = "ACP edit",
|
||||
replaceAll = false
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
decode<AcpChangeSetPayload>(rawInput)?.firstWriteArgs()?.let { return it }
|
||||
|
||||
decode<AcpWritePayload>(rawInput)?.let { payload ->
|
||||
val path = payload.resolvedPath ?: locations.firstOrNull()?.path ?: return@let
|
||||
val text = payload.resolvedContent ?: return@let
|
||||
return AcpToolCallArgs.Write(WriteTool.Args(path, text))
|
||||
}
|
||||
|
||||
decode<AcpEditPayload>(rawInput)?.let { payload ->
|
||||
val path = payload.resolvedPath ?: return@let
|
||||
val oldString = payload.resolvedOldString ?: return@let
|
||||
val newString = payload.resolvedNewString ?: return@let
|
||||
return AcpToolCallArgs.Edit(
|
||||
EditTool.Args(
|
||||
filePath = path,
|
||||
oldString = oldString,
|
||||
newString = newString,
|
||||
shortDescription = payload.resolvedDescription,
|
||||
replaceAll = payload.resolvedReplaceAll
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
fun decodeReadArgs(
|
||||
rawInput: JsonElement?,
|
||||
locations: List<ToolCallLocation>
|
||||
): AcpToolCallArgs? {
|
||||
decode<AcpReadPayload>(rawInput)?.let { payload ->
|
||||
val path = payload.resolvedPath ?: locations.firstOrNull()?.path ?: return@let
|
||||
return AcpToolCallArgs.Read(
|
||||
ReadTool.Args(
|
||||
filePath = path,
|
||||
offset = payload.resolvedOffset ?: locations.firstOrNull()?.line?.toInt(),
|
||||
limit = payload.resolvedLimit
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return locations.firstOrNull()?.path?.let { path ->
|
||||
AcpToolCallArgs.Read(
|
||||
ReadTool.Args(
|
||||
filePath = path,
|
||||
offset = locations.firstOrNull()?.line?.toInt(),
|
||||
limit = null
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun decodeReadArgsFromParsedCommand(
|
||||
rawInput: JsonElement?,
|
||||
locations: List<ToolCallLocation>
|
||||
): AcpToolCallArgs? {
|
||||
val parsed = decode<AcpParsedCommandEnvelope>(rawInput)?.firstParsedCommand ?: return null
|
||||
if (!parsed.type.equals("read", ignoreCase = true)) {
|
||||
return null
|
||||
}
|
||||
val path = parsed.path ?: locations.firstOrNull()?.path ?: return null
|
||||
return AcpToolCallArgs.Read(
|
||||
ReadTool.Args(
|
||||
filePath = path,
|
||||
offset = parsed.resolvedOffset ?: locations.firstOrNull()?.line?.toInt(),
|
||||
limit = null
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeSearchArgs(rawInput: JsonElement?): AcpToolCallArgs? {
|
||||
val payload = decode<AcpSearchPayload>(rawInput) ?: return null
|
||||
val pattern = payload.resolvedPattern ?: return null
|
||||
return AcpToolCallArgs.IntelliJSearch(
|
||||
IntelliJSearchTool.Args(
|
||||
pattern = pattern,
|
||||
scope = payload.scope,
|
||||
path = payload.resolvedPath,
|
||||
fileType = payload.resolvedFileType,
|
||||
context = null,
|
||||
caseSensitive = payload.caseSensitive,
|
||||
regex = payload.regex,
|
||||
wholeWords = null,
|
||||
outputMode = null,
|
||||
limit = payload.resolvedLimit
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeSearchArgsOrPreview(
|
||||
rawTitle: String,
|
||||
rawInput: JsonElement?,
|
||||
locations: List<ToolCallLocation>
|
||||
): AcpToolCallArgs {
|
||||
decodeSearchArgsFromParsedCommand(rawInput)?.let { return it }
|
||||
decodeSearchArgs(rawInput)?.let { return it }
|
||||
|
||||
val normalizedTitle = rawTitle.trim().ifBlank { "Search" }
|
||||
val path = locations.firstOrNull()?.path
|
||||
val pattern = extractGeminiSearchPattern(normalizedTitle)
|
||||
|
||||
return AcpToolCallArgs.SearchPreview(
|
||||
AcpSearchPreviewArgs(
|
||||
title = normalizedTitle,
|
||||
path = path,
|
||||
pattern = pattern
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeReadArgsFromParsedCommandOrFallback(
|
||||
rawInput: JsonElement?,
|
||||
locations: List<ToolCallLocation>
|
||||
): AcpToolCallArgs? {
|
||||
return decodeReadArgsFromParsedCommand(rawInput, locations)
|
||||
?: decodeReadArgs(rawInput, locations)
|
||||
}
|
||||
|
||||
fun decodeSearchArgsFromParsedCommand(rawInput: JsonElement?): AcpToolCallArgs? {
|
||||
val parsed = decode<AcpParsedCommandEnvelope>(rawInput)?.firstParsedCommand ?: return null
|
||||
if (!parsed.type.equals("search", ignoreCase = true)) {
|
||||
return null
|
||||
}
|
||||
val pattern = parsed.query ?: return null
|
||||
return AcpToolCallArgs.IntelliJSearch(
|
||||
IntelliJSearchTool.Args(
|
||||
pattern = pattern,
|
||||
scope = null,
|
||||
path = parsed.path,
|
||||
fileType = null,
|
||||
context = null,
|
||||
caseSensitive = null,
|
||||
regex = null,
|
||||
wholeWords = null,
|
||||
outputMode = null,
|
||||
limit = null
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeBashArgs(rawInput: JsonElement?): AcpToolCallArgs? {
|
||||
decode<AcpCommandStringPayload>(rawInput)?.let { payload ->
|
||||
return AcpToolCallArgs.Bash(
|
||||
BashTool.Args(
|
||||
command = payload.command,
|
||||
workingDirectory = payload.resolvedWorkingDirectory,
|
||||
timeout = payload.timeout ?: 60_000,
|
||||
description = payload.description,
|
||||
runInBackground = payload.resolvedRunInBackground
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
decode<AcpCommandArrayPayload>(rawInput)?.let { payload ->
|
||||
return AcpToolCallArgs.Bash(
|
||||
BashTool.Args(
|
||||
command = payload.resolvedCommand,
|
||||
workingDirectory = payload.resolvedWorkingDirectory,
|
||||
timeout = payload.timeout ?: 60_000,
|
||||
description = payload.description,
|
||||
runInBackground = payload.resolvedRunInBackground
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
fun decodeBashArgsOrPreview(
|
||||
rawTitle: String,
|
||||
rawInput: JsonElement?,
|
||||
defaultTitle: String = "Run shell command"
|
||||
): AcpToolCallArgs {
|
||||
decodeBashArgs(rawInput)?.let { return it }
|
||||
|
||||
val normalizedTitle = rawTitle.trim().ifBlank { defaultTitle }
|
||||
return AcpToolCallArgs.BashPreview(
|
||||
AcpBashPreviewArgs(
|
||||
title = normalizedTitle,
|
||||
command = extractGeminiCommand(normalizedTitle)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
fun decodeWebSearchArgs(rawInput: JsonElement?): AcpToolCallArgs? {
|
||||
decode<AcpActionEnvelope>(rawInput)?.action?.let { action ->
|
||||
if (action.type.equals("search", ignoreCase = true)) {
|
||||
action.resolvedQuery?.let { return AcpToolCallArgs.WebSearch(WebSearchTool.Args(it)) }
|
||||
}
|
||||
}
|
||||
|
||||
decode<AcpQueryPayload>(rawInput)?.resolvedQuery?.let { query ->
|
||||
return AcpToolCallArgs.WebSearch(WebSearchTool.Args(query))
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
fun decodeWebFetchArgs(rawInput: JsonElement?): AcpToolCallArgs? {
|
||||
decode<AcpActionEnvelope>(rawInput)?.action?.let { action ->
|
||||
if (action.type.lowercase() in setOf("open_page", "fetch", "open")) {
|
||||
val url = action.resolvedUrl ?: return@let
|
||||
return AcpToolCallArgs.WebFetch(
|
||||
WebFetchTool.Args(
|
||||
url = url,
|
||||
selector = action.resolvedSelector,
|
||||
timeoutMs = action.resolvedTimeoutMs,
|
||||
offset = action.resolvedOffset,
|
||||
limit = action.resolvedLimit
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
fun resolveWebSearchCall(rawTitle: String, rawInput: JsonElement?): AcpResolvedToolCall? {
|
||||
return decodeWebSearchArgs(rawInput)?.let { args ->
|
||||
fixedCall(rawTitle, "WebSearch", "WebSearch", args)
|
||||
}
|
||||
}
|
||||
|
||||
fun resolveWebFetchCall(rawTitle: String, rawInput: JsonElement?): AcpResolvedToolCall? {
|
||||
return decodeWebFetchArgs(rawInput)?.let { args ->
|
||||
fixedCall(rawTitle, "WebFetch", "WebFetch", args)
|
||||
}
|
||||
}
|
||||
|
||||
fun resolveWebCall(rawTitle: String, rawInput: JsonElement?): AcpResolvedToolCall? {
|
||||
return resolveWebSearchCall(rawTitle, rawInput)
|
||||
?: resolveWebFetchCall(rawTitle, rawInput)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun fixedCall(
|
||||
rawTitle: String,
|
||||
defaultTitle: String,
|
||||
toolName: String,
|
||||
typedArgs: AcpToolCallArgs? = null
|
||||
): AcpResolvedToolCall {
|
||||
return AcpResolvedToolCall(
|
||||
rawTitle = rawTitle.ifBlank { defaultTitle },
|
||||
toolName = toolName,
|
||||
typedArgs = typedArgs
|
||||
)
|
||||
}
|
||||
|
||||
internal fun editOrWriteCall(rawTitle: String, args: AcpToolCallArgs?): AcpResolvedToolCall? {
|
||||
return when (args) {
|
||||
is AcpToolCallArgs.Write -> fixedCall(rawTitle, "Write", "Write", args)
|
||||
is AcpToolCallArgs.Edit -> fixedCall(rawTitle, "Edit", "Edit", args)
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
private fun extractGeminiSearchPattern(title: String): String? {
|
||||
val match = Regex("(?i)^(?:search|grep)(?:\\s+for)?\\s*:?\\s+(.+)$").find(title) ?: return null
|
||||
return match.groupValues.getOrNull(1)?.trim()?.takeIf { it.isNotBlank() }
|
||||
}
|
||||
|
||||
private fun extractGeminiCommand(title: String): String? {
|
||||
val match = Regex(
|
||||
"(?i)^(?:run(?:\\s+shell)?\\s+command|execute(?:\\s+shell)?\\s+command)\\s*:?\\s+(.+)$"
|
||||
).find(title) ?: return null
|
||||
return match.groupValues.getOrNull(1)?.trim()?.takeIf { it.isNotBlank() }
|
||||
}
|
||||
273
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallNormalizers.kt
vendored
Normal file
273
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallNormalizers.kt
vendored
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpToolCallArgs
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpSearchPreviewArgs
|
||||
|
||||
internal class ZedAdapterToolCallNormalizer : AcpToolCallNormalizer {
|
||||
override fun normalize(
|
||||
context: AcpToolCallContext,
|
||||
support: AcpToolCallDecodingSupport
|
||||
): AcpResolvedToolCall? {
|
||||
support.decodeMcpPayload(context.rawInput)?.let { payload ->
|
||||
return fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "MCP",
|
||||
toolName = "MCP",
|
||||
typedArgs = AcpToolCallArgs.Mcp(support.decodeMcpArgs(payload))
|
||||
)
|
||||
}
|
||||
|
||||
if (context.rawKind.equals("fetch", ignoreCase = true)) {
|
||||
support.resolveWebSearchCall(context.rawTitle, context.rawInput)?.let { return it }
|
||||
}
|
||||
|
||||
support.decodeParsedCommand(context.rawInput)?.let { parsed ->
|
||||
return when {
|
||||
parsed.type.equals("read", ignoreCase = true) ->
|
||||
fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Read",
|
||||
toolName = "Read",
|
||||
typedArgs = support.decodeReadArgsFromParsedCommandOrFallback(
|
||||
context.rawInput,
|
||||
context.locations
|
||||
)
|
||||
)
|
||||
|
||||
parsed.type.equals("search", ignoreCase = true) ->
|
||||
fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Search",
|
||||
toolName = "IntelliJSearch",
|
||||
typedArgs = support.decodeSearchArgsOrPreview(
|
||||
context.rawTitle,
|
||||
context.rawInput,
|
||||
context.locations
|
||||
)
|
||||
)
|
||||
|
||||
parsed.type.equals("list_files", ignoreCase = true) ->
|
||||
fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Bash",
|
||||
toolName = "Bash",
|
||||
typedArgs = support.decodeBashArgsOrPreview(
|
||||
context.rawTitle,
|
||||
context.rawInput,
|
||||
defaultTitle = "List files"
|
||||
)
|
||||
)
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
support.decodeBashArgs(context.rawInput)?.let { args ->
|
||||
return fixedCall(context.rawTitle, "Bash", "Bash", args)
|
||||
}
|
||||
|
||||
if (context.rawKind.equals("execute", ignoreCase = true)) {
|
||||
return fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Bash",
|
||||
toolName = "Bash",
|
||||
typedArgs = support.decodeBashArgsOrPreview(
|
||||
context.rawTitle,
|
||||
context.rawInput,
|
||||
defaultTitle = "Bash"
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
support.resolveWebCall(context.rawTitle, context.rawInput)?.let { return it }
|
||||
|
||||
return editOrWriteCall(
|
||||
rawTitle = context.rawTitle,
|
||||
args = support.decodeEditOrWriteArgs(context.rawInput, context.locations, context.content)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GeminiCliToolCallNormalizer : AcpToolCallNormalizer {
|
||||
override fun normalize(
|
||||
context: AcpToolCallContext,
|
||||
support: AcpToolCallDecodingSupport
|
||||
): AcpResolvedToolCall? {
|
||||
return when {
|
||||
context.toolCallId.startsWith("google_web_search-") ->
|
||||
fixedCall(
|
||||
context.rawTitle,
|
||||
"Google web search",
|
||||
"WebSearch",
|
||||
support.decodeWebSearchArgs(context.rawInput)
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("grep_search-") ->
|
||||
fixedCall(
|
||||
context.rawTitle,
|
||||
"Search",
|
||||
"IntelliJSearch",
|
||||
support.decodeSearchArgsOrPreview(
|
||||
context.rawTitle.ifBlank { "Search" },
|
||||
context.rawInput,
|
||||
context.locations
|
||||
)
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("read_file-") ->
|
||||
fixedCall(
|
||||
context.rawTitle,
|
||||
"Read file",
|
||||
"Read",
|
||||
support.decodeReadArgs(context.rawInput, context.locations)
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("list_directory-") ->
|
||||
fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "List directory",
|
||||
toolName = "ListDirectory",
|
||||
typedArgs = AcpToolCallArgs.SearchPreview(
|
||||
AcpSearchPreviewArgs(
|
||||
title = "List directory",
|
||||
path = context.rawTitle.trim().ifBlank { null }
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("glob-") ->
|
||||
fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Find files",
|
||||
toolName = "Glob",
|
||||
typedArgs = AcpToolCallArgs.SearchPreview(
|
||||
AcpSearchPreviewArgs(
|
||||
title = "Find files",
|
||||
pattern = context.rawTitle.trim().trim('\'').ifBlank { null }
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("write_file-") ->
|
||||
editOrWriteCall(
|
||||
rawTitle = context.rawTitle,
|
||||
args = support.decodeEditOrWriteArgs(
|
||||
context.rawInput,
|
||||
context.locations,
|
||||
context.content
|
||||
)
|
||||
) ?: fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Write file",
|
||||
toolName = "Write"
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("replace-") ->
|
||||
editOrWriteCall(
|
||||
rawTitle = context.rawTitle,
|
||||
args = support.decodeEditOrWriteArgs(
|
||||
context.rawInput,
|
||||
context.locations,
|
||||
context.content
|
||||
)
|
||||
) ?: fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Replace text",
|
||||
toolName = "Edit"
|
||||
)
|
||||
|
||||
context.toolCallId.startsWith("run_shell_command-") ->
|
||||
fixedCall(
|
||||
context.rawTitle,
|
||||
"Run shell command",
|
||||
"Bash",
|
||||
support.decodeBashArgsOrPreview(
|
||||
context.rawTitle,
|
||||
context.rawInput,
|
||||
defaultTitle = "Run shell command"
|
||||
)
|
||||
)
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class StandardSemanticToolCallNormalizer : AcpToolCallNormalizer {
|
||||
override fun normalize(
|
||||
context: AcpToolCallContext,
|
||||
support: AcpToolCallDecodingSupport
|
||||
): AcpResolvedToolCall? {
|
||||
support.decodeMcpPayload(context.rawInput)?.let { payload ->
|
||||
return fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "MCP",
|
||||
toolName = "MCP",
|
||||
typedArgs = AcpToolCallArgs.Mcp(support.decodeMcpArgs(payload))
|
||||
)
|
||||
}
|
||||
|
||||
support.resolveWebCall(context.rawTitle, context.rawInput)?.let { return it }
|
||||
|
||||
return when (context.rawKind?.lowercase().orEmpty()) {
|
||||
"read" -> fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Read",
|
||||
toolName = "Read",
|
||||
typedArgs = support.decodeReadArgs(context.rawInput, context.locations)
|
||||
)
|
||||
|
||||
"edit" -> editOrWriteCall(
|
||||
rawTitle = context.rawTitle,
|
||||
args = support.decodeEditOrWriteArgs(context.rawInput, context.locations, context.content)
|
||||
)
|
||||
|
||||
"execute", "terminal", "bash" -> fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Bash",
|
||||
toolName = "Bash",
|
||||
typedArgs = support.decodeBashArgsOrPreview(
|
||||
context.rawTitle,
|
||||
context.rawInput,
|
||||
defaultTitle = "Bash"
|
||||
)
|
||||
)
|
||||
|
||||
"search" -> fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "Search",
|
||||
toolName = "IntelliJSearch",
|
||||
typedArgs = support.decodeSearchArgsOrPreview(
|
||||
context.rawTitle,
|
||||
context.rawInput,
|
||||
context.locations
|
||||
)
|
||||
)
|
||||
|
||||
"fetch" -> fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = "WebFetch",
|
||||
toolName = "WebFetch",
|
||||
typedArgs = support.decodeWebFetchArgs(context.rawInput)
|
||||
)
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class FallbackToolCallNormalizer : AcpToolCallNormalizer {
|
||||
override fun normalize(
|
||||
context: AcpToolCallContext,
|
||||
support: AcpToolCallDecodingSupport
|
||||
): AcpResolvedToolCall {
|
||||
val fallbackName = context.rawTitle.ifBlank {
|
||||
context.rawKind?.replaceFirstChar { it.uppercase() } ?: context.defaultTitle
|
||||
}
|
||||
return fixedCall(
|
||||
rawTitle = context.rawTitle,
|
||||
defaultTitle = fallbackName,
|
||||
toolName = fallbackName
|
||||
)
|
||||
}
|
||||
}
|
||||
11
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallStatus.kt
vendored
Normal file
11
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolCallStatus.kt
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
internal enum class AcpToolCallStatus(val wireValue: String) {
|
||||
IN_PROGRESS("in_progress"),
|
||||
COMPLETED("completed"),
|
||||
FAILED("failed"),
|
||||
CANCELLED("cancelled");
|
||||
|
||||
val isTerminal: Boolean
|
||||
get() = this != IN_PROGRESS
|
||||
}
|
||||
7
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolEventFlavor.kt
vendored
Normal file
7
src/main/kotlin/ee/carlrobert/codegpt/agent/external/AcpToolEventFlavor.kt
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
package ee.carlrobert.codegpt.agent.external
|
||||
|
||||
enum class AcpToolEventFlavor {
|
||||
STANDARD,
|
||||
ZED_ADAPTER,
|
||||
GEMINI_CLI
|
||||
}
|
||||
133
src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocol.kt
vendored
Normal file
133
src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocol.kt
vendored
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
package ee.carlrobert.codegpt.agent.external.acpcompat
|
||||
|
||||
import com.agentclientprotocol.model.AcpMethod
|
||||
import com.agentclientprotocol.model.AcpNotification
|
||||
import com.agentclientprotocol.model.AcpRequest
|
||||
import com.agentclientprotocol.model.AcpResponse
|
||||
import com.agentclientprotocol.rpc.ACPJson
|
||||
import com.agentclientprotocol.rpc.MethodName
|
||||
import com.agentclientprotocol.transport.Transport
|
||||
import ee.carlrobert.codegpt.agent.external.runtime.AcpProtocolCore
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.serialization.json.*
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.EmptyCoroutineContext
|
||||
import kotlin.time.Duration
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
class AcpExpectedError(message: String) : Exception(message)
|
||||
|
||||
fun acpFail(message: String): Nothing = throw AcpExpectedError(message)
|
||||
|
||||
class JsonRpcException(
|
||||
val code: Int,
|
||||
message: String,
|
||||
val data: JsonElement? = null
|
||||
) : Exception(message)
|
||||
|
||||
open class ProtocolOptions(
|
||||
val gracefulRequestCancellationTimeout: Duration = 1.seconds,
|
||||
val protocolDebugName: String = AcpProtocol::class.simpleName!!,
|
||||
val outboundPayloadAugmenter: (MethodName, JsonElement?) -> JsonElement? = { _, payload -> payload },
|
||||
val inboundPayloadNormalizer: (MethodName, JsonElement?) -> JsonElement? = { _, payload -> payload },
|
||||
val trace: ((String) -> Unit)? = null
|
||||
)
|
||||
|
||||
class AcpProtocol(
|
||||
parentScope: CoroutineScope,
|
||||
transport: Transport,
|
||||
val options: ProtocolOptions = ProtocolOptions()
|
||||
) {
|
||||
internal val json: Json = ACPJson
|
||||
|
||||
internal val core = AcpProtocolCore(parentScope, transport, json, options)
|
||||
|
||||
fun start() = core.start()
|
||||
|
||||
fun close() = core.close()
|
||||
|
||||
override fun toString(): String = "Protocol(${options.protocolDebugName})"
|
||||
}
|
||||
|
||||
internal suspend inline fun <reified TRequest : AcpRequest, reified TResponse : AcpResponse> AcpProtocol.sendRequest(
|
||||
method: AcpMethod.AcpRequestResponseMethod<TRequest, TResponse>,
|
||||
request: TRequest?
|
||||
): TResponse {
|
||||
val params = options.outboundPayloadAugmenter(
|
||||
method.methodName,
|
||||
request?.let { json.encodeToJsonElement(request) }
|
||||
)
|
||||
val responseJson = core.sendRequestRaw(method.methodName, params)
|
||||
return json.decodeFromJsonElement(
|
||||
options.inboundPayloadNormalizer(method.methodName, responseJson) ?: JsonNull
|
||||
)
|
||||
}
|
||||
|
||||
internal inline fun <reified TNotification : AcpNotification> AcpProtocol.sendNotification(
|
||||
method: AcpMethod.AcpNotificationMethod<TNotification>,
|
||||
notification: TNotification? = null
|
||||
) {
|
||||
val params = notification?.let { json.encodeToJsonElement(notification) }
|
||||
core.sendNotificationRaw(method, params)
|
||||
}
|
||||
|
||||
internal inline fun <reified TRequest : AcpRequest, reified TResponse : AcpResponse> AcpProtocol.setRequestHandler(
|
||||
method: AcpMethod.AcpRequestResponseMethod<TRequest, TResponse>,
|
||||
additionalContext: CoroutineContext = EmptyCoroutineContext,
|
||||
noinline handler: suspend (TRequest) -> TResponse
|
||||
) {
|
||||
core.setRequestHandlerRaw(method, additionalContext) { request ->
|
||||
val requestParams = decodeAcpPayload<TRequest>(
|
||||
json,
|
||||
options.inboundPayloadNormalizer(method.methodName, request.params)
|
||||
)
|
||||
val responseObject = handler(requestParams)
|
||||
json.encodeToJsonElement(responseObject)
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun <reified TNotification : AcpNotification> AcpProtocol.setNotificationHandler(
|
||||
method: AcpMethod.AcpNotificationMethod<TNotification>,
|
||||
additionalContext: CoroutineContext = EmptyCoroutineContext,
|
||||
noinline handler: suspend (TNotification) -> Unit
|
||||
) {
|
||||
core.setNotificationHandlerRaw(method, additionalContext) { notification ->
|
||||
val notificationParams = decodeAcpPayload<TNotification>(
|
||||
json,
|
||||
options.inboundPayloadNormalizer(method.methodName, notification.params)
|
||||
)
|
||||
handler(notificationParams)
|
||||
}
|
||||
}
|
||||
|
||||
internal suspend inline operator fun <reified TRequest : AcpRequest, reified TResponse : AcpResponse> AcpMethod.AcpRequestResponseMethod<TRequest, TResponse>.invoke(
|
||||
protocol: AcpProtocol,
|
||||
request: TRequest
|
||||
): TResponse {
|
||||
return protocol.sendRequest(this, request)
|
||||
}
|
||||
|
||||
internal inline operator fun <reified TNotification : AcpNotification> AcpMethod.AcpNotificationMethod<TNotification>.invoke(
|
||||
protocol: AcpProtocol,
|
||||
notification: TNotification
|
||||
) {
|
||||
return protocol.sendNotification(this, notification)
|
||||
}
|
||||
|
||||
internal inline fun <reified T> decodeAcpPayload(
|
||||
json: Json,
|
||||
payload: JsonElement?
|
||||
): T {
|
||||
return when (payload) {
|
||||
null -> json.decodeFromJsonElement(JsonNull)
|
||||
is JsonPrimitive -> {
|
||||
if (payload.isString) {
|
||||
json.decodeFromString(payload.content)
|
||||
} else {
|
||||
json.decodeFromJsonElement(payload)
|
||||
}
|
||||
}
|
||||
|
||||
else -> json.decodeFromJsonElement(payload)
|
||||
}
|
||||
}
|
||||
175
src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpCompatibilityRegistry.kt
vendored
Normal file
175
src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpCompatibilityRegistry.kt
vendored
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package ee.carlrobert.codegpt.agent.external.acpcompat.vendor
|
||||
|
||||
import com.agentclientprotocol.agent.AgentInfo
|
||||
import com.agentclientprotocol.rpc.MethodName
|
||||
import ee.carlrobert.codegpt.agent.external.ExternalAcpAgentPreset
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import kotlinx.serialization.json.contentOrNull
|
||||
import kotlinx.serialization.json.jsonArray
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
|
||||
internal class AcpCompatibilityRegistry {
|
||||
|
||||
fun initialProfile(preset: ExternalAcpAgentPreset): AcpPeerProfile {
|
||||
return profileForPresetId(preset.id)
|
||||
}
|
||||
|
||||
fun resolveProfile(
|
||||
preset: ExternalAcpAgentPreset,
|
||||
agentInfo: AgentInfo?
|
||||
): AcpPeerProfile {
|
||||
val presetProfile = initialProfile(preset)
|
||||
val implementationName = buildString {
|
||||
append(agentInfo?.implementation?.name.orEmpty())
|
||||
append(' ')
|
||||
append(agentInfo?.implementation?.title.orEmpty())
|
||||
}.trim().lowercase()
|
||||
|
||||
return when {
|
||||
"codex" in implementationName -> AcpPeerProfile.CODEX
|
||||
"gemini" in implementationName -> AcpPeerProfile.GEMINI
|
||||
"opencode" in implementationName -> AcpPeerProfile.OPENCODE
|
||||
"claude" in implementationName -> AcpPeerProfile.CLAUDE_CODE
|
||||
else -> presetProfile
|
||||
}
|
||||
}
|
||||
|
||||
fun augmentOutboundPayload(
|
||||
profile: AcpPeerProfile,
|
||||
methodName: MethodName,
|
||||
payload: JsonElement?,
|
||||
sessionRequestMeta: JsonElement? = null,
|
||||
launchEnv: Map<String, String> = emptyMap()
|
||||
): JsonElement? {
|
||||
return when (methodName.name) {
|
||||
"initialize" -> augmentInitializeRequest(profile, payload)
|
||||
"authenticate" -> augmentAuthenticateRequest(profile, payload, launchEnv)
|
||||
"session/new", "session/load" -> mergeTopLevelMeta(payload, sessionRequestMeta)
|
||||
else -> payload
|
||||
}
|
||||
}
|
||||
|
||||
fun normalizeInboundPayload(
|
||||
profile: AcpPeerProfile,
|
||||
methodName: MethodName,
|
||||
payload: JsonElement?
|
||||
): JsonElement? {
|
||||
return when (methodName.name) {
|
||||
"initialize" -> normalizeInitializeResponse(profile, payload)
|
||||
"session/new", "session/load", "session/update" -> payload
|
||||
else -> payload
|
||||
}
|
||||
}
|
||||
|
||||
private fun profileForPresetId(presetId: String): AcpPeerProfile {
|
||||
return when (presetId) {
|
||||
"codex" -> AcpPeerProfile.CODEX
|
||||
"gemini-cli" -> AcpPeerProfile.GEMINI
|
||||
"opencode" -> AcpPeerProfile.OPENCODE
|
||||
"claude-code" -> AcpPeerProfile.CLAUDE_CODE
|
||||
else -> AcpPeerProfile.STANDARD
|
||||
}
|
||||
}
|
||||
|
||||
private fun augmentInitializeRequest(
|
||||
profile: AcpPeerProfile,
|
||||
payload: JsonElement?
|
||||
): JsonElement? {
|
||||
val root = payload as? JsonObject ?: return payload
|
||||
if (!profile.supportsTerminalAuthMeta) {
|
||||
return payload
|
||||
}
|
||||
|
||||
val capabilities = root["clientCapabilities"] as? JsonObject ?: return payload
|
||||
val capabilityMeta = mergeMeta(
|
||||
capabilities["_meta"],
|
||||
JsonObject(mapOf("terminal-auth" to JsonPrimitive(true)))
|
||||
) ?: return payload
|
||||
|
||||
return JsonObject(
|
||||
root + ("clientCapabilities" to JsonObject(capabilities + ("_meta" to capabilityMeta)))
|
||||
)
|
||||
}
|
||||
|
||||
private fun augmentAuthenticateRequest(
|
||||
profile: AcpPeerProfile,
|
||||
payload: JsonElement?,
|
||||
launchEnv: Map<String, String>
|
||||
): JsonElement? {
|
||||
if (profile != AcpPeerProfile.GEMINI) {
|
||||
return payload
|
||||
}
|
||||
|
||||
val metaEntries = linkedMapOf<String, JsonElement>()
|
||||
val apiKey = launchEnv["GEMINI_API_KEY"] ?: launchEnv["GOOGLE_API_KEY"]
|
||||
val gateway = launchEnv["GEMINI_GATEWAY"]
|
||||
if (!apiKey.isNullOrBlank()) {
|
||||
metaEntries["api-key"] = JsonPrimitive(apiKey)
|
||||
}
|
||||
if (!gateway.isNullOrBlank()) {
|
||||
metaEntries["gateway"] = JsonPrimitive(gateway)
|
||||
}
|
||||
|
||||
return mergeTopLevelMeta(
|
||||
payload,
|
||||
metaEntries.takeIf { it.isNotEmpty() }?.let(::JsonObject)
|
||||
)
|
||||
}
|
||||
|
||||
private fun normalizeInitializeResponse(
|
||||
profile: AcpPeerProfile,
|
||||
payload: JsonElement?
|
||||
): JsonElement? {
|
||||
if (profile != AcpPeerProfile.CODEX) {
|
||||
return payload
|
||||
}
|
||||
|
||||
val root = payload as? JsonObject ?: return payload
|
||||
val authMethods = root["authMethods"]?.jsonArray ?: return payload
|
||||
val normalizedAuthMethods = authMethods.map(::normalizeAuthMethod)
|
||||
return JsonObject(root + ("authMethods" to kotlinx.serialization.json.JsonArray(normalizedAuthMethods)))
|
||||
}
|
||||
|
||||
private fun normalizeAuthMethod(payload: JsonElement): JsonElement {
|
||||
val authMethod = payload as? JsonObject ?: return payload
|
||||
val type = authMethod["type"]?.jsonPrimitive?.contentOrNull
|
||||
if (type != "env_var" || authMethod["varName"] != null) {
|
||||
return payload
|
||||
}
|
||||
|
||||
val firstVarName = authMethod["vars"]
|
||||
?.jsonArray
|
||||
?.firstOrNull()
|
||||
?.jsonObject
|
||||
?.get("name")
|
||||
?.jsonPrimitive
|
||||
?.contentOrNull
|
||||
?: return payload
|
||||
|
||||
return JsonObject(authMethod + ("varName" to JsonPrimitive(firstVarName)))
|
||||
}
|
||||
|
||||
private fun mergeTopLevelMeta(
|
||||
payload: JsonElement?,
|
||||
extraMeta: JsonElement?
|
||||
): JsonElement? {
|
||||
val root = payload as? JsonObject ?: return payload
|
||||
val mergedMeta = mergeMeta(root["_meta"], extraMeta) ?: return payload
|
||||
return JsonObject(root + ("_meta" to mergedMeta))
|
||||
}
|
||||
|
||||
private fun mergeMeta(
|
||||
base: JsonElement?,
|
||||
extra: JsonElement?
|
||||
): JsonElement? {
|
||||
return when {
|
||||
base == null -> extra
|
||||
extra == null -> base
|
||||
base is JsonObject && extra is JsonObject -> JsonObject(base + extra)
|
||||
else -> extra
|
||||
}
|
||||
}
|
||||
}
|
||||
29
src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpPeerProfile.kt
vendored
Normal file
29
src/main/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/vendor/AcpPeerProfile.kt
vendored
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
package ee.carlrobert.codegpt.agent.external.acpcompat.vendor
|
||||
|
||||
internal enum class AcpPeerProfile(
|
||||
val profileId: String,
|
||||
val displayName: String,
|
||||
val supportsTerminalAuthMeta: Boolean = false
|
||||
) {
|
||||
STANDARD(
|
||||
profileId = "standard",
|
||||
displayName = "Standard ACP"
|
||||
),
|
||||
CODEX(
|
||||
profileId = "codex",
|
||||
displayName = "Codex ACP"
|
||||
),
|
||||
GEMINI(
|
||||
profileId = "gemini",
|
||||
displayName = "Gemini CLI ACP"
|
||||
),
|
||||
OPENCODE(
|
||||
profileId = "opencode",
|
||||
displayName = "OpenCode ACP"
|
||||
),
|
||||
CLAUDE_CODE(
|
||||
profileId = "claude-code",
|
||||
displayName = "Claude Code ACP",
|
||||
supportsTerminalAuthMeta = true
|
||||
)
|
||||
}
|
||||
88
src/main/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpExternalEvent.kt
vendored
Normal file
88
src/main/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpExternalEvent.kt
vendored
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package ee.carlrobert.codegpt.agent.external.events
|
||||
|
||||
import com.agentclientprotocol.model.*
|
||||
import ee.carlrobert.codegpt.agent.external.AcpToolCallStatus
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
|
||||
internal sealed interface AcpExternalEvent {
|
||||
data class TextChunk(val text: String) : AcpExternalEvent
|
||||
data class ThinkingChunk(val text: String) : AcpExternalEvent
|
||||
data class PlanUpdate(val entries: List<PlanEntry>) : AcpExternalEvent
|
||||
data class UsageUpdate(val used: Long, val size: Long, val cost: Cost?) : AcpExternalEvent
|
||||
data class ConfigOptionUpdate(val configOptions: List<SessionConfigOption>) : AcpExternalEvent
|
||||
data class SessionInfoUpdate(val title: String?, val updatedAt: String?) : AcpExternalEvent
|
||||
data class UnknownSessionUpdate(val type: String, val rawJson: JsonObject) : AcpExternalEvent
|
||||
data class AvailableCommandsUpdate(val availableCommands: List<AvailableCommand>) :
|
||||
AcpExternalEvent
|
||||
|
||||
data class CurrentModeUpdate(val currentModeId: String) : AcpExternalEvent
|
||||
data class ToolCallStarted(val toolCall: AcpToolCallSnapshot) : AcpExternalEvent
|
||||
data class ToolCallUpdated(val toolCall: AcpToolCallSnapshot) : AcpExternalEvent
|
||||
}
|
||||
|
||||
internal sealed interface AcpToolCallArgs {
|
||||
data class SearchPreview(val value: AcpSearchPreviewArgs) : AcpToolCallArgs
|
||||
data class BashPreview(val value: AcpBashPreviewArgs) : AcpToolCallArgs
|
||||
data class Mcp(val value: McpTool.Args) : AcpToolCallArgs
|
||||
data class Read(val value: ReadTool.Args) : AcpToolCallArgs
|
||||
data class Write(val value: WriteTool.Args) : AcpToolCallArgs
|
||||
data class Edit(val value: EditTool.Args) : AcpToolCallArgs
|
||||
data class Bash(val value: BashTool.Args) : AcpToolCallArgs
|
||||
data class WebSearch(val value: WebSearchTool.Args) : AcpToolCallArgs
|
||||
data class WebFetch(val value: WebFetchTool.Args) : AcpToolCallArgs
|
||||
data class IntelliJSearch(val value: IntelliJSearchTool.Args) : AcpToolCallArgs
|
||||
data class Unknown(val value: JsonElement?) : AcpToolCallArgs
|
||||
}
|
||||
|
||||
internal data class AcpSearchPreviewArgs(
|
||||
val title: String,
|
||||
val path: String? = null,
|
||||
val pattern: String? = null
|
||||
)
|
||||
|
||||
internal data class AcpBashPreviewArgs(
|
||||
val title: String,
|
||||
val command: String? = null
|
||||
)
|
||||
|
||||
internal sealed interface AcpToolCallContent {
|
||||
data class Block(val content: ContentBlock) : AcpToolCallContent
|
||||
data class Diff(val path: String, val oldText: String?, val newText: String) :
|
||||
AcpToolCallContent
|
||||
|
||||
data class Terminal(val terminalId: String) : AcpToolCallContent
|
||||
}
|
||||
|
||||
internal data class AcpToolCallSnapshot(
|
||||
val id: String,
|
||||
val title: String,
|
||||
val toolName: String,
|
||||
val kind: ToolKind? = null,
|
||||
val status: AcpToolCallStatus? = null,
|
||||
val args: AcpToolCallArgs? = null,
|
||||
val locations: List<ToolCallLocation> = emptyList(),
|
||||
val content: List<AcpToolCallContent> = emptyList(),
|
||||
val meta: JsonElement? = null,
|
||||
val rawInput: JsonElement? = null,
|
||||
val rawOutput: JsonElement? = null
|
||||
)
|
||||
|
||||
internal data class AcpPermissionRequestSnapshot(
|
||||
val toolCall: AcpToolCallSnapshot,
|
||||
val details: String,
|
||||
val options: List<PermissionOption>
|
||||
)
|
||||
|
||||
internal fun ToolCallContent.toAcpToolCallContent(): AcpToolCallContent {
|
||||
return when (this) {
|
||||
is ToolCallContent.Content -> AcpToolCallContent.Block(content)
|
||||
is ToolCallContent.Diff -> AcpToolCallContent.Diff(path, oldText, newText)
|
||||
is ToolCallContent.Terminal -> AcpToolCallContent.Terminal(terminalId)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun List<ToolCallContent>.toAcpToolCallContent(): List<AcpToolCallContent> {
|
||||
return map(ToolCallContent::toAcpToolCallContent)
|
||||
}
|
||||
114
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHost.kt
vendored
Normal file
114
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHost.kt
vendored
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import com.agentclientprotocol.model.*
|
||||
import com.intellij.openapi.application.runReadAction
|
||||
import com.intellij.openapi.fileEditor.FileDocumentManager
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.openapi.vfs.VfsUtil
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import kotlin.io.path.createDirectories
|
||||
import kotlin.io.path.notExists
|
||||
|
||||
class AcpFileHost(
|
||||
private val pathPolicy: AcpPathPolicy = AcpPathPolicy(),
|
||||
private val openDocumentReader: AcpOpenDocumentReader = IntelliJOpenDocumentReader(),
|
||||
private val writer: AcpTextFileWriter = IntelliJTextFileWriter()
|
||||
) {
|
||||
|
||||
fun clientCapabilities(includeTerminal: Boolean = false): ClientCapabilities {
|
||||
return ClientCapabilities(
|
||||
fs = FileSystemCapability(
|
||||
readTextFile = true,
|
||||
writeTextFile = true
|
||||
),
|
||||
terminal = includeTerminal
|
||||
)
|
||||
}
|
||||
|
||||
fun readTextFile(
|
||||
session: AcpHostSessionContext,
|
||||
request: ReadTextFileRequest
|
||||
): ReadTextFileResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Read request session does not match host session"
|
||||
}
|
||||
val resolvedPath = pathPolicy.resolveWithinCwd(request.path, session.cwd)
|
||||
val content = readContent(resolvedPath)
|
||||
val trimmed = applyLineWindow(
|
||||
content = content.content,
|
||||
line = request.line,
|
||||
limit = request.limit
|
||||
)
|
||||
return ReadTextFileResponse(trimmed)
|
||||
}
|
||||
|
||||
fun writeTextFile(
|
||||
session: AcpHostSessionContext,
|
||||
request: WriteTextFileRequest
|
||||
): WriteTextFileResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Write request session does not match host session"
|
||||
}
|
||||
val resolvedPath = pathPolicy.resolveWithinCwd(request.path, session.cwd)
|
||||
writer.write(resolvedPath, request.content)
|
||||
return WriteTextFileResponse()
|
||||
}
|
||||
|
||||
private fun readContent(path: Path): AcpTextFileReadResult {
|
||||
openDocumentReader.read(path)?.let { editorText ->
|
||||
return AcpTextFileReadResult(editorText, fromEditor = true)
|
||||
}
|
||||
return AcpTextFileReadResult(Files.readString(path), fromEditor = false)
|
||||
}
|
||||
|
||||
private fun applyLineWindow(
|
||||
content: String,
|
||||
line: UInt?,
|
||||
limit: UInt?
|
||||
): String {
|
||||
val startLine = line?.toInt()
|
||||
val maxLines = limit?.toInt()
|
||||
if (startLine == null || maxLines == null || startLine <= 0 || maxLines <= 0) {
|
||||
return content
|
||||
}
|
||||
|
||||
return content.lineSequence()
|
||||
.drop(startLine - 1)
|
||||
.take(maxLines)
|
||||
.joinToString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
private class IntelliJOpenDocumentReader : AcpOpenDocumentReader {
|
||||
override fun read(path: Path): String? {
|
||||
return runReadAction<String?> {
|
||||
val virtualFile =
|
||||
LocalFileSystem.getInstance().refreshAndFindFileByIoFile(path.toFile())
|
||||
?: return@runReadAction null
|
||||
FileDocumentManager.getInstance().getDocument(virtualFile)?.text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class IntelliJTextFileWriter : AcpTextFileWriter {
|
||||
override fun write(path: Path, content: String) {
|
||||
if (path.parent != null && path.parent.notExists()) {
|
||||
path.parent.createDirectories()
|
||||
}
|
||||
Files.writeString(path, content, StandardCharsets.UTF_8)
|
||||
|
||||
val virtualFile = LocalFileSystem.getInstance().refreshAndFindFileByIoFile(path.toFile())
|
||||
if (virtualFile != null) {
|
||||
VfsUtil.markDirtyAndRefresh(false, false, false, virtualFile)
|
||||
} else {
|
||||
path.parent?.toFile()?.let { parent ->
|
||||
LocalFileSystem.getInstance().refreshAndFindFileByIoFile(parent)
|
||||
?.let { parentVf ->
|
||||
VfsUtil.markDirtyAndRefresh(false, false, true, parentVf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
52
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpHostModels.kt
vendored
Normal file
52
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpHostModels.kt
vendored
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import java.nio.file.Path
|
||||
|
||||
data class AcpHostSessionContext(
|
||||
val sessionId: String,
|
||||
val cwd: Path
|
||||
)
|
||||
|
||||
data class AcpTextFileReadResult(
|
||||
val content: String,
|
||||
val fromEditor: Boolean
|
||||
)
|
||||
|
||||
data class AcpTerminalExitStatus(
|
||||
val exitCode: UInt? = null,
|
||||
val signal: String? = null
|
||||
)
|
||||
|
||||
data class AcpTerminalOutputSnapshot(
|
||||
val output: String,
|
||||
val truncated: Boolean,
|
||||
val exitStatus: AcpTerminalExitStatus? = null
|
||||
)
|
||||
|
||||
fun interface AcpOpenDocumentReader {
|
||||
fun read(path: Path): String?
|
||||
}
|
||||
|
||||
fun interface AcpTextFileWriter {
|
||||
fun write(path: Path, content: String)
|
||||
}
|
||||
|
||||
interface AcpTerminalProcess {
|
||||
val terminalId: String
|
||||
|
||||
fun output(): AcpTerminalOutputSnapshot
|
||||
suspend fun waitForExit(): AcpTerminalExitStatus
|
||||
fun release()
|
||||
fun kill()
|
||||
}
|
||||
|
||||
fun interface AcpTerminalProcessLauncher {
|
||||
fun launch(
|
||||
command: String,
|
||||
args: List<String>,
|
||||
cwd: Path,
|
||||
env: Map<String, String>,
|
||||
outputByteLimit: ULong?
|
||||
): AcpTerminalProcess
|
||||
}
|
||||
|
||||
57
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicy.kt
vendored
Normal file
57
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicy.kt
vendored
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import java.net.URI
|
||||
import java.nio.file.Path
|
||||
|
||||
class AcpHostPathBoundaryException(message: String) : IllegalArgumentException(message)
|
||||
|
||||
class AcpPathPolicy {
|
||||
|
||||
fun resolveWithinCwd(rawPath: String, cwd: Path): Path {
|
||||
val normalizedCwd = cwd.toAbsolutePath().normalize()
|
||||
val checkedCwd = canonicalizeForBoundaryCheck(normalizedCwd)
|
||||
val requestedPath = parsePath(rawPath)
|
||||
val candidate = if (requestedPath.isAbsolute) {
|
||||
requestedPath
|
||||
} else {
|
||||
normalizedCwd.resolve(requestedPath)
|
||||
}.toAbsolutePath().normalize()
|
||||
val checkedCandidate = canonicalizeForBoundaryCheck(candidate)
|
||||
|
||||
if (!checkedCandidate.startsWith(checkedCwd)) {
|
||||
throw AcpHostPathBoundaryException(
|
||||
"Path escapes the session cwd: $rawPath"
|
||||
)
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
||||
private fun parsePath(rawPath: String): Path {
|
||||
return if (rawPath.startsWith("file://")) {
|
||||
Path.of(URI.create(rawPath))
|
||||
} else {
|
||||
Path.of(rawPath)
|
||||
}
|
||||
}
|
||||
|
||||
private fun canonicalizeForBoundaryCheck(path: Path): Path {
|
||||
val absolute = path.toAbsolutePath().normalize()
|
||||
val deferredSegments = mutableListOf<String>()
|
||||
var current: Path? = absolute
|
||||
|
||||
while (current != null) {
|
||||
val realPath = runCatching { current.toRealPath() }.getOrNull()
|
||||
if (realPath != null) {
|
||||
return deferredSegments.foldRight(realPath) { segment, acc ->
|
||||
acc.resolve(segment)
|
||||
}.normalize()
|
||||
}
|
||||
|
||||
current.fileName?.toString()?.let(deferredSegments::add)
|
||||
current = current.parent
|
||||
}
|
||||
|
||||
return absolute
|
||||
}
|
||||
}
|
||||
149
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHost.kt
vendored
Normal file
149
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHost.kt
vendored
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import com.agentclientprotocol.model.CreateTerminalRequest
|
||||
import com.agentclientprotocol.model.CreateTerminalResponse
|
||||
import com.agentclientprotocol.model.KillTerminalCommandRequest
|
||||
import com.agentclientprotocol.model.KillTerminalCommandResponse
|
||||
import com.agentclientprotocol.model.ReadTextFileRequest
|
||||
import com.agentclientprotocol.model.ReleaseTerminalRequest
|
||||
import com.agentclientprotocol.model.ReleaseTerminalResponse
|
||||
import com.agentclientprotocol.model.TerminalExitStatus
|
||||
import com.agentclientprotocol.model.TerminalOutputRequest
|
||||
import com.agentclientprotocol.model.TerminalOutputResponse
|
||||
import com.agentclientprotocol.model.WaitForTerminalExitRequest
|
||||
import com.agentclientprotocol.model.WaitForTerminalExitResponse
|
||||
import com.agentclientprotocol.model.WriteTextFileRequest
|
||||
import java.nio.file.Path
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
class AcpHostTerminalNotFoundException(message: String) : IllegalArgumentException(message)
|
||||
|
||||
class AcpTerminalHost(
|
||||
private val launcher: AcpTerminalProcessLauncher,
|
||||
private val pathPolicy: AcpPathPolicy = AcpPathPolicy()
|
||||
) {
|
||||
private val sessions = ConcurrentHashMap<String, AcpTerminalProcess>()
|
||||
|
||||
fun createTerminal(
|
||||
session: AcpHostSessionContext,
|
||||
request: CreateTerminalRequest
|
||||
): CreateTerminalResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Terminal request session does not match host session"
|
||||
}
|
||||
|
||||
val terminalId = "terminal-${UUID.randomUUID()}"
|
||||
val process = launcher.launch(
|
||||
command = request.command,
|
||||
args = request.args,
|
||||
cwd = resolveWorkingDirectory(session.cwd, request.cwd),
|
||||
env = request.env.associate { it.name to it.value },
|
||||
outputByteLimit = request.outputByteLimit
|
||||
)
|
||||
sessions[terminalId] = process
|
||||
return CreateTerminalResponse(terminalId = terminalId)
|
||||
}
|
||||
|
||||
fun output(
|
||||
session: AcpHostSessionContext,
|
||||
request: TerminalOutputRequest
|
||||
): TerminalOutputResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Terminal request session does not match host session"
|
||||
}
|
||||
val process = processFor(session, request.terminalId)
|
||||
val snapshot = process.output()
|
||||
return TerminalOutputResponse(
|
||||
output = snapshot.output,
|
||||
truncated = snapshot.truncated,
|
||||
exitStatus = snapshot.exitStatus?.let {
|
||||
TerminalExitStatus(
|
||||
exitCode = it.exitCode,
|
||||
signal = it.signal
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fun release(
|
||||
session: AcpHostSessionContext,
|
||||
request: ReleaseTerminalRequest
|
||||
): ReleaseTerminalResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Terminal request session does not match host session"
|
||||
}
|
||||
val process = processFor(session, request.terminalId)
|
||||
process.release()
|
||||
sessions.remove(request.terminalId)
|
||||
return ReleaseTerminalResponse()
|
||||
}
|
||||
|
||||
suspend fun waitForExit(
|
||||
session: AcpHostSessionContext,
|
||||
request: WaitForTerminalExitRequest
|
||||
): WaitForTerminalExitResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Terminal request session does not match host session"
|
||||
}
|
||||
val process = processFor(session, request.terminalId)
|
||||
val exit = process.waitForExit()
|
||||
return WaitForTerminalExitResponse(
|
||||
exitCode = exit.exitCode,
|
||||
signal = exit.signal
|
||||
)
|
||||
}
|
||||
|
||||
fun kill(
|
||||
session: AcpHostSessionContext,
|
||||
request: KillTerminalCommandRequest
|
||||
): KillTerminalCommandResponse {
|
||||
require(request.sessionId.value == session.sessionId) {
|
||||
"Terminal request session does not match host session"
|
||||
}
|
||||
val process = processFor(session, request.terminalId)
|
||||
process.kill()
|
||||
return KillTerminalCommandResponse()
|
||||
}
|
||||
|
||||
private fun processFor(session: AcpHostSessionContext, terminalId: String): AcpTerminalProcess {
|
||||
return sessions[terminalId]
|
||||
?: throw AcpHostTerminalNotFoundException(
|
||||
"Unknown terminal '$terminalId' for session ${session.sessionId}"
|
||||
)
|
||||
}
|
||||
|
||||
private fun resolveWorkingDirectory(cwd: Path, overrideCwd: String?): Path {
|
||||
return overrideCwd?.takeIf { it.isNotBlank() }
|
||||
?.let { pathPolicy.resolveWithinCwd(it, cwd) }
|
||||
?: cwd.toAbsolutePath().normalize()
|
||||
}
|
||||
}
|
||||
|
||||
class AcpHostCapabilities(
|
||||
private val fileHost: AcpFileHost,
|
||||
private val terminalHost: AcpTerminalHost
|
||||
) {
|
||||
fun clientCapabilities(includeTerminal: Boolean = true) = fileHost.clientCapabilities(includeTerminal)
|
||||
|
||||
fun readTextFile(session: AcpHostSessionContext, request: ReadTextFileRequest) =
|
||||
fileHost.readTextFile(session, request)
|
||||
|
||||
fun writeTextFile(session: AcpHostSessionContext, request: WriteTextFileRequest) =
|
||||
fileHost.writeTextFile(session, request)
|
||||
|
||||
fun createTerminal(session: AcpHostSessionContext, request: CreateTerminalRequest) =
|
||||
terminalHost.createTerminal(session, request)
|
||||
|
||||
fun output(session: AcpHostSessionContext, request: TerminalOutputRequest) =
|
||||
terminalHost.output(session, request)
|
||||
|
||||
fun release(session: AcpHostSessionContext, request: ReleaseTerminalRequest) =
|
||||
terminalHost.release(session, request)
|
||||
|
||||
suspend fun waitForExit(session: AcpHostSessionContext, request: WaitForTerminalExitRequest) =
|
||||
terminalHost.waitForExit(session, request)
|
||||
|
||||
fun kill(session: AcpHostSessionContext, request: KillTerminalCommandRequest) =
|
||||
terminalHost.kill(session, request)
|
||||
}
|
||||
142
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/DefaultAcpTerminalProcessLauncher.kt
vendored
Normal file
142
src/main/kotlin/ee/carlrobert/codegpt/agent/external/host/DefaultAcpTerminalProcessLauncher.kt
vendored
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import com.intellij.execution.configurations.GeneralCommandLine
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.joinAll
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.InputStream
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Path
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
class DefaultAcpTerminalProcessLauncher(
|
||||
private val scope: CoroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
) : AcpTerminalProcessLauncher {
|
||||
override fun launch(
|
||||
command: String,
|
||||
args: List<String>,
|
||||
cwd: Path,
|
||||
env: Map<String, String>,
|
||||
outputByteLimit: ULong?
|
||||
): AcpTerminalProcess {
|
||||
val process = GeneralCommandLine().apply {
|
||||
withExePath(command)
|
||||
withParameters(args)
|
||||
withWorkDirectory(cwd.toFile())
|
||||
withEnvironment(env)
|
||||
withParentEnvironmentType(GeneralCommandLine.ParentEnvironmentType.NONE)
|
||||
withRedirectErrorStream(true)
|
||||
}.createProcess()
|
||||
|
||||
return DefaultAcpTerminalProcess(
|
||||
terminalId = UUID.randomUUID().toString(),
|
||||
process = process,
|
||||
scope = scope,
|
||||
outputByteLimit = outputByteLimit
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private class DefaultAcpTerminalProcess(
|
||||
override val terminalId: String,
|
||||
private val process: Process,
|
||||
scope: CoroutineScope,
|
||||
outputByteLimit: ULong?
|
||||
) : AcpTerminalProcess {
|
||||
|
||||
private val outputBuffer = BoundedOutputBuffer(outputByteLimit)
|
||||
private val exitStatus = AtomicReference<AcpTerminalExitStatus?>(null)
|
||||
private val stdoutJob = scope.launch {
|
||||
collect(process.inputStream)
|
||||
}
|
||||
private val stderrJob = scope.launch {
|
||||
collect(process.errorStream)
|
||||
}
|
||||
|
||||
override fun output(): AcpTerminalOutputSnapshot {
|
||||
if (exitStatus.get() == null && !process.isAlive) {
|
||||
exitStatus.compareAndSet(
|
||||
null,
|
||||
AcpTerminalExitStatus(exitCode = process.exitValue().toUInt())
|
||||
)
|
||||
}
|
||||
return outputBuffer.snapshot(exitStatus.get())
|
||||
}
|
||||
|
||||
override suspend fun waitForExit(): AcpTerminalExitStatus {
|
||||
val exitCode = withContext(Dispatchers.IO) {
|
||||
process.waitFor()
|
||||
}
|
||||
joinAll(stdoutJob, stderrJob)
|
||||
val status = AcpTerminalExitStatus(exitCode = exitCode.toUInt())
|
||||
exitStatus.set(status)
|
||||
return status
|
||||
}
|
||||
|
||||
override fun release() {
|
||||
}
|
||||
|
||||
override fun kill() {
|
||||
process.destroyForcibly()
|
||||
}
|
||||
|
||||
private suspend fun collect(stream: InputStream) {
|
||||
withContext(Dispatchers.IO) {
|
||||
stream.bufferedReader(StandardCharsets.UTF_8).useLines { lines ->
|
||||
lines.forEach { line ->
|
||||
outputBuffer.appendLine(line)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class BoundedOutputBuffer(
|
||||
private val byteLimit: ULong?
|
||||
) {
|
||||
private val text = StringBuilder()
|
||||
private var bytesUsed: ULong = 0u
|
||||
private var truncated = false
|
||||
|
||||
@Synchronized
|
||||
fun appendLine(line: String) {
|
||||
if (truncated) return
|
||||
|
||||
val encoded = (line + '\n').toByteArray(StandardCharsets.UTF_8)
|
||||
val limit = byteLimit
|
||||
if (limit == null) {
|
||||
text.appendLine(line)
|
||||
return
|
||||
}
|
||||
|
||||
val remaining = limit - bytesUsed
|
||||
if (remaining == 0uL) {
|
||||
truncated = true
|
||||
return
|
||||
}
|
||||
|
||||
if (encoded.size.toUInt().toULong() <= remaining) {
|
||||
text.appendLine(line)
|
||||
bytesUsed += encoded.size.toUInt().toULong()
|
||||
return
|
||||
}
|
||||
|
||||
val allowed = remaining.toInt().coerceAtMost(encoded.size)
|
||||
text.append(String(encoded, 0, allowed, StandardCharsets.UTF_8))
|
||||
bytesUsed = limit
|
||||
truncated = true
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
fun snapshot(exitStatus: AcpTerminalExitStatus? = null): AcpTerminalOutputSnapshot {
|
||||
return AcpTerminalOutputSnapshot(
|
||||
output = text.toString(),
|
||||
truncated = truncated,
|
||||
exitStatus = exitStatus
|
||||
)
|
||||
}
|
||||
}
|
||||
175
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpIncomingRequestManager.kt
vendored
Normal file
175
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpIncomingRequestManager.kt
vendored
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.model.AcpMethod
|
||||
import com.agentclientprotocol.model.CancelRequestNotification
|
||||
import com.agentclientprotocol.rpc.*
|
||||
import com.agentclientprotocol.transport.Transport
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.AcpExpectedError
|
||||
import ee.carlrobert.codegpt.agent.external.decodeOrNull
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.SerializationException
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
private val incomingLogger = KotlinLogging.logger {}
|
||||
|
||||
internal class AcpIncomingRequestManager(
|
||||
private val state: AcpProtocolState,
|
||||
private val transport: Transport,
|
||||
private val json: Json,
|
||||
private val requestsScope: CoroutineScope,
|
||||
private val trace: ((String) -> Unit)?
|
||||
) {
|
||||
|
||||
fun installMetaHandlers() {
|
||||
state.notificationHandlers[AcpMethod.MetaMethods.CancelRequest.methodName] =
|
||||
{ notification ->
|
||||
val request = notification.params.decodeOrNull<CancelRequestNotification>(json)
|
||||
when {
|
||||
request == null -> incomingLogger.warn { "Received CancelRequest with invalid payload" }
|
||||
else -> {
|
||||
val requestJob = state.pendingIncomingRequests.remove(request.requestId)
|
||||
if (requestJob == null) {
|
||||
incomingLogger.warn { "Received CancelRequest for unknown request: ${request.requestId}" }
|
||||
} else {
|
||||
requestJob.cancel(
|
||||
JsonRpcIncomingRequestCanceledException(
|
||||
request.message ?: "Cancelled by the counterpart",
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun setRequestHandlerRaw(
|
||||
method: AcpMethod.AcpRequestResponseMethod<*, *>,
|
||||
additionalContext: CoroutineContext,
|
||||
handler: suspend (JsonRpcRequest) -> JsonElement?
|
||||
) {
|
||||
state.requestHandlers[method.methodName] = { request ->
|
||||
withContext(additionalContext) {
|
||||
handler(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun setNotificationHandlerRaw(
|
||||
method: AcpMethod.AcpNotificationMethod<*>,
|
||||
additionalContext: CoroutineContext,
|
||||
handler: suspend (JsonRpcNotification) -> Unit
|
||||
) {
|
||||
state.notificationHandlers[method.methodName] = { notification ->
|
||||
withContext(additionalContext) {
|
||||
handler(notification)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun handleRequest(request: JsonRpcRequest) {
|
||||
val requestId = request.id
|
||||
requestsScope.launch {
|
||||
processRequest(request)
|
||||
}.also { job ->
|
||||
state.pendingIncomingRequests[requestId] = job
|
||||
}.invokeOnCompletion {
|
||||
state.pendingIncomingRequests.remove(requestId)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun handleNotification(notification: JsonRpcNotification) {
|
||||
val handler = state.notificationHandlers[notification.method]
|
||||
if (handler != null) {
|
||||
try {
|
||||
handler(notification)
|
||||
} catch (e: Exception) {
|
||||
incomingLogger.error(e) { "Error handling notification ${notification.method}" }
|
||||
}
|
||||
} else {
|
||||
incomingLogger.debug { "No handler for notification: ${notification.method}" }
|
||||
}
|
||||
}
|
||||
|
||||
fun cancelPendingIncomingRequests(ce: CancellationException? = null) {
|
||||
val requests = state.pendingIncomingRequests.toMap()
|
||||
state.pendingIncomingRequests.clear()
|
||||
for ((requestId, job) in requests) {
|
||||
incomingLogger.trace { "Canceling pending incoming request: $requestId" }
|
||||
job.cancel(ce)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun processRequest(request: JsonRpcRequest) {
|
||||
val handler = state.requestHandlers[request.method]
|
||||
val response = if (handler == null) {
|
||||
JsonRpcResponse(
|
||||
request.id,
|
||||
null,
|
||||
JsonRpcError(
|
||||
JsonRpcErrorCode.METHOD_NOT_FOUND.code,
|
||||
"Method not supported: ${request.method}"
|
||||
)
|
||||
)
|
||||
} else {
|
||||
buildResponse(request, handler)
|
||||
}
|
||||
if (response != null) {
|
||||
trace?.invoke("-> ${response.traceSummary()}")
|
||||
transport.send(response)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun buildResponse(
|
||||
request: JsonRpcRequest,
|
||||
handler: suspend (JsonRpcRequest) -> JsonElement?
|
||||
): JsonRpcResponse? {
|
||||
return try {
|
||||
val result = withContext(JsonRpcRequestContextElement()) {
|
||||
handler(request)
|
||||
}
|
||||
JsonRpcResponse(request.id, result, null)
|
||||
} catch (e: AcpExpectedError) {
|
||||
incomingLogger.trace(e) { "Expected error on '${request.method}'" }
|
||||
errorResponse(
|
||||
request.id,
|
||||
JsonRpcErrorCode.INVALID_PARAMS,
|
||||
e.message ?: "Invalid params"
|
||||
)
|
||||
} catch (e: SerializationException) {
|
||||
incomingLogger.trace(e) { "Serialization error on ${request.method}" }
|
||||
errorResponse(
|
||||
request.id,
|
||||
JsonRpcErrorCode.PARSE_ERROR,
|
||||
e.message ?: "Serialization error"
|
||||
)
|
||||
} catch (ce: CancellationException) {
|
||||
incomingLogger.trace(ce) { "Incoming request cancelled: ${request.method}" }
|
||||
if (ce is JsonRpcIncomingRequestCanceledException) {
|
||||
null
|
||||
} else {
|
||||
errorResponse(request.id, JsonRpcErrorCode.CANCELLED, ce.message ?: "Cancelled")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
incomingLogger.error(e) { "Exception on ${request.method}" }
|
||||
errorResponse(
|
||||
request.id,
|
||||
JsonRpcErrorCode.INTERNAL_ERROR,
|
||||
e.message ?: "Internal error"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun errorResponse(
|
||||
requestId: RequestId,
|
||||
code: JsonRpcErrorCode,
|
||||
message: String
|
||||
): JsonRpcResponse {
|
||||
return JsonRpcResponse(requestId, null, JsonRpcError(code.code, message))
|
||||
}
|
||||
}
|
||||
36
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpMessagePump.kt
vendored
Normal file
36
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpMessagePump.kt
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.rpc.JsonRpcMessage
|
||||
import com.agentclientprotocol.transport.Transport
|
||||
import com.agentclientprotocol.transport.asMessageChannel
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import kotlinx.coroutines.CoroutineName
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private val pumpLogger = KotlinLogging.logger {}
|
||||
|
||||
internal class AcpMessagePump(
|
||||
private val scope: CoroutineScope,
|
||||
private val transport: Transport,
|
||||
private val protocolDebugName: String,
|
||||
private val onMessage: suspend (JsonRpcMessage) -> Unit
|
||||
) {
|
||||
|
||||
fun start() {
|
||||
scope.launch(CoroutineName("$protocolDebugName.read-messages")) {
|
||||
try {
|
||||
for (message in transport.asMessageChannel()) {
|
||||
try {
|
||||
onMessage(message)
|
||||
} catch (e: Exception) {
|
||||
pumpLogger.error(e) { "Error processing incoming message: $message" }
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
pumpLogger.error(e) { "Error processing incoming messages" }
|
||||
}
|
||||
}
|
||||
transport.start()
|
||||
}
|
||||
}
|
||||
129
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpOutgoingRequestManager.kt
vendored
Normal file
129
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpOutgoingRequestManager.kt
vendored
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.model.AcpMethod
|
||||
import com.agentclientprotocol.model.CancelRequestNotification
|
||||
import com.agentclientprotocol.rpc.*
|
||||
import com.agentclientprotocol.transport.Transport
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.JsonRpcException
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.ProtocolOptions
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonNull
|
||||
import kotlinx.serialization.json.encodeToJsonElement
|
||||
|
||||
private val outgoingLogger = KotlinLogging.logger {}
|
||||
|
||||
internal class AcpOutgoingRequestManager(
|
||||
private val state: AcpProtocolState,
|
||||
private val transport: Transport,
|
||||
private val json: Json,
|
||||
private val options: ProtocolOptions,
|
||||
private val trace: ((String) -> Unit)?
|
||||
) {
|
||||
|
||||
suspend fun sendRequestRaw(
|
||||
method: MethodName,
|
||||
params: JsonElement? = null
|
||||
): JsonElement {
|
||||
val requestId = state.nextRequestId()
|
||||
val deferred = CompletableDeferred<JsonElement>()
|
||||
state.pendingOutgoingRequests[requestId] = deferred
|
||||
|
||||
try {
|
||||
val request = JsonRpcRequest(requestId, method, params)
|
||||
trace?.invoke("-> ${request.traceSummary()}")
|
||||
transport.send(request)
|
||||
return deferred.await()
|
||||
} catch (jsonRpcException: JsonRpcException) {
|
||||
throw convertJsonRpcExceptionIfPossible(jsonRpcException)
|
||||
} catch (ce: CancellationException) {
|
||||
outgoingLogger.trace(ce) {
|
||||
"Request cancelled on this side. Sending CancelRequest notification."
|
||||
}
|
||||
withContext(NonCancellable) {
|
||||
sendCancellationNotification(requestId, ce.message)
|
||||
|
||||
if (!deferred.isCancelled) {
|
||||
waitForGracefulCancellation(requestId, deferred)
|
||||
deferred.cancel()
|
||||
}
|
||||
}
|
||||
throw ce
|
||||
} finally {
|
||||
state.pendingOutgoingRequests.remove(requestId)
|
||||
}
|
||||
}
|
||||
|
||||
fun handleResponse(response: JsonRpcResponse) {
|
||||
val deferred = state.pendingOutgoingRequests.remove(response.id)
|
||||
if (deferred != null) {
|
||||
val responseError = response.error
|
||||
if (responseError != null) {
|
||||
deferred.completeExceptionally(
|
||||
JsonRpcException(responseError.code, responseError.message, responseError.data)
|
||||
)
|
||||
} else {
|
||||
deferred.complete(response.result ?: JsonNull)
|
||||
}
|
||||
} else {
|
||||
outgoingLogger.warn { "Received response for unknown request ID: ${response.id}" }
|
||||
}
|
||||
}
|
||||
|
||||
fun cancelPendingOutgoingRequests(ce: CancellationException? = null) {
|
||||
val requests = state.pendingOutgoingRequests.toMap()
|
||||
state.pendingOutgoingRequests.clear()
|
||||
for ((requestId, deferred) in requests) {
|
||||
outgoingLogger.trace { "Canceling pending outgoing request: $requestId" }
|
||||
deferred.cancel(ce)
|
||||
}
|
||||
}
|
||||
|
||||
private fun sendCancellationNotification(
|
||||
requestId: RequestId,
|
||||
message: String?
|
||||
) {
|
||||
val notification = JsonRpcNotification(
|
||||
method = AcpMethod.MetaMethods.CancelRequest.methodName,
|
||||
params = json.encodeToJsonElement(CancelRequestNotification(requestId, message))
|
||||
)
|
||||
trace?.invoke("-> ${notification.traceSummary()}")
|
||||
transport.send(notification)
|
||||
}
|
||||
|
||||
private suspend fun waitForGracefulCancellation(
|
||||
requestId: RequestId,
|
||||
deferred: CompletableDeferred<JsonElement>
|
||||
) {
|
||||
try {
|
||||
withTimeout(options.gracefulRequestCancellationTimeout) {
|
||||
deferred.await()
|
||||
}
|
||||
} catch (e: TimeoutCancellationException) {
|
||||
outgoingLogger.trace(e) {
|
||||
"Timed out waiting for graceful cancellation response for request: $requestId"
|
||||
}
|
||||
} catch (ce: CancellationException) {
|
||||
outgoingLogger.trace(ce) {
|
||||
"Graceful cancellation response received for request: $requestId"
|
||||
}
|
||||
} catch (e: JsonRpcException) {
|
||||
val convertedException = convertJsonRpcExceptionIfPossible(e)
|
||||
if (convertedException is CancellationException) {
|
||||
outgoingLogger.trace(convertedException) {
|
||||
"Graceful cancellation response received for request: $requestId"
|
||||
}
|
||||
} else {
|
||||
outgoingLogger.warn(convertedException) {
|
||||
"Unexpected error while waiting for graceful cancellation response for request: $requestId"
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
outgoingLogger.warn(e) {
|
||||
"Unexpected error while waiting for graceful cancellation response for request: $requestId"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
102
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolCore.kt
vendored
Normal file
102
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolCore.kt
vendored
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.model.AcpMethod
|
||||
import com.agentclientprotocol.rpc.*
|
||||
import com.agentclientprotocol.transport.Transport
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.ProtocolOptions
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import java.util.concurrent.Executors
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.EmptyCoroutineContext
|
||||
|
||||
internal class AcpProtocolCore(
|
||||
parentScope: CoroutineScope,
|
||||
private val transport: Transport,
|
||||
json: Json,
|
||||
private val options: ProtocolOptions
|
||||
) {
|
||||
private val trace = options.trace
|
||||
private val scope = CoroutineScope(
|
||||
parentScope.coroutineContext
|
||||
+ SupervisorJob(parentScope.coroutineContext[Job])
|
||||
+ CoroutineName(options.protocolDebugName)
|
||||
)
|
||||
private val requestsExecutor = Executors.newSingleThreadExecutor { runnable ->
|
||||
Thread(runnable, "${options.protocolDebugName}.requests").apply {
|
||||
isDaemon = true
|
||||
}
|
||||
}
|
||||
private val requestsDispatcher = requestsExecutor.asCoroutineDispatcher()
|
||||
private val requestsScope = CoroutineScope(
|
||||
scope.coroutineContext
|
||||
+ SupervisorJob(scope.coroutineContext[Job])
|
||||
+ requestsDispatcher
|
||||
+ CoroutineName("${options.protocolDebugName}.requests")
|
||||
)
|
||||
private val state = AcpProtocolState()
|
||||
private val incomingRequests = AcpIncomingRequestManager(state, transport, json, requestsScope, trace)
|
||||
private val outgoingRequests = AcpOutgoingRequestManager(state, transport, json, options, trace)
|
||||
private val messagePump = AcpMessagePump(
|
||||
scope,
|
||||
transport,
|
||||
options.protocolDebugName,
|
||||
onMessage = ::handleIncomingMessage
|
||||
)
|
||||
|
||||
fun start() {
|
||||
incomingRequests.installMetaHandlers()
|
||||
messagePump.start()
|
||||
}
|
||||
|
||||
suspend fun sendRequestRaw(
|
||||
method: MethodName,
|
||||
params: JsonElement? = null
|
||||
): JsonElement {
|
||||
return outgoingRequests.sendRequestRaw(method, params)
|
||||
}
|
||||
|
||||
fun sendNotificationRaw(
|
||||
method: AcpMethod.AcpNotificationMethod<*>,
|
||||
params: JsonElement? = null
|
||||
) {
|
||||
val notification = JsonRpcNotification(method = method.methodName, params = params)
|
||||
trace?.invoke("-> ${notification.traceSummary()}")
|
||||
transport.send(notification)
|
||||
}
|
||||
|
||||
fun setRequestHandlerRaw(
|
||||
method: AcpMethod.AcpRequestResponseMethod<*, *>,
|
||||
additionalContext: CoroutineContext = EmptyCoroutineContext,
|
||||
handler: suspend (JsonRpcRequest) -> JsonElement?
|
||||
) {
|
||||
incomingRequests.setRequestHandlerRaw(method, additionalContext, handler)
|
||||
}
|
||||
|
||||
fun setNotificationHandlerRaw(
|
||||
method: AcpMethod.AcpNotificationMethod<*>,
|
||||
additionalContext: CoroutineContext = EmptyCoroutineContext,
|
||||
handler: suspend (JsonRpcNotification) -> Unit
|
||||
) {
|
||||
incomingRequests.setNotificationHandlerRaw(method, additionalContext, handler)
|
||||
}
|
||||
|
||||
fun close() {
|
||||
transport.close()
|
||||
val message = "Protocol closed"
|
||||
incomingRequests.cancelPendingIncomingRequests(CancellationException(message))
|
||||
outgoingRequests.cancelPendingOutgoingRequests(CancellationException(message))
|
||||
scope.cancel(message)
|
||||
requestsDispatcher.close()
|
||||
}
|
||||
|
||||
private suspend fun handleIncomingMessage(message: JsonRpcMessage) {
|
||||
trace?.invoke("<- ${message.traceSummary()}")
|
||||
when (message) {
|
||||
is JsonRpcNotification -> incomingRequests.handleNotification(message)
|
||||
is JsonRpcRequest -> incomingRequests.handleRequest(message)
|
||||
is JsonRpcResponse -> outgoingRequests.handleResponse(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
23
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolState.kt
vendored
Normal file
23
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolState.kt
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.rpc.JsonRpcNotification
|
||||
import com.agentclientprotocol.rpc.JsonRpcRequest
|
||||
import com.agentclientprotocol.rpc.MethodName
|
||||
import com.agentclientprotocol.rpc.RequestId
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
internal class AcpProtocolState {
|
||||
private val requestIdCounter = AtomicInteger(0)
|
||||
|
||||
val pendingOutgoingRequests = ConcurrentHashMap<RequestId, CompletableDeferred<JsonElement>>()
|
||||
val pendingIncomingRequests = ConcurrentHashMap<RequestId, Job>()
|
||||
val requestHandlers = ConcurrentHashMap<MethodName, suspend (JsonRpcRequest) -> JsonElement?>()
|
||||
val notificationHandlers =
|
||||
ConcurrentHashMap<MethodName, suspend (JsonRpcNotification) -> Unit>()
|
||||
|
||||
fun nextRequestId(): RequestId = RequestId.Companion.create(requestIdCounter.incrementAndGet())
|
||||
}
|
||||
40
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolSupport.kt
vendored
Normal file
40
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolSupport.kt
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.rpc.JsonRpcErrorCode
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.AcpExpectedError
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.JsonRpcException
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.serialization.SerializationException
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlin.coroutines.AbstractCoroutineContextElement
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
internal class JsonRpcRequestContextElement :
|
||||
AbstractCoroutineContextElement(Key) {
|
||||
object Key : CoroutineContext.Key<JsonRpcRequestContextElement>
|
||||
}
|
||||
|
||||
internal class JsonRpcIncomingRequestCanceledException(
|
||||
message: String,
|
||||
val data: JsonElement? = null
|
||||
) : CancellationException(message)
|
||||
|
||||
internal fun convertJsonRpcExceptionIfPossible(jsonRpcException: JsonRpcException): Exception {
|
||||
return when (jsonRpcException.code) {
|
||||
JsonRpcErrorCode.PARSE_ERROR.code -> SerializationException(
|
||||
jsonRpcException.message,
|
||||
jsonRpcException
|
||||
)
|
||||
|
||||
JsonRpcErrorCode.INVALID_PARAMS.code -> AcpExpectedError(
|
||||
jsonRpcException.message ?: "Invalid params"
|
||||
)
|
||||
|
||||
JsonRpcErrorCode.CANCELLED.code -> CancellationException(
|
||||
jsonRpcException.message ?: "Cancelled on the counterpart side",
|
||||
jsonRpcException
|
||||
)
|
||||
|
||||
else -> jsonRpcException
|
||||
}
|
||||
}
|
||||
32
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolTrace.kt
vendored
Normal file
32
src/main/kotlin/ee/carlrobert/codegpt/agent/external/runtime/AcpProtocolTrace.kt
vendored
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
package ee.carlrobert.codegpt.agent.external.runtime
|
||||
|
||||
import com.agentclientprotocol.rpc.JsonRpcError
|
||||
import com.agentclientprotocol.rpc.JsonRpcMessage
|
||||
import com.agentclientprotocol.rpc.JsonRpcNotification
|
||||
import com.agentclientprotocol.rpc.JsonRpcRequest
|
||||
import com.agentclientprotocol.rpc.JsonRpcResponse
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
|
||||
internal fun JsonRpcMessage.traceSummary(): String {
|
||||
return when (this) {
|
||||
is JsonRpcRequest -> "request id=$id method=$method params=${params.traceSummary()}"
|
||||
is JsonRpcNotification -> "notification method=$method params=${params.traceSummary()}"
|
||||
is JsonRpcResponse -> "response id=$id ${error.traceSummary(result)}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun JsonRpcError?.traceSummary(result: JsonElement?): String {
|
||||
return if (this == null) {
|
||||
"result=${result.traceSummary()}"
|
||||
} else {
|
||||
"errorCode=$code errorMessage=${message.orEmpty().singleLine()} errorData=${data.traceSummary()}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun JsonElement?.traceSummary(limit: Int = 400): String {
|
||||
return this?.toString()?.singleLine()?.take(limit) ?: "null"
|
||||
}
|
||||
|
||||
private fun String.singleLine(): String {
|
||||
return replace(Regex("\\s+"), " ").trim()
|
||||
}
|
||||
|
|
@ -13,8 +13,8 @@ import kotlinx.serialization.Serializable
|
|||
|
||||
class BashOutputTool(
|
||||
workingDirectory: String,
|
||||
hookManager: HookManager,
|
||||
private val sessionId: String,
|
||||
hookManager: HookManager,
|
||||
) :
|
||||
BaseTool<BashOutputTool.Args, BashOutputTool.Result>(
|
||||
workingDirectory = workingDirectory,
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ import java.nio.file.Paths
|
|||
* Enhanced search tool using IntelliJ's native SearchService and FindModel.
|
||||
*/
|
||||
class IntelliJSearchTool(
|
||||
hookManager: HookManager,
|
||||
sessionId: String,
|
||||
private val project: Project,
|
||||
sessionId: String,
|
||||
hookManager: HookManager,
|
||||
) : BaseTool<IntelliJSearchTool.Args, IntelliJSearchTool.Result>(
|
||||
workingDirectory = project.basePath ?: System.getProperty("user.dir"),
|
||||
argsSerializer = Args.serializer(),
|
||||
|
|
|
|||
|
|
@ -527,7 +527,7 @@ private class ExternalSubagentEventsAdapter(
|
|||
delegate.onSubAgentToolCompleted(parentId, childId, toolName, result)
|
||||
}
|
||||
|
||||
override fun onQueuedMessagesResolved() = Unit
|
||||
override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit
|
||||
|
||||
override suspend fun approveToolCall(request: ToolApprovalRequest): Boolean {
|
||||
return when (request.type) {
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ import java.net.URI
|
|||
|
||||
class WebFetchTool(
|
||||
workingDirectory: String,
|
||||
private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)",
|
||||
sessionId: String,
|
||||
hookManager: HookManager,
|
||||
private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)",
|
||||
) : BaseTool<WebFetchTool.Args, WebFetchTool.Result>(
|
||||
workingDirectory = workingDirectory,
|
||||
argsSerializer = Args.serializer(),
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ import java.time.format.DateTimeFormatter
|
|||
|
||||
class WebSearchTool(
|
||||
workingDirectory: String,
|
||||
private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)",
|
||||
sessionId: String,
|
||||
hookManager: HookManager,
|
||||
private val userAgent: String = "Mozilla/5.0 (compatible; ProxyAI/1.0; +https://tryproxy.io)",
|
||||
) : BaseTool<WebSearchTool.Args, WebSearchTool.Result>(
|
||||
workingDirectory = workingDirectory,
|
||||
argsSerializer = Args.serializer(),
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ internal object AgentCompletionRunner : CompletionRunner {
|
|||
}
|
||||
}
|
||||
|
||||
override fun onQueuedMessagesResolved() = Unit
|
||||
override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit
|
||||
}
|
||||
|
||||
val cancellableRequest = CancellableRequest {
|
||||
|
|
@ -137,7 +137,7 @@ internal object AgentCompletionRunner : CompletionRunner {
|
|||
}
|
||||
(msg as? ai.koog.prompt.message.Message.Reasoning)?.let {
|
||||
if (it.content.isNotBlank()) {
|
||||
events.onTextReceived("<think>${it.content}</think>")
|
||||
events.onThinkingReceived(it.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import com.intellij.openapi.diagnostic.Logger
|
|||
import com.intellij.openapi.diagnostic.thisLogger
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import ee.carlrobert.codegpt.agent.external.ExternalAcpAgents
|
||||
import ee.carlrobert.codegpt.settings.agents.SubagentDefaults
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookConfiguration
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package ee.carlrobert.codegpt.toolwindow.agent
|
|||
|
||||
import ai.koog.agents.core.agent.exception.AIAgentStuckInTheNodeException
|
||||
import ai.koog.http.client.KoogHttpClientException
|
||||
import com.agentclientprotocol.model.PlanEntry
|
||||
import com.agentclientprotocol.model.PlanEntryStatus
|
||||
import com.intellij.openapi.Disposable
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
|
|
@ -18,6 +20,7 @@ import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
|||
import ee.carlrobert.codegpt.agent.rollback.RollbackService
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import ee.carlrobert.codegpt.settings.agents.SubagentRuntimeResolver
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTApiException
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.*
|
||||
|
|
@ -297,6 +300,12 @@ class AgentEventHandler(
|
|||
} else {
|
||||
request
|
||||
}
|
||||
logger.debug(
|
||||
"Enqueueing agent approval for session=$sessionId type=${resolvedRequest.type} title=${resolvedRequest.title.logPreview()} payload=${resolvedRequest.payload.logSummary()}"
|
||||
)
|
||||
approvalTrace(
|
||||
"approval/enqueue session=$sessionId type=${resolvedRequest.type} title=${resolvedRequest.title.logPreview()} payload=${resolvedRequest.payload.logSummary()}"
|
||||
)
|
||||
runInEdt {
|
||||
approvalQueue.addLast(ApprovalRequest(resolvedRequest, deferred))
|
||||
maybeShowNextApproval()
|
||||
|
|
@ -318,6 +327,12 @@ class AgentEventHandler(
|
|||
}
|
||||
|
||||
val decision = CompletableDeferred<Boolean>()
|
||||
logger.debug(
|
||||
"Enqueueing agent approval for session=$sessionId type=${request.type} title=${request.title.logPreview()} payload=${request.payload.logSummary()}"
|
||||
)
|
||||
approvalTrace(
|
||||
"approval/enqueue session=$sessionId type=${request.type} title=${request.title.logPreview()} payload=${request.payload.logSummary()}"
|
||||
)
|
||||
runInEdt {
|
||||
approvalQueue.addLast(ApprovalRequest(request, decision))
|
||||
maybeShowNextApproval()
|
||||
|
|
@ -344,6 +359,23 @@ class AgentEventHandler(
|
|||
}
|
||||
}
|
||||
|
||||
override fun onThinkingReceived(text: String) {
|
||||
runInEdt {
|
||||
currentResponseBody?.appendThinking(text)
|
||||
scrollablePanel.update()
|
||||
scrollablePanel.scrollToBottom()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onPlanUpdated(entries: List<PlanEntry>) {
|
||||
runInEdt {
|
||||
todoListPanel.updateTodos(entries.toTodoItems())
|
||||
todoListPanel.isVisible = entries.isNotEmpty()
|
||||
scrollablePanel.update()
|
||||
scrollablePanel.scrollToBottom()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onToolStarting(id: String, toolName: String, args: Any?) {
|
||||
when (args) {
|
||||
is TodoWriteTool.Args -> {
|
||||
|
|
@ -511,10 +543,29 @@ class AgentEventHandler(
|
|||
}
|
||||
}
|
||||
|
||||
override fun onQueuedMessagesResolved() {
|
||||
val pendingMessage = project.service<AgentService>()
|
||||
.getPendingMessages(sessionId)
|
||||
.firstOrNull { it.uiVisible } ?: return
|
||||
private fun List<PlanEntry>.toTodoItems(): List<TodoWriteTool.TodoItem> {
|
||||
return map { entry ->
|
||||
TodoWriteTool.TodoItem(
|
||||
content = entry.content,
|
||||
status = when (entry.status) {
|
||||
PlanEntryStatus.PENDING -> TodoWriteTool.TodoStatus.PENDING
|
||||
PlanEntryStatus.IN_PROGRESS -> TodoWriteTool.TodoStatus.IN_PROGRESS
|
||||
PlanEntryStatus.COMPLETED -> TodoWriteTool.TodoStatus.COMPLETED
|
||||
},
|
||||
activeForm = entry.content
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onQueuedMessagesResolved(message: MessageWithContext?) {
|
||||
val pendingMessage = message
|
||||
?: project.service<AgentService>()
|
||||
.getPendingMessages(sessionId)
|
||||
.firstOrNull { it.uiVisible }
|
||||
?: return
|
||||
logger.debug(
|
||||
"Resolving queued message in UI for session=$sessionId messageId=${pendingMessage.id} uiVisible=${pendingMessage.uiVisible} preview=${pendingMessage.text.logPreview()}"
|
||||
)
|
||||
onQueuedMessagesResolved(pendingMessage)
|
||||
}
|
||||
|
||||
|
|
@ -525,11 +576,48 @@ class AgentEventHandler(
|
|||
}
|
||||
|
||||
override fun onTokenUsageAvailable(tokenUsage: Long) {
|
||||
lastReportedPromptTokens = tokenUsage
|
||||
onUsageAvailable(AgentUsageEvent(usedTokens = tokenUsage))
|
||||
}
|
||||
|
||||
val event = TokenUsageEvent(sessionId, tokenUsage)
|
||||
override fun onUsageAvailable(event: AgentUsageEvent) {
|
||||
lastReportedPromptTokens = event.usedTokens
|
||||
project.messageBus.syncPublisher(TokenUsageListener.TOKEN_USAGE_TOPIC)
|
||||
.onTokenUsageChanged(event)
|
||||
.onTokenUsageChanged(
|
||||
TokenUsageEvent(
|
||||
sessionId = sessionId,
|
||||
totalTokens = event.usedTokens,
|
||||
sizeTokens = event.sizeTokens,
|
||||
costAmount = event.costAmount,
|
||||
costCurrency = event.costCurrency
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
override fun onRuntimeOptionsUpdated() {
|
||||
runInEdt {
|
||||
userInputPanel.refreshModelDependentState()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onSessionInfoUpdated(title: String?, updatedAt: String?) {
|
||||
val normalizedTitle = title?.trim().orEmpty()
|
||||
if (normalizedTitle.isEmpty()) {
|
||||
return
|
||||
}
|
||||
|
||||
val contentManager = project.service<AgentToolWindowContentManager>()
|
||||
val session = contentManager.getSession(sessionId) ?: return
|
||||
val previousExternalTitle = session.externalAgentSessionTitle
|
||||
session.externalAgentSessionTitle = normalizedTitle
|
||||
|
||||
val shouldRename = session.displayName.isBlank() ||
|
||||
session.displayName == previousExternalTitle ||
|
||||
session.displayName.matches(Regex("""Agent \d+( \(\d+\))?"""))
|
||||
|
||||
if (shouldRename) {
|
||||
project.messageBus.syncPublisher(AgentTabTitleNotifier.AGENT_TAB_TITLE_TOPIC)
|
||||
.updateTabTitle(sessionId, normalizedTitle)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onCreditsAvailable(event: AgentCreditsEvent) {
|
||||
|
|
@ -587,6 +675,12 @@ class AgentEventHandler(
|
|||
|
||||
val contentManager = project.service<AgentToolWindowContentManager>()
|
||||
if (contentManager.isSessionAutoApproved(sessionId)) {
|
||||
approvalTrace(
|
||||
"approval/auto-approve session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}"
|
||||
)
|
||||
logger.debug(
|
||||
"Auto-approving queued agent approval for session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}"
|
||||
)
|
||||
next.deferred.complete(true)
|
||||
currentApproval = null
|
||||
maybeShowNextApproval()
|
||||
|
|
@ -615,6 +709,13 @@ class AgentEventHandler(
|
|||
updateEditToolCardPreview(next.model)
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Showing agent approval for session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()} payload=${next.model.payload.logSummary()} queueRemaining=${approvalQueue.size}"
|
||||
)
|
||||
approvalTrace(
|
||||
"approval/show session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()} payload=${next.model.payload.logSummary()} queueRemaining=${approvalQueue.size}"
|
||||
)
|
||||
|
||||
runCatching {
|
||||
project.service<AgentToolWindowContentManager>()
|
||||
.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.APPROVAL)
|
||||
|
|
@ -629,6 +730,12 @@ class AgentEventHandler(
|
|||
.markSessionAsAutoApproved(sessionId)
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Approved agent approval for session=$sessionId type=${next.model.type} auto=$auto title=${next.model.title.logPreview()}"
|
||||
)
|
||||
approvalTrace(
|
||||
"approval/approved session=$sessionId type=${next.model.type} auto=$auto title=${next.model.title.logPreview()}"
|
||||
)
|
||||
next.deferred.complete(true)
|
||||
currentApproval = null
|
||||
clearApprovalContainer()
|
||||
|
|
@ -639,6 +746,12 @@ class AgentEventHandler(
|
|||
maybeShowNextApproval()
|
||||
},
|
||||
onReject = {
|
||||
logger.debug(
|
||||
"Rejected agent approval for session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}"
|
||||
)
|
||||
approvalTrace(
|
||||
"approval/rejected session=$sessionId type=${next.model.type} title=${next.model.title.logPreview()}"
|
||||
)
|
||||
next.deferred.complete(false)
|
||||
currentApproval = null
|
||||
clearApprovalContainer()
|
||||
|
|
@ -758,6 +871,28 @@ class AgentEventHandler(
|
|||
return viewHolder
|
||||
}
|
||||
|
||||
private fun String.logPreview(limit: Int = 120): String {
|
||||
return replace('\n', ' ')
|
||||
.replace(Regex("\\s+"), " ")
|
||||
.trim()
|
||||
.take(limit)
|
||||
}
|
||||
|
||||
private fun ToolApprovalPayload?.logSummary(): String {
|
||||
return when (this) {
|
||||
is WritePayload -> "write:${filePath.logPreview(80)}"
|
||||
is EditPayload -> "edit:${filePath.logPreview(80)}"
|
||||
is BashPayload -> "bash:${command.logPreview(80)}"
|
||||
null -> "none"
|
||||
}
|
||||
}
|
||||
|
||||
private fun approvalTrace(message: String) {
|
||||
if (ConfigurationSettings.getState().debugModeEnabled) {
|
||||
logger.info("[APPROVAL TRACE] $message")
|
||||
}
|
||||
}
|
||||
|
||||
class RunViewHolder(
|
||||
private val vm: AgentRunViewModel,
|
||||
private val view: AgentRunDslPanel,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import com.agentclientprotocol.model.AvailableCommand
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
|
||||
data class AcpConfigOptionChoice(
|
||||
val value: String,
|
||||
|
|
@ -25,7 +27,7 @@ object AcpConfigOptions {
|
|||
fun selectable(options: List<AcpConfigOption>): List<AcpConfigOption> {
|
||||
return options
|
||||
.asSequence()
|
||||
.filter { it.type == "select" && it.options.isNotEmpty() }
|
||||
.filter { it.type in setOf("select", "boolean") && it.options.isNotEmpty() }
|
||||
.sortedBy { categoryOrder(it.category) }
|
||||
.toList()
|
||||
}
|
||||
|
|
@ -40,7 +42,9 @@ object AcpConfigOptions {
|
|||
}
|
||||
|
||||
fun selectedValueName(options: List<AcpConfigOption>, category: String): String? {
|
||||
val option = selectable(options).firstOrNull { it.category == category } ?: return null
|
||||
val option = selectable(options).firstOrNull {
|
||||
it.category == category || it.id == category
|
||||
} ?: return null
|
||||
return option.options
|
||||
.firstOrNull { it.value == option.currentValue }
|
||||
?.name
|
||||
|
|
@ -73,6 +77,24 @@ object AcpConfigOptions {
|
|||
.toMap(linkedMapOf())
|
||||
}
|
||||
|
||||
fun summaryParts(
|
||||
options: List<AcpConfigOption>,
|
||||
maxEntries: Int = 3
|
||||
): List<String> {
|
||||
return selectable(options)
|
||||
.mapNotNull { option ->
|
||||
selectedValueName(options, option.category ?: option.id)?.let { value ->
|
||||
if (option.category in setOf("model", "mode", "thought_level")) {
|
||||
value
|
||||
} else {
|
||||
"${label(option)}: $value"
|
||||
}
|
||||
}
|
||||
}
|
||||
.distinct()
|
||||
.take(maxEntries)
|
||||
}
|
||||
|
||||
private fun categoryOrder(category: String?): Int {
|
||||
return when (category.orEmpty()) {
|
||||
"model" -> 0
|
||||
|
|
@ -95,6 +117,7 @@ data class AgentSession(
|
|||
var modelCode: String? = null,
|
||||
var externalAgentId: String? = null,
|
||||
var externalAgentSessionId: String? = null,
|
||||
var externalAgentMcpServerIds: Set<String> = emptySet(),
|
||||
var externalAgentConfigOptions: List<AcpConfigOption> = emptyList(),
|
||||
var externalAgentConfigSelections: Map<String, String> = emptyMap(),
|
||||
var externalAgentConfigLoading: Boolean = false,
|
||||
|
|
@ -105,4 +128,14 @@ data class AgentSession(
|
|||
var lastActiveAt: Long = System.currentTimeMillis()
|
||||
) {
|
||||
var externalAgentErrorMessage: String? = null
|
||||
var externalAgentPeerProfileId: String? = null
|
||||
var externalAgentAvailableCommands: List<AvailableCommand> = emptyList()
|
||||
var externalAgentVendorMeta: JsonElement? = null
|
||||
var externalAgentRequestMeta: JsonElement? = null
|
||||
var externalAgentSessionTitle: String? = null
|
||||
|
||||
fun shouldRecreateExternalAgentSession(selectedMcpServerIds: Set<String>): Boolean {
|
||||
return !externalAgentSessionId.isNullOrBlank() &&
|
||||
externalAgentMcpServerIds != selectedMcpServerIds
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class AgentToolWindowPanel(
|
|||
private val tabbedPane = contentManager.initializeTabbedPane()
|
||||
private val centerLayout = CardLayout()
|
||||
private val centerPanel = JPanel(centerLayout)
|
||||
private val creditsLabel = AgentCreditsToolbarLabel(project, ::currentSession)
|
||||
private var landingPanel: AgentToolWindowTabPanel? = null
|
||||
|
||||
init {
|
||||
|
|
@ -75,7 +76,6 @@ class AgentToolWindowPanel(
|
|||
val toolbar = ActionManager.getInstance()
|
||||
.createActionToolbar("AgentToolWindow", actionGroup, true)
|
||||
toolbar.targetComponent = tabbedPane
|
||||
val creditsLabel = AgentCreditsToolbarLabel(project)
|
||||
Disposer.register(this, creditsLabel)
|
||||
return BorderLayoutPanel()
|
||||
.addToLeft(toolbar.component)
|
||||
|
|
@ -87,6 +87,7 @@ class AgentToolWindowPanel(
|
|||
private fun showTabsView() {
|
||||
disposeLandingPanel()
|
||||
centerLayout.show(centerPanel, TABS_CARD)
|
||||
creditsLabel.refresh()
|
||||
}
|
||||
|
||||
private fun showLandingView() {
|
||||
|
|
@ -97,6 +98,7 @@ class AgentToolWindowPanel(
|
|||
landingPanel?.requestFocusForTextArea()
|
||||
centerPanel.revalidate()
|
||||
centerPanel.repaint()
|
||||
creditsLabel.refresh()
|
||||
}
|
||||
|
||||
private fun createLandingPanel(): AgentToolWindowTabPanel {
|
||||
|
|
@ -122,6 +124,11 @@ class AgentToolWindowPanel(
|
|||
landingPanel = null
|
||||
}
|
||||
|
||||
private fun currentSession(): AgentSession? {
|
||||
return landingPanel?.getAgentSession()
|
||||
?: contentManager.getActiveTabPanel()?.getAgentSession()
|
||||
}
|
||||
|
||||
override fun dispose() {
|
||||
tabbedPane.setTabLifecycleCallbacks(onTabsOpened = {}, onAllTabsClosed = {})
|
||||
disposeLandingPanel()
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ class AgentToolWindowTabPanel(
|
|||
).apply {
|
||||
isVisible = false
|
||||
}
|
||||
private val tokenUsageCounterPanel = TokenUsageCounterPanel(project, sessionId)
|
||||
|
||||
private val userInputPanel = UserInputPanel(
|
||||
project,
|
||||
|
|
@ -128,8 +129,8 @@ class AgentToolWindowTabPanel(
|
|||
onSubmit = ::handleSubmit,
|
||||
onStop = ::handleCancel,
|
||||
withRemovableSelectedEditorTag = true,
|
||||
agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId),
|
||||
agentTokenCounterVisibilityProvider = { agentSession.externalAgentId.isNullOrBlank() },
|
||||
agentTokenCounterPanel = tokenUsageCounterPanel,
|
||||
agentTokenCounterVisibilityProvider = { tokenUsageCounterPanel.hasReportedUsage() },
|
||||
sessionIdProvider = { sessionId },
|
||||
conversationIdProvider = { conversation.id },
|
||||
onStartSessionTimeline = ::showSessionStartTimelineDialog,
|
||||
|
|
@ -300,8 +301,7 @@ class AgentToolWindowTabPanel(
|
|||
if (message.text.isBlank()) return
|
||||
disposeLandingPanelIfPresent()
|
||||
scrollablePanel.clearLandingViewIfVisible()
|
||||
val modelSettings = ModelSettings.getInstance()
|
||||
val agentModelSelection = modelSettings.getModelSelectionForFeature(FeatureType.AGENT)
|
||||
val agentModelSelection = service<ModelSettings>().getModelSelectionForFeature(FeatureType.AGENT)
|
||||
agentSession.serviceType = agentModelSelection.provider
|
||||
agentSession.modelCode = agentModelSelection.selectionId
|
||||
|
||||
|
|
@ -409,10 +409,13 @@ class AgentToolWindowTabPanel(
|
|||
project.service<ExternalAcpAgentService>().closeSession(sessionId)
|
||||
agentSession.externalAgentId = externalAgentId
|
||||
agentSession.externalAgentSessionId = null
|
||||
agentSession.externalAgentMcpServerIds = emptySet()
|
||||
agentSession.externalAgentConfigOptions = emptyList()
|
||||
agentSession.externalAgentConfigSelections = emptyMap()
|
||||
agentSession.externalAgentErrorMessage = null
|
||||
agentSession.externalAgentConfigLoading = !externalAgentId.isNullOrBlank()
|
||||
project.messageBus.syncPublisher(AgentUiStateNotifier.AGENT_UI_STATE_TOPIC)
|
||||
.sessionRuntimeChanged(sessionId)
|
||||
if (!externalAgentId.isNullOrBlank()) {
|
||||
backgroundScope.launch {
|
||||
runCatching {
|
||||
|
|
@ -474,15 +477,18 @@ class AgentToolWindowTabPanel(
|
|||
withContext(Dispatchers.EDT) {
|
||||
inputPanel.refreshModelDependentState()
|
||||
}
|
||||
runCatching {
|
||||
project.service<ExternalAcpAgentService>()
|
||||
.setSessionConfigOption(agentSession, optionId, value)
|
||||
}.onFailure { ex ->
|
||||
try {
|
||||
runCatching {
|
||||
project.service<ExternalAcpAgentService>()
|
||||
.setSessionConfigOption(agentSession, optionId, value)
|
||||
}.onFailure { ex ->
|
||||
OverlayUtil.showNotification(
|
||||
"${displayExternalAgentName(agentSession.externalAgentId ?: "agent")} option update failed. ${buildExternalAgentConfigFailureMessage(ex)}",
|
||||
NotificationType.ERROR
|
||||
)
|
||||
}
|
||||
} finally {
|
||||
agentSession.externalAgentConfigLoading = false
|
||||
OverlayUtil.showNotification(
|
||||
"${displayExternalAgentName(agentSession.externalAgentId ?: "agent")} option update failed. ${buildExternalAgentConfigFailureMessage(ex)}",
|
||||
NotificationType.ERROR
|
||||
)
|
||||
}
|
||||
withContext(Dispatchers.EDT) {
|
||||
inputPanel.refreshModelDependentState()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,11 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
init {
|
||||
tabComponentInsets = null
|
||||
setComponentPopupMenu(TabPopupMenu())
|
||||
addChangeListener { refreshTabState() }
|
||||
addChangeListener {
|
||||
refreshTabState()
|
||||
project.messageBus.syncPublisher(AgentUiStateNotifier.AGENT_UI_STATE_TOPIC)
|
||||
.activeSessionChanged()
|
||||
}
|
||||
}
|
||||
|
||||
enum class TabStatus {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import com.intellij.util.messages.Topic
|
||||
|
||||
interface AgentUiContextListener {
|
||||
fun onUiContextChanged()
|
||||
|
||||
companion object {
|
||||
val AGENT_UI_CONTEXT_TOPIC = Topic("Agent UI Context Changes", AgentUiContextListener::class.java)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import com.intellij.util.messages.Topic
|
||||
|
||||
interface AgentUiStateNotifier {
|
||||
|
||||
fun activeSessionChanged()
|
||||
|
||||
fun sessionRuntimeChanged(sessionId: String)
|
||||
|
||||
companion object {
|
||||
@JvmField
|
||||
val AGENT_UI_STATE_TOPIC: Topic<AgentUiStateNotifier> =
|
||||
Topic.create("agentUiState", AgentUiStateNotifier::class.java)
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,9 @@ import com.intellij.util.messages.Topic
|
|||
data class TokenUsageEvent(
|
||||
val sessionId: String,
|
||||
val totalTokens: Long,
|
||||
val sizeTokens: Long? = null,
|
||||
val costAmount: Double? = null,
|
||||
val costCurrency: String? = null,
|
||||
)
|
||||
|
||||
interface TokenUsageListener {
|
||||
|
|
@ -13,4 +16,4 @@ interface TokenUsageListener {
|
|||
companion object {
|
||||
val TOKEN_USAGE_TOPIC = Topic("Token Usage Changes", TokenUsageListener::class.java)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,17 +10,23 @@ import com.intellij.util.ui.JBUI
|
|||
import ee.carlrobert.codegpt.CodeGPTBundle
|
||||
import ee.carlrobert.codegpt.CodeGPTKeys
|
||||
import ee.carlrobert.codegpt.settings.models.ModelSettings
|
||||
import ee.carlrobert.codegpt.settings.service.*
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelChangeNotifier
|
||||
import ee.carlrobert.codegpt.settings.service.ModelChangeNotifierAdapter
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTService
|
||||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTUserDetails
|
||||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTUserDetailsNotifier
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentSession
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsListener
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentUiStateNotifier
|
||||
import java.text.NumberFormat
|
||||
import java.util.*
|
||||
|
||||
class AgentCreditsToolbarLabel(
|
||||
private val project: Project
|
||||
private val project: Project,
|
||||
private val sessionProvider: () -> AgentSession?
|
||||
) : JBLabel(), Disposable {
|
||||
|
||||
private val numberFormat = NumberFormat.getNumberInstance(Locale.US).apply {
|
||||
|
|
@ -67,13 +73,31 @@ class AgentCreditsToolbarLabel(
|
|||
}
|
||||
}
|
||||
)
|
||||
messageBusConnection.subscribe(
|
||||
AgentUiStateNotifier.AGENT_UI_STATE_TOPIC,
|
||||
object : AgentUiStateNotifier {
|
||||
override fun activeSessionChanged() {
|
||||
updateDisplay()
|
||||
}
|
||||
|
||||
override fun sessionRuntimeChanged(sessionId: String) {
|
||||
updateDisplay()
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fun refresh() {
|
||||
updateDisplay()
|
||||
}
|
||||
|
||||
private fun updateDisplay() {
|
||||
ApplicationManager.getApplication().invokeLater {
|
||||
val activeSession = sessionProvider()
|
||||
val provider = ModelSettings.getInstance()
|
||||
.getServiceForFeature(FeatureType.AGENT)
|
||||
if (provider != ServiceType.PROXYAI) {
|
||||
val isExternalAgentSelected = !activeSession?.externalAgentId.isNullOrBlank()
|
||||
if (provider != ServiceType.PROXYAI || isExternalAgentSelected) {
|
||||
text = null
|
||||
toolTipText = null
|
||||
isVisible = false
|
||||
|
|
|
|||
|
|
@ -316,6 +316,7 @@ class AgentModelComboBoxAction(
|
|||
override fun actionPerformed(event: AnActionEvent) {
|
||||
agentSession.externalAgentId = preset.id
|
||||
agentSession.externalAgentSessionId = null
|
||||
agentSession.externalAgentMcpServerIds = emptySet()
|
||||
onAgentRuntimeChanged(preset.id)
|
||||
updateExternalAgentPresentation(preset.id)
|
||||
onModelChange(modelSettings.getServiceForFeature(FeatureType.AGENT))
|
||||
|
|
@ -484,6 +485,7 @@ class AgentModelComboBoxAction(
|
|||
private fun clearExternalAgentSelection() {
|
||||
agentSession.externalAgentId = null
|
||||
agentSession.externalAgentSessionId = null
|
||||
agentSession.externalAgentMcpServerIds = emptySet()
|
||||
onAgentRuntimeChanged(null)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -144,13 +144,7 @@ class AgentRuntimeOptionsComboBoxAction(
|
|||
?.takeIf(String::isNotBlank)
|
||||
?.let { return it }
|
||||
|
||||
val parts = listOfNotNull(
|
||||
AcpConfigOptions.selectedValueName(agentSession.externalAgentConfigOptions, "model"),
|
||||
AcpConfigOptions.selectedValueName(agentSession.externalAgentConfigOptions, "thought_level")
|
||||
).ifEmpty {
|
||||
listOfNotNull(AcpConfigOptions.selectedValueName(agentSession.externalAgentConfigOptions, "mode"))
|
||||
}
|
||||
|
||||
val parts = AcpConfigOptions.summaryParts(agentSession.externalAgentConfigOptions)
|
||||
return parts.takeIf { it.isNotEmpty() }?.joinToString(" · ") ?: "Options"
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import com.intellij.ui.JBColor
|
|||
import com.intellij.util.ui.JBUI
|
||||
import com.intellij.util.ui.components.BorderLayoutPanel
|
||||
import ee.carlrobert.codegpt.Icons
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpBashPreviewArgs
|
||||
import ee.carlrobert.codegpt.agent.external.events.AcpSearchPreviewArgs
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.AgentUiConfig
|
||||
|
|
@ -66,11 +68,11 @@ object ToolCallDescriptorFactory {
|
|||
|
||||
private fun detectToolKind(toolName: String, args: Any, result: Any?): ToolKind {
|
||||
return when {
|
||||
toolName == "IntelliJSearch" || args is IntelliJSearchTool.Args -> ToolKind.SEARCH
|
||||
toolName == "IntelliJSearch" || args is IntelliJSearchTool.Args || args is AcpSearchPreviewArgs -> ToolKind.SEARCH
|
||||
toolName == "Read" || args is ReadTool.Args -> ToolKind.READ
|
||||
toolName == "Write" || args is WriteTool.Args -> ToolKind.WRITE
|
||||
toolName == "Edit" || args is EditTool.Args -> ToolKind.EDIT
|
||||
toolName == "Bash" || args is BashTool.Args -> ToolKind.BASH
|
||||
toolName == "Bash" || args is BashTool.Args || args is AcpBashPreviewArgs -> ToolKind.BASH
|
||||
toolName == "BashOutput" || args is BashOutputTool.Args -> ToolKind.BASH_OUTPUT
|
||||
toolName == "KillShell" || args is KillShellTool.Args -> ToolKind.KILL_SHELL
|
||||
toolName == "WebSearch" || args is WebSearchTool.Args || result is WebSearchTool.Result -> ToolKind.WEB
|
||||
|
|
@ -465,19 +467,24 @@ object ToolCallDescriptorFactory {
|
|||
): ToolCallDescriptor {
|
||||
val command = when (args) {
|
||||
is BashTool.Args -> args.command
|
||||
is AcpBashPreviewArgs -> args.command ?: args.title
|
||||
is BashOutputTool.Args -> args.bashId
|
||||
is KillShellTool.Args -> "kill_shell"
|
||||
else -> "Unknown"
|
||||
}
|
||||
val isGenericBashPreview = args is AcpBashPreviewArgs &&
|
||||
args.command == null &&
|
||||
args.title.equals("Run shell command", ignoreCase = true)
|
||||
val titleMain = if (isGenericBashPreview) "Pending command" else truncateCommand(command)
|
||||
val tooltip = if (isGenericBashPreview) "Command pending approval" else "Command: $command"
|
||||
|
||||
val truncatedCommand = truncateCommand(command)
|
||||
|
||||
return ToolCallDescriptor(
|
||||
kind = ToolKind.BASH,
|
||||
icon = AllIcons.Nodes.Console,
|
||||
titlePrefix = "Bash:",
|
||||
titleMain = truncatedCommand,
|
||||
tooltip = "Command: $command",
|
||||
titleMain = titleMain,
|
||||
tooltip = tooltip,
|
||||
supportsStreaming = true,
|
||||
args = args,
|
||||
result = result,
|
||||
|
|
@ -504,16 +511,29 @@ object ToolCallDescriptorFactory {
|
|||
projectId: String?
|
||||
): ToolCallDescriptor {
|
||||
val searchArgs = args as? IntelliJSearchTool.Args
|
||||
val pattern = searchArgs?.pattern ?: ""
|
||||
val scopeOrPath = searchArgs?.path?.substringAfterLast('/') ?: (searchArgs?.scope ?: "")
|
||||
val titleMain = buildSearchDisplay(truncatePattern(pattern), scopeOrPath)
|
||||
val searchPreviewArgs = args as? AcpSearchPreviewArgs
|
||||
val pattern = searchArgs?.pattern ?: searchPreviewArgs?.pattern.orEmpty()
|
||||
val scopeOrPath = searchArgs?.path?.substringAfterLast('/')
|
||||
?: searchPreviewArgs?.path?.substringAfterLast('/')
|
||||
?: (searchArgs?.scope ?: "")
|
||||
val titleMain = if (pattern.isBlank()) {
|
||||
scopeOrPath.ifBlank {
|
||||
searchPreviewArgs?.title?.takeIf {
|
||||
it.isNotBlank() && !it.equals("search", ignoreCase = true)
|
||||
} ?: "Search"
|
||||
}
|
||||
} else {
|
||||
buildSearchDisplay(truncatePattern(pattern), scopeOrPath)
|
||||
}
|
||||
|
||||
return ToolCallDescriptor(
|
||||
kind = ToolKind.SEARCH,
|
||||
icon = AllIcons.Actions.Search,
|
||||
titlePrefix = "Search:",
|
||||
titleMain = titleMain,
|
||||
tooltip = if (scopeOrPath.isBlank()) {
|
||||
tooltip = if (pattern.isBlank()) {
|
||||
searchPreviewArgs?.path?.let { "Search in $it" } ?: "Search"
|
||||
} else if (scopeOrPath.isBlank()) {
|
||||
"Search: \"$pattern\""
|
||||
} else {
|
||||
"Search: \"$pattern\" in $scopeOrPath"
|
||||
|
|
|
|||
|
|
@ -40,6 +40,10 @@ class TokenUsageCounterPanel(
|
|||
private var messageBusConnection: MessageBusConnection? = null
|
||||
private var lastEstimatedPromptTokens: Long = 0L
|
||||
private var calibrationOffset: Long = 0L
|
||||
private var lastReportedSizeTokens: Long? = null
|
||||
private var lastReportedCostAmount: Double? = null
|
||||
private var lastReportedCostCurrency: String? = null
|
||||
private var hasReportedUsage: Boolean = false
|
||||
|
||||
init {
|
||||
Disposer.register(ApplicationManager.getApplication(), scope)
|
||||
|
|
@ -47,7 +51,7 @@ class TokenUsageCounterPanel(
|
|||
isOpaque = false
|
||||
font = JBFont.small()
|
||||
border = JBUI.Borders.empty(5, 7)
|
||||
text = "100% context left"
|
||||
isVisible = false
|
||||
setupMessageBusConnection()
|
||||
}
|
||||
|
||||
|
|
@ -61,7 +65,11 @@ class TokenUsageCounterPanel(
|
|||
if (event.sessionId == sessionId) {
|
||||
val model = getAgentModelForSession(event.sessionId)
|
||||
?: getSelectedAgentModel()
|
||||
?: return
|
||||
hasReportedUsage = true
|
||||
isVisible = true
|
||||
lastReportedSizeTokens = event.sizeTokens?.takeIf { it > 0L }
|
||||
lastReportedCostAmount = event.costAmount
|
||||
lastReportedCostCurrency = event.costCurrency?.takeIf { it.isNotBlank() }
|
||||
calibrationOffset = event.totalTokens - lastEstimatedPromptTokens
|
||||
updateDisplay(event.totalTokens, model)
|
||||
}
|
||||
|
|
@ -71,13 +79,18 @@ class TokenUsageCounterPanel(
|
|||
}
|
||||
|
||||
fun updateFromTotalTokens(totalTokens: Long) {
|
||||
if (!hasReportedUsage) {
|
||||
return
|
||||
}
|
||||
lastEstimatedPromptTokens = totalTokens
|
||||
val model = getSelectedAgentModel() ?: return
|
||||
val model = getSelectedAgentModel()
|
||||
val calibratedTotal = (totalTokens + calibrationOffset).coerceAtLeast(0L)
|
||||
updateDisplay(calibratedTotal, model)
|
||||
}
|
||||
|
||||
private fun updateDisplay(totalTokens: Long, model: LLModel) {
|
||||
fun hasReportedUsage(): Boolean = hasReportedUsage
|
||||
|
||||
private fun updateDisplay(totalTokens: Long, model: LLModel?) {
|
||||
scope.launch {
|
||||
withEdt {
|
||||
updateColorAndText(totalTokens, model)
|
||||
|
|
@ -88,10 +101,14 @@ class TokenUsageCounterPanel(
|
|||
}
|
||||
}
|
||||
|
||||
private fun updateColorAndText(usedPromptTokens: Long, model: LLModel) {
|
||||
val budget = computeBudget(model)
|
||||
private fun updateColorAndText(usedPromptTokens: Long, model: LLModel?) {
|
||||
val effectiveCapacity = (
|
||||
lastReportedSizeTokens
|
||||
?: model?.let { computeBudget(it).inputBudget }
|
||||
?: return
|
||||
).coerceAtLeast(1L)
|
||||
val percentageUsed =
|
||||
((usedPromptTokens.coerceAtLeast(0).toDouble() / budget.inputBudget.toDouble()) * 100.0)
|
||||
((usedPromptTokens.coerceAtLeast(0).toDouble() / effectiveCapacity.toDouble()) * 100.0)
|
||||
.coerceIn(0.0, 100.0)
|
||||
val percentageLeft = (100.0 - percentageUsed).coerceIn(0.0, 100.0)
|
||||
|
||||
|
|
@ -105,14 +122,33 @@ class TokenUsageCounterPanel(
|
|||
text = "${percentageLeft.toInt()}% context left"
|
||||
}
|
||||
|
||||
private fun updateTooltipText(totalTokens: Long, model: LLModel) {
|
||||
val budget = computeBudget(model)
|
||||
private fun updateTooltipText(totalTokens: Long, model: LLModel?) {
|
||||
toolTipText = buildString {
|
||||
append("<html><body>")
|
||||
append("<b>Usage Details</b><br>")
|
||||
append("Input size: ${numberFormat.format(totalTokens)} tokens<br>")
|
||||
append("Max output size: ${numberFormat.format(budget.reservedOutput)} tokens<br>")
|
||||
append("Max context size: ${numberFormat.format(budget.contextLength)} tokens<br>")
|
||||
append("Current usage: ${numberFormat.format(totalTokens)} tokens<br>")
|
||||
lastReportedSizeTokens?.let {
|
||||
append("Reported context size: ${numberFormat.format(it)} tokens<br>")
|
||||
append(
|
||||
"Context remaining: ${
|
||||
numberFormat.format((it - totalTokens).coerceAtLeast(0L))
|
||||
} tokens<br>"
|
||||
)
|
||||
}
|
||||
if (lastReportedSizeTokens == null) {
|
||||
model?.let {
|
||||
val budget = computeBudget(it)
|
||||
append("Max output size: ${numberFormat.format(budget.reservedOutput)} tokens<br>")
|
||||
append("Max context size: ${numberFormat.format(budget.contextLength)} tokens<br>")
|
||||
}
|
||||
}
|
||||
if (lastReportedCostAmount != null && lastReportedCostCurrency != null) {
|
||||
append(
|
||||
"Reported cost: ${
|
||||
numberFormat.format(lastReportedCostAmount)
|
||||
} ${lastReportedCostCurrency}<br>"
|
||||
)
|
||||
}
|
||||
append("</body></html>")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package ee.carlrobert.codegpt.ui.hover
|
||||
|
||||
import com.intellij.codeInsight.documentation.DocumentationHintEditorPane
|
||||
import com.intellij.codeInsight.documentation.DocumentationManager
|
||||
import com.intellij.lang.documentation.DocumentationImageResolver
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.application.ModalityState
|
||||
import com.intellij.openapi.application.ReadAction
|
||||
|
|
@ -21,7 +19,6 @@ import com.intellij.util.ui.UIUtil
|
|||
import ee.carlrobert.codegpt.util.NavigationResolverFactory
|
||||
import ee.carlrobert.codegpt.util.EditorUtil
|
||||
import java.awt.Dimension
|
||||
import java.awt.Image
|
||||
import java.awt.MouseInfo
|
||||
import java.awt.Point
|
||||
import java.awt.event.MouseAdapter
|
||||
|
|
@ -30,6 +27,7 @@ import java.net.URLDecoder
|
|||
import java.nio.charset.StandardCharsets
|
||||
import java.util.concurrent.ScheduledFuture
|
||||
import java.util.concurrent.TimeUnit
|
||||
import javax.swing.JEditorPane
|
||||
import javax.swing.JComponent
|
||||
import javax.swing.JTextPane
|
||||
import javax.swing.SwingUtilities
|
||||
|
|
@ -73,6 +71,10 @@ object PsiLinkHoverPreview {
|
|||
private val executor = AppExecutorUtil.getAppExecutorService()
|
||||
|
||||
override fun hyperlinkUpdate(e: HyperlinkEvent) {
|
||||
if (isAnchorLink(e.description)) {
|
||||
handleAnchorEvent(e)
|
||||
return
|
||||
}
|
||||
when (e.eventType) {
|
||||
HyperlinkEvent.EventType.ENTERED -> onEntered(e)
|
||||
HyperlinkEvent.EventType.EXITED -> onExited()
|
||||
|
|
@ -80,6 +82,29 @@ object PsiLinkHoverPreview {
|
|||
}
|
||||
}
|
||||
|
||||
private fun handleAnchorEvent(e: HyperlinkEvent) {
|
||||
when (e.eventType) {
|
||||
HyperlinkEvent.EventType.ENTERED -> {
|
||||
scheduledClose?.cancel(true)
|
||||
scheduledClose = null
|
||||
pending?.cancel(true)
|
||||
pending = null
|
||||
lastDesc = null
|
||||
}
|
||||
|
||||
HyperlinkEvent.EventType.EXITED -> onExited()
|
||||
HyperlinkEvent.EventType.ACTIVATED -> {
|
||||
scheduledClose?.cancel(true)
|
||||
scheduledClose = null
|
||||
pending?.cancel(true)
|
||||
pending = null
|
||||
scrollToReference(e.source, e.description)
|
||||
}
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
private fun onExited() {
|
||||
scheduledClose?.cancel(true)
|
||||
scheduledClose = scheduledExecutor.schedule({
|
||||
|
|
@ -89,28 +114,24 @@ object PsiLinkHoverPreview {
|
|||
}, EXIT_CLOSE_DELAY_MS, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
|
||||
fun cancelAndHide() {
|
||||
fun cancelAndHide(force: Boolean = false) {
|
||||
scheduledClose?.cancel(true)
|
||||
scheduledClose = null
|
||||
|
||||
val comp = popupComponent
|
||||
if (comp != null && isMouseOverComponent(comp)) return
|
||||
if (!force && comp != null && isMouseOverComponent(comp)) return
|
||||
|
||||
pending?.cancel(true)
|
||||
pending = null
|
||||
|
||||
val popupToCancel = popup ?: return
|
||||
popup = null
|
||||
popupComponent = null
|
||||
|
||||
try {
|
||||
if (SwingUtilities.isEventDispatchThread()) {
|
||||
popup?.cancel()
|
||||
popup = null
|
||||
popupComponent = null
|
||||
} else {
|
||||
ApplicationManager.getApplication().invokeLater({
|
||||
popup?.cancel()
|
||||
popup = null
|
||||
popupComponent = null
|
||||
}, ModalityState.any())
|
||||
}
|
||||
ApplicationManager.getApplication().invokeLater({
|
||||
popupToCancel.cancel()
|
||||
}, ModalityState.any())
|
||||
} catch (t: Throwable) {
|
||||
logger.warn("Failed while cancelling popup", t)
|
||||
}
|
||||
|
|
@ -170,11 +191,7 @@ object PsiLinkHoverPreview {
|
|||
private fun showDocHint(where: RelativePoint, html: String) {
|
||||
scheduledClose?.cancel(true)
|
||||
scheduledClose = null
|
||||
cancelAndHide()
|
||||
|
||||
val imageResolver = object : DocumentationImageResolver {
|
||||
override fun resolveImage(p0: String): Image? = null
|
||||
}
|
||||
cancelAndHide(force = true)
|
||||
|
||||
val safeHtml = if (html.contains("<body", ignoreCase = true)) {
|
||||
html.replaceFirst(
|
||||
|
|
@ -185,7 +202,7 @@ object PsiLinkHoverPreview {
|
|||
"<html><body style=\"margin:0;padding:0\">$html</body></html>"
|
||||
}
|
||||
|
||||
val editor = DocumentationHintEditorPane(project, emptyMap(), imageResolver).apply {
|
||||
val editor = JEditorPane().apply {
|
||||
isEditable = false
|
||||
contentType = "text/html"
|
||||
text = safeHtml
|
||||
|
|
@ -196,6 +213,11 @@ object PsiLinkHoverPreview {
|
|||
isOpaque = true
|
||||
margin = JBUI.insets(8)
|
||||
}
|
||||
editor.addHyperlinkListener { event ->
|
||||
if (event.eventType == HyperlinkEvent.EventType.ACTIVATED && isAnchorLink(event.description)) {
|
||||
scrollToReference(editor, event.description)
|
||||
}
|
||||
}
|
||||
|
||||
val scroll = ScrollPaneFactory.createScrollPane(editor).apply {
|
||||
border = JBUI.Borders.empty()
|
||||
|
|
@ -218,8 +240,6 @@ object PsiLinkHoverPreview {
|
|||
.setCancelKeyEnabled(true)
|
||||
.createPopup()
|
||||
|
||||
editor.setHint(newPopup)
|
||||
|
||||
popup = newPopup
|
||||
popupComponent = scroll
|
||||
|
||||
|
|
@ -252,6 +272,25 @@ object PsiLinkHoverPreview {
|
|||
}
|
||||
}
|
||||
|
||||
private fun isAnchorLink(description: String?): Boolean {
|
||||
return !description.isNullOrBlank() && description.startsWith("#")
|
||||
}
|
||||
|
||||
private fun scrollToReference(source: Any?, description: String?) {
|
||||
val anchor = description?.removePrefix("#")?.takeIf { it.isNotBlank() } ?: return
|
||||
val editor = source as? JEditorPane ?: return
|
||||
ApplicationManager.getApplication().invokeLater({
|
||||
if (!editor.isDisplayable) {
|
||||
return@invokeLater
|
||||
}
|
||||
runCatching {
|
||||
editor.scrollToReference(anchor)
|
||||
}.onFailure { error ->
|
||||
logger.debug("Failed to scroll to anchor #$anchor", error)
|
||||
}
|
||||
}, ModalityState.any())
|
||||
}
|
||||
|
||||
private fun decode(value: String): String = try {
|
||||
URLDecoder.decode(value, StandardCharsets.UTF_8)
|
||||
} catch (_: Throwable) {
|
||||
|
|
@ -303,7 +342,7 @@ object PsiLinkHoverPreview {
|
|||
}
|
||||
|
||||
private fun computeHoverSize(
|
||||
editor: DocumentationHintEditorPane,
|
||||
editor: JEditorPane,
|
||||
maxW: Int,
|
||||
maxH: Int
|
||||
): Dimension {
|
||||
|
|
|
|||
|
|
@ -981,6 +981,6 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
this.text.append(text)
|
||||
}
|
||||
|
||||
override fun onQueuedMessagesResolved() = Unit
|
||||
override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class AgentServiceIntegrationTest : IntegrationTest() {
|
|||
private lateinit var originalRuntimeFactory: AgentRuntimeFactory
|
||||
private val createdSessions = mutableListOf<String>()
|
||||
private val noopEvents = object : AgentEvents {
|
||||
override fun onQueuedMessagesResolved() = Unit
|
||||
override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit
|
||||
}
|
||||
private val runTimeoutMillis = 5_000L
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ class AgentServiceIntegrationTest : IntegrationTest() {
|
|||
var callbackMessageId: UUID? = null
|
||||
var callbackRef: CheckpointRef? = null
|
||||
val events = object : AgentEvents {
|
||||
override fun onQueuedMessagesResolved() = Unit
|
||||
override fun onQueuedMessagesResolved(message: MessageWithContext?) = Unit
|
||||
|
||||
override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {
|
||||
callbackMessageId = runMessageId
|
||||
|
|
|
|||
|
|
@ -110,4 +110,5 @@ class CustomOpenAIResponsesSerializationIntegrationTest : IntegrationTest() {
|
|||
)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
102
src/test/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocolJsonTest.kt
vendored
Normal file
102
src/test/kotlin/ee/carlrobert/codegpt/agent/external/acpcompat/AcpProtocolJsonTest.kt
vendored
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package ee.carlrobert.codegpt.agent.external.acpcompat
|
||||
|
||||
import com.agentclientprotocol.model.AcpMethod
|
||||
import com.agentclientprotocol.model.McpServer
|
||||
import com.agentclientprotocol.model.NewSessionRequest
|
||||
import com.agentclientprotocol.rpc.JsonRpcMessage
|
||||
import com.agentclientprotocol.transport.Transport
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.vendor.AcpCompatibilityRegistry
|
||||
import ee.carlrobert.codegpt.agent.external.acpcompat.vendor.AcpPeerProfile
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.serialization.encodeToString
|
||||
import kotlinx.serialization.json.Json
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertContains
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AcpProtocolJsonTest {
|
||||
private val parser = Json { ignoreUnknownKeys = true }
|
||||
private val compatibilityRegistry = AcpCompatibilityRegistry()
|
||||
|
||||
@Test
|
||||
fun serializesMcpServerTypeInformation() {
|
||||
val protocol = AcpProtocol(
|
||||
parentScope = CoroutineScope(SupervisorJob() + Dispatchers.Unconfined),
|
||||
transport = NoOpTransport()
|
||||
)
|
||||
|
||||
val encoded = protocol.json.encodeToString(
|
||||
NewSessionRequest(
|
||||
cwd = "/tmp",
|
||||
mcpServers = listOf(
|
||||
McpServer.Stdio(
|
||||
name = "local",
|
||||
command = "npx",
|
||||
args = listOf("server"),
|
||||
env = emptyList()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assertContains(encoded, "\"type\":\"stdio\"")
|
||||
assertContains(encoded, "\"mcpServers\"")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun normalizesEnvVarAuthMethodFromVarsArray() {
|
||||
val payload = parser.parseToJsonElement(
|
||||
"""
|
||||
{
|
||||
"protocolVersion": 1,
|
||||
"agentCapabilities": {},
|
||||
"authMethods": [
|
||||
{
|
||||
"type": "env_var",
|
||||
"id": "openai-api-key",
|
||||
"name": "Use OPENAI_API_KEY",
|
||||
"vars": [
|
||||
{ "name": "OPENAI_API_KEY" }
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
|
||||
val normalized = compatibilityRegistry.normalizeInboundPayload(
|
||||
profile = AcpPeerProfile.CODEX,
|
||||
methodName = AcpMethod.AgentMethods.Initialize.methodName,
|
||||
payload = payload
|
||||
).toString()
|
||||
|
||||
assertContains(normalized, "\"varName\":\"OPENAI_API_KEY\"")
|
||||
assertEquals(1, Regex("\"varName\"").findAll(normalized).count())
|
||||
}
|
||||
}
|
||||
|
||||
private class NoOpTransport : Transport {
|
||||
private val currentState = MutableStateFlow(Transport.State.CREATED)
|
||||
|
||||
override val state: StateFlow<Transport.State> = currentState
|
||||
|
||||
override fun start() {
|
||||
currentState.value = Transport.State.STARTED
|
||||
}
|
||||
|
||||
override fun send(message: JsonRpcMessage) = Unit
|
||||
|
||||
override fun onMessage(handler: (JsonRpcMessage) -> Unit) = Unit
|
||||
|
||||
override fun onError(handler: (Throwable) -> Unit) = Unit
|
||||
|
||||
override fun onClose(handler: () -> Unit) = Unit
|
||||
|
||||
override fun close() {
|
||||
currentState.value = Transport.State.CLOSED
|
||||
}
|
||||
}
|
||||
331
src/test/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpToolCallDecoderTest.kt
vendored
Normal file
331
src/test/kotlin/ee/carlrobert/codegpt/agent/external/events/AcpToolCallDecoderTest.kt
vendored
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
package ee.carlrobert.codegpt.agent.external.events
|
||||
|
||||
import com.agentclientprotocol.model.*
|
||||
import ee.carlrobert.codegpt.agent.external.AcpToolCallDecoder
|
||||
import ee.carlrobert.codegpt.agent.external.AcpToolCallStatus
|
||||
import ee.carlrobert.codegpt.agent.external.AcpToolEventFlavor
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlin.test.*
|
||||
|
||||
class AcpToolCallDecoderTest {
|
||||
|
||||
private val json = Json { ignoreUnknownKeys = true }
|
||||
private val decoder = AcpToolCallDecoder(json)
|
||||
|
||||
@Test
|
||||
fun decodeToolCallStartedProducesTypedTerminalContent() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("tool-1"),
|
||||
title = "Run shell command",
|
||||
kind = ToolKind.EXECUTE,
|
||||
status = ToolCallStatus.IN_PROGRESS,
|
||||
content = listOf(ToolCallContent.Terminal("terminal-1")),
|
||||
locations = listOf(ToolCallLocation("src/Main.kt")),
|
||||
rawInput = json.parseToJsonElement("""{"command":"npm test"}""")
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("tool-1", started.toolCall.id)
|
||||
assertEquals("Bash", started.toolCall.toolName)
|
||||
val bashArgs = assertIs<AcpToolCallArgs.Bash>(started.toolCall.args)
|
||||
assertEquals("npm test", bashArgs.value.command)
|
||||
assertEquals(ToolKind.EXECUTE, started.toolCall.kind)
|
||||
assertEquals(AcpToolCallStatus.IN_PROGRESS, started.toolCall.status)
|
||||
val terminal = started.toolCall.content.single()
|
||||
assertIs<AcpToolCallContent.Terminal>(terminal)
|
||||
assertEquals("terminal-1", terminal.terminalId)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeToolCallUpdateProducesTypedEditArgs() {
|
||||
val update = SessionUpdate.ToolCallUpdate(
|
||||
toolCallId = ToolCallId("tool-2"),
|
||||
title = "Edit file",
|
||||
kind = ToolKind.EDIT,
|
||||
status = ToolCallStatus.COMPLETED,
|
||||
content = listOf(ToolCallContent.Diff("src/Main.kt", "old", "new")),
|
||||
rawInput = json.parseToJsonElement(
|
||||
"""{"file_path":"src/Main.kt","old_string":"old","new_string":"new"}"""
|
||||
)
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update)
|
||||
|
||||
val updated = assertIs<AcpExternalEvent.ToolCallUpdated>(event)
|
||||
assertEquals("Edit", updated.toolCall.toolName)
|
||||
val editArgs = assertIs<AcpToolCallArgs.Edit>(updated.toolCall.args)
|
||||
assertEquals("src/Main.kt", editArgs.value.filePath)
|
||||
assertEquals(AcpToolCallStatus.COMPLETED, updated.toolCall.status)
|
||||
val diff = updated.toolCall.content.single()
|
||||
assertIs<AcpToolCallContent.Diff>(diff)
|
||||
assertEquals("src/Main.kt", diff.path)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodePermissionRequestPreservesTypedToolCallSnapshot() {
|
||||
val request = RequestPermissionRequest(
|
||||
sessionId = SessionId("session-1"),
|
||||
toolCall = SessionUpdate.ToolCallUpdate(
|
||||
toolCallId = ToolCallId("tool-3"),
|
||||
title = "Allow write",
|
||||
kind = ToolKind.EDIT,
|
||||
status = ToolCallStatus.IN_PROGRESS,
|
||||
content = listOf(ToolCallContent.Diff("src/Main.kt", "old", "new")),
|
||||
locations = listOf(ToolCallLocation("src/Main.kt")),
|
||||
rawInput = json.parseToJsonElement(
|
||||
"""{"file_path":"src/Main.kt","old_string":"old","new_string":"new"}"""
|
||||
)
|
||||
),
|
||||
options = listOf(
|
||||
PermissionOption(
|
||||
optionId = PermissionOptionId("allow_once"),
|
||||
name = "Allow once",
|
||||
kind = PermissionOptionKind.ALLOW_ONCE
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
val permission = decoder.decodePermissionRequest(AcpToolEventFlavor.STANDARD, request)
|
||||
|
||||
assertEquals("Allow write", permission.toolCall.title)
|
||||
assertEquals("Edit", permission.toolCall.toolName)
|
||||
val editArgs = assertIs<AcpToolCallArgs.Edit>(permission.toolCall.args)
|
||||
assertEquals("src/Main.kt", editArgs.value.filePath)
|
||||
assertTrue(permission.details.contains("src/Main.kt"))
|
||||
assertEquals(1, permission.options.size)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeUsageUpdatePreservesUsageAndCostFields() {
|
||||
val update = SessionUpdate.UsageUpdate(
|
||||
used = 1_024,
|
||||
size = 128_000,
|
||||
cost = Cost(amount = 0.02, currency = "USD")
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update)
|
||||
|
||||
val usage = assertIs<AcpExternalEvent.UsageUpdate>(event)
|
||||
assertEquals(1_024, usage.used)
|
||||
assertEquals(128_000, usage.size)
|
||||
assertEquals(0.02, usage.cost?.amount)
|
||||
assertEquals("USD", usage.cost?.currency)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeConfigOptionUpdatePreservesSelectAndBooleanOptions() {
|
||||
val update = SessionUpdate.ConfigOptionUpdate(
|
||||
configOptions = listOf(
|
||||
SessionConfigOption.select(
|
||||
id = "reasoning_effort",
|
||||
name = "Reasoning Effort",
|
||||
currentValue = "high",
|
||||
options = SessionConfigSelectOptions.Flat(
|
||||
listOf(
|
||||
SessionConfigSelectOption(SessionConfigValueId("medium"), "Medium"),
|
||||
SessionConfigSelectOption(SessionConfigValueId("high"), "High")
|
||||
)
|
||||
)
|
||||
),
|
||||
SessionConfigOption.boolean(
|
||||
id = "sandbox",
|
||||
name = "Sandbox",
|
||||
currentValue = true
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update)
|
||||
|
||||
val configUpdate = assertIs<AcpExternalEvent.ConfigOptionUpdate>(event)
|
||||
assertEquals(2, configUpdate.configOptions.size)
|
||||
assertIs<SessionConfigOption.Select>(configUpdate.configOptions[0])
|
||||
assertIs<SessionConfigOption.BooleanOption>(configUpdate.configOptions[1])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeSessionInfoUpdatePreservesTitleAndTimestamp() {
|
||||
val update = SessionUpdate.SessionInfoUpdate(
|
||||
title = "Fix ACP runtime",
|
||||
updatedAt = "2026-03-22T18:30:00Z"
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update)
|
||||
|
||||
val sessionInfo = assertIs<AcpExternalEvent.SessionInfoUpdate>(event)
|
||||
assertEquals("Fix ACP runtime", sessionInfo.title)
|
||||
assertEquals("2026-03-22T18:30:00Z", sessionInfo.updatedAt)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeUnknownSessionUpdatePreservesRawPayload() {
|
||||
val rawJson = json.parseToJsonElement(
|
||||
"""{"sessionUpdate":"future_update","flag":true,"count":3}"""
|
||||
).jsonObject
|
||||
val update = SessionUpdate.UnknownSessionUpdate(
|
||||
sessionUpdateType = "future_update",
|
||||
rawJson = rawJson
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.STANDARD, update)
|
||||
|
||||
val unknown = assertIs<AcpExternalEvent.UnknownSessionUpdate>(event)
|
||||
assertEquals("future_update", unknown.type)
|
||||
assertEquals(rawJson, unknown.rawJson)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeGeminiSearchWithoutRawInputFallsBackToPreviewArgs() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("grep_search-1"),
|
||||
title = "Search",
|
||||
kind = ToolKind.SEARCH,
|
||||
status = ToolCallStatus.IN_PROGRESS,
|
||||
locations = listOf(ToolCallLocation("/tmp/README.md"))
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("IntelliJSearch", started.toolCall.toolName)
|
||||
val preview = assertIs<AcpToolCallArgs.SearchPreview>(started.toolCall.args).value
|
||||
assertEquals("Search", preview.title)
|
||||
assertEquals("/tmp/README.md", preview.path)
|
||||
assertEquals(null, preview.pattern)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeGeminiShellWithoutRawInputKeepsBashToolIdentity() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("run_shell_command-1"),
|
||||
title = "",
|
||||
kind = ToolKind.EXECUTE,
|
||||
status = ToolCallStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("Bash", started.toolCall.toolName)
|
||||
val preview = assertIs<AcpToolCallArgs.BashPreview>(started.toolCall.args).value
|
||||
assertEquals("Run shell command", preview.title)
|
||||
assertEquals(null, preview.command)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeGeminiWriteUpdateFromDiffMapsToWriteTool() {
|
||||
val update = SessionUpdate.ToolCallUpdate(
|
||||
toolCallId = ToolCallId("write_file-1"),
|
||||
title = "",
|
||||
kind = null,
|
||||
status = ToolCallStatus.COMPLETED,
|
||||
content = listOf(ToolCallContent.Diff("src/Main.kt", "new content"))
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update)
|
||||
|
||||
val updated = assertIs<AcpExternalEvent.ToolCallUpdated>(event)
|
||||
assertEquals("Write", updated.toolCall.toolName)
|
||||
val writeArgs = assertIs<AcpToolCallArgs.Write>(updated.toolCall.args).value
|
||||
assertEquals("src/Main.kt", writeArgs.filePath)
|
||||
assertEquals("new content", writeArgs.content)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeGeminiGlobMapsToGlobTool() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("glob-1"),
|
||||
title = "'**/*.java'",
|
||||
kind = ToolKind.SEARCH,
|
||||
status = ToolCallStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("Glob", started.toolCall.toolName)
|
||||
val preview = assertIs<AcpToolCallArgs.SearchPreview>(started.toolCall.args).value
|
||||
assertEquals("Find files", preview.title)
|
||||
assertEquals("**/*.java", preview.pattern)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeGeminiListDirectoryMapsToListDirectoryTool() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("list_directory-1"),
|
||||
title = ".",
|
||||
kind = ToolKind.SEARCH,
|
||||
status = ToolCallStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("ListDirectory", started.toolCall.toolName)
|
||||
val preview = assertIs<AcpToolCallArgs.SearchPreview>(started.toolCall.args).value
|
||||
assertEquals("List directory", preview.title)
|
||||
assertEquals(".", preview.path)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeGeminiReplaceUpdateFromDiffMapsToEditTool() {
|
||||
val update = SessionUpdate.ToolCallUpdate(
|
||||
toolCallId = ToolCallId("replace-1"),
|
||||
title = "",
|
||||
kind = null,
|
||||
status = ToolCallStatus.COMPLETED,
|
||||
content = listOf(ToolCallContent.Diff("src/Main.kt", "new", "old"))
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.GEMINI_CLI, update)
|
||||
|
||||
val updated = assertIs<AcpExternalEvent.ToolCallUpdated>(event)
|
||||
assertEquals("Edit", updated.toolCall.toolName)
|
||||
val editArgs = assertIs<AcpToolCallArgs.Edit>(updated.toolCall.args).value
|
||||
assertEquals("src/Main.kt", editArgs.filePath)
|
||||
assertEquals("old", editArgs.oldString)
|
||||
assertEquals("new", editArgs.newString)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeZedSearchWithoutQueryablePayloadFallsBackToPreviewArgs() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("zed-search-1"),
|
||||
title = "Search",
|
||||
kind = ToolKind.SEARCH,
|
||||
status = ToolCallStatus.IN_PROGRESS,
|
||||
locations = listOf(ToolCallLocation("/tmp/src")),
|
||||
rawInput = json.parseToJsonElement("""{"parsed_cmd":[{"type":"search","path":"/tmp/src"}]}""")
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.ZED_ADAPTER, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("IntelliJSearch", started.toolCall.toolName)
|
||||
val preview = assertIs<AcpToolCallArgs.SearchPreview>(started.toolCall.args).value
|
||||
assertEquals("Search", preview.title)
|
||||
assertEquals("/tmp/src", preview.path)
|
||||
assertEquals(null, preview.pattern)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun decodeZedExecuteWithoutPayloadKeepsBashToolIdentity() {
|
||||
val update = SessionUpdate.ToolCall(
|
||||
toolCallId = ToolCallId("zed-bash-1"),
|
||||
title = "Run shell command",
|
||||
kind = ToolKind.EXECUTE,
|
||||
status = ToolCallStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
val event = decoder.decodeExternalEvent(AcpToolEventFlavor.ZED_ADAPTER, update)
|
||||
|
||||
val started = assertIs<AcpExternalEvent.ToolCallStarted>(event)
|
||||
assertEquals("Bash", started.toolCall.toolName)
|
||||
val preview = assertIs<AcpToolCallArgs.BashPreview>(started.toolCall.args).value
|
||||
assertEquals("Run shell command", preview.title)
|
||||
assertEquals(null, preview.command)
|
||||
}
|
||||
}
|
||||
58
src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHostTest.kt
vendored
Normal file
58
src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpFileHostTest.kt
vendored
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import com.agentclientprotocol.model.ReadTextFileRequest
|
||||
import com.agentclientprotocol.model.SessionId
|
||||
import java.nio.file.Files
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AcpFileHostTest {
|
||||
|
||||
@Test
|
||||
fun `readTextFile prefers editor content over disk content`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-file")
|
||||
val file = cwd.resolve("Example.txt")
|
||||
Files.writeString(file, "disk content")
|
||||
val host = AcpFileHost(
|
||||
openDocumentReader = AcpOpenDocumentReader { path ->
|
||||
if (path == file) "editor content" else null
|
||||
},
|
||||
writer = AcpTextFileWriter { _, _ -> error("unexpected write") }
|
||||
)
|
||||
|
||||
val response = host.readTextFile(
|
||||
AcpHostSessionContext(sessionId = "session-1", cwd = cwd),
|
||||
ReadTextFileRequest(
|
||||
sessionId = SessionId("session-1"),
|
||||
path = file.toString()
|
||||
)
|
||||
)
|
||||
|
||||
assertEquals("editor content", response.content)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `readTextFile applies line window after resolving editor content`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-file")
|
||||
val file = cwd.resolve("Example.txt")
|
||||
val host = AcpFileHost(
|
||||
openDocumentReader = AcpOpenDocumentReader { path ->
|
||||
if (path == file) "line1\nline2\nline3\nline4" else null
|
||||
},
|
||||
writer = AcpTextFileWriter { _, _ -> error("unexpected write") }
|
||||
)
|
||||
|
||||
val response = host.readTextFile(
|
||||
AcpHostSessionContext(sessionId = "session-1", cwd = cwd),
|
||||
ReadTextFileRequest(
|
||||
sessionId = SessionId("session-1"),
|
||||
path = file.toString(),
|
||||
line = 2u,
|
||||
limit = 2u
|
||||
)
|
||||
)
|
||||
|
||||
assertEquals("line2\nline3", response.content)
|
||||
}
|
||||
|
||||
}
|
||||
39
src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicyTest.kt
vendored
Normal file
39
src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpPathPolicyTest.kt
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import java.nio.file.Files
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
class AcpPathPolicyTest {
|
||||
|
||||
private val policy = AcpPathPolicy()
|
||||
|
||||
@Test
|
||||
fun `resolveWithinCwd keeps relative paths inside cwd`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-policy")
|
||||
|
||||
val resolved = policy.resolveWithinCwd("src/Main.kt", cwd)
|
||||
|
||||
assertEquals(cwd.resolve("src/Main.kt").toAbsolutePath().normalize(), resolved)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `resolveWithinCwd allows absolute path inside cwd`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-policy")
|
||||
val inside = cwd.resolve("nested/file.txt")
|
||||
|
||||
val resolved = policy.resolveWithinCwd(inside.toString(), cwd)
|
||||
|
||||
assertEquals(inside.toAbsolutePath().normalize(), resolved)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `resolveWithinCwd rejects path traversal`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-policy")
|
||||
|
||||
assertFailsWith<AcpHostPathBoundaryException> {
|
||||
policy.resolveWithinCwd("../secret.txt", cwd)
|
||||
}
|
||||
}
|
||||
}
|
||||
101
src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHostTest.kt
vendored
Normal file
101
src/test/kotlin/ee/carlrobert/codegpt/agent/external/host/AcpTerminalHostTest.kt
vendored
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
package ee.carlrobert.codegpt.agent.external.host
|
||||
|
||||
import com.agentclientprotocol.model.CreateTerminalRequest
|
||||
import com.agentclientprotocol.model.EnvVariable
|
||||
import com.agentclientprotocol.model.SessionId
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class AcpTerminalHostTest {
|
||||
|
||||
@Test
|
||||
fun `createTerminal resolves cwd and stores terminal process`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-terminal")
|
||||
val launcher = FakeTerminalProcessLauncher()
|
||||
val host = AcpTerminalHost(launcher)
|
||||
val session = AcpHostSessionContext(sessionId = "session-1", cwd = cwd)
|
||||
|
||||
val response = host.createTerminal(
|
||||
session,
|
||||
CreateTerminalRequest(
|
||||
sessionId = SessionId("session-1"),
|
||||
command = "git",
|
||||
args = listOf("status"),
|
||||
cwd = "worktree",
|
||||
env = listOf(EnvVariable(name = "FOO", value = "bar")),
|
||||
outputByteLimit = 128uL
|
||||
)
|
||||
)
|
||||
|
||||
assertTrue(response.terminalId.isNotBlank())
|
||||
assertEquals(cwd.resolve("worktree").toAbsolutePath().normalize(), launcher.lastCwd)
|
||||
assertEquals(mapOf("FOO" to "bar"), launcher.nextEnv)
|
||||
assertEquals(128uL, launcher.process.outputByteLimit)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `createTerminal rejects session mismatch`() {
|
||||
val cwd = Files.createTempDirectory("acp-host-terminal")
|
||||
val host = AcpTerminalHost(FakeTerminalProcessLauncher())
|
||||
|
||||
assertFailsWith<IllegalArgumentException> {
|
||||
host.createTerminal(
|
||||
AcpHostSessionContext(sessionId = "session-1", cwd = cwd),
|
||||
CreateTerminalRequest(
|
||||
sessionId = SessionId("different-session"),
|
||||
command = "echo"
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class FakeTerminalProcessLauncher : AcpTerminalProcessLauncher {
|
||||
val process = FakeTerminalProcess()
|
||||
var lastCwd: Path? = null
|
||||
var nextEnv: Map<String, String> = emptyMap()
|
||||
|
||||
override fun launch(
|
||||
command: String,
|
||||
args: List<String>,
|
||||
cwd: Path,
|
||||
env: Map<String, String>,
|
||||
outputByteLimit: ULong?
|
||||
): AcpTerminalProcess {
|
||||
lastCwd = cwd
|
||||
nextEnv = env
|
||||
process.command = command
|
||||
process.args = args
|
||||
process.cwd = cwd
|
||||
process.outputByteLimit = outputByteLimit
|
||||
return process
|
||||
}
|
||||
}
|
||||
|
||||
private class FakeTerminalProcess : AcpTerminalProcess {
|
||||
override val terminalId: String = "fake-terminal"
|
||||
var command: String = ""
|
||||
var args: List<String> = emptyList()
|
||||
var cwd: Path? = null
|
||||
var outputByteLimit: ULong? = null
|
||||
var snapshot: AcpTerminalOutputSnapshot = AcpTerminalOutputSnapshot("", truncated = false)
|
||||
var waitResult: AcpTerminalExitStatus = AcpTerminalExitStatus(exitCode = 0u)
|
||||
var killed: Boolean = false
|
||||
var released: Boolean = false
|
||||
|
||||
override fun output(): AcpTerminalOutputSnapshot = snapshot
|
||||
|
||||
override suspend fun waitForExit(): AcpTerminalExitStatus = waitResult
|
||||
|
||||
override fun release() {
|
||||
released = true
|
||||
}
|
||||
|
||||
override fun kill() {
|
||||
killed = true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AcpConfigOptionsTest {
|
||||
|
||||
@Test
|
||||
fun summaryPartsIncludesBooleanAndCustomOptions() {
|
||||
val options = listOf(
|
||||
AcpConfigOption(
|
||||
id = "model",
|
||||
name = "Model",
|
||||
category = "model",
|
||||
type = "select",
|
||||
currentValue = "gpt-5",
|
||||
options = listOf(AcpConfigOptionChoice("gpt-5", "GPT-5"))
|
||||
),
|
||||
AcpConfigOption(
|
||||
id = "sandbox",
|
||||
name = "Sandbox",
|
||||
type = "boolean",
|
||||
currentValue = "true",
|
||||
options = listOf(
|
||||
AcpConfigOptionChoice("true", "Enabled"),
|
||||
AcpConfigOptionChoice("false", "Disabled")
|
||||
)
|
||||
),
|
||||
AcpConfigOption(
|
||||
id = "approval_mode",
|
||||
name = "Approval",
|
||||
type = "select",
|
||||
currentValue = "manual",
|
||||
options = listOf(
|
||||
AcpConfigOptionChoice("manual", "Manual"),
|
||||
AcpConfigOptionChoice("auto", "Automatic")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assertEquals(
|
||||
listOf("GPT-5", "Sandbox: Enabled", "Approval: Manual"),
|
||||
AcpConfigOptions.summaryParts(options)
|
||||
)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue