From 3f32f19f72c429f4829665f4a4be5ce4c188a003 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Fri, 13 Feb 2026 16:55:39 +0000 Subject: [PATCH] feat: agent timeline --- .../carlrobert/codegpt/agent/AgentEvents.kt | 3 + .../carlrobert/codegpt/agent/AgentService.kt | 62 +- .../carlrobert/codegpt/agent/ProxyAIAgent.kt | 2 + .../ee/carlrobert/codegpt/agent/ToolSpecs.kt | 13 +- .../AgentCheckpointConversationMapper.kt | 110 +- .../history/AgentCheckpointHistoryService.kt | 7 + .../history/AgentCheckpointTurnSequencer.kt | 193 +++ .../codegpt/agent/rollback/RollbackService.kt | 202 +-- .../codegpt/agent/tools/BashTool.kt | 38 +- .../codegpt/agent/tools/WriteTool.kt | 72 +- .../toolwindow/agent/AgentApprovalManager.kt | 15 +- .../toolwindow/agent/AgentEventHandler.kt | 48 +- .../toolwindow/agent/AgentMessageText.kt | 9 + .../agent/AgentSessionTimelineController.kt | 1219 +++++++++++++++++ .../AgentSessionTimelineDialogBuilder.kt | 479 +++++++ ...essionTimelineHistoricalRollbackSupport.kt | 328 +++++ .../agent/AgentSessionTimelineModels.kt | 70 + .../agent/AgentToolWindowTabPanel.kt | 404 +++++- .../agent/HistoricalRollbackCompatibility.kt | 59 + .../agent/ui/AgentToolWindowLandingPanel.kt | 246 ++-- .../toolwindow/agent/ui/RollbackPanel.kt | 45 +- .../codegpt/toolwindow/ui/BaseMessagePanel.kt | 4 + .../codegpt/ui/textarea/UserInputPanel.kt | 30 +- .../lookup/action/FolderActionItem.kt | 4 +- .../lookup/action/files/FileActionItem.kt | 5 +- .../ee/carlrobert/codegpt/util/StringUtil.kt | 15 +- .../agent/AgentServiceIntegrationTest.kt | 68 +- .../agent/RollbackServiceTrackingTest.kt | 92 +- .../AgentCheckpointConversationMapperTest.kt | 138 +- .../AgentCheckpointTurnSequencerTest.kt | 248 ++++ .../chat/ChatToolWindowTabbedPaneTest.kt | 1 - 31 files changed, 3779 insertions(+), 450 deletions(-) create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencer.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentMessageText.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineController.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineDialogBuilder.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineHistoricalRollbackSupport.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineModels.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/HistoricalRollbackCompatibility.kt create mode 100644 src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencerTest.kt diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt index d5be8056..4957d40b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentEvents.kt @@ -1,11 +1,13 @@ package ee.carlrobert.codegpt.agent import ai.koog.prompt.executor.clients.LLMClientException +import ee.carlrobert.codegpt.agent.history.CheckpointRef import ee.carlrobert.codegpt.agent.tools.AskUserQuestionTool import ee.carlrobert.codegpt.conversations.message.TokenUsage import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest +import java.util.UUID interface AgentEvents { fun onTextReceived(text: String) {} @@ -22,6 +24,7 @@ interface AgentEvents { } fun onRetry(attempt: Int, maxAttempts: Int, reason: String? = null) {} + fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {} fun onQueuedMessagesResolved() fun onTokenUsageAvailable(tokenUsage: Long) {} fun onCreditsAvailable(event: AgentCreditsEvent) {} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt index f76b82ea..e3efc535 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/AgentService.kt @@ -19,9 +19,12 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.McpTagDetails import kotlinx.coroutines.* import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.asSharedFlow +import kotlinx.datetime.Clock +import kotlinx.serialization.json.JsonNull import java.util.* import java.util.concurrent.ConcurrentHashMap import kotlin.io.path.Path +import ai.koog.prompt.message.Message as PromptMessage internal fun interface AgentRuntimeFactory { fun create( @@ -53,16 +56,17 @@ class AgentService(private val project: Project) { private val pendingMessages = ConcurrentHashMap>() private val sessionAgents = ConcurrentHashMap>() private val sessionRuntimes = ConcurrentHashMap() - internal var runtimeFactory: AgentRuntimeFactory = AgentRuntimeFactory { p, storage, provider, events, sid, pending -> - ProxyAIAgent.createService( - project = p, - checkpointStorage = storage, - provider = provider, - events = events, - sessionId = sid, - pendingMessages = pending - ) - } + internal var runtimeFactory: AgentRuntimeFactory = + AgentRuntimeFactory { p, storage, provider, events, sid, pending -> + ProxyAIAgent.createService( + project = p, + checkpointStorage = storage, + provider = provider, + events = events, + sessionId = sid, + pendingMessages = pending + ) + } private val checkpointStorage = JVMFilePersistenceStorageProvider(Path(project.basePath ?: "", ".proxyai")) private val historyService = project.service() @@ -132,7 +136,8 @@ class AgentService(private val project: Project) { } catch (ex: Exception) { logger.error(ex) } finally { - refreshSessionResumeCheckpoint(sessionId, agent.id) + val ref = refreshSessionResumeCheckpoint(sessionId, agent.id) + events.onRunCheckpointUpdated(message.id, ref) sessionAgents.remove(sessionId, agent) runCatching { runtime.service.removeAgentWithId(agent.id) } .onFailure { ex -> @@ -201,7 +206,39 @@ class AgentService(private val project: Project) { .update(sessionId, conversationId, selectedServerIds) } - private suspend fun refreshSessionResumeCheckpoint(sessionId: String, agentId: String) { + suspend fun createSeedCheckpointFromHistory(history: List): CheckpointRef? = + withContext(Dispatchers.IO) { + if (history.isEmpty()) { + return@withContext null + } + + val agentId = UUID.randomUUID().toString() + val checkpointId = UUID.randomUUID().toString() + val checkpoint = AgentCheckpointData( + checkpointId = checkpointId, + createdAt = Clock.System.now(), + nodePath = "$agentId/single_run/nodeExecuteTool", + lastInput = JsonNull, + messageHistory = history, + version = 0 + ) + + runCatching { + checkpointStorage.saveCheckpoint(agentId, checkpoint) + CheckpointRef(agentId, checkpointId) + }.onFailure { ex -> + logger.warn( + "Agent checkpoints: failed to create seed checkpoint from history " + + "agentId=$agentId error=${ex.message}", + ex + ) + }.getOrNull() + } + + private suspend fun refreshSessionResumeCheckpoint( + sessionId: String, + agentId: String + ): CheckpointRef? { val ref = runCatching { historyService.loadLatestResumeCheckpoint(agentId) ?.let { CheckpointRef(agentId, it.checkpointId) } @@ -216,6 +253,7 @@ class AgentService(private val project: Project) { if (ref != null) { project.service().setResumeCheckpointRef(sessionId, ref) } + return ref } private fun ensureSessionRuntime( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt index 35b2e258..faa14baf 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ProxyAIAgent.kt @@ -171,6 +171,8 @@ object ProxyAIAgent { } onNodeExecutionCompleted { ctx -> + if (stream) return@onNodeExecutionCompleted + (ctx.output as? List<*>)?.forEach { msg -> (msg as? Message.Assistant)?.let { events.onTextReceived(it.content) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt index 82491f4f..fad8028c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/ToolSpecs.kt @@ -169,25 +169,20 @@ object ToolSpecs { fun find(toolName: String): ToolSpec<*, *>? = specsByName[toolName.lowercase()] - fun approvalTypeFor(toolName: String): ToolApprovalType { - return find(toolName)?.approvalType ?: ToolApprovalType.GENERIC - } + fun approvalTypeFor(toolName: String): ToolApprovalType = + find(toolName)?.approvalType ?: ToolApprovalType.GENERIC fun decodeArgsOrNull( json: Json, toolName: String, payload: String - ): Any? { - return decodeOrNull(json, find(toolName)?.argsSerializer, payload) - } + ) = decodeOrNull(json, find(toolName)?.argsSerializer, payload) fun decodeResultOrNull( json: Json, toolName: String, payload: String - ): Any? { - return decodeOrNull(json, find(toolName)?.resultSerializer, payload) - } + ) = decodeOrNull(json, find(toolName)?.resultSerializer, payload) @Suppress("UNCHECKED_CAST") private fun decodeOrNull( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapper.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapper.kt index 6c74363a..0c74637b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapper.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapper.kt @@ -3,9 +3,9 @@ package ee.carlrobert.codegpt.agent.history import ai.koog.agents.snapshot.feature.AgentCheckpointData import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks import ee.carlrobert.llm.client.openai.completion.response.ToolCall import ee.carlrobert.llm.client.openai.completion.response.ToolFunctionResponse -import ai.koog.prompt.message.Message as PromptMessage object AgentCheckpointConversationMapper { @@ -14,21 +14,54 @@ object AgentCheckpointConversationMapper { projectInstructions: String? ): Conversation { val conversation = Conversation() - val history = checkpoint.messageHistory.filterNot { it is PromptMessage.System } + val turns = AgentCheckpointTurnSequencer.toVisibleTurns( + history = checkpoint.messageHistory, + projectInstructions = projectInstructions + ) - var currentPrompt: String? = null - val response = StringBuilder() - val toolCalls = mutableListOf() - val toolResults = LinkedHashMap() - var syntheticToolIdIndex = 0 + turns.forEach { turn -> + val response = StringBuilder() + val toolCalls = mutableListOf() + val toolResults = LinkedHashMap() + var syntheticToolIdIndex = 0 - fun flushTurn() { - val prompt = currentPrompt?.trim().orEmpty() - if (prompt.isBlank()) { - return + turn.events.forEach { event -> + when (event) { + is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> { + appendAssistant(response, event.content) + } + + is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> { + appendAssistant(response, event.content) + } + + is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> { + val callId = event.id?.takeIf { it.isNotBlank() } + ?: "tool-call-${++syntheticToolIdIndex}" + toolCalls.add( + ToolCall( + null, + callId, + "function", + ToolFunctionResponse(event.tool, event.content.trim()) + ) + ) + } + + is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> { + val callId = event.id?.takeIf { it.isNotBlank() } + ?: toolCalls.lastOrNull()?.id + ?: "tool-call-${++syntheticToolIdIndex}" + val prior = toolResults[callId] + val merged = if (prior.isNullOrBlank()) event.content.trim() else { + "$prior\n\n${event.content.trim()}" + } + toolResults[callId] = merged + } + } } - val uiMessage = Message(prompt) + val uiMessage = Message(turn.prompt) uiMessage.response = response.toString().trim() if (toolCalls.isNotEmpty()) { uiMessage.toolCalls = ArrayList(toolCalls) @@ -37,64 +70,13 @@ object AgentCheckpointConversationMapper { uiMessage.toolCallResults = LinkedHashMap(toolResults) } conversation.addMessage(uiMessage) - currentPrompt = null - response.setLength(0) - toolCalls.clear() - toolResults.clear() } - history.forEach { msg -> - when (msg) { - is PromptMessage.User -> { - val text = msg.content.trim() - if (shouldHideInAgentToolWindow(msg, projectInstructions)) { - return@forEach - } - flushTurn() - currentPrompt = text - } - - is PromptMessage.Assistant -> appendAssistant(response, msg.content) - is PromptMessage.Reasoning -> appendAssistant(response, msg.content) - is PromptMessage.Tool.Call -> { - if (currentPrompt != null) { - val callId = - msg.id?.takeIf { it.isNotBlank() } - ?: "tool-call-${++syntheticToolIdIndex}" - toolCalls.add( - ToolCall( - null, - callId, - "function", - ToolFunctionResponse(msg.tool, msg.content.trim()) - ) - ) - } - } - - is PromptMessage.Tool.Result -> { - if (currentPrompt != null) { - val callId = msg.id?.takeIf { it.isNotBlank() } - ?: toolCalls.lastOrNull()?.id - ?: "tool-call-${++syntheticToolIdIndex}" - val prior = toolResults[callId] - val merged = if (prior.isNullOrBlank()) msg.content.trim() else { - "$prior\n\n${msg.content.trim()}" - } - toolResults[callId] = merged - } - } - - else -> Unit - } - } - - flushTurn() return conversation } private fun appendAssistant(sb: StringBuilder, content: String) { - val text = content.trim() + val text = content.stripThinkingBlocks() if (text.isBlank()) { return } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointHistoryService.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointHistoryService.kt index f2167efb..98139e54 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointHistoryService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointHistoryService.kt @@ -96,6 +96,13 @@ class AgentCheckpointHistoryService(project: Project) { ?: checkpoints.maxByOrNull { it.createdAt } } + suspend fun listCheckpoints(agentId: String): List = + withContext(Dispatchers.IO) { + storage.getCheckpoints(agentId) + .filterNot { it.isTombstone() } + .sortedByDescending { it.createdAt } + } + private suspend fun buildSummary(agentId: String): AgentHistoryThreadSummary? { val checkpoints = storage.getCheckpoints(agentId) .filterNot { it.isTombstone() } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencer.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencer.kt new file mode 100644 index 00000000..01d152d1 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencer.kt @@ -0,0 +1,193 @@ +package ee.carlrobert.codegpt.agent.history + +import kotlinx.serialization.json.booleanOrNull +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.jsonPrimitive +import ai.koog.prompt.message.Message as PromptMessage + +object AgentCheckpointTurnSequencer { + + data class Turn( + val prompt: String, + val userNonSystemMessageCount: Int, + val events: List + ) + + sealed interface TurnEvent { + val nonSystemMessageCount: Int + + data class Assistant( + val content: String, + override val nonSystemMessageCount: Int + ) : TurnEvent + + data class Reasoning( + val content: String, + override val nonSystemMessageCount: Int + ) : TurnEvent + + data class ToolCall( + val id: String?, + val tool: String, + val content: String, + override val nonSystemMessageCount: Int + ) : TurnEvent + + data class ToolResult( + val id: String?, + val tool: String, + val content: String, + override val nonSystemMessageCount: Int + ) : TurnEvent + } + + fun toVisibleTurns( + history: List, + projectInstructions: String?, + preserveSyntheticContinuation: Boolean = false + ): List { + val turns = mutableListOf() + var currentPrompt: String? = null + var currentUserNonSystemMessageCount = 0 + val currentEvents = mutableListOf() + + fun flushTurn() { + val prompt = currentPrompt?.trim().orEmpty() + if (prompt.isBlank() || currentUserNonSystemMessageCount <= 0) { + return + } + turns.add( + Turn( + prompt = prompt, + userNonSystemMessageCount = currentUserNonSystemMessageCount, + events = currentEvents.toList() + ) + ) + currentPrompt = null + currentUserNonSystemMessageCount = 0 + currentEvents.clear() + } + + history + .filterNot { it is PromptMessage.System } + .forEachIndexed { index, message -> + val nonSystemMessageCount = index + 1 + when (message) { + is PromptMessage.User -> { + if (isSyntheticTimelineUserMessage(message)) { + if (preserveSyntheticContinuation && currentPrompt != null) { + // Keep collecting events for the currently active visible turn. + // Synthetic todo prompts are injected internally mid-run. + return@forEachIndexed + } + flushTurn() + currentPrompt = null + currentUserNonSystemMessageCount = 0 + return@forEachIndexed + } + + flushTurn() + if (isHiddenUserMessage(message, projectInstructions)) { + currentPrompt = null + currentUserNonSystemMessageCount = 0 + return@forEachIndexed + } + currentPrompt = message.content.trim() + currentUserNonSystemMessageCount = nonSystemMessageCount + } + + is PromptMessage.Assistant -> { + if (currentPrompt != null) { + currentEvents.add( + TurnEvent.Assistant( + content = message.content, + nonSystemMessageCount = nonSystemMessageCount + ) + ) + } + } + + is PromptMessage.Reasoning -> { + if (currentPrompt != null) { + currentEvents.add( + TurnEvent.Reasoning( + content = message.content, + nonSystemMessageCount = nonSystemMessageCount + ) + ) + } + } + + is PromptMessage.Tool.Call -> { + if (currentPrompt != null && !isTodoWriteTool(message.tool)) { + currentEvents.add( + TurnEvent.ToolCall( + id = message.id, + tool = message.tool, + content = message.content, + nonSystemMessageCount = nonSystemMessageCount + ) + ) + } + } + + is PromptMessage.Tool.Result -> { + if (currentPrompt != null && !isTodoWriteTool(message.tool)) { + currentEvents.add( + TurnEvent.ToolResult( + id = message.id, + tool = message.tool, + content = message.content, + nonSystemMessageCount = nonSystemMessageCount + ) + ) + } + } + + else -> Unit + } + } + + flushTurn() + return turns + } + + fun isTodoWriteTool(toolName: String): Boolean { + return toolName.equals("TodoWrite", ignoreCase = true) + } + + fun isSyntheticTimelineUserMessage(message: PromptMessage.User): Boolean { + val normalized = message.content.lowercase() + return normalized.contains("haven't created a todo list yet") + } + + private fun isHiddenUserMessage( + message: PromptMessage.User, + projectInstructions: String? + ): Boolean { + return isCacheableInstructionMessage(message) || + isProjectInstructionsMessage(message.content, projectInstructions) + } + + private fun isCacheableInstructionMessage(message: PromptMessage.User): Boolean { + val cacheable = message.metaInfo.metadata + ?.get("cacheable") + ?.jsonPrimitive + ?: return false + return cacheable.booleanOrNull ?: (cacheable.contentOrNull?.equals( + "true", + ignoreCase = true + ) == true) + } + + private fun isProjectInstructionsMessage(text: String, projectInstructions: String?): Boolean { + if (projectInstructions.isNullOrBlank()) { + return false + } + return normalize(text) == normalize(projectInstructions) + } + + private fun normalize(value: String): String { + return value.replace("\\s+".toRegex(), " ").trim() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt index 72a18e02..aa3b264c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt @@ -13,9 +13,7 @@ import com.intellij.openapi.fileTypes.FileTypeManager import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.LocalFileSystem import com.intellij.openapi.vfs.VfsUtilCore -import com.intellij.openapi.vfs.VirtualFileManager import ee.carlrobert.codegpt.settings.ProxyAISettingsService -import ee.carlrobert.codegpt.agent.tools.WriteTool import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import java.io.File @@ -23,85 +21,100 @@ import java.nio.file.Paths import java.time.Instant import java.time.format.DateTimeFormatter import java.util.concurrent.ConcurrentHashMap +import java.util.UUID -/** - * Tracks file changes during agent runs and provides rollback to a checkpoint. - * - * Uses IntelliJ's LocalHistory labels for persistence and a lightweight in-memory - * snapshot of modified files for fast restore. - * - * Now uses direct tracking via trackEdit() and trackWrite() methods instead of - * VFS event monitoring to avoid capturing build artifacts and temporary files. - */ @Service(Service.Level.PROJECT) class RollbackService(private val project: Project) { private val activeRuns = ConcurrentHashMap() - private val snapshots = ConcurrentHashMap() + private val activeRunsBySession = ConcurrentHashMap() + private val snapshotsByRunId = ConcurrentHashMap() + private val latestSnapshotRunIdBySession = ConcurrentHashMap() - @Volatile - private var isApplyingRollback = false + private val MAX_TRACKABLE_BYTES = 10 * 1024 * 1024 - private val MAX_TRACKABLE_BYTES = 10 * 1024 * 1024 // 10MB - - init { - // No VFS listener - using direct tracking via trackEdit() and trackWrite() - } - - /** - * Track an EditTool operation directly from the agent. - * This captures the original file content before the edit is applied. - */ fun trackEdit( sessionId: String, filePath: String, originalContent: String ) { - val tracker = activeRuns[sessionId] ?: return - val normalizedPath = filePath.replace("\\", "/") - tracker.recordExplicitEdit(normalizedPath, originalContent) + val runId = activeRunsBySession[sessionId] ?: return + trackEditForRun(runId, filePath, originalContent) } - /** - * Track a WriteTool operation directly from the agent. - * This determines if the file was new or existing before the write. - */ - fun trackWrite(sessionId: String, filePath: String, args: WriteTool.Args) { - val tracker = activeRuns[sessionId] ?: return + fun trackEditForRun( + runId: String, + filePath: String, + originalContent: String + ) { + val tracker = activeRuns[runId] ?: return val normalizedPath = filePath.replace("\\", "/") - tracker.recordExplicitWrite(normalizedPath, args) + tracker.recordExplicitEdit(normalizedPath, originalContent) } - fun startSession(sessionId: String) { + fun trackWrite(sessionId: String, filePath: String) { + val runId = activeRunsBySession[sessionId] ?: return + trackWriteForRun(runId, filePath) + } + + fun trackWriteForRun(runId: String, filePath: String) { + val tracker = activeRuns[runId] ?: return + val normalizedPath = filePath.replace("\\", "/") + tracker.recordExplicitWrite(normalizedPath) + } + + fun startSession(sessionId: String): String { val labelText = "ProxyAI: Agent run ${DateTimeFormatter.ISO_INSTANT.format(Instant.now())}" val label = LocalHistory.getInstance().putSystemLabel(project, labelText) - activeRuns[sessionId] = RunTracker(sessionId, Instant.now(), labelText, label) - snapshots.remove(sessionId) + val runId = UUID.randomUUID().toString() + val tracker = RunTracker(runId, sessionId, Instant.now(), labelText, label) + activeRuns[runId] = tracker + activeRunsBySession[sessionId] = runId + latestSnapshotRunIdBySession.remove(sessionId) + return runId } fun finishSession(sessionId: String): RollbackSnapshot? { - val tracker = activeRuns.remove(sessionId) ?: return getSnapshot(sessionId) + val runId = activeRunsBySession.remove(sessionId) ?: return getSnapshot(sessionId) + return finishRun(runId) + } + + fun finishRun(runId: String): RollbackSnapshot? { + val tracker = activeRuns.remove(runId) ?: return getRunSnapshot(runId) val snapshot = SnapshotState( - sessionId = sessionId, + runId = runId, + sessionId = tracker.sessionId, label = tracker.label, labelRef = tracker.labelRef, startedAt = tracker.startedAt, completedAt = Instant.now(), changes = tracker.changes.toMap() ) - snapshots[sessionId] = snapshot + snapshotsByRunId[runId] = snapshot + latestSnapshotRunIdBySession[tracker.sessionId] = runId return snapshot.toSnapshot() } - fun getSnapshot(sessionId: String): RollbackSnapshot? = - snapshots[sessionId]?.toSnapshot() + fun getSnapshot(sessionId: String): RollbackSnapshot? { + val runId = latestSnapshotRunIdBySession[sessionId] ?: return null + return snapshotsByRunId[runId]?.toSnapshot() + } + + fun getRunSnapshot(runId: String): RollbackSnapshot? = + snapshotsByRunId[runId]?.toSnapshot() fun clearSnapshot(sessionId: String) { - snapshots.remove(sessionId) + val runId = latestSnapshotRunIdBySession.remove(sessionId) ?: return + snapshotsByRunId.remove(runId) + } + + fun clearRunSnapshot(runId: String) { + val snapshot = snapshotsByRunId.remove(runId) ?: return + latestSnapshotRunIdBySession.remove(snapshot.sessionId, runId) } fun getDiffData(sessionId: String, path: String): RollbackDiffData? { - val snapshot = snapshots[sessionId] ?: return null + val snapshot = snapshotForSession(sessionId) ?: return null val change = snapshot.changes[path] ?: return null if (change.kind != ChangeKind.DELETED && !isTrackable(path)) return null val beforeText = when (change.kind) { @@ -131,28 +144,19 @@ class RollbackService(private val project: Project) { } fun isRollbackAvailable(sessionId: String): Boolean = - snapshots[sessionId]?.changes?.isNotEmpty() == true && !activeRuns.containsKey(sessionId) + snapshotForSession(sessionId)?.changes?.isNotEmpty() == true && + !activeRunsBySession.containsKey(sessionId) fun isDisplayable(path: String): Boolean = isTrackable(path) suspend fun rollbackFile(sessionId: String, path: String): RollbackResult = withContext(Dispatchers.IO) { - val snapshot = snapshots[sessionId] + val snapshot = snapshotForSession(sessionId) ?: return@withContext RollbackResult.Failure("No rollback snapshot available") val change = snapshot.changes[path] ?: return@withContext RollbackResult.Failure("No change tracked for $path") - val errors = mutableListOf() - runInEdt(ModalityState.defaultModalityState()) { - runWriteAction { - try { - isApplyingRollback = true - applyChangeWithLabel(snapshot.labelRef, path, change, errors) - } finally { - isApplyingRollback = false - } - } - } + val errors = applyChanges(snapshot.labelRef, mapOf(path to change)) if (errors.isNotEmpty()) { RollbackResult.Failure(errors.joinToString("\n")) @@ -160,36 +164,38 @@ class RollbackService(private val project: Project) { val updated = snapshot.changes.toMutableMap() updated.remove(path) if (updated.isEmpty()) { - snapshots.remove(sessionId) + clearRunSnapshot(snapshot.runId) } else { - snapshots[sessionId] = snapshot.copy(changes = updated.toMap()) + snapshotsByRunId[snapshot.runId] = snapshot.copy(changes = updated.toMap()) } RollbackResult.Success("Rollback completed") } } - suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.Main) { - val snapshot = snapshots[sessionId] + suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.IO) { + val snapshot = snapshotForSession(sessionId) ?: return@withContext RollbackResult.Failure("No rollback snapshot available") - val errors = mutableListOf() - runInEdt(ModalityState.defaultModalityState()) { - runWriteAction { - try { - isApplyingRollback = true - snapshot.changes.forEach { (path, change) -> - applyChangeWithLabel(snapshot.labelRef, path, change, errors) - } - } finally { - isApplyingRollback = false - } - } - } + val errors = applyChanges(snapshot.labelRef, snapshot.changes) if (errors.isNotEmpty()) { RollbackResult.Failure(errors.joinToString("\n")) } else { - snapshots.remove(sessionId) + clearRunSnapshot(snapshot.runId) + RollbackResult.Success("Rollback completed") + } + } + + suspend fun rollbackRun(runId: String): RollbackResult = withContext(Dispatchers.IO) { + val snapshot = snapshotsByRunId[runId] + ?: return@withContext RollbackResult.Failure("No rollback snapshot available") + + val errors = applyChanges(snapshot.labelRef, snapshot.changes) + + if (errors.isNotEmpty()) { + RollbackResult.Failure(errors.joinToString("\n")) + } else { + clearRunSnapshot(runId) RollbackResult.Success("Rollback completed") } } @@ -269,8 +275,8 @@ class RollbackService(private val project: Project) { errors.add("Missing parent directory for $path") return@runCatching null } - val parentVf = VirtualFileManager.getInstance() - .findFileByUrl("file://$parentPath") ?: return@runCatching null + val parentVf = LocalFileSystem.getInstance() + .refreshAndFindFileByPath(parentPath) ?: return@runCatching null parentVf.createChildData(this, ioFile.name) }.getOrNull() @@ -345,7 +351,6 @@ class RollbackService(private val project: Project) { ): ByteArray? { val labelContent = runCatching { label.getByteContent(path) } .getOrNull()?.bytes - // If label content is null or empty, use fallback (captured at track time) return if (labelContent == null || labelContent.isEmpty()) { fallback } else { @@ -353,6 +358,26 @@ class RollbackService(private val project: Project) { } } + private fun applyChanges( + label: Label, + changes: Map + ): List { + val errors = mutableListOf() + runInEdt(ModalityState.defaultModalityState()) { + runWriteAction { + changes.forEach { (path, change) -> + applyChangeWithLabel(label, path, change, errors) + } + } + } + return errors + } + + private fun snapshotForSession(sessionId: String): SnapshotState? { + val runId = latestSnapshotRunIdBySession[sessionId] ?: return null + return snapshotsByRunId[runId] + } + companion object { fun getInstance(project: Project): RollbackService { return project.getService(RollbackService::class.java) @@ -360,6 +385,7 @@ class RollbackService(private val project: Project) { } private class RunTracker( + val runId: String, val sessionId: String, val startedAt: Instant, val label: String, @@ -378,24 +404,23 @@ class RollbackService(private val project: Project) { ) } - fun recordExplicitWrite(filePath: String, args: WriteTool.Args) { - val path = filePath.replace("\\", "/") - val existing = changes[path] - val file = File(path) - + fun recordExplicitWrite(filePath: String) { + val existing = changes[filePath] + val file = File(filePath) + if (existing?.kind == ChangeKind.DELETED) { - changes[path] = existing.copy(kind = ChangeKind.MODIFIED) + changes[filePath] = existing.copy(kind = ChangeKind.MODIFIED) return } if (existing == null) { val originalContent = if (file.exists()) { - runCatching { file.readText() }.getOrNull() + runCatching { file.readText(Charsets.UTF_8) }.getOrNull() } else null - changes[path] = TrackedChange( + changes[filePath] = TrackedChange( kind = if (file.exists()) ChangeKind.MODIFIED else ChangeKind.ADDED, originalPath = null, - originalContent = originalContent?.toByteArray() + originalContent = originalContent?.toByteArray(Charsets.UTF_8) ) } } @@ -428,6 +453,7 @@ class RollbackService(private val project: Project) { } private data class SnapshotState( + val runId: String, val sessionId: String, val label: String, val labelRef: Label, @@ -438,6 +464,7 @@ class RollbackService(private val project: Project) { fun toSnapshot(): RollbackSnapshot? { if (changes.isEmpty()) return null return RollbackSnapshot( + runId = runId, sessionId = sessionId, label = label, startedAt = startedAt, @@ -455,6 +482,7 @@ class RollbackService(private val project: Project) { } data class RollbackSnapshot( + val runId: String, val sessionId: String, val label: String, val startedAt: Instant, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashTool.kt index 88a3b0ab..332ac8da 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashTool.kt @@ -84,12 +84,12 @@ class BashTool( - If the commands depend on each other and must run sequentially, use a single Bash call with '&&' to chain them together (e.g., `git add . && git commit -m "message" && git push`). For instance, if one operation must complete before another starts (like mkdir before cp, Write before Bash for git operations, or git add before git commit), run these operations sequentially instead. - Use ';' only when you need to run commands sequentially but don't care if earlier commands fail - DO NOT use newlines to separate commands (newlines are ok in quoted strings) - - Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of `cd`. You may use `cd` if the User explicitly requests it. + - Commands run with the working directory already set to the project root. Prefer relative project paths (e.g., `src/...` or `./src/...`) instead of root-absolute paths like `/src` or hardcoded machine-specific paths like `/Users/.../project/...`. You may use `cd` only if the user explicitly requests it. - pytest /foo/bar/tests + find src -type f -name "*.kt" - cd /foo/bar && pytest tests + find /src -type f -name "*.kt" # Committing changes with git @@ -261,11 +261,10 @@ class BashTool( ) } - val isAskPattern = shouldAskForConfirmation(args.command) val resolvedToolId = toolId ?: throw IllegalArgumentException("Tool ID is missing") val confirmation = when { - isWhiteListed(args) && !isAskPattern -> ShellCommandConfirmation.Approved + isWhiteListed(args) -> ShellCommandConfirmation.Approved else -> confirmationHandler.requestConfirmation(args) } return when (confirmation) { @@ -377,7 +376,6 @@ class BashTool( } } } catch (_: IOException) { - // Ignore IO exception if the stream is closed } } @@ -391,7 +389,6 @@ class BashTool( } } } catch (_: IOException) { - // Ignore IO exception if the stream is closed } } @@ -463,8 +460,7 @@ class BashTool( } when (event) { WaitEvent.EXIT -> done = true - WaitEvent.ACTIVITY -> { /* reset idle timer */ - } + WaitEvent.ACTIVITY -> Unit null -> { timedOut.set(true) @@ -541,28 +537,6 @@ class BashTool( } } - private fun shouldAskForConfirmation(command: String): Boolean { - val askPatterns = listOf( - "find * -delete*", - "find * -exec*", - "find * -fprint*", - "find * -fls*", - "find * -fprintf*", - "find * -ok*", - "sort --output=*", - "sort -o *", - "tree -o *" - ) - - return askPatterns.any { pattern -> - if (pattern.contains("*")) { - command.startsWith(pattern.removeSuffix("*")) - } else { - command == pattern - } - } - } - private fun shouldBlockByIgnore(command: String): Boolean { val readers = setOf( "cat", @@ -607,8 +581,6 @@ class BashTool( lastWasReader = tokenIsReader } val settingsService = project.service() - // TODO(PROXYAI-IGNORE): Replace deny-style bash path checks with visibility filtering. - // Bash output and directory listings should hide ignored paths instead of returning policy-denied. return paths.any { candidate -> settingsService.isPathIgnored(toAbsolute(candidate)) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WriteTool.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WriteTool.kt index 8289fe17..e3778d20 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WriteTool.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/tools/WriteTool.kt @@ -8,11 +8,15 @@ import com.intellij.openapi.components.service import com.intellij.openapi.fileEditor.FileDocumentManager import com.intellij.openapi.project.Project import com.intellij.openapi.util.text.StringUtil +import com.intellij.openapi.vfs.LocalFileSystem import com.intellij.openapi.vfs.VirtualFileManager +import com.intellij.openapi.vfs.VfsUtil import com.intellij.psi.PsiDocumentManager import ee.carlrobert.codegpt.settings.ProxyAISettingsService import ee.carlrobert.codegpt.settings.hooks.HookManager import ee.carlrobert.codegpt.tokens.truncateToolResult +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import java.io.File @@ -122,43 +126,59 @@ class WriteTool( ) } - val fileUrl = file.toURI().toString() - val virtualFile = - VirtualFileManager.getInstance().findFileByUrl(fileUrl) - ?: if (isNewFile) { - VirtualFileManager.getInstance().refreshAndFindFileByUrl(fileUrl) - } else { - null - } - val bytesWritten = args.content.toByteArray(StandardCharsets.UTF_8).size val projectToUse = project + val normalizedPath = filePath.replace("\\", "/") + val fileUrl = "file://$normalizedPath" + + // Never do filesystem IO or synchronous VFS refreshes on the EDT. + // If we can resolve a VirtualFile+Document, use the Document API for proper IDE integration. + val vfm = VirtualFileManager.getInstance() + val virtualFile = if (!isNewFile) { + vfm.findFileByUrl(fileUrl) ?: vfm.refreshAndFindFileByUrl(fileUrl) + } else { + vfm.findFileByUrl(fileUrl) + } + if (virtualFile != null) { - val document = + val document = withContext(Dispatchers.Default) { runReadAction { FileDocumentManager.getInstance().getDocument(virtualFile) } - runInEdt { - if (document != null) { + } + if (document != null) { + runInEdt { runWriteAction { - PsiDocumentManager.getInstance(projectToUse).commitDocument(document) document.setText(StringUtil.convertLineSeparators(args.content)) + PsiDocumentManager.getInstance(projectToUse).commitDocument(document) FileDocumentManager.getInstance().saveDocument(document) } - } else { - runWriteAction { - virtualFile.setBinaryContent(args.content.toByteArray(StandardCharsets.UTF_8)) - } } + } else { + // For non-document-backed files, fall back to plain IO and refresh the specific VirtualFile. + withContext(Dispatchers.IO) { + file.writeBytes(args.content.toByteArray(StandardCharsets.UTF_8)) + } + VfsUtil.markDirtyAndRefresh(true, false, false, virtualFile) } - } else if (isNewFile) { - runInEdt { - runWriteAction { - file.writeText( - StringUtil.convertLineSeparators(args.content), - StandardCharsets.UTF_8 - ) - VirtualFileManager.getInstance().syncRefresh() - } + } else { + // Fallback for files not present in VFS (e.g., newly created or external paths). + withContext(Dispatchers.IO) { + file.writeText( + StringUtil.convertLineSeparators(args.content), + StandardCharsets.UTF_8 + ) + } + + val lfs = LocalFileSystem.getInstance() + val parentVf = file.parentFile?.let { parent -> + lfs.findFileByIoFile(parent) ?: lfs.refreshAndFindFileByIoFile(parent) + } + if (parentVf != null) { + // Reload the directory children so the newly created file shows up. + VfsUtil.markDirtyAndRefresh(true, false, true, parentVf) + } else { + // As a last resort, refresh just the file path. + lfs.refreshAndFindFileByIoFile(file) } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentApprovalManager.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentApprovalManager.kt index 696cb082..8ea17e08 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentApprovalManager.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentApprovalManager.kt @@ -8,6 +8,7 @@ import com.intellij.diff.requests.SimpleDiffRequest import com.intellij.diff.util.DiffUserDataKeys import com.intellij.diff.util.DiffUserDataKeysEx import com.intellij.icons.AllIcons +import com.intellij.openapi.Disposable import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.Presentation @@ -34,7 +35,9 @@ import ee.carlrobert.codegpt.agent.tools.EditTool import ee.carlrobert.codegpt.agent.tools.WriteTool import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.applyStringReplacement import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.getFileContentWithFallback +import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers import java.awt.BorderLayout import java.awt.FlowLayout import java.awt.Insets @@ -54,8 +57,9 @@ import javax.swing.JPanel */ class AgentApprovalManager( private val project: Project -) { +) : Disposable { private val approvalPopups = ConcurrentHashMap() + private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO) companion object { private val AGENT_DIFF_REQUEST_KEY: Key = Key.create("agent.approval.diffRequest") @@ -96,7 +100,7 @@ class AgentApprovalManager( args.filePath } - ApplicationManager.getApplication().executeOnPooledThread { + backgroundScope.launch { val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) val factory = DiffContentFactory.getInstance() @@ -340,4 +344,11 @@ class AgentApprovalManager( } } } + + override fun dispose() { + backgroundScope.dispose() + runInEdt { + approvalPopups.keys.toList().forEach(::closeDiffById) + } + } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt index fc299f06..7055e192 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt @@ -13,6 +13,7 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.LocalFileSystem import ee.carlrobert.codegpt.CodeGPTBundle import ee.carlrobert.codegpt.agent.* +import ee.carlrobert.codegpt.agent.history.CheckpointRef import ee.carlrobert.codegpt.agent.rollback.RollbackService import ee.carlrobert.codegpt.agent.tools.* import ee.carlrobert.codegpt.settings.service.ServiceType @@ -24,6 +25,7 @@ import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.* import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatMessageResponseBody import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatToolWindowScrollablePanel import ee.carlrobert.codegpt.ui.textarea.UserInputPanel +import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope import kotlinx.coroutines.* import java.awt.Component import java.util.* @@ -40,6 +42,8 @@ class AgentEventHandler( private val userInputPanel: UserInputPanel, private val onShowLoading: (String) -> Unit, private val onHideLoading: () -> Unit, + private val onRunFinishedCallback: () -> Unit = {}, + private val onRunCheckpointUpdatedCallback: (UUID, CheckpointRef?) -> Unit = { _, _ -> }, private val onQueuedMessagesResolved: (MessageWithContext) -> Unit = {} ) : AgentEvents, Disposable { @@ -67,6 +71,9 @@ class AgentEventHandler( @Volatile private var lastEditArgs: EditTool.Args? = null + @Volatile + private var currentRollbackRunId: String? = null + private val approvalQueue: ArrayDeque = ArrayDeque() @Volatile @@ -84,7 +91,7 @@ class AgentEventHandler( private var runViewHolder: RunViewHolder? = null private val subagentViewHolders = ConcurrentHashMap() - private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val serviceScope = DisposableCoroutineScope(Dispatchers.Default) data class ApprovalRequest( var model: ToolApprovalRequest, @@ -105,6 +112,7 @@ class AgentEventHandler( runViewHolder = null subagentViewHolders.clear() lastReportedPromptTokens = 0 + currentRollbackRunId = null } private fun clearApprovalContainer() { @@ -162,6 +170,10 @@ class AgentEventHandler( currentResponseBody = responseBody } + fun setCurrentRollbackRunId(runId: String?) { + currentRollbackRunId = runId + } + override fun onAgentCompleted(agentId: String) { runCatching { project.service().getTabbedPane() @@ -226,8 +238,12 @@ class AgentEventHandler( private fun handleDone() { runInEdt { - project.service().finishSession(sessionId) + currentRollbackRunId?.let { runId -> + project.service().finishRun(runId) + } ?: project.service().finishSession(sessionId) + currentRollbackRunId = null currentResponseBody?.finishThinking() + onRunFinishedCallback() onHideLoading() userInputPanel.setStopEnabled(false) scrollablePanel.update() @@ -483,6 +499,12 @@ class AgentEventHandler( onQueuedMessagesResolved(pendingMessage) } + override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) { + runInEdt { + onRunCheckpointUpdatedCallback(runMessageId, ref) + } + } + override fun onTokenUsageAvailable(tokenUsage: Long) { lastReportedPromptTokens = tokenUsage @@ -739,6 +761,8 @@ class AgentEventHandler( } override fun dispose() { + serviceScope.dispose() + agentApprovalManager.dispose() mainToolCards.clear() approvalQueue.clear() subagentViewHolders.clear() @@ -752,16 +776,26 @@ class AgentEventHandler( val documentText = vf?.let { file -> runReadAction { FileDocumentManager.getInstance().getDocument(file)?.text } } - documentText ?: java.io.File(normalizedPath).readText() + documentText ?: java.io.File(normalizedPath).readText(Charsets.UTF_8) }.getOrNull() ?: "" - project.service() - .trackEdit(sessionId, normalizedPath, originalContent) + val rollbackService = project.service() + val runId = currentRollbackRunId + if (runId != null) { + rollbackService.trackEditForRun(runId, normalizedPath, originalContent) + } else { + rollbackService.trackEdit(sessionId, normalizedPath, originalContent) + } } private fun trackWriteOperation(args: WriteTool.Args) { lastWriteArgs = args val normalizedPath = args.filePath.replace("\\", "/") - project.service() - .trackWrite(sessionId, normalizedPath, args) + val rollbackService = project.service() + val runId = currentRollbackRunId + if (runId != null) { + rollbackService.trackWriteForRun(runId, normalizedPath) + } else { + rollbackService.trackWrite(sessionId, normalizedPath) + } } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentMessageText.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentMessageText.kt new file mode 100644 index 00000000..797f4f47 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentMessageText.kt @@ -0,0 +1,9 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +internal object AgentMessageText { + fun abbreviate(text: String, maxLength: Int): String { + val normalized = text.replace("\\s+".toRegex(), " ").trim() + if (normalized.length <= maxLength) return normalized + return normalized.take(maxLength - 1).trimEnd() + "…" + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineController.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineController.kt new file mode 100644 index 00000000..a0511b9b --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineController.kt @@ -0,0 +1,1219 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import ai.koog.agents.snapshot.feature.AgentCheckpointData +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import com.intellij.icons.AllIcons +import com.intellij.openapi.Disposable +import com.intellij.openapi.application.runInEdt +import com.intellij.openapi.components.service +import com.intellij.openapi.project.Project +import com.intellij.openapi.ui.DialogWrapper +import com.intellij.openapi.ui.Messages +import com.intellij.ui.components.JBLabel +import com.intellij.util.ui.JBUI +import ee.carlrobert.codegpt.EncodingManager +import ee.carlrobert.codegpt.agent.AgentService +import ee.carlrobert.codegpt.agent.ProxyAIAgent.loadProjectInstructions +import ee.carlrobert.codegpt.agent.history.AgentCheckpointConversationMapper +import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService +import ee.carlrobert.codegpt.agent.history.AgentCheckpointTurnSequencer +import ee.carlrobert.codegpt.agent.history.CheckpointRef +import ee.carlrobert.codegpt.agent.rollback.ChangeKind +import ee.carlrobert.codegpt.agent.rollback.FileChange +import ee.carlrobert.codegpt.agent.rollback.RollbackResult +import ee.carlrobert.codegpt.agent.rollback.RollbackService +import ee.carlrobert.codegpt.conversations.Conversation +import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.toolwindow.chat.editor.actions.CopyAction +import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks +import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.serialization.json.* +import java.awt.BorderLayout +import java.awt.CardLayout +import java.awt.FlowLayout +import java.awt.event.ActionEvent +import java.util.* +import javax.swing.Action +import javax.swing.JButton +import javax.swing.JComponent +import javax.swing.JPanel +import ai.koog.prompt.message.Message as PromptMessage + +internal data class AgentTimelineRunState( + val rollbackRunId: String, + val sourceMessage: Message? +) + +internal class AgentSessionTimelineController( + private val project: Project, + private val agentSession: AgentSession, + private val conversation: Conversation, + private val runStateForRunIndex: (Int) -> AgentTimelineRunState?, + private val applySeededSessionState: (Conversation, CheckpointRef) -> Unit, + private val onAfterRollbackRefresh: () -> Unit +) : Disposable { + private val replayJson = Json { + ignoreUnknownKeys = true + isLenient = true + explicitNulls = false + } + private val rollbackService = RollbackService.getInstance(project) + private val historyService = project.service() + private val historicalRollbackSupport = + AgentSessionTimelineHistoricalRollbackSupport(project, historyService, replayJson) + private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO) + + private var sessionTimelinePointsCache: List? = null + + private fun launchBackground(block: suspend () -> Unit) { + backgroundScope.launch { + block() + } + } + + override fun dispose() { + backgroundScope.dispose() + } + + fun invalidateTimelineCache() { + sessionTimelinePointsCache = null + } + + fun showSessionStartTimelineDialog() { + val agentService = project.service() + val sessionId = agentSession.sessionId + if (agentService.isSessionRunning(sessionId)) { + Messages.showErrorDialog( + project, + "Stop the active run before opening timeline context editor.", + "Agent Timeline" + ) + return + } + + launchBackground { + val payload = runCatching { + val points = loadSessionTimelinePoints() + val timelineBaseHistory = loadSessionTimelineBaseHistory() + val snapshot = loadSessionContextSnapshot(timelineBaseHistory) + points to snapshot + }.getOrNull() + val points = payload?.first.orEmpty() + if (points.isEmpty()) { + runInEdt { + Messages.showErrorDialog( + project, + "No timeline points are available for this session yet.", + "Start New Session" + ) + } + return@launchBackground + } + + runInEdt { + val snapshot = payload?.second + showTimelineDialog( + points = points, + onSelect = { point -> + startNewSessionFromTimelinePoint(point) + }, + reloadPoints = { loadSessionTimelinePoints(forceRefresh = true) }, + contextSnapshot = snapshot + ) + } + } + } + + private fun loadSessionContextSnapshot( + timelineBaseHistory: List? = null + ): SessionContextSnapshot? { + if (timelineBaseHistory != null) { + val baseHistory = timelineBaseHistory.toList() + if (baseHistory.none { it !is PromptMessage.System }) { + return null + } + return SessionContextSnapshot(baseHistory = baseHistory) + } + + val fallbackHistory = buildPromptHistoryFromConversation(conversation.messages) + if (fallbackHistory.none { it !is PromptMessage.System }) { + return null + } + return SessionContextSnapshot(baseHistory = fallbackHistory) + } + + private suspend fun loadSessionTimelineBaseHistory(): List? { + val agentId = resolveTimelineAgentId() ?: return null + val checkpoints = loadTimelineCheckpoints(agentId) + return checkpoints.lastOrNull()?.messageHistory?.toList() + } + + private fun resolveTimelineAgentId(): String? { + val resumeRef = agentSession.resumeCheckpointRef + return resumeRef?.agentId ?: agentSession.runtimeAgentId + } + + private suspend fun loadTimelineCheckpoints(agentId: String): List { + val resumeRef = agentSession.resumeCheckpointRef + var checkpoints = historyService.listCheckpoints(agentId).sortedBy { it.createdAt } + if (resumeRef != null) { + val anchorIndex = checkpoints.indexOfFirst { it.checkpointId == resumeRef.checkpointId } + if (anchorIndex >= 0) { + checkpoints = checkpoints.subList(0, anchorIndex + 1) + } + } + return checkpoints + } + + private fun buildPromptHistoryFromConversation(messages: List): List { + return buildList { + messages.forEach { message -> + val prompt = message.prompt.orEmpty().trim() + if (prompt.isBlank()) return@forEach + add(PromptMessage.User(prompt, RequestMetaInfo.Empty)) + + val response = message.response.orEmpty().stripThinkingBlocks().trim() + if (response.isNotBlank()) { + add(PromptMessage.Assistant(response, ResponseMetaInfo.Empty)) + } + } + } + } + + private fun applyEditedSessionContext( + baseHistory: List, + selectedNonSystemMessageCounts: Set, + alwaysIncludedNonSystemMessageCounts: Set + ) { + launchBackground { + val payload = runCatching { + buildSeededContextPayload( + baseHistory = baseHistory, + selectedNonSystemMessageCounts = selectedNonSystemMessageCounts, + alwaysIncludedNonSystemMessageCounts = alwaysIncludedNonSystemMessageCounts + ) + }.getOrNull() + + runInEdt { + if (payload == null) { + Messages.showErrorDialog( + project, + "Unable to apply session context changes.", + "Edit Session Context" + ) + return@runInEdt + } + + val (seededConversation, seedRef) = payload + invalidateTimelineCache() + applySeededSessionState(seededConversation, seedRef) + } + } + } + + private fun startNewSessionFromEditedContext( + baseHistory: List, + selectedNonSystemMessageCounts: Set, + alwaysIncludedNonSystemMessageCounts: Set + ) { + launchBackground { + val payload = runCatching { + buildSeededContextPayload( + baseHistory = baseHistory, + selectedNonSystemMessageCounts = selectedNonSystemMessageCounts, + alwaysIncludedNonSystemMessageCounts = alwaysIncludedNonSystemMessageCounts + ) + }.getOrNull() + + runInEdt { + if (payload == null) { + Messages.showErrorDialog( + project, + "Unable to create a new session from selected checkpoints.", + "Agent Timeline" + ) + return@runInEdt + } + + val (seededConversation, seedRef) = payload + val newSession = AgentSession( + sessionId = UUID.randomUUID().toString(), + conversation = seededConversation, + runtimeAgentId = seedRef.agentId, + resumeCheckpointRef = seedRef + ) + project.service() + .createNewAgentTab(newSession, select = true) + } + } + } + + private suspend fun buildSeededContextPayload( + baseHistory: List, + selectedNonSystemMessageCounts: Set, + alwaysIncludedNonSystemMessageCounts: Set + ): Pair? { + val history = rebuildHistoryFromEditedContext( + baseHistory = baseHistory, + selectedNonSystemMessageCounts = selectedNonSystemMessageCounts, + alwaysIncludedNonSystemMessageCounts = alwaysIncludedNonSystemMessageCounts + ) + return createSeededConversationFromHistory(history) + } + + private suspend fun createSeededConversationFromHistory(history: List): Pair? { + if (history.none { it !is PromptMessage.System }) return null + + val seedRef = project.service() + .createSeedCheckpointFromHistory(history) + ?: return null + val seedCheckpoint = historyService.loadCheckpoint(seedRef) ?: return null + val seededConversation = AgentCheckpointConversationMapper.toConversation( + checkpoint = seedCheckpoint, + projectInstructions = loadProjectInstructions(project.basePath) + ) + return seededConversation to seedRef + } + + private fun rebuildHistoryFromEditedContext( + baseHistory: List, + selectedNonSystemMessageCounts: Set, + alwaysIncludedNonSystemMessageCounts: Set + ): List { + val includeIndexes = (selectedNonSystemMessageCounts + alwaysIncludedNonSystemMessageCounts) + .filter { it > 0 } + .toMutableSet() + if (includeIndexes.isEmpty()) { + return baseHistory.filterIsInstance() + } + + val nonSystemMessages = baseHistory.filterNot { it is PromptMessage.System } + val selectedToolCallIds = mutableSetOf() + + includeIndexes.toList().forEach { oneBasedIndex -> + val zeroBasedIndex = oneBasedIndex - 1 + if (zeroBasedIndex !in nonSystemMessages.indices) return@forEach + if (nonSystemMessages[zeroBasedIndex] is PromptMessage.User) return@forEach + + var cursor = zeroBasedIndex - 1 + while (cursor >= 0) { + if (nonSystemMessages[cursor] is PromptMessage.User) { + includeIndexes += cursor + 1 + break + } + cursor -= 1 + } + } + + nonSystemMessages.forEachIndexed { index, message -> + val nonSystemIndex = index + 1 + if (nonSystemIndex !in includeIndexes) return@forEachIndexed + + val toolCall = message as? PromptMessage.Tool.Call ?: return@forEachIndexed + val callId = toolCall.id?.takeIf { it.isNotBlank() } + if (callId != null) selectedToolCallIds += callId + + var cursor = index + 1 + while (cursor < nonSystemMessages.size) { + val nextResult = nonSystemMessages[cursor] as? PromptMessage.Tool.Result ?: break + val nextResultId = nextResult.id?.takeIf { it.isNotBlank() } + if (callId != null && nextResultId != null && nextResultId != callId) break + includeIndexes += cursor + 1 + cursor += 1 + } + } + + if (selectedToolCallIds.isNotEmpty()) { + nonSystemMessages.forEachIndexed { index, message -> + val result = message as? PromptMessage.Tool.Result ?: return@forEachIndexed + val resultId = result.id?.takeIf { it.isNotBlank() } ?: return@forEachIndexed + if (resultId in selectedToolCallIds) includeIndexes += index + 1 + } + } + + val rebuilt = mutableListOf() + var nonSystemIndex = 0 + baseHistory.forEach { message -> + if (message is PromptMessage.System) { + rebuilt += message + return@forEach + } + nonSystemIndex += 1 + if (nonSystemIndex in includeIndexes) rebuilt += message + } + return rebuilt + } + + private fun computeAlwaysIncludedNonSystemMessageCounts( + baseHistory: List, + points: List + ): Set { + val selectableNonSystemMessageCounts = points + .mapNotNull { it.nonSystemMessageCount } + .filter { it > 0 } + .toSet() + val nonSystemMessages = baseHistory.filterNot { it is PromptMessage.System } + + return nonSystemMessages.mapIndexedNotNull { index, message -> + val nonSystemMessageCount = index + 1 + if (nonSystemMessageCount in selectableNonSystemMessageCounts) return@mapIndexedNotNull null + if (message is PromptMessage.Tool.Result && !isInternalTimelineTool(message.tool)) { + return@mapIndexedNotNull null + } + nonSystemMessageCount + }.toSet() + } + + private fun estimatePromptHistoryTokenCount(history: List): Int { + val encodingManager = EncodingManager.getInstance() + return history.sumOf { message -> + val text = when (message) { + is PromptMessage.System -> message.content + is PromptMessage.User -> message.content + is PromptMessage.Assistant -> message.content + is PromptMessage.Reasoning -> message.content + is PromptMessage.Tool.Call -> message.content + is PromptMessage.Tool.Result -> message.content + } + if (text.isBlank()) 0 else encodingManager.countTokens(text) + } + } + + private suspend fun loadSessionTimelinePoints(forceRefresh: Boolean = false): List { + if (!forceRefresh) { + sessionTimelinePointsCache?.let { return it } + } + + val points = mutableListOf() + val agentId = resolveTimelineAgentId() + + if (!agentId.isNullOrBlank()) { + val checkpoints = loadTimelineCheckpoints(agentId) + points += buildSessionTimelinePointsFromCheckpoints(agentId, checkpoints) + } + + sessionTimelinePointsCache = points + return points + } + + private fun buildSessionTimelinePointsFromCheckpoints( + agentId: String, + checkpoints: List + ): List { + val checkpoint = checkpoints.lastOrNull() ?: return emptyList() + val projectInstructions = loadProjectInstructions(project.basePath) + val checkpointRef = CheckpointRef(agentId, checkpoint.checkpointId) + val turns = AgentCheckpointTurnSequencer.toVisibleTurns( + history = checkpoint.messageHistory, + projectInstructions = projectInstructions, + preserveSyntheticContinuation = true + ) + if (turns.isEmpty()) return emptyList() + + val points = mutableListOf() + val nonSystemMessages = checkpoint.messageHistory.filterNot { it is PromptMessage.System } + turns.forEachIndexed { index, turn -> + val runIndex = index + 1 + val runLabel = + resolveRunLabelForTurn(nonSystemMessages = nonSystemMessages, turn = turn) + + points.add( + RunTimelinePoint( + cacheKey = "${checkpoint.checkpointId}:${turn.userNonSystemMessageCount}", + checkpointRef = checkpointRef, + title = "User message", + subtitle = AgentMessageText.abbreviate(turn.prompt, 120), + runLabel = runLabel, + icon = AllIcons.General.User, + nonSystemMessageCount = turn.userNonSystemMessageCount, + kind = TimelinePointKind.USER, + runIndex = runIndex + ) + ) + + turn.events.forEach { event -> + toTimelinePointFromTurnEvent( + checkpoint = checkpoint, + checkpointRef = checkpointRef, + event = event, + runIndex = runIndex, + runLabel = runLabel + )?.let(points::add) + } + } + return points + } + + private fun resolveRunLabelForTurn( + nonSystemMessages: List, + turn: AgentCheckpointTurnSequencer.Turn + ): String { + val userMessageText = AgentMessageText.abbreviate(turn.prompt, 80) + val userIndex = turn.userNonSystemMessageCount - 1 + if (userIndex !in nonSystemMessages.indices) return userMessageText + + var cursor = userIndex + 1 + while (cursor < nonSystemMessages.size) { + val message = nonSystemMessages[cursor] + if (message is PromptMessage.User) break + + val call = message as? PromptMessage.Tool.Call + if (call != null && isTodoWriteTool(call.tool)) { + extractTodoWriteRunLabel(call.content)?.let { return it } + } + cursor += 1 + } + + return userMessageText + } + + private fun toTimelinePointFromTurnEvent( + checkpoint: AgentCheckpointData, + checkpointRef: CheckpointRef, + event: AgentCheckpointTurnSequencer.TurnEvent, + runIndex: Int, + runLabel: String + ): RunTimelinePoint? { + return when (event) { + is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> { + val text = event.content.stripThinkingBlocks().trim() + if (text.isBlank()) return null + RunTimelinePoint( + cacheKey = "${checkpoint.checkpointId}:${event.nonSystemMessageCount}", + checkpointRef = checkpointRef, + title = "Assistant response", + subtitle = AgentMessageText.abbreviate(text, 120), + runLabel = runLabel, + icon = AllIcons.General.Balloon, + nonSystemMessageCount = event.nonSystemMessageCount, + outputText = text, + kind = TimelinePointKind.ASSISTANT, + runIndex = runIndex + ) + } + + is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> { + val text = event.content.stripThinkingBlocks().trim() + if (text.isBlank()) return null + RunTimelinePoint( + cacheKey = "${checkpoint.checkpointId}:${event.nonSystemMessageCount}", + checkpointRef = checkpointRef, + title = "Assistant reasoning", + subtitle = AgentMessageText.abbreviate(text, 120), + runLabel = runLabel, + icon = AllIcons.General.ContextHelp, + nonSystemMessageCount = event.nonSystemMessageCount, + outputText = text, + kind = TimelinePointKind.REASONING, + runIndex = runIndex + ) + } + + is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> { + val toolName = event.tool.ifBlank { "Tool" } + if (isInternalTimelineTool(toolName)) return null + RunTimelinePoint( + cacheKey = "${checkpoint.checkpointId}:${event.nonSystemMessageCount}", + checkpointRef = checkpointRef, + title = toolName, + subtitle = extractToolCallSubtitle(toolName, event.content), + runLabel = runLabel, + icon = iconForTool(toolName), + nonSystemMessageCount = event.nonSystemMessageCount, + toolCallId = event.id, + kind = TimelinePointKind.TOOL_CALL, + runIndex = runIndex + ) + } + + is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> null + } + } + + private fun showTimelineDialog( + points: List, + onSelect: (RunTimelinePoint) -> Unit, + reloadPoints: (suspend () -> List)?, + contextSnapshot: SessionContextSnapshot? + ) { + var dialog: DialogWrapper? = null + var onSelectionStateChanged: (() -> Unit)? = null + val ui = AgentSessionTimelineDialogBuilder.build( + points = points, + onSelect = onSelect, + reloadPoints = reloadPoints, + contextSelectionEnabled = contextSnapshot != null, + onNonCopyMenuAction = { + dialog?.close(DialogWrapper.CANCEL_EXIT_CODE) + }, + onSelectionStateChanged = { + onSelectionStateChanged?.invoke() + }, + timelinePointDisplayLabel = ::timelinePointDisplayLabel, + canRollbackTimelinePoint = ::canRollbackTimelinePoint, + rollbackToTimelinePoint = ::rollbackToTimelinePoint, + canCopyTimelinePointOutput = ::canCopyTimelinePointOutput, + copyTimelinePointOutput = ::copyTimelinePointOutput, + launchBackground = ::launchBackground + ) + + if (contextSnapshot != null) { + val alwaysIncludedNonSystemMessageCounts = computeAlwaysIncludedNonSystemMessageCounts( + baseHistory = contextSnapshot.baseHistory, + points = points + ) + + fun checkedMessageCountsOrError(): Set? { + val checkedMessageCounts = ui.getCheckedNonSystemMessageCounts() + if (checkedMessageCounts.isEmpty()) { + Messages.showErrorDialog( + project, + "Check at least one run or checkpoint.", + "Agent Timeline" + ) + return null + } + return checkedMessageCounts + } + + val contextStatsLabel = JBLabel().apply { + foreground = JBUI.CurrentTheme.Label.disabledForeground() + font = JBUI.Fonts.smallFont() + isVisible = true + } + + fun updateContextStats() { + val history = if (ui.isEditMode()) { + val selectedCounts = ui.getCheckedNonSystemMessageCounts() + rebuildHistoryFromEditedContext( + baseHistory = contextSnapshot.baseHistory, + selectedNonSystemMessageCounts = selectedCounts, + alwaysIncludedNonSystemMessageCounts = alwaysIncludedNonSystemMessageCounts + ) + } else { + contextSnapshot.baseHistory + } + val messageCount = history.count { it !is PromptMessage.System } + val tokenCount = estimatePromptHistoryTokenCount(history) + val messageSuffix = if (messageCount == 1) "message" else "messages" + val tokenText = "%,d".format(tokenCount) + contextStatsLabel.text = if (ui.isEditMode()) { + "Selected: $messageCount $messageSuffix • ~$tokenText tokens" + } else { + "Context: $messageCount $messageSuffix • ~$tokenText tokens" + } + } + + onSelectionStateChanged = { updateContextStats() } + + dialog = object : DialogWrapper(project, true) { + private val closeAction = object : DialogWrapperAction("Close") { + override fun doAction(e: ActionEvent?) { + close(CANCEL_EXIT_CODE) + } + } + private val editAction = object : DialogWrapperAction("Edit") { + override fun doAction(e: ActionEvent?) { + ui.setEditMode(true) + syncActions() + } + } + private val cancelAction = object : DialogWrapperAction("Cancel") { + override fun doAction(e: ActionEvent?) { + ui.setEditMode(false) + syncActions() + } + } + private val editSessionAction = object : DialogWrapperAction("Apply") { + override fun doAction(e: ActionEvent?) { + val selectedMessageCounts = checkedMessageCountsOrError() ?: return + close(OK_EXIT_CODE) + applyEditedSessionContext( + baseHistory = contextSnapshot.baseHistory, + selectedNonSystemMessageCounts = selectedMessageCounts, + alwaysIncludedNonSystemMessageCounts = alwaysIncludedNonSystemMessageCounts + ) + } + } + private val newSessionAction = object : DialogWrapperAction("New Session") { + override fun doAction(e: ActionEvent?) { + val selectedMessageCounts = checkedMessageCountsOrError() ?: return + close(OK_EXIT_CODE) + startNewSessionFromEditedContext( + baseHistory = contextSnapshot.baseHistory, + selectedNonSystemMessageCounts = selectedMessageCounts, + alwaysIncludedNonSystemMessageCounts = alwaysIncludedNonSystemMessageCounts + ) + } + } + private var closeButton: JButton? = null + private var editButton: JButton? = null + private var newSessionButton: JButton? = null + private lateinit var actionCards: JPanel + private lateinit var actionCardLayout: CardLayout + + private fun syncActions() { + val inEdit = ui.isEditMode() + if (::actionCards.isInitialized) { + actionCardLayout.show(actionCards, if (inEdit) "edit" else "view") + } + rootPane?.defaultButton = + if (inEdit) newSessionButton else editButton ?: closeButton + actionCards.revalidate() + actionCards.repaint() + updateContextStats() + } + + init { + title = "Agent Timeline" + isResizable = true + init() + syncActions() + } + + override fun createCenterPanel(): JComponent { + return JPanel(BorderLayout()).apply { + border = JBUI.Borders.empty(8, 10, 0, 10) + add(ui.component, BorderLayout.CENTER) + } + } + + override fun createSouthPanel(): JComponent { + val closeBtn = createJButtonForAction(closeAction).also { closeButton = it } + val editBtn = createJButtonForAction(editAction).also { editButton = it } + val cancelBtn = createJButtonForAction(cancelAction) + val editSessionBtn = createJButtonForAction(editSessionAction) + val newSessionBtn = createJButtonForAction(newSessionAction).also { + newSessionButton = it + }.apply { + putClientProperty("JButton.buttonType", "default") + } + + val viewButtons = JPanel(FlowLayout(FlowLayout.RIGHT, 8, 0)).apply { + isOpaque = false + add(closeBtn) + add(editBtn) + } + val editButtons = JPanel(FlowLayout(FlowLayout.RIGHT, 8, 0)).apply { + isOpaque = false + add(cancelBtn) + add(editSessionBtn) + add(newSessionBtn) + } + + actionCardLayout = CardLayout() + actionCards = JPanel(actionCardLayout).apply { + isOpaque = false + add(viewButtons, "view") + add(editButtons, "edit") + } + syncActions() + + return JPanel(BorderLayout()).apply { + isOpaque = false + border = JBUI.Borders.empty(8, 10, 8, 10) + add(contextStatsLabel, BorderLayout.WEST) + add(actionCards, BorderLayout.EAST) + } + } + + override fun createActions(): Array = emptyArray() + } + dialog.show() + return + } + + dialog = object : DialogWrapper(project, true) { + init { + title = "Agent Timeline" + isResizable = true + init() + } + + override fun createCenterPanel(): JComponent { + return JPanel(BorderLayout()).apply { + border = JBUI.Borders.empty(8, 10, 6, 10) + add(ui.component, BorderLayout.CENTER) + } + } + + override fun createActions(): Array = emptyArray() + } + dialog.show() + } + + private fun startNewSessionFromTimelinePoint(point: RunTimelinePoint) { + launchBackground { + val payload = runCatching { + val checkpointRef = point.checkpointRef ?: return@runCatching null + val checkpoint = + historyService.loadCheckpoint(checkpointRef) ?: return@runCatching null + val messageHistory = trimHistoryByNonSystemCount( + checkpoint.messageHistory, + point.nonSystemMessageCount + ) + if (messageHistory.none { it !is PromptMessage.System }) return@runCatching null + + createSeededConversationFromHistory(messageHistory) + }.getOrNull() + + if (payload == null) { + runInEdt { + Messages.showErrorDialog( + project, + "Unable to start a new session from the selected timeline point.", + "Start New Session" + ) + } + return@launchBackground + } + + runInEdt { + val (seededConversation, seedRef) = payload + val newSession = AgentSession( + sessionId = UUID.randomUUID().toString(), + conversation = seededConversation, + runtimeAgentId = seedRef.agentId, + resumeCheckpointRef = seedRef + ) + project.service() + .createNewAgentTab(newSession, select = true) + } + } + } + + private fun trimHistoryByNonSystemCount( + history: List, + nonSystemMessageCount: Int? + ): List { + val limit = nonSystemMessageCount ?: return history + if (limit <= 0) return history.filterIsInstance() + + val trimmed = mutableListOf() + var remaining = limit + history.forEach { message -> + if (message is PromptMessage.System) { + trimmed.add(message) + return@forEach + } + if (remaining <= 0) return@forEach + trimmed.add(message) + remaining -= 1 + } + return trimmed + } + + private fun isInternalTimelineTool(toolName: String): Boolean { + return AgentCheckpointTurnSequencer.isTodoWriteTool(toolName) + } + + private fun isTodoWriteTool(toolName: String): Boolean { + return toolName.equals("TodoWrite", ignoreCase = true) || + toolName.equals("TodoWriteTool", ignoreCase = true) + } + + private fun extractTodoWriteRunLabel(rawArgs: String): String? { + val argsObject = + runCatching { replayJson.parseToJsonElement(rawArgs).jsonObject }.getOrNull() + ?: return null + + listOf("title", "short_description", "description", "summary", "task").forEach { key -> + val value = stringValue(argsObject[key])?.trim().orEmpty() + if (value.isNotBlank()) return AgentMessageText.abbreviate(value, 80) + } + + val todos = argsObject["todos"]?.let { element -> + runCatching { element.jsonArray }.getOrNull() + }.orEmpty() + todos.forEach { todo -> + val todoObject = runCatching { todo.jsonObject }.getOrNull() ?: return@forEach + listOf("content", "title", "task", "description").forEach { key -> + val value = stringValue(todoObject[key])?.trim().orEmpty() + if (value.isNotBlank()) return AgentMessageText.abbreviate(value, 80) + } + } + + return null + } + + private fun iconForTool(toolName: String): javax.swing.Icon { + val normalized = toolName.lowercase() + return when { + normalized.contains("read") -> AllIcons.Actions.Show + normalized.contains("bash") || normalized.contains("shell") -> AllIcons.Nodes.Console + normalized.contains("search") -> AllIcons.Actions.Search + normalized.contains("web") -> AllIcons.General.Web + normalized.contains("write") || normalized.contains("edit") -> AllIcons.Actions.Edit + else -> AllIcons.Nodes.Function + } + } + + private fun extractToolCallSubtitle(toolName: String, rawArgs: String): String { + val argsObject = + runCatching { replayJson.parseToJsonElement(rawArgs).jsonObject }.getOrNull() + if (argsObject == null) return AgentMessageText.abbreviate(rawArgs, 140) + + val preferredKeys = when { + toolName.equals("Read", ignoreCase = true) -> listOf( + "file_path", + "path", + "pathInProject" + ) + + toolName.contains("Bash", ignoreCase = true) -> listOf("command", "cmd") + toolName.contains("Search", ignoreCase = true) -> listOf( + "query", + "searchText", + "regexPattern", + "pattern", + "q" + ) + + isTodoWriteTool(toolName) -> listOf( + "title", + "short_description", + "description", + "summary", + "task" + ) + + else -> listOf( + "short_description", + "description", + "file_path", + "path", + "query", + "command" + ) + } + + preferredKeys.forEach { key -> + val value = stringValue(argsObject[key]) ?: return@forEach + if (value.isNotBlank()) return AgentMessageText.abbreviate(value, 140) + } + + val firstEntryValue = argsObject.entries.asSequence() + .mapNotNull { (_, value) -> stringValue(value) } + .firstOrNull { it.isNotBlank() } + ?.let { AgentMessageText.abbreviate(it, 140) } + if (!firstEntryValue.isNullOrBlank()) return firstEntryValue + + return AgentMessageText.abbreviate(rawArgs, 140) + } + + private fun stringValue(element: JsonElement?): String? { + if (element == null) return null + val primitive = element as? JsonPrimitive + return if (primitive != null && primitive.isString) primitive.content else element.toString() + } + + private fun timelinePointDisplayLabel(point: RunTimelinePoint): String { + val runNumber = if (point.runIndex <= 0) 1 else point.runIndex + val base = point.title.ifBlank { "Timeline point" } + val subtitle = point.subtitle.trim() + return if (subtitle.isBlank()) { + "Run $runNumber • $base" + } else { + "Run $runNumber • $base: ${AgentMessageText.abbreviate(subtitle, 80)}" + } + } + + private fun canRollbackTimelinePoint(point: RunTimelinePoint): Boolean { + val runState = runStateForTimelinePoint(point) + if (runState != null && runState.rollbackRunId.isNotBlank()) { + if (rollbackService.getRunSnapshot(runState.rollbackRunId)?.changes?.isNotEmpty() == true) { + return true + } + } + return point.checkpointRef != null && point.nonSystemMessageCount != null + } + + private fun runStateForTimelinePoint(point: RunTimelinePoint): AgentTimelineRunState? { + val runNumber = if (point.runIndex <= 0) 1 else point.runIndex + return runStateForRunIndex(runNumber) + } + + private fun rollbackToTimelinePoint( + point: RunTimelinePoint, + selectedLabel: String, + onCompleted: (Boolean) -> Unit = {} + ) { + val runState = runStateForTimelinePoint(point) + val rollbackRunId = runState?.rollbackRunId?.takeIf { it.isNotBlank() } + if (rollbackRunId != null) { + val snapshot = rollbackService.getRunSnapshot(rollbackRunId) + val displayableChanges = snapshot?.changes + ?.filter { rollbackService.isDisplayable(it.path) } + .orEmpty() + if (displayableChanges.isNotEmpty()) { + val confirm = Messages.showYesNoDialog( + project, + buildRollbackConfirmationText(selectedLabel, displayableChanges), + "Rollback", + "Rollback", + "Cancel", + AllIcons.General.WarningDialog + ) + if (confirm != Messages.YES) return + + launchBackground { + val result = rollbackService.rollbackRun(rollbackRunId) + runInEdt { + when (result) { + is RollbackResult.Success -> { + if (point.checkpointRef != null && point.nonSystemMessageCount != null) { + syncCurrentSessionViewToTimelinePoint(point, onCompleted) + } else { + onAfterRollbackRefresh() + onCompleted(true) + } + } + + is RollbackResult.Failure -> { + onCompleted(false) + Messages.showErrorDialog(project, result.message, "Rollback Failed") + } + } + } + } + return + } + } + + val checkpointRef = point.checkpointRef + if (checkpointRef == null || point.nonSystemMessageCount == null) { + Messages.showErrorDialog( + project, + "Rollback is not available for the selected entry.", + "Rollback" + ) + onCompleted(false) + return + } + + launchBackground { + val operations = historicalRollbackSupport.collectOperations(point) + runInEdt { + if (operations.isEmpty()) { + val confirm = Messages.showYesNoDialog( + project, + """ + This will rewind the current session to: + $selectedLabel + + Continue? + """.trimIndent(), + "Rollback", + "Rollback", + "Cancel", + AllIcons.General.WarningDialog + ) + if (confirm != Messages.YES) { + onCompleted(false) + return@runInEdt + } + syncCurrentSessionViewToTimelinePoint(point, onCompleted) + return@runInEdt + } + + val confirm = Messages.showYesNoDialog( + project, + historicalRollbackSupport.buildConfirmationText(selectedLabel, operations), + "Rollback", + "Rollback", + "Cancel", + AllIcons.General.WarningDialog + ) + if (confirm != Messages.YES) { + onCompleted(false) + return@runInEdt + } + + launchBackground { + val errors = historicalRollbackSupport.applyOperations(operations) + runInEdt { + syncCurrentSessionViewToTimelinePoint(point, onCompleted) + if (errors.isNotEmpty()) { + val details = errors.take(10).joinToString(separator = "\n") + val suffix = + if (errors.size > 10) "\n...and ${errors.size - 10} more error(s)." else "" + Messages.showErrorDialog(project, "$details$suffix", "Rollback Failed") + } + } + } + } + } + } + + private fun syncCurrentSessionViewToTimelinePoint( + point: RunTimelinePoint, + onCompleted: (Boolean) -> Unit = {} + ) { + launchBackground { + val payload = + runCatching { buildSessionStateFromTimelinePoint(point) }.getOrNull() + + runInEdt { + if (payload == null) { + onAfterRollbackRefresh() + onCompleted(false) + return@runInEdt + } + + val (seededConversation, seedRef) = payload + invalidateTimelineCache() + applySeededSessionState(seededConversation, seedRef) + onCompleted(true) + } + } + } + + private suspend fun buildSessionStateFromTimelinePoint( + point: RunTimelinePoint + ): Pair? { + val checkpointRef = point.checkpointRef ?: return null + val checkpoint = historyService.listCheckpoints(checkpointRef.agentId).firstOrNull() + ?: historyService.loadCheckpoint(checkpointRef) + ?: return null + val messageHistory = + trimHistoryByNonSystemCount(checkpoint.messageHistory, point.nonSystemMessageCount) + if (messageHistory.none { it !is PromptMessage.System }) return null + + return createSeededConversationFromHistory(messageHistory) + } + + private fun buildRollbackConfirmationText( + selectedLabel: String, + changes: List + ): String { + val previewLimit = 12 + val listedChanges = changes.take(previewLimit).joinToString(separator = "\n") { change -> + val symbol = when (change.kind) { + ChangeKind.ADDED -> "+" + ChangeKind.DELETED -> "-" + ChangeKind.MODIFIED -> "~" + ChangeKind.MOVED -> "~" + } + val basePath = toProjectRelativePath(change.path) + if (change.kind == ChangeKind.MOVED && !change.originalPath.isNullOrBlank()) { + val fromPath = toProjectRelativePath(change.originalPath) + "$symbol $basePath (renamed from $fromPath)" + } else { + "$symbol $basePath" + } + } + + val remaining = changes.size - previewLimit + val suffix = if (remaining > 0) "\n...and $remaining more file(s)." else "" + + return """ + This rollback will return the session to: + $selectedLabel + + It will revert ${changes.size} file change(s): + + $listedChanges$suffix + """.trimIndent() + } + + private fun toProjectRelativePath(path: String): String { + val basePath = project.basePath?.replace("\\", "/") ?: return path.replace("\\", "/") + val normalizedPath = path.replace("\\", "/") + return if (normalizedPath.startsWith(basePath)) { + normalizedPath.removePrefix(basePath).trimStart('/') + } else { + normalizedPath + } + } + + private fun canCopyTimelinePointOutput(point: RunTimelinePoint): Boolean { + return when (point.kind) { + TimelinePointKind.ASSISTANT, + TimelinePointKind.REASONING, + TimelinePointKind.TOOL_CALL -> true + + else -> false + } + } + + private fun copyTimelinePointOutput(point: RunTimelinePoint, selectedLabel: String) { + launchBackground { + val output = + runCatching { resolveTimelinePointOutput(point) }.getOrNull() + + runInEdt { + if (output.isNullOrBlank()) { + Messages.showErrorDialog( + project, + "No output found for \"$selectedLabel\".", + "Copy Output" + ) + return@runInEdt + } + CopyAction.copyToClipboard(output) + } + } + } + + private suspend fun resolveTimelinePointOutput(point: RunTimelinePoint): String? { + point.outputText?.trim()?.takeIf { it.isNotBlank() }?.let { return it } + + if (point.kind == TimelinePointKind.TOOL_CALL) { + if (point.checkpointRef == null) { + val sourceMessage = runStateForTimelinePoint(point)?.sourceMessage + val output = point.toolCallId + ?.let { callId -> sourceMessage?.toolCallResults?.get(callId) } + ?.trim() + if (!output.isNullOrBlank()) return output + + return sourceMessage?.toolCallResults + ?.values + ?.lastOrNull { !it.isNullOrBlank() } + ?.trim() + } + + val checkpointRef = point.checkpointRef + val index = point.nonSystemMessageCount?.minus(1) ?: return null + if (index < 0) return null + + val checkpoint = historyService.loadCheckpoint(checkpointRef) ?: return null + val history = checkpoint.messageHistory.filterNot { it is PromptMessage.System } + if (index !in history.indices) return null + + val selected = history[index] as? PromptMessage.Tool.Call ?: return null + val selectedId = selected.id?.takeIf { it.isNotBlank() } + + val outputs = mutableListOf() + for (nextIndex in index + 1 until history.size) { + when (val next = history[nextIndex]) { + is PromptMessage.Tool.Result -> { + val resultId = next.id?.takeIf { it.isNotBlank() } + if (selectedId != null) { + if (resultId == selectedId) { + val output = next.content.trim() + if (output.isNotBlank()) outputs.add(output) + } + } else { + val output = next.content.trim() + if (output.isNotBlank()) return output + } + } + + is PromptMessage.User, + is PromptMessage.Assistant, + is PromptMessage.Reasoning -> break + + else -> Unit + } + } + if (outputs.isNotEmpty()) { + return outputs.joinToString(separator = "\n\n") + } + } + + return null + } + +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineDialogBuilder.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineDialogBuilder.kt new file mode 100644 index 00000000..8dd56a88 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineDialogBuilder.kt @@ -0,0 +1,479 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import com.intellij.icons.AllIcons +import com.intellij.openapi.actionSystem.ActionManager +import com.intellij.openapi.actionSystem.ActionUpdateThread +import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.actionSystem.DefaultActionGroup +import com.intellij.openapi.application.runInEdt +import com.intellij.openapi.project.DumbAwareAction +import com.intellij.ui.CheckboxTree +import com.intellij.ui.CheckedTreeNode +import com.intellij.ui.SimpleTextAttributes +import com.intellij.ui.components.JBLabel +import com.intellij.ui.components.JBScrollPane +import com.intellij.util.ui.JBUI +import java.awt.BorderLayout +import java.awt.Dimension +import java.awt.event.ActionEvent +import java.awt.event.KeyAdapter +import java.awt.event.KeyEvent +import java.awt.event.MouseAdapter +import java.awt.event.MouseEvent +import javax.swing.AbstractAction +import javax.swing.KeyStroke +import javax.swing.JPanel +import javax.swing.JTree +import javax.swing.ScrollPaneConstants +import javax.swing.SwingUtilities +import javax.swing.tree.DefaultMutableTreeNode +import javax.swing.tree.DefaultTreeModel + +internal object AgentSessionTimelineDialogBuilder { + + fun build( + points: List, + onSelect: (RunTimelinePoint) -> Unit, + reloadPoints: (suspend () -> List)?, + contextSelectionEnabled: Boolean, + onNonCopyMenuAction: (() -> Unit)?, + onSelectionStateChanged: (() -> Unit)?, + timelinePointDisplayLabel: (RunTimelinePoint) -> String, + canRollbackTimelinePoint: (RunTimelinePoint) -> Boolean, + rollbackToTimelinePoint: (RunTimelinePoint, String) -> Unit, + canCopyTimelinePointOutput: (RunTimelinePoint) -> Boolean, + copyTimelinePointOutput: (RunTimelinePoint, String) -> Unit, + launchBackground: (suspend () -> Unit) -> Unit + ): TimelineDialogUi { + fun runNodeKey(runNumber: Int): String = "run:$runNumber" + fun checkpointNodeKey(point: RunTimelinePoint): String = "cp:${point.cacheKey}" + + fun buildTreeRoot( + dataPoints: List, + checkedNodeState: Map, + withCheckboxes: Boolean + ): CheckedTreeNode { + val groups = buildTimelineRunGroups(dataPoints) + val root = CheckedTreeNode(TimelineTreeEntry("Session")) + groups.forEach { group -> + val toolCallSuffix = if (group.toolCallCount == 1) "tool call" else "tool calls" + val runEntry = TimelineTreeEntry( + title = "Run ${group.runNumber}", + subtitle = "${group.toolCallCount} $toolCallSuffix", + icon = AllIcons.Vcs.History, + runNumber = group.runNumber + ) + val runNode: DefaultMutableTreeNode = if (withCheckboxes) { + CheckedTreeNode(runEntry).apply { + isChecked = checkedNodeState[runNodeKey(group.runNumber)] ?: true + } + } else { + DefaultMutableTreeNode(runEntry) + } + group.points.forEach { point -> + val pointEntry = TimelineTreeEntry( + title = point.title, + subtitle = point.subtitle.trim(), + icon = point.icon, + point = point + ) + val pointNode: DefaultMutableTreeNode = if (withCheckboxes) { + val parentChecked = (runNode as? CheckedTreeNode)?.isChecked ?: true + CheckedTreeNode(pointEntry).apply { + isChecked = checkedNodeState[checkpointNodeKey(point)] ?: parentChecked + } + } else { + DefaultMutableTreeNode(pointEntry) + } + runNode.add(pointNode) + } + root.add(runNode) + } + return root + } + + fun collectCheckedNonSystemMessageCounts(root: DefaultMutableTreeNode): Set { + val checkedNonSystemMessageCounts = mutableSetOf() + for (index in 0 until root.childCount) { + val runNode = root.getChildAt(index) as? DefaultMutableTreeNode ?: continue + for (childIndex in 0 until runNode.childCount) { + val childNode = + runNode.getChildAt(childIndex) as? DefaultMutableTreeNode ?: continue + val childEntry = childNode.userObject as? TimelineTreeEntry ?: continue + val point = childEntry.point ?: continue + val nonSystemMessageCount = point.nonSystemMessageCount ?: continue + if (childNode is CheckedTreeNode && childNode.isChecked) { + checkedNonSystemMessageCounts.add(nonSystemMessageCount) + } + } + } + return checkedNonSystemMessageCounts + } + + fun captureCheckedRunState(root: DefaultMutableTreeNode): Map { + val state = mutableMapOf() + for (index in 0 until root.childCount) { + val runNode = root.getChildAt(index) as? DefaultMutableTreeNode ?: continue + val runEntry = runNode.userObject as? TimelineTreeEntry + val runNumber = runEntry?.runNumber + if (runNode is CheckedTreeNode && runNumber != null) { + state[runNodeKey(runNumber)] = runNode.isChecked + } + + for (childIndex in 0 until runNode.childCount) { + val childNode = + runNode.getChildAt(childIndex) as? DefaultMutableTreeNode ?: continue + val childEntry = childNode.userObject as? TimelineTreeEntry ?: continue + val point = childEntry.point ?: continue + if (childNode is CheckedTreeNode) { + state[checkpointNodeKey(point)] = childNode.isChecked + } + } + } + return state + } + + fun ensureCheckedStateDefaults( + dataPoints: List, + checkedNodeState: MutableMap + ) { + if (!contextSelectionEnabled) return + buildTimelineRunGroups(dataPoints).forEach { group -> + checkedNodeState.putIfAbsent(runNodeKey(group.runNumber), true) + } + dataPoints.forEach { point -> + checkedNodeState.putIfAbsent(checkpointNodeKey(point), true) + } + } + + fun expandAllRows(tree: JTree) { + var row = 0 + while (row < tree.rowCount) { + tree.expandRow(row) + row++ + } + } + + var latestPoints = points + var editMode = false + val groups = buildTimelineRunGroups(latestPoints) + val checkedNodeState = if (contextSelectionEnabled) { + buildMap { + groups.forEach { group -> put(runNodeKey(group.runNumber), true) } + latestPoints.forEach { point -> put(checkpointNodeKey(point), true) } + }.toMutableMap() + } else { + mutableMapOf() + } + ensureCheckedStateDefaults(latestPoints, checkedNodeState) + + fun headerText(): String { + return when { + !contextSelectionEnabled -> + "Choose a point in this session timeline." + + editMode -> + "Right-click row for rollback/new/copy. Check runs or checkpoints to keep in context." + + else -> + "Right-click row for rollback/new/copy. Click Edit to modify current context." + } + } + + val header = JBLabel(headerText()).apply { + foreground = JBUI.CurrentTheme.Label.disabledForeground() + font = JBUI.Fonts.smallFont() + border = JBUI.Borders.emptyBottom(5) + } + + lateinit var tree: CheckboxTree + + fun captureCurrentCheckedState() { + if (!contextSelectionEnabled || !editMode) return + val currentRoot = tree.model.root as? DefaultMutableTreeNode ?: return + checkedNodeState.clear() + checkedNodeState.putAll(captureCheckedRunState(currentRoot)) + ensureCheckedStateDefaults(latestPoints, checkedNodeState) + } + + fun rebuildTree( + dataPoints: List, + captureCurrentSelectionState: Boolean + ) { + if (captureCurrentSelectionState) { + captureCurrentCheckedState() + } + latestPoints = dataPoints + ensureCheckedStateDefaults(latestPoints, checkedNodeState) + tree.model = DefaultTreeModel( + buildTreeRoot( + dataPoints = latestPoints, + checkedNodeState = checkedNodeState, + withCheckboxes = contextSelectionEnabled && editMode + ) + ) + expandAllRows(tree) + tree.revalidate() + tree.repaint() + if (contextSelectionEnabled && editMode) { + onSelectionStateChanged?.invoke() + } + } + + tree = object : CheckboxTree( + object : CheckboxTree.CheckboxTreeCellRenderer() { + override fun customizeRenderer( + tree: JTree, + value: Any?, + selected: Boolean, + expanded: Boolean, + leaf: Boolean, + row: Int, + hasFocus: Boolean + ) { + val node = value as? DefaultMutableTreeNode ?: return + val entry = node.userObject as? TimelineTreeEntry ?: return + textRenderer.icon = entry.icon + if (entry.point == null) { + textRenderer.append( + entry.title, + SimpleTextAttributes.REGULAR_BOLD_ATTRIBUTES + ) + toolTipText = null + if (entry.subtitle.isNotBlank()) { + textRenderer.append( + " ${entry.subtitle}", + SimpleTextAttributes.GRAYED_ATTRIBUTES + ) + } + } else { + textRenderer.append(entry.title, SimpleTextAttributes.REGULAR_ATTRIBUTES) + if (entry.subtitle.isNotBlank()) { + textRenderer.append( + " ${entry.subtitle}", + SimpleTextAttributes.GRAYED_ATTRIBUTES + ) + toolTipText = entry.subtitle + } else { + toolTipText = null + } + } + } + }, + buildTreeRoot( + dataPoints = latestPoints, + checkedNodeState = checkedNodeState, + withCheckboxes = contextSelectionEnabled && editMode + ) + ) {}.apply { + isRootVisible = false + showsRootHandles = true + rowHeight = 0 + border = JBUI.Borders.empty(2, 0) + + expandAllRows(this) + + fun selectCurrentTimelinePoint() { + val selectedNode = lastSelectedPathComponent as? DefaultMutableTreeNode ?: return + val selectedEntry = selectedNode.userObject as? TimelineTreeEntry ?: return + val point = selectedEntry.point ?: return + onSelect(point) + } + + addMouseListener(object : MouseAdapter() { + override fun mouseClicked(e: MouseEvent) { + if (e.button == MouseEvent.BUTTON1 && e.clickCount >= 2) { + selectCurrentTimelinePoint() + } + } + + override fun mousePressed(e: MouseEvent) { + maybeShowTimelineContextMenu(e) + } + + override fun mouseReleased(e: MouseEvent) { + maybeShowTimelineContextMenu(e) + if (contextSelectionEnabled && editMode && !e.isPopupTrigger && e.button == MouseEvent.BUTTON1) { + SwingUtilities.invokeLater { + onSelectionStateChanged?.invoke() + } + } + } + + private fun maybeShowTimelineContextMenu(e: MouseEvent) { + if (!e.isPopupTrigger && e.button != MouseEvent.BUTTON3) return + + val path = getPathForLocation(e.x, e.y) ?: return + selectionPath = path + val selectedNode = path.lastPathComponent as? DefaultMutableTreeNode ?: return + val selectedEntry = selectedNode.userObject as? TimelineTreeEntry ?: return + val point = selectedEntry.point + val selectedLabel = + point?.let { timelinePointDisplayLabel(it) } ?: selectedEntry.title + + val group = DefaultActionGroup( + object : DumbAwareAction( + "Rollback", + "Rollback to checkpoint", + AllIcons.Actions.Undo + ) { + override fun actionPerformed(event: AnActionEvent) { + val timelinePoint = point ?: return + onNonCopyMenuAction?.invoke() + rollbackToTimelinePoint(timelinePoint, selectedLabel) + } + + override fun update(event: AnActionEvent) { + event.presentation.isEnabled = + point != null && canRollbackTimelinePoint(point) + } + + override fun getActionUpdateThread(): ActionUpdateThread = + ActionUpdateThread.EDT + }, + object : DumbAwareAction( + "Continue From New Session", + "Start new session from given checkpoint", + AllIcons.General.Add + ) { + override fun actionPerformed(event: AnActionEvent) { + val timelinePoint = point ?: return + onNonCopyMenuAction?.invoke() + onSelect(timelinePoint) + } + + override fun update(event: AnActionEvent) { + event.presentation.isEnabled = point != null + } + + override fun getActionUpdateThread(): ActionUpdateThread = + ActionUpdateThread.EDT + }, + object : DumbAwareAction( + "Copy Output", + "Copies the Tool or Assistant's output", + AllIcons.General.Copy + ) { + override fun actionPerformed(event: AnActionEvent) { + val timelinePoint = point ?: return + copyTimelinePointOutput(timelinePoint, selectedLabel) + } + + override fun update(event: AnActionEvent) { + event.presentation.isEnabled = + point != null && canCopyTimelinePointOutput(point) + } + + override fun getActionUpdateThread(): ActionUpdateThread = + ActionUpdateThread.EDT + } + ) + ActionManager.getInstance() + .createActionPopupMenu("AgentTimeline.CheckpointMenu", group) + .component + .show(e.component, e.x, e.y) + } + }) + + addKeyListener(object : KeyAdapter() { + override fun keyReleased(e: KeyEvent) { + if (contextSelectionEnabled && editMode && e.keyCode == KeyEvent.VK_SPACE) { + onSelectionStateChanged?.invoke() + } + } + }) + + inputMap.put(KeyStroke.getKeyStroke("ENTER"), "timeline.select") + actionMap.put("timeline.select", object : AbstractAction() { + override fun actionPerformed(e: ActionEvent?) { + selectCurrentTimelinePoint() + } + }) + } + + val scrollPane = JBScrollPane(tree).apply { + border = JBUI.Borders.empty() + horizontalScrollBarPolicy = ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER + verticalScrollBar.unitIncrement = 16 + } + + SwingUtilities.invokeLater { + if (tree.rowCount > 0) { + tree.scrollRowToVisible(tree.rowCount - 1) + scrollPane.verticalScrollBar.value = scrollPane.verticalScrollBar.maximum + } + } + + val popupWidth = JBUI.scale(620) + val minPopupWidth = JBUI.scale(560) + val maxPopupWidth = JBUI.scale(760) + val popupHeight = + JBUI.scale(computeTimelineDialogHeight(groups.sumOf { it.points.size + 1 })) + + val container = JPanel(BorderLayout()).apply { + isOpaque = false + border = JBUI.Borders.empty() + preferredSize = Dimension(popupWidth, popupHeight) + minimumSize = Dimension(minPopupWidth, popupHeight) + maximumSize = Dimension(maxPopupWidth, popupHeight) + add(header, BorderLayout.NORTH) + add(scrollPane, BorderLayout.CENTER) + } + + return TimelineDialogUi( + component = container, + getSelectedPoint = { + val node = tree.lastSelectedPathComponent as? DefaultMutableTreeNode + val entry = node?.userObject as? TimelineTreeEntry + entry?.point + }, + getCheckedNonSystemMessageCounts = { + val root = tree.model.root as? DefaultMutableTreeNode + if (root == null) emptySet() else collectCheckedNonSystemMessageCounts(root) + }, + refresh = { + reloadPoints?.let { loader -> + launchBackground { + val refreshedPoints = loader() + runInEdt { + rebuildTree(refreshedPoints, captureCurrentSelectionState = true) + } + } + } + }, + setEditMode = { value -> + if (contextSelectionEnabled && editMode != value) { + if (editMode) captureCurrentCheckedState() + editMode = value + header.text = headerText() + rebuildTree(latestPoints, captureCurrentSelectionState = false) + } + }, + isEditMode = { editMode } + ) + } + + private fun buildTimelineRunGroups(points: List): List { + if (points.isEmpty()) return emptyList() + + val grouped = linkedMapOf>() + points.forEach { point -> + val key = if (point.runIndex <= 0) 1 else point.runIndex + grouped.getOrPut(key) { mutableListOf() }.add(point) + } + + return grouped.values.mapIndexed { index, runPoints -> + TimelineRunGroup( + runNumber = index + 1, + points = runPoints, + toolCallCount = runPoints.count { it.kind == TimelinePointKind.TOOL_CALL } + ) + } + } + + private fun computeTimelineDialogHeight(estimatedRowCount: Int): Int { + val visibleRows = estimatedRowCount.coerceIn(4, 16) + val rowsHeight = visibleRows * 26 + return (rowsHeight + 44).coerceIn(160, 460) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineHistoricalRollbackSupport.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineHistoricalRollbackSupport.kt new file mode 100644 index 00000000..45385295 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineHistoricalRollbackSupport.kt @@ -0,0 +1,328 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import ai.koog.agents.snapshot.feature.AgentCheckpointData +import ai.koog.prompt.message.Message as PromptMessage +import com.intellij.openapi.application.runInEdt +import com.intellij.openapi.application.runReadAction +import com.intellij.openapi.application.runWriteAction +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.LocalFileSystem +import ee.carlrobert.codegpt.agent.ToolName +import ee.carlrobert.codegpt.agent.ToolSpecs +import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService +import ee.carlrobert.codegpt.agent.tools.EditTool +import ee.carlrobert.codegpt.agent.tools.ReadTool +import ee.carlrobert.codegpt.agent.tools.WriteTool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.booleanOrNull +import kotlinx.serialization.json.jsonObject +import java.io.File +import java.nio.file.Paths +import java.util.ArrayDeque + +internal class AgentSessionTimelineHistoricalRollbackSupport( + private val project: Project, + private val historyService: AgentCheckpointHistoryService, + private val replayJson: Json +) { + + suspend fun collectOperations(point: RunTimelinePoint): List { + val checkpointRef = point.checkpointRef ?: return emptyList() + val latestCheckpoint: AgentCheckpointData = + historyService.listCheckpoints(checkpointRef.agentId).firstOrNull() + ?: historyService.loadLatestResumeCheckpoint(checkpointRef.agentId) + ?: historyService.loadCheckpoint(checkpointRef) + ?: return emptyList() + val cutoff = point.nonSystemMessageCount ?: return emptyList() + val history = latestCheckpoint.messageHistory.filterNot { it is PromptMessage.System } + if (history.isEmpty() || cutoff >= history.size) return emptyList() + + data class PendingCall(val index: Int, val call: PromptMessage.Tool.Call) + + val pendingById = mutableMapOf() + val pendingWithoutId = ArrayDeque() + val latestKnownContentByFile = mutableMapOf() + val operations = mutableListOf() + + history.forEachIndexed { index, message -> + when (message) { + is PromptMessage.Tool.Call -> { + val callId = message.id?.takeIf { it.isNotBlank() } + if (callId != null) { + pendingById[callId] = PendingCall(index, message) + } else { + pendingWithoutId.addLast(PendingCall(index, message)) + } + } + + is PromptMessage.Tool.Result -> { + val pendingCall = message.id + ?.takeIf { it.isNotBlank() } + ?.let { pendingById.remove(it) } + ?: pendingWithoutId.pollFirst() + ?: return@forEachIndexed + val call = pendingCall.call + + val tool = HistoricalRollbackCompatibility.resolveSupportedTool(call.tool) + ?: return@forEachIndexed + if (!HistoricalRollbackCompatibility.isSuccessfulResult(tool, message.content, replayJson)) { + return@forEachIndexed + } + + when (tool) { + ToolName.READ -> { + val args = decodeReadArgs(call.tool, call.content) ?: return@forEachIndexed + val filePath = normalizeToolFilePath(args.filePath) + ?: return@forEachIndexed + val content = + decodeReadToolResultContent(message.content) ?: return@forEachIndexed + latestKnownContentByFile[filePath] = content + return@forEachIndexed + } + + ToolName.EDIT -> { + val args = decodeEditArgs(call.tool, call.content) ?: return@forEachIndexed + val filePath = normalizeToolFilePath(args.filePath) + ?: return@forEachIndexed + val oldString = args.oldString + val newString = args.newString + val replaceAll = args.replaceAll + if (oldString.isEmpty() || newString.isEmpty() || oldString == newString) return@forEachIndexed + + if (pendingCall.index >= cutoff) { + operations.add( + HistoricalRollbackOperation( + filePath = filePath, + searchText = newString, + replacementText = oldString, + replaceAll = replaceAll, + sourceTool = HistoricalRollbackSourceTool.EDIT + ) + ) + } + + latestKnownContentByFile[filePath]?.let { known -> + if (known.contains(oldString)) { + latestKnownContentByFile[filePath] = + if (replaceAll) known.replace(oldString, newString) + else known.replaceFirst(oldString, newString) + } + } + return@forEachIndexed + } + + ToolName.WRITE -> { + val args = decodeWriteArgs(call.tool, call.content) ?: return@forEachIndexed + val filePath = normalizeToolFilePath(args.filePath) + ?: return@forEachIndexed + val newContent = args.content + val previousContent = latestKnownContentByFile[filePath] + if (pendingCall.index >= cutoff && previousContent != null && previousContent != newContent) { + operations.add( + HistoricalRollbackOperation( + filePath = filePath, + searchText = newContent, + replacementText = previousContent, + replaceAll = false, + sourceTool = HistoricalRollbackSourceTool.WRITE + ) + ) + } + latestKnownContentByFile[filePath] = newContent + } + + else -> Unit + } + } + + else -> Unit + } + } + + return operations + } + + fun buildConfirmationText( + selectedLabel: String, + operations: List + ): String { + val ordered = operations.asReversed() + val previewLimit = 12 + val listed = ordered.take(previewLimit).joinToString(separator = "\n") { operation -> + "${operation.sourceTool.symbol} ${toProjectRelativePath(operation.filePath)}" + } + val remaining = ordered.size - previewLimit + val suffix = if (remaining > 0) "\n...and $remaining more operation(s)." else "" + return """ + This rollback will return the session to: + $selectedLabel + + It will also replay ${operations.size} file operation(s) in reverse order: + + $listed$suffix + """.trimIndent() + } + + fun applyOperations(operations: List): List { + val errors = mutableListOf() + operations.asReversed().forEach { operation -> + val filePath = operation.filePath + val currentText = readFileText(filePath) + if (currentText == null) { + errors.add("File not found: ${toProjectRelativePath(filePath)}") + return@forEach + } + if (!currentText.contains(operation.searchText)) { + errors.add( + "Expected content not found in ${toProjectRelativePath(filePath)} for ${operation.sourceTool.displayName} rollback." + ) + return@forEach + } + + val updatedText = if (operation.replaceAll) { + currentText.replace(operation.searchText, operation.replacementText) + } else { + currentText.replaceFirst(operation.searchText, operation.replacementText) + } + if (updatedText == currentText) { + errors.add("No changes applied to ${toProjectRelativePath(filePath)}") + return@forEach + } + + val writeOk = writeFileText(filePath, updatedText) + if (!writeOk) { + errors.add("Failed to write ${toProjectRelativePath(filePath)}") + } + } + return errors + } + + private fun parseToolArgs(rawArgs: String): Map? { + return runCatching { replayJson.parseToJsonElement(rawArgs).jsonObject }.getOrNull() + } + + private fun decodeReadArgs(rawToolName: String, rawArgs: String): ReadTool.Args? { + val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? ReadTool.Args + if (typed != null) return typed + + val args = parseToolArgs(rawArgs) ?: return null + val filePath = stringValue(args["file_path"]) + ?: stringValue(args["path"]) + ?: stringValue(args["pathInProject"]) + ?: return null + return ReadTool.Args(filePath = filePath) + } + + private fun decodeEditArgs(rawToolName: String, rawArgs: String): EditTool.Args? { + val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? EditTool.Args + if (typed != null) return typed + + val args = parseToolArgs(rawArgs) ?: return null + val filePath = stringValue(args["file_path"]) ?: return null + val oldString = stringValue(args["old_string"]) ?: return null + val newString = stringValue(args["new_string"]) ?: return null + val shortDescription = stringValue(args["short_description"]) ?: "Recovered historical edit" + val replaceAll = booleanValue(args["replace_all"]) ?: false + + return EditTool.Args( + filePath = filePath, + oldString = oldString, + newString = newString, + shortDescription = shortDescription, + replaceAll = replaceAll + ) + } + + private fun decodeWriteArgs(rawToolName: String, rawArgs: String): WriteTool.Args? { + val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? WriteTool.Args + if (typed != null) return typed + + val args = parseToolArgs(rawArgs) ?: return null + val filePath = stringValue(args["file_path"]) ?: return null + val content = stringValue(args["content"]) ?: return null + return WriteTool.Args(filePath = filePath, content = content) + } + + private fun booleanValue(element: JsonElement?): Boolean? { + val primitive = element as? JsonPrimitive ?: return null + return if (primitive.isString) primitive.content.toBooleanStrictOrNull() else primitive.booleanOrNull + } + + private fun stringValue(element: JsonElement?): String? { + if (element == null) return null + val primitive = element as? JsonPrimitive + return if (primitive != null && primitive.isString) primitive.content else element.toString() + } + + private fun normalizeToolFilePath(rawPath: String?): String? { + val trimmed = rawPath?.trim()?.takeIf { it.isNotEmpty() } ?: return null + val normalized = trimmed.replace("\\", "/") + val file = File(normalized) + if (file.isAbsolute) { + return file.toPath().normalize().toString().replace("\\", "/") + } + + val basePath = project.basePath ?: return file.absolutePath.replace("\\", "/") + return Paths.get(basePath).resolve(normalized).normalize().toString().replace("\\", "/") + } + + private fun decodeReadToolResultContent(content: String): String? { + if (content.isBlank()) return "" + + val numberedLines = content.lineSequence().mapNotNull { line -> + val tabIndex = line.indexOf('\t') + if (tabIndex <= 0) return@mapNotNull null + val prefix = line.substring(0, tabIndex) + if (!prefix.all { it.isDigit() }) return@mapNotNull null + line.substring(tabIndex + 1) + }.toList() + if (numberedLines.isNotEmpty()) return numberedLines.joinToString(separator = "\n") + + if (content.startsWith("Error reading file", ignoreCase = true)) return null + return content + } + + private fun readFileText(path: String): String? { + val virtualFile = + LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return null + val documentText = runReadAction { + FileDocumentManager.getInstance().getDocument(virtualFile)?.text + } + if (documentText != null) return documentText + return runCatching { String(virtualFile.contentsToByteArray(), Charsets.UTF_8) }.getOrNull() + } + + private fun writeFileText(path: String, content: String): Boolean { + val virtualFile = + LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return false + val document = runReadAction { FileDocumentManager.getInstance().getDocument(virtualFile) } + return runCatching { + runInEdt { + runWriteAction { + if (document != null) { + document.setText(content) + FileDocumentManager.getInstance().saveDocument(document) + } else { + virtualFile.setBinaryContent(content.toByteArray(Charsets.UTF_8)) + } + } + } + }.isSuccess + } + + private fun toProjectRelativePath(path: String): String { + val basePath = project.basePath ?: return path + return runCatching { + val absolute = Paths.get(path).normalize() + val base = Paths.get(basePath).normalize() + if (absolute.startsWith(base)) { + base.relativize(absolute).toString().replace("\\", "/") + } else { + path + } + }.getOrDefault(path) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineModels.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineModels.kt new file mode 100644 index 00000000..05d41fb8 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentSessionTimelineModels.kt @@ -0,0 +1,70 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import ee.carlrobert.codegpt.agent.history.CheckpointRef +import javax.swing.Icon +import javax.swing.JComponent +import ai.koog.prompt.message.Message as PromptMessage + +internal data class RunTimelinePoint( + val cacheKey: String, + val checkpointRef: CheckpointRef?, + val title: String, + val subtitle: String, + val runLabel: String? = null, + val icon: Icon? = null, + val nonSystemMessageCount: Int? = null, + val outputText: String? = null, + val toolCallId: String? = null, + val kind: TimelinePointKind = TimelinePointKind.ASSISTANT, + val runIndex: Int = 0 +) + +internal enum class TimelinePointKind { + USER, + ASSISTANT, + REASONING, + TOOL_CALL +} + +internal data class TimelineRunGroup( + val runNumber: Int, + val points: List, + val toolCallCount: Int +) + +internal data class TimelineTreeEntry( + val title: String, + val subtitle: String = "", + val icon: Icon? = null, + val point: RunTimelinePoint? = null, + val runNumber: Int? = null +) + +internal data class HistoricalRollbackOperation( + val filePath: String, + val searchText: String, + val replacementText: String, + val replaceAll: Boolean, + val sourceTool: HistoricalRollbackSourceTool +) + +internal enum class HistoricalRollbackSourceTool( + val displayName: String, + val symbol: String +) { + EDIT(displayName = "Edit", symbol = "E"), + WRITE(displayName = "Write", symbol = "W") +} + +internal data class SessionContextSnapshot( + val baseHistory: List +) + +internal data class TimelineDialogUi( + val component: JComponent, + val getSelectedPoint: () -> RunTimelinePoint?, + val getCheckedNonSystemMessageCounts: () -> Set, + val refresh: () -> Unit, + val setEditMode: (Boolean) -> Unit, + val isEditMode: () -> Boolean +) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt index caa9707c..1e375b66 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt @@ -2,21 +2,23 @@ package ee.carlrobert.codegpt.toolwindow.agent import com.intellij.openapi.Disposable import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.application.EDT import com.intellij.openapi.application.runInEdt import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import com.intellij.openapi.util.Disposer +import com.intellij.openapi.vfs.VirtualFileManager import com.intellij.ui.AnimatedIcon import com.intellij.ui.JBColor import com.intellij.ui.components.JBLabel import com.intellij.util.ui.JBUI import com.intellij.util.ui.components.BorderLayoutPanel import ee.carlrobert.codegpt.CodeGPTBundle -import ee.carlrobert.codegpt.agent.AgentService -import ee.carlrobert.codegpt.agent.AgentToolOutputNotifier -import ee.carlrobert.codegpt.agent.MessageWithContext -import ee.carlrobert.codegpt.agent.ToolSpecs -import ee.carlrobert.codegpt.agent.ToolRunContext +import ee.carlrobert.codegpt.agent.* +import ee.carlrobert.codegpt.agent.ProxyAIAgent.loadProjectInstructions +import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService +import ee.carlrobert.codegpt.agent.history.AgentCheckpointTurnSequencer +import ee.carlrobert.codegpt.agent.history.CheckpointRef import ee.carlrobert.codegpt.agent.rollback.RollbackService import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.message.Message @@ -42,9 +44,12 @@ import ee.carlrobert.codegpt.ui.queue.QueuedMessagePanel import ee.carlrobert.codegpt.ui.textarea.UserInputPanel import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager import ee.carlrobert.codegpt.util.EditorUtil +import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks import ee.carlrobert.codegpt.util.coroutines.CoroutineDispatchers -import kotlinx.coroutines.launch +import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope +import kotlinx.coroutines.* import kotlinx.serialization.json.Json +import java.util.* import javax.swing.Box import javax.swing.BoxLayout import javax.swing.JComponent @@ -54,6 +59,10 @@ class AgentToolWindowTabPanel( private val project: Project, private val agentSession: AgentSession ) : BorderLayoutPanel(), Disposable { + companion object { + private const val RECOVERED_CONVERSATION_RENDER_BATCH_SIZE = 6 + } + private val replayJson = Json { ignoreUnknownKeys = true isLenient = true @@ -63,6 +72,7 @@ class AgentToolWindowTabPanel( private val scrollablePanel = ChatToolWindowScrollablePanel() private val tagManager = TagManager() private val dispatchers = CoroutineDispatchers() + private val backgroundScope = DisposableCoroutineScope(dispatchers.io()) private val sessionId = agentSession.sessionId private val conversation = agentSession.conversation private val psiRepository = PsiStructureRepository( @@ -107,18 +117,43 @@ class AgentToolWindowTabPanel( this, FeatureType.AGENT, tagManager, - onSubmit = { text -> handleSubmit(text) }, - onStop = { handleCancel() }, + onSubmit = ::handleSubmit, + onStop = ::handleCancel, withRemovableSelectedEditorTag = true, agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId), sessionIdProvider = { sessionId }, - conversationIdProvider = { conversation.id } + conversationIdProvider = { conversation.id }, + onStartSessionTimeline = ::showSessionStartTimelineDialog ) private var rollbackPanel: RollbackPanel private val todoListPanel = TodoListPanel() private val projectMessageBusConnection = project.messageBus.connect() private val appMessageBusConnection = ApplicationManager.getApplication().messageBus.connect() private val rollbackService = RollbackService.getInstance(project) + private val historyService = project.service() + + private data class RunCardState( + val runMessageId: UUID, + val rollbackRunId: String, + var responsePanel: ResponseMessagePanel, + var sourceMessage: Message? = null, + var completed: Boolean = false + ) + + // Insertion order matters: timeline run numbering depends on the message order. + private val runCardsByMessageId = linkedMapOf() + private var activeRunMessageId: UUID? = null + private var recoveredConversationJob: Job? = null + private var activeLandingPanel: AgentToolWindowLandingPanel? = null + + private val timelineController = AgentSessionTimelineController( + project = project, + agentSession = agentSession, + conversation = conversation, + runStateForRunIndex = ::runStateForRunIndex, + applySeededSessionState = ::applySeededSessionState, + onAfterRollbackRefresh = ::refreshViewAfterRollback + ) private val eventHandler = AgentEventHandler( project = project, @@ -140,6 +175,12 @@ class AgentToolWindowTabPanel( repaint() rollbackPanel.refreshOperations() }, + onRunFinishedCallback = { + markActiveRunCompleted() + }, + onRunCheckpointUpdatedCallback = { runMessageId, ref -> + updateRunCheckpoint(runMessageId, ref) + }, onQueuedMessagesResolved = { message -> runInEdt { clearQueuedMessagesAndCreateNewResponse( @@ -163,12 +204,15 @@ class AgentToolWindowTabPanel( } userInputPanel.setStopEnabled(false) + Disposer.register(this, rollbackPanel) Disposer.register(this, eventHandler) + Disposer.register(this, timelineController) + Disposer.register(this, backgroundScope) } private fun setupMessageBusSubscriptions() { project.service().queuedMessageProcessed.let { flow -> - kotlinx.coroutines.CoroutineScope(dispatchers.io()).launch { + backgroundScope.launch { flow.collect { processedMessage -> ApplicationManager.getApplication().invokeLater { removeQueuedMessage(processedMessage) @@ -230,6 +274,7 @@ class AgentToolWindowTabPanel( private fun handleSubmit(text: String) { if (text.isBlank()) return + disposeLandingPanelIfPresent() scrollablePanel.clearLandingViewIfVisible() agentSession.serviceType = ModelSelectionService.getInstance().getServiceForFeature(FeatureType.AGENT) @@ -255,7 +300,7 @@ class AgentToolWindowTabPanel( .setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.RUNNING) } - rollbackService.startSession(sessionId) + val rollbackRunId = rollbackService.startSession(sessionId) rollbackPanel.refreshOperations() val message = MessageWithContext(text, userInputPanel.getSelectedTags()) @@ -282,8 +327,16 @@ class AgentToolWindowTabPanel( messagePanel.add(responsePanel) scrollablePanel.update() + registerRunCard( + runMessageId = message.id, + rollbackRunId = rollbackRunId, + responsePanel = responsePanel, + prompt = text + ) + eventHandler.resetForNewSubmission() eventHandler.setCurrentResponseBody(responseBody) + eventHandler.setCurrentRollbackRunId(rollbackRunId) loadingLabel.text = CodeGPTBundle.get("toolwindow.chat.loading") loadingLabel.isVisible = true @@ -300,7 +353,15 @@ class AgentToolWindowTabPanel( agentService.cancelCurrentRun(sessionId) agentService.clearPendingMessages(sessionId) - rollbackService.finishSession(sessionId) + val activeRun = activeRunMessageId?.let { runCardsByMessageId[it] } + if (activeRun != null) { + rollbackService.finishRun(activeRun.rollbackRunId) + runCardsByMessageId.remove(activeRun.runMessageId) + activeRunMessageId = null + } else { + rollbackService.finishSession(sessionId) + } + eventHandler.setCurrentRollbackRunId(null) rollbackPanel.refreshOperations() approvalContainer.removeAll() @@ -316,37 +377,182 @@ class AgentToolWindowTabPanel( } private fun displayLandingView() { - scrollablePanel.displayLandingView(createLandingView()) + disposeLandingPanelIfPresent() + val landingPanel = createLandingView() + activeLandingPanel = landingPanel + scrollablePanel.displayLandingView(landingPanel) } private fun displayRecoveredConversation() { + disposeLandingPanelIfPresent() scrollablePanel.clearAll() - conversation.messages.forEach { message -> - val prompt = message.prompt.orEmpty() - val wrapper = scrollablePanel.addMessage(message.id) - val userPanel = UserMessagePanel(project, message, this) - userPanel.addCopyAction { CopyAction.copyToClipboard(prompt) } - wrapper.add(userPanel) + recoveredConversationJob?.cancel() + recoveredConversationJob = backgroundScope.launch { + val recoveredTurns = + runCatching { loadRecoveredTurnsFromResumeCheckpoint() }.getOrNull() + renderRecoveredConversation(recoveredTurns) + } + } - val responseBody = ChatMessageResponseBody( - project, - false, - false, - false, - false, - false, - this - ) - addRecoveredToolCards(responseBody, message) - responseBody.withResponse(message.response.orEmpty()) - val responsePanel = ResponseMessagePanel().apply { - setResponseContent(responseBody) + private suspend fun renderRecoveredConversation( + recoveredTurns: List? + ) { + withContext(Dispatchers.EDT) { + if (!isActive || project.isDisposed) return@withContext + + val canRenderInOrder = recoveredTurns != null && + recoveredTurns.size == conversation.messages.size && + recoveredTurns.indices.all { index -> + recoveredTurns[index].prompt == conversation.messages[index].prompt.orEmpty() + .trim() + } + + val messages = conversation.messages.toList() + var nextIndex = 0 + while (nextIndex < messages.size) { + if (!isActive || project.isDisposed) return@withContext + + val batchEnd = minOf( + nextIndex + RECOVERED_CONVERSATION_RENDER_BATCH_SIZE, + messages.size + ) + + while (nextIndex < batchEnd) { + if (!isActive || project.isDisposed) return@withContext + + val index = nextIndex + val message = messages[index] + val prompt = message.prompt.orEmpty() + val wrapper = scrollablePanel.addMessage(message.id) + val userPanel = UserMessagePanel(project, message, this@AgentToolWindowTabPanel) + userPanel.addCopyAction { CopyAction.copyToClipboard(prompt) } + wrapper.add(userPanel) + + val responseBody = ChatMessageResponseBody( + project, + false, + false, + false, + false, + false, + this@AgentToolWindowTabPanel + ) + + val renderedInOrder = if (canRenderInOrder) { + renderRecoveredTurnInOrder(responseBody, recoveredTurns[index].events) + } else { + false + } + + if (!renderedInOrder) { + addRecoveredToolCards(responseBody, message) + responseBody.withResponse(message.response.orEmpty().stripThinkingBlocks()) + } + + val responsePanel = ResponseMessagePanel().apply { + setResponseContent(responseBody) + } + wrapper.add(responsePanel) + registerRecoveredRunCard(message, responsePanel) + nextIndex += 1 + } + + scrollablePanel.update() + if (nextIndex < messages.size) { + yield() + } } - wrapper.add(responsePanel) + + scrollablePanel.scrollToBottom() + } + } + + private suspend fun loadRecoveredTurnsFromResumeCheckpoint(): List? { + val resumeRef = agentSession.resumeCheckpointRef ?: return null + val checkpoint = historyService.loadCheckpoint(resumeRef) + ?: historyService.loadResumeCheckpoint(resumeRef) + ?: return null + val projectInstructions = loadProjectInstructions(project.basePath) + return AgentCheckpointTurnSequencer.toVisibleTurns( + history = checkpoint.messageHistory, + projectInstructions = projectInstructions, + preserveSyntheticContinuation = true + ) + } + + private fun renderRecoveredTurnInOrder( + responseBody: ChatMessageResponseBody, + events: List + ): Boolean { + if (events.isEmpty()) { + return false } - scrollablePanel.update() - scrollablePanel.scrollToBottom() + val pendingById = mutableMapOf() + val pendingWithoutId = ArrayDeque() + var rendered = false + + events.forEach { event -> + when (event) { + is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> { + val text = event.content.stripThinkingBlocks() + if (text.isNotBlank()) { + responseBody.withResponse(text) + rendered = true + } + } + + is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> { + val text = event.content.stripThinkingBlocks() + if (text.isNotBlank()) { + responseBody.withResponse(text) + rendered = true + } + } + + is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> { + val toolName = event.tool.ifBlank { "Tool" } + if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) { + return@forEach + } + val rawArgs = event.content + val args = parseRecoveredToolArgs(toolName, rawArgs) + val card = createRecoveredToolCard(toolName, args, rawArgs) + responseBody.addToolStatusPanel(card) + val callId = event.id?.takeIf { it.isNotBlank() } + if (callId != null) { + pendingById[callId] = card + } else { + pendingWithoutId.addLast(card) + } + rendered = true + } + + is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> { + val toolName = event.tool.ifBlank { "Tool" } + if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) { + return@forEach + } + val rawResult = event.content + val parsedResult = parseRecoveredToolResult(toolName, rawResult) + val success = inferRecoveredToolSuccess(parsedResult, rawResult) + val card = event.id + ?.takeIf { it.isNotBlank() } + ?.let { pendingById.remove(it) } + ?: pendingWithoutId.pollFirst() + ?: run { + val orphan = createRecoveredToolCard(toolName, "", "") + responseBody.addToolStatusPanel(orphan) + orphan + } + card.complete(success, parsedResult ?: rawResult) + rendered = true + } + + } + } + + return rendered } private fun addRecoveredToolCards(responseBody: ChatMessageResponseBody, message: Message) { @@ -355,7 +561,7 @@ class AgentToolWindowTabPanel( toolCalls.forEach { toolCall -> val toolName = toolCall.function.name ?: return@forEach - if (toolName == "TodoWrite") { + if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) { return@forEach } @@ -419,6 +625,11 @@ class AgentToolWindowTabPanel( return AgentToolWindowLandingPanel(project) } + private fun disposeLandingPanelIfPresent() { + activeLandingPanel?.let { Disposer.dispose(it) } + activeLandingPanel = null + } + fun getSessionId(): String = sessionId fun getAgentSession(): AgentSession = agentSession @@ -483,8 +694,19 @@ class AgentToolWindowTabPanel( responsePanel.setResponseContent(responseBody) messagePanel.add(responsePanel) + val rollbackRunId = activeRunMessageId + ?.let { runCardsByMessageId[it] } + ?.rollbackRunId + eventHandler.resetForNewSubmission() eventHandler.setCurrentResponseBody(responseBody) + eventHandler.setCurrentRollbackRunId(rollbackRunId) + + activeRunMessageId?.let { runMessageId -> + runCardsByMessageId[runMessageId]?.let { state -> + state.responsePanel = responsePanel + } + } scrollablePanel.update() } @@ -518,8 +740,118 @@ class AgentToolWindowTabPanel( } } + private fun registerRunCard( + runMessageId: UUID, + rollbackRunId: String, + responsePanel: ResponseMessagePanel, + prompt: String + ) { + val state = RunCardState( + runMessageId = runMessageId, + rollbackRunId = rollbackRunId, + responsePanel = responsePanel, + sourceMessage = Message(prompt) + ) + runCardsByMessageId[runMessageId] = state + activeRunMessageId = runMessageId + } + + private fun registerRecoveredRunCard(message: Message, responsePanel: ResponseMessagePanel) { + val copiedMessage = Message( + message.prompt.orEmpty(), + message.response.orEmpty().stripThinkingBlocks() + ).apply { + message.toolCalls?.let { toolCalls = ArrayList(it) } + message.toolCallResults?.let { toolCallResults = LinkedHashMap(it) } + } + val state = RunCardState( + runMessageId = message.id, + rollbackRunId = "", + responsePanel = responsePanel, + sourceMessage = copiedMessage, + completed = true + ) + runCardsByMessageId[message.id] = state + timelineController.invalidateTimelineCache() + } + + private fun markActiveRunCompleted() { + val runMessageId = activeRunMessageId ?: return + val state = runCardsByMessageId[runMessageId] ?: return + state.completed = true + activeRunMessageId = null + } + + private fun updateRunCheckpoint(runMessageId: UUID, ref: CheckpointRef?) { + val state = runCardsByMessageId[runMessageId] ?: return + timelineController.invalidateTimelineCache() + if (ref != null) { + state.sourceMessage = null + } + } + + private fun showSessionStartTimelineDialog() { + timelineController.showSessionStartTimelineDialog() + } + + private fun runStateForRunIndex(runIndex: Int): AgentTimelineRunState? { + if (runIndex <= 0) return null + val state = runCardsByMessageId.values.elementAtOrNull(runIndex - 1) ?: return null + return AgentTimelineRunState( + rollbackRunId = state.rollbackRunId, + sourceMessage = state.sourceMessage + ) + } + + private fun applySeededSessionState(seededConversation: Conversation, seedRef: CheckpointRef) { + conversation.messages = seededConversation.messages + runCardsByMessageId.clear() + activeRunMessageId = null + timelineController.invalidateTimelineCache() + + agentSession.runtimeAgentId = seedRef.agentId + agentSession.resumeCheckpointRef = seedRef + + val contentManager = project.service() + contentManager.setRuntimeAgentId(sessionId, seedRef.agentId) + contentManager.setResumeCheckpointRef(sessionId, seedRef) + + project.service().clearPendingMessages(sessionId) + loadingLabel.isVisible = false + clearQueuedMessages() + approvalContainer.removeAll() + approvalContainer.isVisible = false + eventHandler.resetForNewSubmission() + eventHandler.setCurrentRollbackRunId(null) + userInputPanel.setStopEnabled(false) + + if (conversation.messages.isEmpty()) { + displayLandingView() + } else { + displayRecoveredConversation() + } + + refreshViewAfterRollback() + runCatching { + contentManager.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.STOPPED) + } + } + + private fun refreshViewAfterRollback() { + timelineController.invalidateTimelineCache() + runCatching { VirtualFileManager.getInstance().asyncRefresh(null) } + rollbackPanel.refreshOperations() + scrollablePanel.update() + revalidate() + repaint() + } + override fun dispose() { + recoveredConversationJob?.cancel() + disposeLandingPanelIfPresent() ToolRunContext.cleanupSession(sessionId) + runCardsByMessageId.clear() + activeRunMessageId = null projectMessageBusConnection.disconnect() appMessageBusConnection.disconnect() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/HistoricalRollbackCompatibility.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/HistoricalRollbackCompatibility.kt new file mode 100644 index 00000000..b76196c6 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/HistoricalRollbackCompatibility.kt @@ -0,0 +1,59 @@ +package ee.carlrobert.codegpt.toolwindow.agent + +import ee.carlrobert.codegpt.agent.ToolName +import ee.carlrobert.codegpt.agent.ToolSpecs +import ee.carlrobert.codegpt.agent.tools.EditTool +import ee.carlrobert.codegpt.agent.tools.ReadTool +import ee.carlrobert.codegpt.agent.tools.WriteTool +import kotlinx.serialization.json.Json + +internal object HistoricalRollbackCompatibility { + + private val supportedTools = setOf(ToolName.READ, ToolName.EDIT, ToolName.WRITE) + + fun resolveSupportedTool(rawToolName: String): ToolName? { + val normalized = rawToolName.trim() + if (normalized.isEmpty()) return null + + val resolved = ToolSpecs.find(normalized)?.name ?: return null + return resolved.takeIf { it in supportedTools } + } + + fun isSuccessfulResult(toolName: ToolName, content: String, replayJson: Json): Boolean { + if (toolName !in supportedTools) return false + + val normalized = content.trim() + val decoded = ToolSpecs.decodeResultOrNull(replayJson, toolName.id, normalized) + if (decoded != null) { + return when (toolName) { + ToolName.READ -> decoded is ReadTool.Result.Success + ToolName.EDIT -> decoded is EditTool.Result.Success + ToolName.WRITE -> decoded is WriteTool.Result.Success + else -> false + } + } + + // TODO: Is there a better way to determine if the result is successful? + // This string comparison won't cut it + return when (toolName) { + ToolName.READ -> !normalized.startsWith(READ_ERROR_PREFIX, ignoreCase = true) + ToolName.EDIT -> + normalized.isNotBlank() && + !normalized.startsWith(EDIT_ERROR_PREFIX, ignoreCase = true) && + (normalized.contains(EDIT_SUCCESS_MARKER, ignoreCase = true) || + normalized.contains(LEGACY_EDIT_SUCCESS_MARKER, ignoreCase = true)) + ToolName.WRITE -> + normalized.isNotBlank() && + normalized.contains(WRITE_SUCCESS_MARKER, ignoreCase = true) && + !normalized.startsWith(WRITE_ERROR_PREFIX, ignoreCase = true) + else -> false + } + } + + private const val READ_ERROR_PREFIX = "Error reading file" + private const val EDIT_ERROR_PREFIX = "Error editing file" + private const val WRITE_ERROR_PREFIX = "Error writing file" + private const val EDIT_SUCCESS_MARKER = "Successfully edited file" + private const val LEGACY_EDIT_SUCCESS_MARKER = "Successfully made" + private const val WRITE_SUCCESS_MARKER = "successfully" +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentToolWindowLandingPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentToolWindowLandingPanel.kt index f2a5a18b..0fbc2119 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentToolWindowLandingPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/AgentToolWindowLandingPanel.kt @@ -1,6 +1,5 @@ package ee.carlrobert.codegpt.toolwindow.agent.ui -import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.application.runInEdt import com.intellij.openapi.command.WriteCommandAction import com.intellij.openapi.components.service @@ -21,27 +20,33 @@ import ee.carlrobert.codegpt.agent.history.AgentCheckpointConversationMapper import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService import ee.carlrobert.codegpt.agent.history.AgentHistoryThreadSummary import ee.carlrobert.codegpt.settings.GeneralSettings +import ee.carlrobert.codegpt.settings.ProxyAISettingsService +import ee.carlrobert.codegpt.settings.ProxyAISubagent import ee.carlrobert.codegpt.settings.agents.SubagentsConfigurable +import ee.carlrobert.codegpt.settings.hooks.HookConfig +import ee.carlrobert.codegpt.settings.hooks.HookConfiguration +import ee.carlrobert.codegpt.settings.skills.SkillDescriptor +import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService import ee.carlrobert.codegpt.tokens.TokenComputationService import ee.carlrobert.codegpt.toolwindow.agent.AgentToolWindowContentManager import ee.carlrobert.codegpt.toolwindow.agent.history.AgentHistoryListPanel import ee.carlrobert.codegpt.toolwindow.ui.ResponseMessagePanel import ee.carlrobert.codegpt.ui.UIUtil.createTextPane -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope +import com.intellij.openapi.Disposable import java.awt.BorderLayout import java.awt.Color import java.awt.Desktop import java.net.URI import java.nio.charset.StandardCharsets import java.nio.file.Path -import java.time.Instant -import java.time.ZoneId -import java.time.format.DateTimeFormatter import javax.swing.Box import javax.swing.BoxLayout import javax.swing.JPanel -class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessagePanel() { +class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessagePanel(), Disposable { companion object { private val logger = thisLogger() @@ -49,7 +54,10 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag private val historyService = project.service() private val historyListPanel = AgentHistoryListPanel(defaultLimit = 5) + private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO) private var refreshHistory = true + @Volatile + private var disposed = false private fun createLabel(text: String) = JBLabel(text) private fun createLink(text: String, onClick: () -> Unit) = ActionLink(text) { onClick() } @@ -166,77 +174,39 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag tokensLabel.foreground = healthColor(tokens) topRow.add(tokensLabel) - val bottomRow = JPanel() - bottomRow.layout = BoxLayout(bottomRow, BoxLayout.X_AXIS) - bottomRow.alignmentX = LEFT_ALIGNMENT + val settingsService = project.service() + val skills = project.service().listSkills() + val subagents = settingsService.getSubagents() + val hookEntries = collectHookEntries(settingsService.getHooks()) + val enabledHooksCount = hookEntries.count { it.hook.enabled } - val ts = vf.timeStamp - val now = Instant.now() - val fileTime = Instant.ofEpochMilli(ts) - val zonedNow = now.atZone(ZoneId.systemDefault()) - val zonedFileTime = fileTime.atZone(ZoneId.systemDefault()) - - val timeLabel = when { - zonedFileTime.toLocalDate() == zonedNow.toLocalDate() -> { - val timeFormat = - DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault()) - JBLabel("Modified today at ${timeFormat.format(fileTime)}") - } - - zonedFileTime.toLocalDate() == zonedNow.minusDays(1).toLocalDate() -> { - val timeFormat = - DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault()) - JBLabel("Modified yesterday at ${timeFormat.format(fileTime)}") - } - - fileTime.isAfter(now.minusSeconds(7 * 24 * 60 * 60)) -> { - val dayFormat = - DateTimeFormatter.ofPattern("EEEE").withZone(ZoneId.systemDefault()) - val timeFormat = - DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault()) - JBLabel( - "Modified on ${dayFormat.format(fileTime)} at ${ - timeFormat.format( - fileTime - ) - }" - ) - } - - zonedFileTime.toLocalDate().year == zonedNow.toLocalDate().year -> { - val dateFormat = - DateTimeFormatter.ofPattern("MMM d").withZone(ZoneId.systemDefault()) - val timeFormat = - DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault()) - JBLabel( - "Modified on ${dateFormat.format(fileTime)} at ${ - timeFormat.format( - fileTime - ) - }" - ) - } - - else -> { - val dateFormat = - DateTimeFormatter.ofPattern("MMM d, yyyy").withZone(ZoneId.systemDefault()) - val timeFormat = - DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault()) - JBLabel( - "Modified on ${dateFormat.format(fileTime)} at ${ - timeFormat.format( - fileTime - ) - }" - ) - } - } - timeLabel.foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor - bottomRow.add(timeLabel) + val detailsRow = JPanel() + detailsRow.layout = BoxLayout(detailsRow, BoxLayout.X_AXIS) + detailsRow.alignmentX = LEFT_ALIGNMENT + detailsRow.add( + createHoverDetailsLabel( + text = "Skills ${skills.size}", + tooltip = buildSkillsTooltip(skills) + ) + ) + detailsRow.add(createDetailsSeparator()) + detailsRow.add( + createHoverDetailsLabel( + text = "Hooks $enabledHooksCount", + tooltip = buildHooksTooltip(hookEntries) + ) + ) + detailsRow.add(createDetailsSeparator()) + detailsRow.add( + createHoverDetailsLabel( + text = "Subagents ${subagents.size}", + tooltip = buildSubagentsTooltip(subagents) + ) + ) container.add(topRow) container.add(Box.createVerticalStrut(2)) - container.add(bottomRow) + container.add(detailsRow) } panel.add(container) @@ -299,22 +269,23 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag limit: Int, onResult: (List, Boolean, Int) -> Unit ) { + if (disposed || project.isDisposed) return val shouldRefresh = offset == 0 && refreshHistory - ApplicationManager.getApplication().executeOnPooledThread { + backgroundScope.launch { val page = runCatching { - runBlocking { - historyService.listThreadsPage( - query = query, - offset = offset, - limit = limit, - refresh = shouldRefresh - ) - } + historyService.listThreadsPage( + query = query, + offset = offset, + limit = limit, + refresh = shouldRefresh + ) } .onFailure { logger.warn("Failed to load checkpoint history", it) } .getOrNull() + if (disposed || project.isDisposed) return@launch runInEdt { + if (disposed || project.isDisposed) return@runInEdt if (page != null) { refreshHistory = false } @@ -324,25 +295,37 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag } private fun openCheckpointThread(thread: AgentHistoryThreadSummary) { - ApplicationManager.getApplication().executeOnPooledThread { + if (disposed || project.isDisposed) return + backgroundScope.launch { + val checkpoint = runCatching { + historyService.loadCheckpoint(thread.latest) + }.onFailure { + logger.warn("Failed to open checkpoint thread ${thread.agentId}", it) + }.getOrNull() ?: return@launch + val conversation = runCatching { - val checkpoint = runBlocking { historyService.loadCheckpoint(thread.latest) } - ?: return@executeOnPooledThread AgentCheckpointConversationMapper.toConversation( checkpoint = checkpoint, projectInstructions = loadProjectInstructions(project.basePath) ) }.onFailure { logger.warn("Failed to open checkpoint thread ${thread.agentId}", it) - }.getOrNull() ?: return@executeOnPooledThread + }.getOrNull() ?: return@launch - ApplicationManager.getApplication().invokeLater { + if (disposed || project.isDisposed) return@launch + runInEdt { + if (disposed || project.isDisposed) return@runInEdt project.service() .openCheckpointConversation(thread, conversation) } } } + override fun dispose() { + disposed = true + backgroundScope.dispose() + } + private fun welcomeMessage(): String { val name = GeneralSettings.getCurrentState().displayName return """ @@ -385,4 +368,93 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag val bl = (a.blue + (b.blue - a.blue) * t).toInt() return Color(r, g, bl) } + + private fun createHoverDetailsLabel(text: String, tooltip: String): JBLabel { + return JBLabel(text).apply { + foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor + toolTipText = tooltip + } + } + + private fun createDetailsSeparator(): JBLabel { + return JBLabel(" • ").apply { + foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor + } + } + + private fun collectHookEntries(configuration: HookConfiguration): List = buildList { + addAll(configuration.beforeToolUse.map { HookEntry("Before tool", it) }) + addAll(configuration.afterToolUse.map { HookEntry("After tool", it) }) + addAll(configuration.subagentStart.map { HookEntry("Subagent start", it) }) + addAll(configuration.subagentStop.map { HookEntry("Subagent stop", it) }) + addAll(configuration.beforeShellExecution.map { HookEntry("Before shell", it) }) + addAll(configuration.afterShellExecution.map { HookEntry("After shell", it) }) + addAll(configuration.beforeReadFile.map { HookEntry("Before read", it) }) + addAll(configuration.afterFileEdit.map { HookEntry("After edit", it) }) + addAll(configuration.stop.map { HookEntry("Stop", it) }) + } + + private fun buildSkillsTooltip(skills: List): String { + if (skills.isEmpty()) { + return tooltipHtml("Skills", listOf("No skills discovered in .proxyai/skills")) + } + + val lines = skills.take(5).map { "• ${it.name} — ${it.title}" }.toMutableList() + val remaining = skills.size - lines.size + if (remaining > 0) { + lines.add("+$remaining more") + } + + return tooltipHtml("Skills (${skills.size})", lines) + } + + private fun buildHooksTooltip(hookEntries: List): String { + val enabled = hookEntries.filter { it.hook.enabled } + if (hookEntries.isEmpty()) { + return tooltipHtml("Hooks", listOf("No hooks configured")) + } + + val lines = mutableListOf() + enabled.take(4).forEach { entry -> + lines.add("• ${entry.event}: ${entry.hook.command}") + } + val remainingEnabled = enabled.size - 4 + if (remainingEnabled > 0) { + lines.add("+$remainingEnabled more enabled") + } + + return tooltipHtml("Hooks", lines) + } + + private fun buildSubagentsTooltip(subagents: List): String { + if (subagents.isEmpty()) { + return tooltipHtml("Subagents", listOf("No subagents configured")) + } + + val lines = subagents.take(5).map { "• ${it.title}" }.toMutableList() + val remaining = subagents.size - lines.size + if (remaining > 0) { + lines.add("+$remaining more") + } + + return tooltipHtml("Subagents (${subagents.size})", lines) + } + + private fun tooltipHtml(title: String, lines: List): String { + val escapedTitle = escapeHtml(title) + val escapedLines = lines.joinToString("
") { escapeHtml(it) } + return "$escapedTitle
$escapedLines" + } + + private fun escapeHtml(value: String): String { + return value + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + } + + private data class HookEntry( + val event: String, + val hook: HookConfig + ) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt index 28fa8122..1258548e 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt @@ -5,6 +5,7 @@ import com.intellij.diff.DiffManager import com.intellij.diff.requests.SimpleDiffRequest import com.intellij.icons.AllIcons import com.intellij.ide.actions.OpenFileAction +import com.intellij.openapi.Disposable import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.application.EDT @@ -23,6 +24,7 @@ import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.drawCenteredText import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.lineDiffStats import ee.carlrobert.codegpt.ui.IconActionButton import ee.carlrobert.codegpt.ui.components.LeftEllipsisLabel +import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope import kotlinx.coroutines.* import java.awt.BorderLayout import java.awt.Dimension @@ -40,15 +42,11 @@ import java.time.format.DateTimeFormatter import java.util.concurrent.ConcurrentHashMap import javax.swing.* -/** - * Panel showing file operations performed by the agent with rollback controls. - * - */ class RollbackPanel( private val project: Project, private val sessionId: String, private val onRollbackComplete: () -> Unit -) : BorderLayoutPanel() { +) : BorderLayoutPanel(), Disposable { private val rollbackService = RollbackService.getInstance(project) private val titleLabel = JBLabel() private val timeLabel = JBLabel() @@ -58,7 +56,8 @@ class RollbackPanel( private val keepAllLink = createKeepAllLink() private val diffStatsCache = ConcurrentHashMap>() private val diffDataCache = ConcurrentHashMap() - private val backgroundScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + private var lastSnapshotRunId: String? = null + private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO) init { setupUI() @@ -130,12 +129,26 @@ class RollbackPanel( .sortedBy { it.path } withContext(Dispatchers.EDT) { - refreshOperationsUI(changes, snapshot?.completedAt) + refreshOperationsUI( + changes = changes, + completedAt = snapshot?.completedAt, + snapshotRunId = snapshot?.runId + ) } } } - private fun refreshOperationsUI(changes: List, completedAt: Instant?) { + private fun refreshOperationsUI( + changes: List, + completedAt: Instant?, + snapshotRunId: String? + ) { + if (snapshotRunId != lastSnapshotRunId) { + diffStatsCache.clear() + diffDataCache.clear() + lastSnapshotRunId = snapshotRunId + } + isVisible = changes.isNotEmpty() if (changes.isEmpty()) { titleLabel.text = "Changes" @@ -168,10 +181,10 @@ class RollbackPanel( } private fun handleRollback() { - CoroutineScope(Dispatchers.Main).launch { + backgroundScope.launch { val result = rollbackService.rollbackSession(sessionId) - withContext(Dispatchers.Main) { + withContext(Dispatchers.EDT) { when (result) { is RollbackResult.Success -> { refreshOperations() @@ -424,10 +437,10 @@ class RollbackPanel( } private fun rollbackFile(path: String) { - CoroutineScope(Dispatchers.Main).launch { + backgroundScope.launch { val result = rollbackService.rollbackFile(sessionId, path) - withContext(Dispatchers.Main) { + withContext(Dispatchers.EDT) { when (result) { is RollbackResult.Success -> { refreshOperations() @@ -466,7 +479,7 @@ class RollbackPanel( backgroundScope.launch { val diffData = rollbackService.getDiffData(sessionId, path) ?: return@launch diffDataCache[path] = diffData - withContext(Dispatchers.Main) { + withContext(Dispatchers.EDT) { openDiff(diffData) } } @@ -480,4 +493,8 @@ class RollbackPanel( val request = SimpleDiffRequest(title, before, after, "Before", "After") DiffManager.getInstance().showDiff(project, request) } -} \ No newline at end of file + + override fun dispose() { + backgroundScope.dispose() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/ui/BaseMessagePanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/ui/BaseMessagePanel.kt index acea0c77..1d9e1b37 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/ui/BaseMessagePanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/ui/BaseMessagePanel.kt @@ -70,6 +70,10 @@ abstract class BaseMessagePanel : BorderLayoutPanel() { ) } + fun addHeaderAction(action: AnAction, actionCode: String) { + addIconActionButton(IconActionButton(action, actionCode)) + } + fun addContent(content: JComponent) { body.addContent(content) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/UserInputPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/UserInputPanel.kt index 70c362fe..720b1013 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/UserInputPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/UserInputPanel.kt @@ -73,6 +73,7 @@ class UserInputPanel @JvmOverloads constructor( private val agentTokenCounterPanel: JComponent? = null, private val sessionIdProvider: (() -> String?)? = null, private val conversationIdProvider: (() -> UUID?)? = null, + private val onStartSessionTimeline: (() -> Unit)? = null, ) : BorderLayoutPanel() { constructor( @@ -98,6 +99,7 @@ class UserInputPanel @JvmOverloads constructor( null, withRemovableSelectedEditorTag, null, + null, null ) @@ -220,7 +222,26 @@ class UserInputPanel @JvmOverloads constructor( } else { null } - private val promptEnhancerSeparator = if (featureType == FeatureType.AGENT) { + private val sessionTimelineButton = if (featureType == FeatureType.AGENT) { + IconActionButton( + object : AnAction( + "Timeline", + "Choose a timeline point from this session", + AllIcons.Vcs.History + ) { + override fun actionPerformed(e: AnActionEvent) { + onStartSessionTimeline?.invoke() + } + }, + "SESSION_TIMELINE" + ).apply { + isVisible = onStartSessionTimeline != null + cursor = Cursor.getPredefinedCursor(Cursor.HAND_CURSOR) + } + } else { + null + } + private val sessionTimelineSeparator = if (featureType == FeatureType.AGENT) { createActionSeparator() } else { null @@ -600,9 +621,12 @@ class UserInputPanel @JvmOverloads constructor( panel { row { if (applyChip != null) cell(applyChip).gap(RightGap.SMALL) - if (promptEnhancerButton != null && promptEnhancerSeparator != null) { + if (promptEnhancerButton != null) { cell(promptEnhancerButton).gap(RightGap.SMALL) - cell(promptEnhancerSeparator).gap(RightGap.SMALL) + } + if (sessionTimelineButton != null && sessionTimelineSeparator != null) { + cell(sessionTimelineButton).gap(RightGap.SMALL) + cell(sessionTimelineSeparator).gap(RightGap.SMALL) } cell(submitButton).gap(RightGap.SMALL) cell(stopButton) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/FolderActionItem.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/FolderActionItem.kt index 9038de57..60f96fb9 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/FolderActionItem.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/FolderActionItem.kt @@ -13,7 +13,7 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.FolderTagDetails class FolderActionItem( private val project: Project, private val folder: VirtualFile -) : AbstractLookupActionItem() { +) : AbstractLookupActionItem(), InsertsDisplayNameLookupItem { override val displayName = folder.name override val icon = AllIcons.Nodes.Folder @@ -33,4 +33,4 @@ class FolderActionItem( override fun execute(project: Project, userInputPanel: UserInputPanel) { userInputPanel.addTag(FolderTagDetails(folder)) } -} \ No newline at end of file +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/files/FileActionItem.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/files/FileActionItem.kt index dd75b260..69016245 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/files/FileActionItem.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/textarea/lookup/action/files/FileActionItem.kt @@ -10,9 +10,10 @@ import com.intellij.openapi.vfs.VirtualFile import ee.carlrobert.codegpt.ui.textarea.UserInputPanel import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails import ee.carlrobert.codegpt.ui.textarea.lookup.action.AbstractLookupActionItem +import ee.carlrobert.codegpt.ui.textarea.lookup.action.InsertsDisplayNameLookupItem class FileActionItem(private val project: Project, val file: VirtualFile) : - AbstractLookupActionItem() { + AbstractLookupActionItem(), InsertsDisplayNameLookupItem { override val displayName = file.name override val icon = file.fileType.icon ?: AllIcons.FileTypes.Any_type @@ -32,4 +33,4 @@ class FileActionItem(private val project: Project, val file: VirtualFile) : override fun execute(project: Project, userInputPanel: UserInputPanel) { userInputPanel.addTag(EditorTagDetails(file)) } -} \ No newline at end of file +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/util/StringUtil.kt b/src/main/kotlin/ee/carlrobert/codegpt/util/StringUtil.kt index 518345b6..45d4284c 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/util/StringUtil.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/util/StringUtil.kt @@ -4,6 +4,13 @@ import ai.grazie.nlp.utils.takeWhitespaces object StringUtil { + private val thinkBlockRegex = Regex("(?s).*?\\s*") + + fun String.stripThinkingBlocks(): String { + if (this.isBlank()) return "" + return thinkBlockRegex.replace(this, "").trim() + } + fun adjustWhitespace( completionLine: String, editorLine: String @@ -33,12 +40,4 @@ object StringUtil { val intersection = bigrams1.intersect(bigrams2).size return (2.0 * intersection) / (bigrams1.size + bigrams2.size) } - - fun String.extractUntilNewline(): String { - val index = this.indexOf('\n') - if (index == -1) { - return this - } - return this.substring(0, index + 1) - } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt index 94403fe8..c0147790 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/AgentServiceIntegrationTest.kt @@ -69,26 +69,26 @@ class AgentServiceIntegrationTest : IntegrationTest() { val secondMessage = MessageWithContext("Second request") submitAndAwait(sessionId, firstMessage) - val firstSnapshot = snapshot(sessionId) + val firstActualSnapshot = snapshot(sessionId) submitAndAwait(sessionId, secondMessage) - val secondSnapshot = snapshot(sessionId) + val secondActualSnapshot = snapshot(sessionId) - val runtimeAgentId = firstSnapshot.runtimeAgentId - val managedService = factory.createdServices.single() - assertThat(runtimeAgentId).isNotBlank() + val actualRuntimeAgentId = firstActualSnapshot.runtimeAgentId + val actualManagedService = factory.createdServices.single() + assertThat(actualRuntimeAgentId).isNotBlank() assertThat(factory.createdServices).hasSize(1) - assertThat(listOf(firstSnapshot, secondSnapshot)) + assertThat(listOf(firstActualSnapshot, secondActualSnapshot)) .extracting("runtimeAgentId", "resumeCheckpointRef.agentId") .containsExactly( - tuple(runtimeAgentId, runtimeAgentId), - tuple(runtimeAgentId, runtimeAgentId) + tuple(actualRuntimeAgentId, actualRuntimeAgentId), + tuple(actualRuntimeAgentId, actualRuntimeAgentId) ) - assertThat(firstSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank() - assertThat(secondSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank() - assertThat(secondSnapshot.resumeCheckpointRef?.checkpointId) - .isNotEqualTo(firstSnapshot.resumeCheckpointRef?.checkpointId) - assertThat(managedService.createdAgentIds).hasSize(2) - assertThat(managedService.createdAgentIds.last()).isEqualTo(runtimeAgentId) + assertThat(firstActualSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank() + assertThat(secondActualSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank() + assertThat(secondActualSnapshot.resumeCheckpointRef?.checkpointId) + .isNotEqualTo(firstActualSnapshot.resumeCheckpointRef?.checkpointId) + assertThat(actualManagedService.createdAgentIds).hasSize(2) + assertThat(actualManagedService.createdAgentIds.last()).isEqualTo(actualRuntimeAgentId) } fun testSubmitMessageRebuildsRuntimeWhenMcpSelectionChanges() { @@ -102,12 +102,38 @@ class AgentServiceIntegrationTest : IntegrationTest() { submitAndAwait(sessionId, withoutMcp) submitAndAwait(sessionId, withMcp) - assertThat(factory.createdServices).hasSize(2) - assertThat(factory.createdServices) + val actualCreatedServices = factory.createdServices + assertThat(actualCreatedServices).hasSize(2) + assertThat(actualCreatedServices) .extracting("closeAllCalls") .containsExactly(1, 0) } + fun testSubmitMessageEmitsRunCheckpointCallback() { + val sessionId = createSession("agent-runtime-callback") + val factory = RecordingRuntimeFactory(agentModel()) + agentService.runtimeFactory = factory + val message = MessageWithContext("Callback request") + var callbackMessageId: UUID? = null + var callbackRef: CheckpointRef? = null + val events = object : AgentEvents { + override fun onQueuedMessagesResolved() = Unit + + override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) { + callbackMessageId = runMessageId + callbackRef = ref + } + } + + agentService.submitMessage(message, events, sessionId) + awaitSessionToFinish(sessionId) + + val actualSession = contentManager.getSession(sessionId) + assertThat(callbackMessageId).isEqualTo(message.id) + assertThat(callbackRef).isNotNull + assertThat(callbackRef?.agentId).isEqualTo(actualSession?.runtimeAgentId) + } + fun testGetCheckpointFallsBackToRuntimeAgentIdWhenResumeRefIsStale() { val sessionId = createSession("agent-checkpoint-fallback") val runtimeAgentId = "runtime-agent-${UUID.randomUUID()}" @@ -133,11 +159,11 @@ class AgentServiceIntegrationTest : IntegrationTest() { contentManager.setRuntimeAgentId(sessionId, runtimeAgentId) contentManager.setResumeCheckpointRef(sessionId, staleRef) - val checkpoint = runBlocking { agentService.getCheckpoint(sessionId) } - - assertThat(checkpoint).isNotNull - assertThat(checkpoint!!.checkpointId).isEqualTo(checkpointId) - assertThat(contentManager.getSession(sessionId)!!.resumeCheckpointRef) + val actualCheckpoint = runBlocking { agentService.getCheckpoint(sessionId) } + val actualSession = contentManager.getSession(sessionId) + assertThat(actualCheckpoint).isNotNull + assertThat(actualCheckpoint!!.checkpointId).isEqualTo(checkpointId) + assertThat(actualSession!!.resumeCheckpointRef) .isEqualTo(CheckpointRef(runtimeAgentId, checkpointId)) } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/RollbackServiceTrackingTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/RollbackServiceTrackingTest.kt index a14c408f..767cb1a9 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/RollbackServiceTrackingTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/RollbackServiceTrackingTest.kt @@ -2,8 +2,9 @@ package ee.carlrobert.codegpt.agent import com.intellij.openapi.vfs.LocalFileSystem import ee.carlrobert.codegpt.agent.rollback.ChangeKind +import ee.carlrobert.codegpt.agent.rollback.RollbackResult import ee.carlrobert.codegpt.agent.rollback.RollbackService -import ee.carlrobert.codegpt.agent.tools.WriteTool +import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat import testsupport.IntegrationTest import java.io.File @@ -31,10 +32,10 @@ class RollbackServiceTrackingTest : IntegrationTest() { filePath = filePath, originalContent = "before" ) - val snapshot = rollbackService.finishSession(sessionId) + val actualSnapshot = rollbackService.finishSession(sessionId) - assertThat(snapshot!!.changes).hasSize(1) - assertThat(snapshot.changes.single()) + assertThat(actualSnapshot!!.changes).hasSize(1) + assertThat(actualSnapshot.changes.single()) .extracting("path", "kind") .containsExactly(filePath, ChangeKind.MODIFIED) } @@ -46,10 +47,10 @@ class RollbackServiceTrackingTest : IntegrationTest() { val sessionId = "s2" rollbackService.startSession(sessionId) - rollbackService.trackWrite(sessionId, filePath, WriteTool.Args(filePath, "hello")) - val snapshot = rollbackService.finishSession(sessionId) + rollbackService.trackWrite(sessionId, filePath) + val actualSnapshot = rollbackService.finishSession(sessionId) - assertThat(snapshot).isNotNull + assertThat(actualSnapshot).isNotNull .extracting { it!!.changes.single().kind } .isEqualTo(ChangeKind.ADDED) } @@ -62,14 +63,81 @@ class RollbackServiceTrackingTest : IntegrationTest() { val sessionId = "s3" rollbackService.startSession(sessionId) - rollbackService.trackWrite(sessionId, filePath, WriteTool.Args(filePath, "after")) + rollbackService.trackWrite(sessionId, filePath) file.writeText("after") LocalFileSystem.getInstance().refreshAndFindFileByPath(filePath) - val snapshot = rollbackService.finishSession(sessionId) + val actualSnapshot = rollbackService.finishSession(sessionId) + val actualDiff = rollbackService.getDiffData(sessionId, filePath) - assertThat(snapshot!!.changes.single().kind).isEqualTo(ChangeKind.MODIFIED) - assertThat(rollbackService.getDiffData(sessionId, filePath)) + assertThat(actualSnapshot!!.changes.single().kind).isEqualTo(ChangeKind.MODIFIED) + assertThat(actualDiff) .extracting("beforeText", "afterText") .containsExactly("before", "after") } -} \ No newline at end of file + + fun testSnapshotsRemainAvailablePerRunWithinSameSession() { + val rollbackService = RollbackService(project) + val sessionId = "run-scoped-session" + + val firstPath = getTestFilePath("run_scoped_a.txt") + File(firstPath).delete() + val firstRunId = rollbackService.startSession(sessionId) + rollbackService.trackWrite(sessionId, firstPath) + File(firstPath).writeText("run-a") + rollbackService.finishSession(sessionId) + + val secondPath = getTestFilePath("run_scoped_b.txt") + File(secondPath).delete() + val secondRunId = rollbackService.startSession(sessionId) + rollbackService.trackWrite(sessionId, secondPath) + File(secondPath).writeText("run-b") + rollbackService.finishSession(sessionId) + + val firstRunSnapshot = rollbackService.getRunSnapshot(firstRunId) + val secondRunSnapshot = rollbackService.getRunSnapshot(secondRunId) + val latestSessionSnapshot = rollbackService.getSnapshot(sessionId) + + assertThat(firstRunSnapshot) + .extracting("runId") + .isEqualTo(firstRunId) + assertThat(secondRunSnapshot) + .extracting("runId") + .isEqualTo(secondRunId) + assertThat(latestSessionSnapshot) + .extracting("runId") + .isEqualTo(secondRunId) + } + + fun testRollbackRunDoesNotClearOtherRunSnapshots() { + val rollbackService = RollbackService(project) + val sessionId = "run-rollback-session" + + val firstPath = getTestFilePath("run_rollback_a.txt") + File(firstPath).delete() + val firstRunId = rollbackService.startSession(sessionId) + rollbackService.trackWrite(sessionId, firstPath) + File(firstPath).writeText("run-a") + rollbackService.finishSession(sessionId) + + val secondPath = getTestFilePath("run_rollback_b.txt") + File(secondPath).delete() + val secondRunId = rollbackService.startSession(sessionId) + rollbackService.trackWrite(sessionId, secondPath) + File(secondPath).writeText("run-b") + rollbackService.finishSession(sessionId) + + val actualRollbackResult = runBlocking { + rollbackService.rollbackRun(firstRunId) + } + val firstRunSnapshotAfterRollback = rollbackService.getRunSnapshot(firstRunId) + val secondRunSnapshotAfterRollback = rollbackService.getRunSnapshot(secondRunId) + val latestSessionSnapshot = rollbackService.getSnapshot(sessionId) + + assertThat(actualRollbackResult).isInstanceOf(RollbackResult.Success::class.java) + assertThat(firstRunSnapshotAfterRollback).isNull() + assertThat(secondRunSnapshotAfterRollback).isNotNull() + assertThat(latestSessionSnapshot) + .extracting("runId") + .isEqualTo(secondRunId) + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapperTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapperTest.kt index 074970ea..ac1715ba 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapperTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointConversationMapperTest.kt @@ -4,7 +4,6 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo -import ee.carlrobert.codegpt.conversations.Conversation import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlinx.serialization.json.JsonNull @@ -15,7 +14,6 @@ import org.assertj.core.groups.Tuple.tuple import kotlin.test.Test class AgentCheckpointConversationMapperTest { - private val actualUserTurn = "Actual user prompt" to "Actual response" @Test fun `maps user and assistant turns into conversation messages`() { @@ -29,9 +27,9 @@ class AgentCheckpointConversationMapperTest { ) ) - val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) + val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) - assertThat(conversation.messages) + assertThat(actualConversation.messages) .extracting("prompt", "response") .containsExactly( tuple("First prompt", "First response"), @@ -60,14 +58,14 @@ class AgentCheckpointConversationMapperTest { ) ) - val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) - val mapped = conversation.messages.single() + val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) + val actualMessage = actualConversation.messages.single() - assertThat(mapped.response).isEqualTo("Working on it") - assertThat(mapped.toolCalls.orEmpty()) + assertThat(actualMessage.response).isEqualTo("Working on it") + assertThat(actualMessage.toolCalls.orEmpty()) .extracting("function.name", "function.arguments") .containsExactly(tuple("Bash", """{"command":"ls"}""")) - assertThat(mapped.toolCallResults).containsEntry("tool-1", "file1.txt") + assertThat(actualMessage.toolCallResults).containsEntry("tool-1", "file1.txt") } @Test @@ -81,12 +79,14 @@ class AgentCheckpointConversationMapperTest { ) ) - val conversation = AgentCheckpointConversationMapper.toConversation( + val actualConversation = AgentCheckpointConversationMapper.toConversation( checkpoint = checkpoint, projectInstructions = projectInstructions ) - assertSingleTurn(conversation, actualUserTurn) + assertThat(actualConversation.messages) + .extracting("prompt", "response") + .containsExactly(tuple("Actual user prompt", "Actual response")) } @Test @@ -102,12 +102,43 @@ class AgentCheckpointConversationMapperTest { ) ) - val conversation = AgentCheckpointConversationMapper.toConversation( + val actualConversation = AgentCheckpointConversationMapper.toConversation( checkpoint = checkpoint, projectInstructions = null ) - assertSingleTurn(conversation, actualUserTurn) + assertThat(actualConversation.messages) + .extracting("prompt", "response") + .containsExactly(tuple("Actual user prompt", "Actual response")) + } + + @Test + fun `does not merge hidden user turn output into previous visible turn`() { + val checkpoint = checkpoint( + listOf( + Message.User("Visible prompt 1", RequestMetaInfo.Empty), + Message.Assistant("Visible response 1", ResponseMetaInfo.Empty), + Message.User( + content = "Hidden cacheable instruction", + metaInfo = cacheableMetaInfo() + ), + Message.Assistant("Hidden response", ResponseMetaInfo.Empty), + Message.User("Visible prompt 2", RequestMetaInfo.Empty), + Message.Assistant("Visible response 2", ResponseMetaInfo.Empty) + ) + ) + + val actualConversation = AgentCheckpointConversationMapper.toConversation( + checkpoint = checkpoint, + projectInstructions = null + ) + + assertThat(actualConversation.messages) + .extracting("prompt", "response") + .containsExactly( + tuple("Visible prompt 1", "Visible response 1"), + tuple("Visible prompt 2", "Visible response 2") + ) } @Test @@ -136,12 +167,80 @@ class AgentCheckpointConversationMapperTest { ) ) - val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) + val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) - assertThat(conversation.messages.single().toolCallResults) + assertThat(actualConversation.messages.single().toolCallResults) .containsEntry("tool-1", "file1.txt\n\nfile2.txt") } + @Test + fun `filters synthetic timeline user turn and its output`() { + val checkpoint = checkpoint( + listOf( + Message.User("Visible prompt 1", RequestMetaInfo.Empty), + Message.Assistant("Visible response 1", ResponseMetaInfo.Empty), + Message.User("I haven't created a todo list yet.", RequestMetaInfo.Empty), + Message.Assistant("Synthetic response", ResponseMetaInfo.Empty), + Message.User("Visible prompt 2", RequestMetaInfo.Empty), + Message.Assistant("Visible response 2", ResponseMetaInfo.Empty) + ) + ) + + val actualConversation = AgentCheckpointConversationMapper.toConversation( + checkpoint = checkpoint, + projectInstructions = null + ) + + assertThat(actualConversation.messages) + .extracting("prompt", "response") + .containsExactly( + tuple("Visible prompt 1", "Visible response 1"), + tuple("Visible prompt 2", "Visible response 2") + ) + } + + @Test + fun `filters internal timeline tools from mapped message`() { + val checkpoint = checkpoint( + listOf( + Message.User("Visible prompt", RequestMetaInfo.Empty), + Message.Tool.Call( + id = "todo-1", + tool = "TodoWrite", + content = """{"todos":[{"content":"x","status":"pending"}]}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "todo-1", + tool = "TodoWrite", + content = "ok", + metaInfo = RequestMetaInfo.Empty + ), + Message.Tool.Call( + id = "tool-1", + tool = "Bash", + content = """{"command":"ls"}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "tool-1", + tool = "Bash", + content = "file1.txt", + metaInfo = RequestMetaInfo.Empty + ) + ) + ) + + val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null) + + val actualMessage = actualConversation.messages.single() + assertThat(actualMessage.toolCalls.orEmpty()) + .extracting("function.name") + .containsExactly("Bash") + assertThat(actualMessage.toolCallResults) + .containsOnlyKeys("tool-1") + } + private fun cacheableMetaInfo(): RequestMetaInfo { return RequestMetaInfo( timestamp = Clock.System.now(), @@ -149,15 +248,6 @@ class AgentCheckpointConversationMapperTest { ) } - private fun assertSingleTurn( - conversation: Conversation, - expectedTurn: Pair - ) { - assertThat(conversation.messages) - .extracting("prompt", "response") - .containsExactly(tuple(expectedTurn.first, expectedTurn.second)) - } - private fun checkpoint(history: List): AgentCheckpointData { return AgentCheckpointData( checkpointId = "checkpoint-1", diff --git a/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencerTest.kt new file mode 100644 index 00000000..92e517b3 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/agent/history/AgentCheckpointTurnSequencerTest.kt @@ -0,0 +1,248 @@ +package ee.carlrobert.codegpt.agent.history + +import ai.koog.agents.snapshot.feature.AgentCheckpointData +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.assertj.core.api.Assertions.assertThat +import kotlin.test.Test + +class AgentCheckpointTurnSequencerTest { + + @Test + fun `preserves assistant and tool event order within a turn`() { + val history = listOf( + Message.User("Prompt 1", RequestMetaInfo.Empty), + Message.Assistant("Assistant 1", ResponseMetaInfo.Empty), + Message.Tool.Call( + id = "tool-1", + tool = "Bash", + content = """{"command":"ls"}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Assistant("Assistant 2", ResponseMetaInfo.Empty), + Message.Tool.Result( + id = "tool-1", + tool = "Bash", + content = "file.txt", + metaInfo = RequestMetaInfo.Empty + ), + Message.Reasoning(content = "Reasoning 1", metaInfo = ResponseMetaInfo.Empty), + Message.User("Prompt 2", RequestMetaInfo.Empty), + Message.Assistant("Assistant 3", ResponseMetaInfo.Empty) + ) + + val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns( + history = history, + projectInstructions = null + ) + + assertThat(actualTurns).hasSize(2) + assertThat(actualTurns[0].prompt).isEqualTo("Prompt 1") + assertThat(actualTurns[0].userNonSystemMessageCount).isEqualTo(1) + assertThat(actualTurns[0].events.map(::eventSignature)) + .containsExactly( + "assistant:Assistant 1", + "tool-call:Bash:tool-1", + "assistant:Assistant 2", + "tool-result:Bash:tool-1", + "reasoning:Reasoning 1" + ) + assertThat(actualTurns[0].events.map { it.nonSystemMessageCount }) + .containsExactly(2, 3, 4, 5, 6) + assertThat(actualTurns[1].prompt).isEqualTo("Prompt 2") + assertThat(actualTurns[1].userNonSystemMessageCount).isEqualTo(7) + assertThat(actualTurns[1].events.map(::eventSignature)) + .containsExactly("assistant:Assistant 3") + assertThat(actualTurns[1].events.map { it.nonSystemMessageCount }) + .containsExactly(8) + } + + @Test + fun `filters hidden and synthetic user turns and internal tools`() { + val history = listOf( + Message.User("Visible prompt 1", RequestMetaInfo.Empty), + Message.Assistant("Visible response 1", ResponseMetaInfo.Empty), + Message.User("Hidden cacheable prompt", cacheableMetaInfo()), + Message.Assistant("Hidden response", ResponseMetaInfo.Empty), + Message.User("I haven't created a todo list yet.", RequestMetaInfo.Empty), + Message.Assistant("Synthetic response", ResponseMetaInfo.Empty), + Message.User("Visible prompt 2", RequestMetaInfo.Empty), + Message.Tool.Call( + id = "todo-1", + tool = "TodoWrite", + content = """{"todos":[{"content":"x","status":"pending"}]}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "todo-1", + tool = "TodoWrite", + content = "ok", + metaInfo = RequestMetaInfo.Empty + ), + Message.Tool.Call( + id = "tool-1", + tool = "Bash", + content = """{"command":"ls"}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "tool-1", + tool = "Bash", + content = "file.txt", + metaInfo = RequestMetaInfo.Empty + ), + Message.Assistant("Visible response 2", ResponseMetaInfo.Empty) + ) + + val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns( + history = history, + projectInstructions = null + ) + + assertThat(actualTurns).hasSize(2) + assertThat(actualTurns.map { it.prompt }) + .containsExactly("Visible prompt 1", "Visible prompt 2") + assertThat(actualTurns.map { it.userNonSystemMessageCount }) + .containsExactly(1, 7) + assertThat(actualTurns[0].events.map(::eventSignature)) + .containsExactly("assistant:Visible response 1") + assertThat(actualTurns[0].events.map { it.nonSystemMessageCount }) + .containsExactly(2) + assertThat(actualTurns[1].events.map(::eventSignature)) + .containsExactly( + "tool-call:Bash:tool-1", + "tool-result:Bash:tool-1", + "assistant:Visible response 2" + ) + assertThat(actualTurns[1].events.map { it.nonSystemMessageCount }) + .containsExactly(10, 11, 12) + } + + @Test + fun `mapper and sequencer keep same visible turn prompts`() { + val history = listOf( + Message.User("Visible prompt 1", RequestMetaInfo.Empty), + Message.Assistant("Visible response 1", ResponseMetaInfo.Empty), + Message.User("Hidden cacheable prompt", cacheableMetaInfo()), + Message.Assistant("Hidden response", ResponseMetaInfo.Empty), + Message.User("Visible prompt 2", RequestMetaInfo.Empty), + Message.Tool.Call( + id = "tool-1", + tool = "Bash", + content = """{"command":"ls"}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "tool-1", + tool = "Bash", + content = "file.txt", + metaInfo = RequestMetaInfo.Empty + ) + ) + + val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(history, null) + val actualConversation = AgentCheckpointConversationMapper.toConversation( + checkpoint = AgentCheckpointData( + checkpointId = "checkpoint-1", + createdAt = Instant.parse("2026-02-04T00:00:00Z"), + nodePath = "agent/single_run/nodeExecuteTool", + lastInput = JsonNull, + messageHistory = history, + version = 0 + ), + projectInstructions = null + ) + + assertThat(actualConversation.messages.map { it.prompt }) + .containsExactlyElementsOf(actualTurns.map { it.prompt }) + } + + @Test + fun `keeps events after synthetic todo user message within the active visible turn`() { + val history = listOf( + Message.User("Implement feature X", RequestMetaInfo.Empty), + Message.Assistant("Starting implementation", ResponseMetaInfo.Empty), + Message.Tool.Call( + id = "bash-1", + tool = "Bash", + content = """{"command":"ls"}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "bash-1", + tool = "Bash", + content = "file.txt", + metaInfo = RequestMetaInfo.Empty + ), + Message.User( + "It seems that you haven't created a todo list yet. If the task on hand requires multiple steps then create a todo list to track your changes.", + RequestMetaInfo.Empty + ), + Message.Tool.Call( + id = "todo-1", + tool = "TodoWrite", + content = """{"todos":[{"content":"x","status":"pending"}]}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "todo-1", + tool = "TodoWrite", + content = "ok", + metaInfo = RequestMetaInfo.Empty + ), + Message.Tool.Call( + id = "read-1", + tool = "Read", + content = """{"path":"src/Main.kt"}""", + metaInfo = ResponseMetaInfo.Empty + ), + Message.Tool.Result( + id = "read-1", + tool = "Read", + content = "content", + metaInfo = RequestMetaInfo.Empty + ), + Message.Assistant("Done", ResponseMetaInfo.Empty) + ) + + val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns( + history = history, + projectInstructions = null, + preserveSyntheticContinuation = true + ) + + assertThat(actualTurns).hasSize(1) + assertThat(actualTurns.single().prompt).isEqualTo("Implement feature X") + assertThat(actualTurns.single().events.map(::eventSignature)) + .containsExactly( + "assistant:Starting implementation", + "tool-call:Bash:bash-1", + "tool-result:Bash:bash-1", + "tool-call:Read:read-1", + "tool-result:Read:read-1", + "assistant:Done" + ) + } + + private fun eventSignature(event: AgentCheckpointTurnSequencer.TurnEvent): String { + return when (event) { + is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> "assistant:${event.content}" + is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> "reasoning:${event.content}" + is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> "tool-call:${event.tool}:${event.id}" + is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> "tool-result:${event.tool}:${event.id}" + } + } + + private fun cacheableMetaInfo(): RequestMetaInfo { + return RequestMetaInfo( + timestamp = Clock.System.now(), + metadata = JsonObject(mapOf("cacheable" to JsonPrimitive(true))) + ) + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt index 572ffab9..80bea9f5 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt @@ -17,7 +17,6 @@ class ChatToolWindowTabbedPaneTest : BasePlatformTestCase() { assertThat(tabbedPane.activeTabMapping).isEmpty() } - fun testAddingNewTabs() { val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())