feat: agent timeline

This commit is contained in:
Carl-Robert Linnupuu 2026-02-13 16:55:39 +00:00
parent 24e2a5fbad
commit 3f32f19f72
31 changed files with 3779 additions and 450 deletions

View file

@ -1,11 +1,13 @@
package ee.carlrobert.codegpt.agent
import ai.koog.prompt.executor.clients.LLMClientException
import ee.carlrobert.codegpt.agent.history.CheckpointRef
import ee.carlrobert.codegpt.agent.tools.AskUserQuestionTool
import ee.carlrobert.codegpt.conversations.message.TokenUsage
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest
import java.util.UUID
interface AgentEvents {
fun onTextReceived(text: String) {}
@ -22,6 +24,7 @@ interface AgentEvents {
}
fun onRetry(attempt: Int, maxAttempts: Int, reason: String? = null) {}
fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {}
fun onQueuedMessagesResolved()
fun onTokenUsageAvailable(tokenUsage: Long) {}
fun onCreditsAvailable(event: AgentCreditsEvent) {}

View file

@ -19,9 +19,12 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.McpTagDetails
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.datetime.Clock
import kotlinx.serialization.json.JsonNull
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.io.path.Path
import ai.koog.prompt.message.Message as PromptMessage
internal fun interface AgentRuntimeFactory {
fun create(
@ -53,16 +56,17 @@ class AgentService(private val project: Project) {
private val pendingMessages = ConcurrentHashMap<String, ArrayDeque<MessageWithContext>>()
private val sessionAgents = ConcurrentHashMap<String, AIAgent<MessageWithContext, String>>()
private val sessionRuntimes = ConcurrentHashMap<String, SessionRuntime>()
internal var runtimeFactory: AgentRuntimeFactory = AgentRuntimeFactory { p, storage, provider, events, sid, pending ->
ProxyAIAgent.createService(
project = p,
checkpointStorage = storage,
provider = provider,
events = events,
sessionId = sid,
pendingMessages = pending
)
}
internal var runtimeFactory: AgentRuntimeFactory =
AgentRuntimeFactory { p, storage, provider, events, sid, pending ->
ProxyAIAgent.createService(
project = p,
checkpointStorage = storage,
provider = provider,
events = events,
sessionId = sid,
pendingMessages = pending
)
}
private val checkpointStorage =
JVMFilePersistenceStorageProvider(Path(project.basePath ?: "", ".proxyai"))
private val historyService = project.service<AgentCheckpointHistoryService>()
@ -132,7 +136,8 @@ class AgentService(private val project: Project) {
} catch (ex: Exception) {
logger.error(ex)
} finally {
refreshSessionResumeCheckpoint(sessionId, agent.id)
val ref = refreshSessionResumeCheckpoint(sessionId, agent.id)
events.onRunCheckpointUpdated(message.id, ref)
sessionAgents.remove(sessionId, agent)
runCatching { runtime.service.removeAgentWithId(agent.id) }
.onFailure { ex ->
@ -201,7 +206,39 @@ class AgentService(private val project: Project) {
.update(sessionId, conversationId, selectedServerIds)
}
private suspend fun refreshSessionResumeCheckpoint(sessionId: String, agentId: String) {
suspend fun createSeedCheckpointFromHistory(history: List<PromptMessage>): CheckpointRef? =
withContext(Dispatchers.IO) {
if (history.isEmpty()) {
return@withContext null
}
val agentId = UUID.randomUUID().toString()
val checkpointId = UUID.randomUUID().toString()
val checkpoint = AgentCheckpointData(
checkpointId = checkpointId,
createdAt = Clock.System.now(),
nodePath = "$agentId/single_run/nodeExecuteTool",
lastInput = JsonNull,
messageHistory = history,
version = 0
)
runCatching {
checkpointStorage.saveCheckpoint(agentId, checkpoint)
CheckpointRef(agentId, checkpointId)
}.onFailure { ex ->
logger.warn(
"Agent checkpoints: failed to create seed checkpoint from history " +
"agentId=$agentId error=${ex.message}",
ex
)
}.getOrNull()
}
private suspend fun refreshSessionResumeCheckpoint(
sessionId: String,
agentId: String
): CheckpointRef? {
val ref = runCatching {
historyService.loadLatestResumeCheckpoint(agentId)
?.let { CheckpointRef(agentId, it.checkpointId) }
@ -216,6 +253,7 @@ class AgentService(private val project: Project) {
if (ref != null) {
project.service<AgentToolWindowContentManager>().setResumeCheckpointRef(sessionId, ref)
}
return ref
}
private fun ensureSessionRuntime(

View file

@ -171,6 +171,8 @@ object ProxyAIAgent {
}
onNodeExecutionCompleted { ctx ->
if (stream) return@onNodeExecutionCompleted
(ctx.output as? List<*>)?.forEach { msg ->
(msg as? Message.Assistant)?.let {
events.onTextReceived(it.content)

View file

@ -169,25 +169,20 @@ object ToolSpecs {
fun find(toolName: String): ToolSpec<*, *>? = specsByName[toolName.lowercase()]
fun approvalTypeFor(toolName: String): ToolApprovalType {
return find(toolName)?.approvalType ?: ToolApprovalType.GENERIC
}
fun approvalTypeFor(toolName: String): ToolApprovalType =
find(toolName)?.approvalType ?: ToolApprovalType.GENERIC
fun decodeArgsOrNull(
json: Json,
toolName: String,
payload: String
): Any? {
return decodeOrNull(json, find(toolName)?.argsSerializer, payload)
}
) = decodeOrNull(json, find(toolName)?.argsSerializer, payload)
fun decodeResultOrNull(
json: Json,
toolName: String,
payload: String
): Any? {
return decodeOrNull(json, find(toolName)?.resultSerializer, payload)
}
) = decodeOrNull(json, find(toolName)?.resultSerializer, payload)
@Suppress("UNCHECKED_CAST")
private fun decodeOrNull(

View file

@ -3,9 +3,9 @@ package ee.carlrobert.codegpt.agent.history
import ai.koog.agents.snapshot.feature.AgentCheckpointData
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks
import ee.carlrobert.llm.client.openai.completion.response.ToolCall
import ee.carlrobert.llm.client.openai.completion.response.ToolFunctionResponse
import ai.koog.prompt.message.Message as PromptMessage
object AgentCheckpointConversationMapper {
@ -14,21 +14,54 @@ object AgentCheckpointConversationMapper {
projectInstructions: String?
): Conversation {
val conversation = Conversation()
val history = checkpoint.messageHistory.filterNot { it is PromptMessage.System }
val turns = AgentCheckpointTurnSequencer.toVisibleTurns(
history = checkpoint.messageHistory,
projectInstructions = projectInstructions
)
var currentPrompt: String? = null
val response = StringBuilder()
val toolCalls = mutableListOf<ToolCall>()
val toolResults = LinkedHashMap<String, String>()
var syntheticToolIdIndex = 0
turns.forEach { turn ->
val response = StringBuilder()
val toolCalls = mutableListOf<ToolCall>()
val toolResults = LinkedHashMap<String, String>()
var syntheticToolIdIndex = 0
fun flushTurn() {
val prompt = currentPrompt?.trim().orEmpty()
if (prompt.isBlank()) {
return
turn.events.forEach { event ->
when (event) {
is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> {
appendAssistant(response, event.content)
}
is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> {
appendAssistant(response, event.content)
}
is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> {
val callId = event.id?.takeIf { it.isNotBlank() }
?: "tool-call-${++syntheticToolIdIndex}"
toolCalls.add(
ToolCall(
null,
callId,
"function",
ToolFunctionResponse(event.tool, event.content.trim())
)
)
}
is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> {
val callId = event.id?.takeIf { it.isNotBlank() }
?: toolCalls.lastOrNull()?.id
?: "tool-call-${++syntheticToolIdIndex}"
val prior = toolResults[callId]
val merged = if (prior.isNullOrBlank()) event.content.trim() else {
"$prior\n\n${event.content.trim()}"
}
toolResults[callId] = merged
}
}
}
val uiMessage = Message(prompt)
val uiMessage = Message(turn.prompt)
uiMessage.response = response.toString().trim()
if (toolCalls.isNotEmpty()) {
uiMessage.toolCalls = ArrayList(toolCalls)
@ -37,64 +70,13 @@ object AgentCheckpointConversationMapper {
uiMessage.toolCallResults = LinkedHashMap(toolResults)
}
conversation.addMessage(uiMessage)
currentPrompt = null
response.setLength(0)
toolCalls.clear()
toolResults.clear()
}
history.forEach { msg ->
when (msg) {
is PromptMessage.User -> {
val text = msg.content.trim()
if (shouldHideInAgentToolWindow(msg, projectInstructions)) {
return@forEach
}
flushTurn()
currentPrompt = text
}
is PromptMessage.Assistant -> appendAssistant(response, msg.content)
is PromptMessage.Reasoning -> appendAssistant(response, msg.content)
is PromptMessage.Tool.Call -> {
if (currentPrompt != null) {
val callId =
msg.id?.takeIf { it.isNotBlank() }
?: "tool-call-${++syntheticToolIdIndex}"
toolCalls.add(
ToolCall(
null,
callId,
"function",
ToolFunctionResponse(msg.tool, msg.content.trim())
)
)
}
}
is PromptMessage.Tool.Result -> {
if (currentPrompt != null) {
val callId = msg.id?.takeIf { it.isNotBlank() }
?: toolCalls.lastOrNull()?.id
?: "tool-call-${++syntheticToolIdIndex}"
val prior = toolResults[callId]
val merged = if (prior.isNullOrBlank()) msg.content.trim() else {
"$prior\n\n${msg.content.trim()}"
}
toolResults[callId] = merged
}
}
else -> Unit
}
}
flushTurn()
return conversation
}
private fun appendAssistant(sb: StringBuilder, content: String) {
val text = content.trim()
val text = content.stripThinkingBlocks()
if (text.isBlank()) {
return
}

View file

@ -96,6 +96,13 @@ class AgentCheckpointHistoryService(project: Project) {
?: checkpoints.maxByOrNull { it.createdAt }
}
suspend fun listCheckpoints(agentId: String): List<AgentCheckpointData> =
withContext(Dispatchers.IO) {
storage.getCheckpoints(agentId)
.filterNot { it.isTombstone() }
.sortedByDescending { it.createdAt }
}
private suspend fun buildSummary(agentId: String): AgentHistoryThreadSummary? {
val checkpoints = storage.getCheckpoints(agentId)
.filterNot { it.isTombstone() }

View file

@ -0,0 +1,193 @@
package ee.carlrobert.codegpt.agent.history
import kotlinx.serialization.json.booleanOrNull
import kotlinx.serialization.json.contentOrNull
import kotlinx.serialization.json.jsonPrimitive
import ai.koog.prompt.message.Message as PromptMessage
object AgentCheckpointTurnSequencer {
data class Turn(
val prompt: String,
val userNonSystemMessageCount: Int,
val events: List<TurnEvent>
)
sealed interface TurnEvent {
val nonSystemMessageCount: Int
data class Assistant(
val content: String,
override val nonSystemMessageCount: Int
) : TurnEvent
data class Reasoning(
val content: String,
override val nonSystemMessageCount: Int
) : TurnEvent
data class ToolCall(
val id: String?,
val tool: String,
val content: String,
override val nonSystemMessageCount: Int
) : TurnEvent
data class ToolResult(
val id: String?,
val tool: String,
val content: String,
override val nonSystemMessageCount: Int
) : TurnEvent
}
fun toVisibleTurns(
history: List<PromptMessage>,
projectInstructions: String?,
preserveSyntheticContinuation: Boolean = false
): List<Turn> {
val turns = mutableListOf<Turn>()
var currentPrompt: String? = null
var currentUserNonSystemMessageCount = 0
val currentEvents = mutableListOf<TurnEvent>()
fun flushTurn() {
val prompt = currentPrompt?.trim().orEmpty()
if (prompt.isBlank() || currentUserNonSystemMessageCount <= 0) {
return
}
turns.add(
Turn(
prompt = prompt,
userNonSystemMessageCount = currentUserNonSystemMessageCount,
events = currentEvents.toList()
)
)
currentPrompt = null
currentUserNonSystemMessageCount = 0
currentEvents.clear()
}
history
.filterNot { it is PromptMessage.System }
.forEachIndexed { index, message ->
val nonSystemMessageCount = index + 1
when (message) {
is PromptMessage.User -> {
if (isSyntheticTimelineUserMessage(message)) {
if (preserveSyntheticContinuation && currentPrompt != null) {
// Keep collecting events for the currently active visible turn.
// Synthetic todo prompts are injected internally mid-run.
return@forEachIndexed
}
flushTurn()
currentPrompt = null
currentUserNonSystemMessageCount = 0
return@forEachIndexed
}
flushTurn()
if (isHiddenUserMessage(message, projectInstructions)) {
currentPrompt = null
currentUserNonSystemMessageCount = 0
return@forEachIndexed
}
currentPrompt = message.content.trim()
currentUserNonSystemMessageCount = nonSystemMessageCount
}
is PromptMessage.Assistant -> {
if (currentPrompt != null) {
currentEvents.add(
TurnEvent.Assistant(
content = message.content,
nonSystemMessageCount = nonSystemMessageCount
)
)
}
}
is PromptMessage.Reasoning -> {
if (currentPrompt != null) {
currentEvents.add(
TurnEvent.Reasoning(
content = message.content,
nonSystemMessageCount = nonSystemMessageCount
)
)
}
}
is PromptMessage.Tool.Call -> {
if (currentPrompt != null && !isTodoWriteTool(message.tool)) {
currentEvents.add(
TurnEvent.ToolCall(
id = message.id,
tool = message.tool,
content = message.content,
nonSystemMessageCount = nonSystemMessageCount
)
)
}
}
is PromptMessage.Tool.Result -> {
if (currentPrompt != null && !isTodoWriteTool(message.tool)) {
currentEvents.add(
TurnEvent.ToolResult(
id = message.id,
tool = message.tool,
content = message.content,
nonSystemMessageCount = nonSystemMessageCount
)
)
}
}
else -> Unit
}
}
flushTurn()
return turns
}
fun isTodoWriteTool(toolName: String): Boolean {
return toolName.equals("TodoWrite", ignoreCase = true)
}
fun isSyntheticTimelineUserMessage(message: PromptMessage.User): Boolean {
val normalized = message.content.lowercase()
return normalized.contains("haven't created a todo list yet")
}
private fun isHiddenUserMessage(
message: PromptMessage.User,
projectInstructions: String?
): Boolean {
return isCacheableInstructionMessage(message) ||
isProjectInstructionsMessage(message.content, projectInstructions)
}
private fun isCacheableInstructionMessage(message: PromptMessage.User): Boolean {
val cacheable = message.metaInfo.metadata
?.get("cacheable")
?.jsonPrimitive
?: return false
return cacheable.booleanOrNull ?: (cacheable.contentOrNull?.equals(
"true",
ignoreCase = true
) == true)
}
private fun isProjectInstructionsMessage(text: String, projectInstructions: String?): Boolean {
if (projectInstructions.isNullOrBlank()) {
return false
}
return normalize(text) == normalize(projectInstructions)
}
private fun normalize(value: String): String {
return value.replace("\\s+".toRegex(), " ").trim()
}
}

View file

@ -13,9 +13,7 @@ import com.intellij.openapi.fileTypes.FileTypeManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFileManager
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
import ee.carlrobert.codegpt.agent.tools.WriteTool
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
@ -23,85 +21,100 @@ import java.nio.file.Paths
import java.time.Instant
import java.time.format.DateTimeFormatter
import java.util.concurrent.ConcurrentHashMap
import java.util.UUID
/**
* Tracks file changes during agent runs and provides rollback to a checkpoint.
*
* Uses IntelliJ's LocalHistory labels for persistence and a lightweight in-memory
* snapshot of modified files for fast restore.
*
* Now uses direct tracking via trackEdit() and trackWrite() methods instead of
* VFS event monitoring to avoid capturing build artifacts and temporary files.
*/
@Service(Service.Level.PROJECT)
class RollbackService(private val project: Project) {
private val activeRuns = ConcurrentHashMap<String, RunTracker>()
private val snapshots = ConcurrentHashMap<String, SnapshotState>()
private val activeRunsBySession = ConcurrentHashMap<String, String>()
private val snapshotsByRunId = ConcurrentHashMap<String, SnapshotState>()
private val latestSnapshotRunIdBySession = ConcurrentHashMap<String, String>()
@Volatile
private var isApplyingRollback = false
private val MAX_TRACKABLE_BYTES = 10 * 1024 * 1024
private val MAX_TRACKABLE_BYTES = 10 * 1024 * 1024 // 10MB
init {
// No VFS listener - using direct tracking via trackEdit() and trackWrite()
}
/**
* Track an EditTool operation directly from the agent.
* This captures the original file content before the edit is applied.
*/
fun trackEdit(
sessionId: String,
filePath: String,
originalContent: String
) {
val tracker = activeRuns[sessionId] ?: return
val normalizedPath = filePath.replace("\\", "/")
tracker.recordExplicitEdit(normalizedPath, originalContent)
val runId = activeRunsBySession[sessionId] ?: return
trackEditForRun(runId, filePath, originalContent)
}
/**
* Track a WriteTool operation directly from the agent.
* This determines if the file was new or existing before the write.
*/
fun trackWrite(sessionId: String, filePath: String, args: WriteTool.Args) {
val tracker = activeRuns[sessionId] ?: return
fun trackEditForRun(
runId: String,
filePath: String,
originalContent: String
) {
val tracker = activeRuns[runId] ?: return
val normalizedPath = filePath.replace("\\", "/")
tracker.recordExplicitWrite(normalizedPath, args)
tracker.recordExplicitEdit(normalizedPath, originalContent)
}
fun startSession(sessionId: String) {
fun trackWrite(sessionId: String, filePath: String) {
val runId = activeRunsBySession[sessionId] ?: return
trackWriteForRun(runId, filePath)
}
fun trackWriteForRun(runId: String, filePath: String) {
val tracker = activeRuns[runId] ?: return
val normalizedPath = filePath.replace("\\", "/")
tracker.recordExplicitWrite(normalizedPath)
}
fun startSession(sessionId: String): String {
val labelText = "ProxyAI: Agent run ${DateTimeFormatter.ISO_INSTANT.format(Instant.now())}"
val label = LocalHistory.getInstance().putSystemLabel(project, labelText)
activeRuns[sessionId] = RunTracker(sessionId, Instant.now(), labelText, label)
snapshots.remove(sessionId)
val runId = UUID.randomUUID().toString()
val tracker = RunTracker(runId, sessionId, Instant.now(), labelText, label)
activeRuns[runId] = tracker
activeRunsBySession[sessionId] = runId
latestSnapshotRunIdBySession.remove(sessionId)
return runId
}
fun finishSession(sessionId: String): RollbackSnapshot? {
val tracker = activeRuns.remove(sessionId) ?: return getSnapshot(sessionId)
val runId = activeRunsBySession.remove(sessionId) ?: return getSnapshot(sessionId)
return finishRun(runId)
}
fun finishRun(runId: String): RollbackSnapshot? {
val tracker = activeRuns.remove(runId) ?: return getRunSnapshot(runId)
val snapshot = SnapshotState(
sessionId = sessionId,
runId = runId,
sessionId = tracker.sessionId,
label = tracker.label,
labelRef = tracker.labelRef,
startedAt = tracker.startedAt,
completedAt = Instant.now(),
changes = tracker.changes.toMap()
)
snapshots[sessionId] = snapshot
snapshotsByRunId[runId] = snapshot
latestSnapshotRunIdBySession[tracker.sessionId] = runId
return snapshot.toSnapshot()
}
fun getSnapshot(sessionId: String): RollbackSnapshot? =
snapshots[sessionId]?.toSnapshot()
fun getSnapshot(sessionId: String): RollbackSnapshot? {
val runId = latestSnapshotRunIdBySession[sessionId] ?: return null
return snapshotsByRunId[runId]?.toSnapshot()
}
fun getRunSnapshot(runId: String): RollbackSnapshot? =
snapshotsByRunId[runId]?.toSnapshot()
fun clearSnapshot(sessionId: String) {
snapshots.remove(sessionId)
val runId = latestSnapshotRunIdBySession.remove(sessionId) ?: return
snapshotsByRunId.remove(runId)
}
fun clearRunSnapshot(runId: String) {
val snapshot = snapshotsByRunId.remove(runId) ?: return
latestSnapshotRunIdBySession.remove(snapshot.sessionId, runId)
}
fun getDiffData(sessionId: String, path: String): RollbackDiffData? {
val snapshot = snapshots[sessionId] ?: return null
val snapshot = snapshotForSession(sessionId) ?: return null
val change = snapshot.changes[path] ?: return null
if (change.kind != ChangeKind.DELETED && !isTrackable(path)) return null
val beforeText = when (change.kind) {
@ -131,28 +144,19 @@ class RollbackService(private val project: Project) {
}
fun isRollbackAvailable(sessionId: String): Boolean =
snapshots[sessionId]?.changes?.isNotEmpty() == true && !activeRuns.containsKey(sessionId)
snapshotForSession(sessionId)?.changes?.isNotEmpty() == true &&
!activeRunsBySession.containsKey(sessionId)
fun isDisplayable(path: String): Boolean = isTrackable(path)
suspend fun rollbackFile(sessionId: String, path: String): RollbackResult =
withContext(Dispatchers.IO) {
val snapshot = snapshots[sessionId]
val snapshot = snapshotForSession(sessionId)
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
val change = snapshot.changes[path]
?: return@withContext RollbackResult.Failure("No change tracked for $path")
val errors = mutableListOf<String>()
runInEdt(ModalityState.defaultModalityState()) {
runWriteAction {
try {
isApplyingRollback = true
applyChangeWithLabel(snapshot.labelRef, path, change, errors)
} finally {
isApplyingRollback = false
}
}
}
val errors = applyChanges(snapshot.labelRef, mapOf(path to change))
if (errors.isNotEmpty()) {
RollbackResult.Failure(errors.joinToString("\n"))
@ -160,36 +164,38 @@ class RollbackService(private val project: Project) {
val updated = snapshot.changes.toMutableMap()
updated.remove(path)
if (updated.isEmpty()) {
snapshots.remove(sessionId)
clearRunSnapshot(snapshot.runId)
} else {
snapshots[sessionId] = snapshot.copy(changes = updated.toMap())
snapshotsByRunId[snapshot.runId] = snapshot.copy(changes = updated.toMap())
}
RollbackResult.Success("Rollback completed")
}
}
suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.Main) {
val snapshot = snapshots[sessionId]
suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.IO) {
val snapshot = snapshotForSession(sessionId)
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
val errors = mutableListOf<String>()
runInEdt(ModalityState.defaultModalityState()) {
runWriteAction {
try {
isApplyingRollback = true
snapshot.changes.forEach { (path, change) ->
applyChangeWithLabel(snapshot.labelRef, path, change, errors)
}
} finally {
isApplyingRollback = false
}
}
}
val errors = applyChanges(snapshot.labelRef, snapshot.changes)
if (errors.isNotEmpty()) {
RollbackResult.Failure(errors.joinToString("\n"))
} else {
snapshots.remove(sessionId)
clearRunSnapshot(snapshot.runId)
RollbackResult.Success("Rollback completed")
}
}
suspend fun rollbackRun(runId: String): RollbackResult = withContext(Dispatchers.IO) {
val snapshot = snapshotsByRunId[runId]
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
val errors = applyChanges(snapshot.labelRef, snapshot.changes)
if (errors.isNotEmpty()) {
RollbackResult.Failure(errors.joinToString("\n"))
} else {
clearRunSnapshot(runId)
RollbackResult.Success("Rollback completed")
}
}
@ -269,8 +275,8 @@ class RollbackService(private val project: Project) {
errors.add("Missing parent directory for $path")
return@runCatching null
}
val parentVf = VirtualFileManager.getInstance()
.findFileByUrl("file://$parentPath") ?: return@runCatching null
val parentVf = LocalFileSystem.getInstance()
.refreshAndFindFileByPath(parentPath) ?: return@runCatching null
parentVf.createChildData(this, ioFile.name)
}.getOrNull()
@ -345,7 +351,6 @@ class RollbackService(private val project: Project) {
): ByteArray? {
val labelContent = runCatching { label.getByteContent(path) }
.getOrNull()?.bytes
// If label content is null or empty, use fallback (captured at track time)
return if (labelContent == null || labelContent.isEmpty()) {
fallback
} else {
@ -353,6 +358,26 @@ class RollbackService(private val project: Project) {
}
}
private fun applyChanges(
label: Label,
changes: Map<String, TrackedChange>
): List<String> {
val errors = mutableListOf<String>()
runInEdt(ModalityState.defaultModalityState()) {
runWriteAction {
changes.forEach { (path, change) ->
applyChangeWithLabel(label, path, change, errors)
}
}
}
return errors
}
private fun snapshotForSession(sessionId: String): SnapshotState? {
val runId = latestSnapshotRunIdBySession[sessionId] ?: return null
return snapshotsByRunId[runId]
}
companion object {
fun getInstance(project: Project): RollbackService {
return project.getService(RollbackService::class.java)
@ -360,6 +385,7 @@ class RollbackService(private val project: Project) {
}
private class RunTracker(
val runId: String,
val sessionId: String,
val startedAt: Instant,
val label: String,
@ -378,24 +404,23 @@ class RollbackService(private val project: Project) {
)
}
fun recordExplicitWrite(filePath: String, args: WriteTool.Args) {
val path = filePath.replace("\\", "/")
val existing = changes[path]
val file = File(path)
fun recordExplicitWrite(filePath: String) {
val existing = changes[filePath]
val file = File(filePath)
if (existing?.kind == ChangeKind.DELETED) {
changes[path] = existing.copy(kind = ChangeKind.MODIFIED)
changes[filePath] = existing.copy(kind = ChangeKind.MODIFIED)
return
}
if (existing == null) {
val originalContent = if (file.exists()) {
runCatching { file.readText() }.getOrNull()
runCatching { file.readText(Charsets.UTF_8) }.getOrNull()
} else null
changes[path] = TrackedChange(
changes[filePath] = TrackedChange(
kind = if (file.exists()) ChangeKind.MODIFIED else ChangeKind.ADDED,
originalPath = null,
originalContent = originalContent?.toByteArray()
originalContent = originalContent?.toByteArray(Charsets.UTF_8)
)
}
}
@ -428,6 +453,7 @@ class RollbackService(private val project: Project) {
}
private data class SnapshotState(
val runId: String,
val sessionId: String,
val label: String,
val labelRef: Label,
@ -438,6 +464,7 @@ class RollbackService(private val project: Project) {
fun toSnapshot(): RollbackSnapshot? {
if (changes.isEmpty()) return null
return RollbackSnapshot(
runId = runId,
sessionId = sessionId,
label = label,
startedAt = startedAt,
@ -455,6 +482,7 @@ class RollbackService(private val project: Project) {
}
data class RollbackSnapshot(
val runId: String,
val sessionId: String,
val label: String,
val startedAt: Instant,

View file

@ -84,12 +84,12 @@ class BashTool(
- If the commands depend on each other and must run sequentially, use a single Bash call with '&&' to chain them together (e.g., `git add . && git commit -m "message" && git push`). For instance, if one operation must complete before another starts (like mkdir before cp, Write before Bash for git operations, or git add before git commit), run these operations sequentially instead.
- Use ';' only when you need to run commands sequentially but don't care if earlier commands fail
- DO NOT use newlines to separate commands (newlines are ok in quoted strings)
- Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of `cd`. You may use `cd` if the User explicitly requests it.
- Commands run with the working directory already set to the project root. Prefer relative project paths (e.g., `src/...` or `./src/...`) instead of root-absolute paths like `/src` or hardcoded machine-specific paths like `/Users/.../project/...`. You may use `cd` only if the user explicitly requests it.
<good-example>
pytest /foo/bar/tests
find src -type f -name "*.kt"
</good-example>
<bad-example>
cd /foo/bar && pytest tests
find /src -type f -name "*.kt"
</bad-example>
# Committing changes with git
@ -261,11 +261,10 @@ class BashTool(
)
}
val isAskPattern = shouldAskForConfirmation(args.command)
val resolvedToolId = toolId ?: throw IllegalArgumentException("Tool ID is missing")
val confirmation = when {
isWhiteListed(args) && !isAskPattern -> ShellCommandConfirmation.Approved
isWhiteListed(args) -> ShellCommandConfirmation.Approved
else -> confirmationHandler.requestConfirmation(args)
}
return when (confirmation) {
@ -377,7 +376,6 @@ class BashTool(
}
}
} catch (_: IOException) {
// Ignore IO exception if the stream is closed
}
}
@ -391,7 +389,6 @@ class BashTool(
}
}
} catch (_: IOException) {
// Ignore IO exception if the stream is closed
}
}
@ -463,8 +460,7 @@ class BashTool(
}
when (event) {
WaitEvent.EXIT -> done = true
WaitEvent.ACTIVITY -> { /* reset idle timer */
}
WaitEvent.ACTIVITY -> Unit
null -> {
timedOut.set(true)
@ -541,28 +537,6 @@ class BashTool(
}
}
private fun shouldAskForConfirmation(command: String): Boolean {
val askPatterns = listOf(
"find * -delete*",
"find * -exec*",
"find * -fprint*",
"find * -fls*",
"find * -fprintf*",
"find * -ok*",
"sort --output=*",
"sort -o *",
"tree -o *"
)
return askPatterns.any { pattern ->
if (pattern.contains("*")) {
command.startsWith(pattern.removeSuffix("*"))
} else {
command == pattern
}
}
}
private fun shouldBlockByIgnore(command: String): Boolean {
val readers = setOf(
"cat",
@ -607,8 +581,6 @@ class BashTool(
lastWasReader = tokenIsReader
}
val settingsService = project.service<ProxyAISettingsService>()
// TODO(PROXYAI-IGNORE): Replace deny-style bash path checks with visibility filtering.
// Bash output and directory listings should hide ignored paths instead of returning policy-denied.
return paths.any { candidate ->
settingsService.isPathIgnored(toAbsolute(candidate))
}

View file

@ -8,11 +8,15 @@ import com.intellij.openapi.components.service
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.text.StringUtil
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.openapi.vfs.VfsUtil
import com.intellij.psi.PsiDocumentManager
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
import ee.carlrobert.codegpt.settings.hooks.HookManager
import ee.carlrobert.codegpt.tokens.truncateToolResult
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import java.io.File
@ -122,43 +126,59 @@ class WriteTool(
)
}
val fileUrl = file.toURI().toString()
val virtualFile =
VirtualFileManager.getInstance().findFileByUrl(fileUrl)
?: if (isNewFile) {
VirtualFileManager.getInstance().refreshAndFindFileByUrl(fileUrl)
} else {
null
}
val bytesWritten = args.content.toByteArray(StandardCharsets.UTF_8).size
val projectToUse = project
val normalizedPath = filePath.replace("\\", "/")
val fileUrl = "file://$normalizedPath"
// Never do filesystem IO or synchronous VFS refreshes on the EDT.
// If we can resolve a VirtualFile+Document, use the Document API for proper IDE integration.
val vfm = VirtualFileManager.getInstance()
val virtualFile = if (!isNewFile) {
vfm.findFileByUrl(fileUrl) ?: vfm.refreshAndFindFileByUrl(fileUrl)
} else {
vfm.findFileByUrl(fileUrl)
}
if (virtualFile != null) {
val document =
val document = withContext(Dispatchers.Default) {
runReadAction { FileDocumentManager.getInstance().getDocument(virtualFile) }
runInEdt {
if (document != null) {
}
if (document != null) {
runInEdt {
runWriteAction {
PsiDocumentManager.getInstance(projectToUse).commitDocument(document)
document.setText(StringUtil.convertLineSeparators(args.content))
PsiDocumentManager.getInstance(projectToUse).commitDocument(document)
FileDocumentManager.getInstance().saveDocument(document)
}
} else {
runWriteAction {
virtualFile.setBinaryContent(args.content.toByteArray(StandardCharsets.UTF_8))
}
}
} else {
// For non-document-backed files, fall back to plain IO and refresh the specific VirtualFile.
withContext(Dispatchers.IO) {
file.writeBytes(args.content.toByteArray(StandardCharsets.UTF_8))
}
VfsUtil.markDirtyAndRefresh(true, false, false, virtualFile)
}
} else if (isNewFile) {
runInEdt {
runWriteAction {
file.writeText(
StringUtil.convertLineSeparators(args.content),
StandardCharsets.UTF_8
)
VirtualFileManager.getInstance().syncRefresh()
}
} else {
// Fallback for files not present in VFS (e.g., newly created or external paths).
withContext(Dispatchers.IO) {
file.writeText(
StringUtil.convertLineSeparators(args.content),
StandardCharsets.UTF_8
)
}
val lfs = LocalFileSystem.getInstance()
val parentVf = file.parentFile?.let { parent ->
lfs.findFileByIoFile(parent) ?: lfs.refreshAndFindFileByIoFile(parent)
}
if (parentVf != null) {
// Reload the directory children so the newly created file shows up.
VfsUtil.markDirtyAndRefresh(true, false, true, parentVf)
} else {
// As a last resort, refresh just the file path.
lfs.refreshAndFindFileByIoFile(file)
}
}

View file

@ -8,6 +8,7 @@ import com.intellij.diff.requests.SimpleDiffRequest
import com.intellij.diff.util.DiffUserDataKeys
import com.intellij.diff.util.DiffUserDataKeysEx
import com.intellij.icons.AllIcons
import com.intellij.openapi.Disposable
import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.actionSystem.Presentation
@ -34,7 +35,9 @@ import ee.carlrobert.codegpt.agent.tools.EditTool
import ee.carlrobert.codegpt.agent.tools.WriteTool
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.applyStringReplacement
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.getFileContentWithFallback
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Dispatchers
import java.awt.BorderLayout
import java.awt.FlowLayout
import java.awt.Insets
@ -54,8 +57,9 @@ import javax.swing.JPanel
*/
class AgentApprovalManager(
private val project: Project
) {
) : Disposable {
private val approvalPopups = ConcurrentHashMap<String, JBPopup>()
private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO)
companion object {
private val AGENT_DIFF_REQUEST_KEY: Key<String> = Key.create("agent.approval.diffRequest")
@ -96,7 +100,7 @@ class AgentApprovalManager(
args.filePath
}
ApplicationManager.getApplication().executeOnPooledThread {
backgroundScope.launch {
val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path)
val factory = DiffContentFactory.getInstance()
@ -340,4 +344,11 @@ class AgentApprovalManager(
}
}
}
override fun dispose() {
backgroundScope.dispose()
runInEdt {
approvalPopups.keys.toList().forEach(::closeDiffById)
}
}
}

View file

@ -13,6 +13,7 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.LocalFileSystem
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.agent.*
import ee.carlrobert.codegpt.agent.history.CheckpointRef
import ee.carlrobert.codegpt.agent.rollback.RollbackService
import ee.carlrobert.codegpt.agent.tools.*
import ee.carlrobert.codegpt.settings.service.ServiceType
@ -24,6 +25,7 @@ import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.*
import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatMessageResponseBody
import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatToolWindowScrollablePanel
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
import kotlinx.coroutines.*
import java.awt.Component
import java.util.*
@ -40,6 +42,8 @@ class AgentEventHandler(
private val userInputPanel: UserInputPanel,
private val onShowLoading: (String) -> Unit,
private val onHideLoading: () -> Unit,
private val onRunFinishedCallback: () -> Unit = {},
private val onRunCheckpointUpdatedCallback: (UUID, CheckpointRef?) -> Unit = { _, _ -> },
private val onQueuedMessagesResolved: (MessageWithContext) -> Unit = {}
) : AgentEvents, Disposable {
@ -67,6 +71,9 @@ class AgentEventHandler(
@Volatile
private var lastEditArgs: EditTool.Args? = null
@Volatile
private var currentRollbackRunId: String? = null
private val approvalQueue: ArrayDeque<ApprovalRequest> = ArrayDeque()
@Volatile
@ -84,7 +91,7 @@ class AgentEventHandler(
private var runViewHolder: RunViewHolder? = null
private val subagentViewHolders = ConcurrentHashMap<String, RunViewHolder>()
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
private val serviceScope = DisposableCoroutineScope(Dispatchers.Default)
data class ApprovalRequest(
var model: ToolApprovalRequest,
@ -105,6 +112,7 @@ class AgentEventHandler(
runViewHolder = null
subagentViewHolders.clear()
lastReportedPromptTokens = 0
currentRollbackRunId = null
}
private fun clearApprovalContainer() {
@ -162,6 +170,10 @@ class AgentEventHandler(
currentResponseBody = responseBody
}
fun setCurrentRollbackRunId(runId: String?) {
currentRollbackRunId = runId
}
override fun onAgentCompleted(agentId: String) {
runCatching {
project.service<AgentToolWindowContentManager>().getTabbedPane()
@ -226,8 +238,12 @@ class AgentEventHandler(
private fun handleDone() {
runInEdt {
project.service<RollbackService>().finishSession(sessionId)
currentRollbackRunId?.let { runId ->
project.service<RollbackService>().finishRun(runId)
} ?: project.service<RollbackService>().finishSession(sessionId)
currentRollbackRunId = null
currentResponseBody?.finishThinking()
onRunFinishedCallback()
onHideLoading()
userInputPanel.setStopEnabled(false)
scrollablePanel.update()
@ -483,6 +499,12 @@ class AgentEventHandler(
onQueuedMessagesResolved(pendingMessage)
}
override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {
runInEdt {
onRunCheckpointUpdatedCallback(runMessageId, ref)
}
}
override fun onTokenUsageAvailable(tokenUsage: Long) {
lastReportedPromptTokens = tokenUsage
@ -739,6 +761,8 @@ class AgentEventHandler(
}
override fun dispose() {
serviceScope.dispose()
agentApprovalManager.dispose()
mainToolCards.clear()
approvalQueue.clear()
subagentViewHolders.clear()
@ -752,16 +776,26 @@ class AgentEventHandler(
val documentText = vf?.let { file ->
runReadAction { FileDocumentManager.getInstance().getDocument(file)?.text }
}
documentText ?: java.io.File(normalizedPath).readText()
documentText ?: java.io.File(normalizedPath).readText(Charsets.UTF_8)
}.getOrNull() ?: ""
project.service<RollbackService>()
.trackEdit(sessionId, normalizedPath, originalContent)
val rollbackService = project.service<RollbackService>()
val runId = currentRollbackRunId
if (runId != null) {
rollbackService.trackEditForRun(runId, normalizedPath, originalContent)
} else {
rollbackService.trackEdit(sessionId, normalizedPath, originalContent)
}
}
private fun trackWriteOperation(args: WriteTool.Args) {
lastWriteArgs = args
val normalizedPath = args.filePath.replace("\\", "/")
project.service<RollbackService>()
.trackWrite(sessionId, normalizedPath, args)
val rollbackService = project.service<RollbackService>()
val runId = currentRollbackRunId
if (runId != null) {
rollbackService.trackWriteForRun(runId, normalizedPath)
} else {
rollbackService.trackWrite(sessionId, normalizedPath)
}
}
}

View file

@ -0,0 +1,9 @@
package ee.carlrobert.codegpt.toolwindow.agent
internal object AgentMessageText {
fun abbreviate(text: String, maxLength: Int): String {
val normalized = text.replace("\\s+".toRegex(), " ").trim()
if (normalized.length <= maxLength) return normalized
return normalized.take(maxLength - 1).trimEnd() + ""
}
}

View file

@ -0,0 +1,479 @@
package ee.carlrobert.codegpt.toolwindow.agent
import com.intellij.icons.AllIcons
import com.intellij.openapi.actionSystem.ActionManager
import com.intellij.openapi.actionSystem.ActionUpdateThread
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.actionSystem.DefaultActionGroup
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.project.DumbAwareAction
import com.intellij.ui.CheckboxTree
import com.intellij.ui.CheckedTreeNode
import com.intellij.ui.SimpleTextAttributes
import com.intellij.ui.components.JBLabel
import com.intellij.ui.components.JBScrollPane
import com.intellij.util.ui.JBUI
import java.awt.BorderLayout
import java.awt.Dimension
import java.awt.event.ActionEvent
import java.awt.event.KeyAdapter
import java.awt.event.KeyEvent
import java.awt.event.MouseAdapter
import java.awt.event.MouseEvent
import javax.swing.AbstractAction
import javax.swing.KeyStroke
import javax.swing.JPanel
import javax.swing.JTree
import javax.swing.ScrollPaneConstants
import javax.swing.SwingUtilities
import javax.swing.tree.DefaultMutableTreeNode
import javax.swing.tree.DefaultTreeModel
internal object AgentSessionTimelineDialogBuilder {
fun build(
points: List<RunTimelinePoint>,
onSelect: (RunTimelinePoint) -> Unit,
reloadPoints: (suspend () -> List<RunTimelinePoint>)?,
contextSelectionEnabled: Boolean,
onNonCopyMenuAction: (() -> Unit)?,
onSelectionStateChanged: (() -> Unit)?,
timelinePointDisplayLabel: (RunTimelinePoint) -> String,
canRollbackTimelinePoint: (RunTimelinePoint) -> Boolean,
rollbackToTimelinePoint: (RunTimelinePoint, String) -> Unit,
canCopyTimelinePointOutput: (RunTimelinePoint) -> Boolean,
copyTimelinePointOutput: (RunTimelinePoint, String) -> Unit,
launchBackground: (suspend () -> Unit) -> Unit
): TimelineDialogUi {
fun runNodeKey(runNumber: Int): String = "run:$runNumber"
fun checkpointNodeKey(point: RunTimelinePoint): String = "cp:${point.cacheKey}"
fun buildTreeRoot(
dataPoints: List<RunTimelinePoint>,
checkedNodeState: Map<String, Boolean>,
withCheckboxes: Boolean
): CheckedTreeNode {
val groups = buildTimelineRunGroups(dataPoints)
val root = CheckedTreeNode(TimelineTreeEntry("Session"))
groups.forEach { group ->
val toolCallSuffix = if (group.toolCallCount == 1) "tool call" else "tool calls"
val runEntry = TimelineTreeEntry(
title = "Run ${group.runNumber}",
subtitle = "${group.toolCallCount} $toolCallSuffix",
icon = AllIcons.Vcs.History,
runNumber = group.runNumber
)
val runNode: DefaultMutableTreeNode = if (withCheckboxes) {
CheckedTreeNode(runEntry).apply {
isChecked = checkedNodeState[runNodeKey(group.runNumber)] ?: true
}
} else {
DefaultMutableTreeNode(runEntry)
}
group.points.forEach { point ->
val pointEntry = TimelineTreeEntry(
title = point.title,
subtitle = point.subtitle.trim(),
icon = point.icon,
point = point
)
val pointNode: DefaultMutableTreeNode = if (withCheckboxes) {
val parentChecked = (runNode as? CheckedTreeNode)?.isChecked ?: true
CheckedTreeNode(pointEntry).apply {
isChecked = checkedNodeState[checkpointNodeKey(point)] ?: parentChecked
}
} else {
DefaultMutableTreeNode(pointEntry)
}
runNode.add(pointNode)
}
root.add(runNode)
}
return root
}
fun collectCheckedNonSystemMessageCounts(root: DefaultMutableTreeNode): Set<Int> {
val checkedNonSystemMessageCounts = mutableSetOf<Int>()
for (index in 0 until root.childCount) {
val runNode = root.getChildAt(index) as? DefaultMutableTreeNode ?: continue
for (childIndex in 0 until runNode.childCount) {
val childNode =
runNode.getChildAt(childIndex) as? DefaultMutableTreeNode ?: continue
val childEntry = childNode.userObject as? TimelineTreeEntry ?: continue
val point = childEntry.point ?: continue
val nonSystemMessageCount = point.nonSystemMessageCount ?: continue
if (childNode is CheckedTreeNode && childNode.isChecked) {
checkedNonSystemMessageCounts.add(nonSystemMessageCount)
}
}
}
return checkedNonSystemMessageCounts
}
fun captureCheckedRunState(root: DefaultMutableTreeNode): Map<String, Boolean> {
val state = mutableMapOf<String, Boolean>()
for (index in 0 until root.childCount) {
val runNode = root.getChildAt(index) as? DefaultMutableTreeNode ?: continue
val runEntry = runNode.userObject as? TimelineTreeEntry
val runNumber = runEntry?.runNumber
if (runNode is CheckedTreeNode && runNumber != null) {
state[runNodeKey(runNumber)] = runNode.isChecked
}
for (childIndex in 0 until runNode.childCount) {
val childNode =
runNode.getChildAt(childIndex) as? DefaultMutableTreeNode ?: continue
val childEntry = childNode.userObject as? TimelineTreeEntry ?: continue
val point = childEntry.point ?: continue
if (childNode is CheckedTreeNode) {
state[checkpointNodeKey(point)] = childNode.isChecked
}
}
}
return state
}
fun ensureCheckedStateDefaults(
dataPoints: List<RunTimelinePoint>,
checkedNodeState: MutableMap<String, Boolean>
) {
if (!contextSelectionEnabled) return
buildTimelineRunGroups(dataPoints).forEach { group ->
checkedNodeState.putIfAbsent(runNodeKey(group.runNumber), true)
}
dataPoints.forEach { point ->
checkedNodeState.putIfAbsent(checkpointNodeKey(point), true)
}
}
fun expandAllRows(tree: JTree) {
var row = 0
while (row < tree.rowCount) {
tree.expandRow(row)
row++
}
}
var latestPoints = points
var editMode = false
val groups = buildTimelineRunGroups(latestPoints)
val checkedNodeState = if (contextSelectionEnabled) {
buildMap {
groups.forEach { group -> put(runNodeKey(group.runNumber), true) }
latestPoints.forEach { point -> put(checkpointNodeKey(point), true) }
}.toMutableMap()
} else {
mutableMapOf()
}
ensureCheckedStateDefaults(latestPoints, checkedNodeState)
fun headerText(): String {
return when {
!contextSelectionEnabled ->
"Choose a point in this session timeline."
editMode ->
"Right-click row for rollback/new/copy. Check runs or checkpoints to keep in context."
else ->
"Right-click row for rollback/new/copy. Click Edit to modify current context."
}
}
val header = JBLabel(headerText()).apply {
foreground = JBUI.CurrentTheme.Label.disabledForeground()
font = JBUI.Fonts.smallFont()
border = JBUI.Borders.emptyBottom(5)
}
lateinit var tree: CheckboxTree
fun captureCurrentCheckedState() {
if (!contextSelectionEnabled || !editMode) return
val currentRoot = tree.model.root as? DefaultMutableTreeNode ?: return
checkedNodeState.clear()
checkedNodeState.putAll(captureCheckedRunState(currentRoot))
ensureCheckedStateDefaults(latestPoints, checkedNodeState)
}
fun rebuildTree(
dataPoints: List<RunTimelinePoint>,
captureCurrentSelectionState: Boolean
) {
if (captureCurrentSelectionState) {
captureCurrentCheckedState()
}
latestPoints = dataPoints
ensureCheckedStateDefaults(latestPoints, checkedNodeState)
tree.model = DefaultTreeModel(
buildTreeRoot(
dataPoints = latestPoints,
checkedNodeState = checkedNodeState,
withCheckboxes = contextSelectionEnabled && editMode
)
)
expandAllRows(tree)
tree.revalidate()
tree.repaint()
if (contextSelectionEnabled && editMode) {
onSelectionStateChanged?.invoke()
}
}
tree = object : CheckboxTree(
object : CheckboxTree.CheckboxTreeCellRenderer() {
override fun customizeRenderer(
tree: JTree,
value: Any?,
selected: Boolean,
expanded: Boolean,
leaf: Boolean,
row: Int,
hasFocus: Boolean
) {
val node = value as? DefaultMutableTreeNode ?: return
val entry = node.userObject as? TimelineTreeEntry ?: return
textRenderer.icon = entry.icon
if (entry.point == null) {
textRenderer.append(
entry.title,
SimpleTextAttributes.REGULAR_BOLD_ATTRIBUTES
)
toolTipText = null
if (entry.subtitle.isNotBlank()) {
textRenderer.append(
" ${entry.subtitle}",
SimpleTextAttributes.GRAYED_ATTRIBUTES
)
}
} else {
textRenderer.append(entry.title, SimpleTextAttributes.REGULAR_ATTRIBUTES)
if (entry.subtitle.isNotBlank()) {
textRenderer.append(
" ${entry.subtitle}",
SimpleTextAttributes.GRAYED_ATTRIBUTES
)
toolTipText = entry.subtitle
} else {
toolTipText = null
}
}
}
},
buildTreeRoot(
dataPoints = latestPoints,
checkedNodeState = checkedNodeState,
withCheckboxes = contextSelectionEnabled && editMode
)
) {}.apply {
isRootVisible = false
showsRootHandles = true
rowHeight = 0
border = JBUI.Borders.empty(2, 0)
expandAllRows(this)
fun selectCurrentTimelinePoint() {
val selectedNode = lastSelectedPathComponent as? DefaultMutableTreeNode ?: return
val selectedEntry = selectedNode.userObject as? TimelineTreeEntry ?: return
val point = selectedEntry.point ?: return
onSelect(point)
}
addMouseListener(object : MouseAdapter() {
override fun mouseClicked(e: MouseEvent) {
if (e.button == MouseEvent.BUTTON1 && e.clickCount >= 2) {
selectCurrentTimelinePoint()
}
}
override fun mousePressed(e: MouseEvent) {
maybeShowTimelineContextMenu(e)
}
override fun mouseReleased(e: MouseEvent) {
maybeShowTimelineContextMenu(e)
if (contextSelectionEnabled && editMode && !e.isPopupTrigger && e.button == MouseEvent.BUTTON1) {
SwingUtilities.invokeLater {
onSelectionStateChanged?.invoke()
}
}
}
private fun maybeShowTimelineContextMenu(e: MouseEvent) {
if (!e.isPopupTrigger && e.button != MouseEvent.BUTTON3) return
val path = getPathForLocation(e.x, e.y) ?: return
selectionPath = path
val selectedNode = path.lastPathComponent as? DefaultMutableTreeNode ?: return
val selectedEntry = selectedNode.userObject as? TimelineTreeEntry ?: return
val point = selectedEntry.point
val selectedLabel =
point?.let { timelinePointDisplayLabel(it) } ?: selectedEntry.title
val group = DefaultActionGroup(
object : DumbAwareAction(
"Rollback",
"Rollback to checkpoint",
AllIcons.Actions.Undo
) {
override fun actionPerformed(event: AnActionEvent) {
val timelinePoint = point ?: return
onNonCopyMenuAction?.invoke()
rollbackToTimelinePoint(timelinePoint, selectedLabel)
}
override fun update(event: AnActionEvent) {
event.presentation.isEnabled =
point != null && canRollbackTimelinePoint(point)
}
override fun getActionUpdateThread(): ActionUpdateThread =
ActionUpdateThread.EDT
},
object : DumbAwareAction(
"Continue From New Session",
"Start new session from given checkpoint",
AllIcons.General.Add
) {
override fun actionPerformed(event: AnActionEvent) {
val timelinePoint = point ?: return
onNonCopyMenuAction?.invoke()
onSelect(timelinePoint)
}
override fun update(event: AnActionEvent) {
event.presentation.isEnabled = point != null
}
override fun getActionUpdateThread(): ActionUpdateThread =
ActionUpdateThread.EDT
},
object : DumbAwareAction(
"Copy Output",
"Copies the Tool or Assistant's output",
AllIcons.General.Copy
) {
override fun actionPerformed(event: AnActionEvent) {
val timelinePoint = point ?: return
copyTimelinePointOutput(timelinePoint, selectedLabel)
}
override fun update(event: AnActionEvent) {
event.presentation.isEnabled =
point != null && canCopyTimelinePointOutput(point)
}
override fun getActionUpdateThread(): ActionUpdateThread =
ActionUpdateThread.EDT
}
)
ActionManager.getInstance()
.createActionPopupMenu("AgentTimeline.CheckpointMenu", group)
.component
.show(e.component, e.x, e.y)
}
})
addKeyListener(object : KeyAdapter() {
override fun keyReleased(e: KeyEvent) {
if (contextSelectionEnabled && editMode && e.keyCode == KeyEvent.VK_SPACE) {
onSelectionStateChanged?.invoke()
}
}
})
inputMap.put(KeyStroke.getKeyStroke("ENTER"), "timeline.select")
actionMap.put("timeline.select", object : AbstractAction() {
override fun actionPerformed(e: ActionEvent?) {
selectCurrentTimelinePoint()
}
})
}
val scrollPane = JBScrollPane(tree).apply {
border = JBUI.Borders.empty()
horizontalScrollBarPolicy = ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER
verticalScrollBar.unitIncrement = 16
}
SwingUtilities.invokeLater {
if (tree.rowCount > 0) {
tree.scrollRowToVisible(tree.rowCount - 1)
scrollPane.verticalScrollBar.value = scrollPane.verticalScrollBar.maximum
}
}
val popupWidth = JBUI.scale(620)
val minPopupWidth = JBUI.scale(560)
val maxPopupWidth = JBUI.scale(760)
val popupHeight =
JBUI.scale(computeTimelineDialogHeight(groups.sumOf { it.points.size + 1 }))
val container = JPanel(BorderLayout()).apply {
isOpaque = false
border = JBUI.Borders.empty()
preferredSize = Dimension(popupWidth, popupHeight)
minimumSize = Dimension(minPopupWidth, popupHeight)
maximumSize = Dimension(maxPopupWidth, popupHeight)
add(header, BorderLayout.NORTH)
add(scrollPane, BorderLayout.CENTER)
}
return TimelineDialogUi(
component = container,
getSelectedPoint = {
val node = tree.lastSelectedPathComponent as? DefaultMutableTreeNode
val entry = node?.userObject as? TimelineTreeEntry
entry?.point
},
getCheckedNonSystemMessageCounts = {
val root = tree.model.root as? DefaultMutableTreeNode
if (root == null) emptySet() else collectCheckedNonSystemMessageCounts(root)
},
refresh = {
reloadPoints?.let { loader ->
launchBackground {
val refreshedPoints = loader()
runInEdt {
rebuildTree(refreshedPoints, captureCurrentSelectionState = true)
}
}
}
},
setEditMode = { value ->
if (contextSelectionEnabled && editMode != value) {
if (editMode) captureCurrentCheckedState()
editMode = value
header.text = headerText()
rebuildTree(latestPoints, captureCurrentSelectionState = false)
}
},
isEditMode = { editMode }
)
}
private fun buildTimelineRunGroups(points: List<RunTimelinePoint>): List<TimelineRunGroup> {
if (points.isEmpty()) return emptyList()
val grouped = linkedMapOf<Int, MutableList<RunTimelinePoint>>()
points.forEach { point ->
val key = if (point.runIndex <= 0) 1 else point.runIndex
grouped.getOrPut(key) { mutableListOf() }.add(point)
}
return grouped.values.mapIndexed { index, runPoints ->
TimelineRunGroup(
runNumber = index + 1,
points = runPoints,
toolCallCount = runPoints.count { it.kind == TimelinePointKind.TOOL_CALL }
)
}
}
private fun computeTimelineDialogHeight(estimatedRowCount: Int): Int {
val visibleRows = estimatedRowCount.coerceIn(4, 16)
val rowsHeight = visibleRows * 26
return (rowsHeight + 44).coerceIn(160, 460)
}
}

View file

@ -0,0 +1,328 @@
package ee.carlrobert.codegpt.toolwindow.agent
import ai.koog.agents.snapshot.feature.AgentCheckpointData
import ai.koog.prompt.message.Message as PromptMessage
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.LocalFileSystem
import ee.carlrobert.codegpt.agent.ToolName
import ee.carlrobert.codegpt.agent.ToolSpecs
import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService
import ee.carlrobert.codegpt.agent.tools.EditTool
import ee.carlrobert.codegpt.agent.tools.ReadTool
import ee.carlrobert.codegpt.agent.tools.WriteTool
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.booleanOrNull
import kotlinx.serialization.json.jsonObject
import java.io.File
import java.nio.file.Paths
import java.util.ArrayDeque
internal class AgentSessionTimelineHistoricalRollbackSupport(
private val project: Project,
private val historyService: AgentCheckpointHistoryService,
private val replayJson: Json
) {
suspend fun collectOperations(point: RunTimelinePoint): List<HistoricalRollbackOperation> {
val checkpointRef = point.checkpointRef ?: return emptyList()
val latestCheckpoint: AgentCheckpointData =
historyService.listCheckpoints(checkpointRef.agentId).firstOrNull()
?: historyService.loadLatestResumeCheckpoint(checkpointRef.agentId)
?: historyService.loadCheckpoint(checkpointRef)
?: return emptyList()
val cutoff = point.nonSystemMessageCount ?: return emptyList()
val history = latestCheckpoint.messageHistory.filterNot { it is PromptMessage.System }
if (history.isEmpty() || cutoff >= history.size) return emptyList()
data class PendingCall(val index: Int, val call: PromptMessage.Tool.Call)
val pendingById = mutableMapOf<String, PendingCall>()
val pendingWithoutId = ArrayDeque<PendingCall>()
val latestKnownContentByFile = mutableMapOf<String, String>()
val operations = mutableListOf<HistoricalRollbackOperation>()
history.forEachIndexed { index, message ->
when (message) {
is PromptMessage.Tool.Call -> {
val callId = message.id?.takeIf { it.isNotBlank() }
if (callId != null) {
pendingById[callId] = PendingCall(index, message)
} else {
pendingWithoutId.addLast(PendingCall(index, message))
}
}
is PromptMessage.Tool.Result -> {
val pendingCall = message.id
?.takeIf { it.isNotBlank() }
?.let { pendingById.remove(it) }
?: pendingWithoutId.pollFirst()
?: return@forEachIndexed
val call = pendingCall.call
val tool = HistoricalRollbackCompatibility.resolveSupportedTool(call.tool)
?: return@forEachIndexed
if (!HistoricalRollbackCompatibility.isSuccessfulResult(tool, message.content, replayJson)) {
return@forEachIndexed
}
when (tool) {
ToolName.READ -> {
val args = decodeReadArgs(call.tool, call.content) ?: return@forEachIndexed
val filePath = normalizeToolFilePath(args.filePath)
?: return@forEachIndexed
val content =
decodeReadToolResultContent(message.content) ?: return@forEachIndexed
latestKnownContentByFile[filePath] = content
return@forEachIndexed
}
ToolName.EDIT -> {
val args = decodeEditArgs(call.tool, call.content) ?: return@forEachIndexed
val filePath = normalizeToolFilePath(args.filePath)
?: return@forEachIndexed
val oldString = args.oldString
val newString = args.newString
val replaceAll = args.replaceAll
if (oldString.isEmpty() || newString.isEmpty() || oldString == newString) return@forEachIndexed
if (pendingCall.index >= cutoff) {
operations.add(
HistoricalRollbackOperation(
filePath = filePath,
searchText = newString,
replacementText = oldString,
replaceAll = replaceAll,
sourceTool = HistoricalRollbackSourceTool.EDIT
)
)
}
latestKnownContentByFile[filePath]?.let { known ->
if (known.contains(oldString)) {
latestKnownContentByFile[filePath] =
if (replaceAll) known.replace(oldString, newString)
else known.replaceFirst(oldString, newString)
}
}
return@forEachIndexed
}
ToolName.WRITE -> {
val args = decodeWriteArgs(call.tool, call.content) ?: return@forEachIndexed
val filePath = normalizeToolFilePath(args.filePath)
?: return@forEachIndexed
val newContent = args.content
val previousContent = latestKnownContentByFile[filePath]
if (pendingCall.index >= cutoff && previousContent != null && previousContent != newContent) {
operations.add(
HistoricalRollbackOperation(
filePath = filePath,
searchText = newContent,
replacementText = previousContent,
replaceAll = false,
sourceTool = HistoricalRollbackSourceTool.WRITE
)
)
}
latestKnownContentByFile[filePath] = newContent
}
else -> Unit
}
}
else -> Unit
}
}
return operations
}
fun buildConfirmationText(
selectedLabel: String,
operations: List<HistoricalRollbackOperation>
): String {
val ordered = operations.asReversed()
val previewLimit = 12
val listed = ordered.take(previewLimit).joinToString(separator = "\n") { operation ->
"${operation.sourceTool.symbol} ${toProjectRelativePath(operation.filePath)}"
}
val remaining = ordered.size - previewLimit
val suffix = if (remaining > 0) "\n...and $remaining more operation(s)." else ""
return """
This rollback will return the session to:
$selectedLabel
It will also replay ${operations.size} file operation(s) in reverse order:
$listed$suffix
""".trimIndent()
}
fun applyOperations(operations: List<HistoricalRollbackOperation>): List<String> {
val errors = mutableListOf<String>()
operations.asReversed().forEach { operation ->
val filePath = operation.filePath
val currentText = readFileText(filePath)
if (currentText == null) {
errors.add("File not found: ${toProjectRelativePath(filePath)}")
return@forEach
}
if (!currentText.contains(operation.searchText)) {
errors.add(
"Expected content not found in ${toProjectRelativePath(filePath)} for ${operation.sourceTool.displayName} rollback."
)
return@forEach
}
val updatedText = if (operation.replaceAll) {
currentText.replace(operation.searchText, operation.replacementText)
} else {
currentText.replaceFirst(operation.searchText, operation.replacementText)
}
if (updatedText == currentText) {
errors.add("No changes applied to ${toProjectRelativePath(filePath)}")
return@forEach
}
val writeOk = writeFileText(filePath, updatedText)
if (!writeOk) {
errors.add("Failed to write ${toProjectRelativePath(filePath)}")
}
}
return errors
}
private fun parseToolArgs(rawArgs: String): Map<String, JsonElement>? {
return runCatching { replayJson.parseToJsonElement(rawArgs).jsonObject }.getOrNull()
}
private fun decodeReadArgs(rawToolName: String, rawArgs: String): ReadTool.Args? {
val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? ReadTool.Args
if (typed != null) return typed
val args = parseToolArgs(rawArgs) ?: return null
val filePath = stringValue(args["file_path"])
?: stringValue(args["path"])
?: stringValue(args["pathInProject"])
?: return null
return ReadTool.Args(filePath = filePath)
}
private fun decodeEditArgs(rawToolName: String, rawArgs: String): EditTool.Args? {
val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? EditTool.Args
if (typed != null) return typed
val args = parseToolArgs(rawArgs) ?: return null
val filePath = stringValue(args["file_path"]) ?: return null
val oldString = stringValue(args["old_string"]) ?: return null
val newString = stringValue(args["new_string"]) ?: return null
val shortDescription = stringValue(args["short_description"]) ?: "Recovered historical edit"
val replaceAll = booleanValue(args["replace_all"]) ?: false
return EditTool.Args(
filePath = filePath,
oldString = oldString,
newString = newString,
shortDescription = shortDescription,
replaceAll = replaceAll
)
}
private fun decodeWriteArgs(rawToolName: String, rawArgs: String): WriteTool.Args? {
val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? WriteTool.Args
if (typed != null) return typed
val args = parseToolArgs(rawArgs) ?: return null
val filePath = stringValue(args["file_path"]) ?: return null
val content = stringValue(args["content"]) ?: return null
return WriteTool.Args(filePath = filePath, content = content)
}
private fun booleanValue(element: JsonElement?): Boolean? {
val primitive = element as? JsonPrimitive ?: return null
return if (primitive.isString) primitive.content.toBooleanStrictOrNull() else primitive.booleanOrNull
}
private fun stringValue(element: JsonElement?): String? {
if (element == null) return null
val primitive = element as? JsonPrimitive
return if (primitive != null && primitive.isString) primitive.content else element.toString()
}
private fun normalizeToolFilePath(rawPath: String?): String? {
val trimmed = rawPath?.trim()?.takeIf { it.isNotEmpty() } ?: return null
val normalized = trimmed.replace("\\", "/")
val file = File(normalized)
if (file.isAbsolute) {
return file.toPath().normalize().toString().replace("\\", "/")
}
val basePath = project.basePath ?: return file.absolutePath.replace("\\", "/")
return Paths.get(basePath).resolve(normalized).normalize().toString().replace("\\", "/")
}
private fun decodeReadToolResultContent(content: String): String? {
if (content.isBlank()) return ""
val numberedLines = content.lineSequence().mapNotNull { line ->
val tabIndex = line.indexOf('\t')
if (tabIndex <= 0) return@mapNotNull null
val prefix = line.substring(0, tabIndex)
if (!prefix.all { it.isDigit() }) return@mapNotNull null
line.substring(tabIndex + 1)
}.toList()
if (numberedLines.isNotEmpty()) return numberedLines.joinToString(separator = "\n")
if (content.startsWith("Error reading file", ignoreCase = true)) return null
return content
}
private fun readFileText(path: String): String? {
val virtualFile =
LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return null
val documentText = runReadAction {
FileDocumentManager.getInstance().getDocument(virtualFile)?.text
}
if (documentText != null) return documentText
return runCatching { String(virtualFile.contentsToByteArray(), Charsets.UTF_8) }.getOrNull()
}
private fun writeFileText(path: String, content: String): Boolean {
val virtualFile =
LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return false
val document = runReadAction { FileDocumentManager.getInstance().getDocument(virtualFile) }
return runCatching {
runInEdt {
runWriteAction {
if (document != null) {
document.setText(content)
FileDocumentManager.getInstance().saveDocument(document)
} else {
virtualFile.setBinaryContent(content.toByteArray(Charsets.UTF_8))
}
}
}
}.isSuccess
}
private fun toProjectRelativePath(path: String): String {
val basePath = project.basePath ?: return path
return runCatching {
val absolute = Paths.get(path).normalize()
val base = Paths.get(basePath).normalize()
if (absolute.startsWith(base)) {
base.relativize(absolute).toString().replace("\\", "/")
} else {
path
}
}.getOrDefault(path)
}
}

View file

@ -0,0 +1,70 @@
package ee.carlrobert.codegpt.toolwindow.agent
import ee.carlrobert.codegpt.agent.history.CheckpointRef
import javax.swing.Icon
import javax.swing.JComponent
import ai.koog.prompt.message.Message as PromptMessage
internal data class RunTimelinePoint(
val cacheKey: String,
val checkpointRef: CheckpointRef?,
val title: String,
val subtitle: String,
val runLabel: String? = null,
val icon: Icon? = null,
val nonSystemMessageCount: Int? = null,
val outputText: String? = null,
val toolCallId: String? = null,
val kind: TimelinePointKind = TimelinePointKind.ASSISTANT,
val runIndex: Int = 0
)
internal enum class TimelinePointKind {
USER,
ASSISTANT,
REASONING,
TOOL_CALL
}
internal data class TimelineRunGroup(
val runNumber: Int,
val points: List<RunTimelinePoint>,
val toolCallCount: Int
)
internal data class TimelineTreeEntry(
val title: String,
val subtitle: String = "",
val icon: Icon? = null,
val point: RunTimelinePoint? = null,
val runNumber: Int? = null
)
internal data class HistoricalRollbackOperation(
val filePath: String,
val searchText: String,
val replacementText: String,
val replaceAll: Boolean,
val sourceTool: HistoricalRollbackSourceTool
)
internal enum class HistoricalRollbackSourceTool(
val displayName: String,
val symbol: String
) {
EDIT(displayName = "Edit", symbol = "E"),
WRITE(displayName = "Write", symbol = "W")
}
internal data class SessionContextSnapshot(
val baseHistory: List<PromptMessage>
)
internal data class TimelineDialogUi(
val component: JComponent,
val getSelectedPoint: () -> RunTimelinePoint?,
val getCheckedNonSystemMessageCounts: () -> Set<Int>,
val refresh: () -> Unit,
val setEditMode: (Boolean) -> Unit,
val isEditMode: () -> Boolean
)

View file

@ -2,21 +2,23 @@ package ee.carlrobert.codegpt.toolwindow.agent
import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.EDT
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Disposer
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.ui.AnimatedIcon
import com.intellij.ui.JBColor
import com.intellij.ui.components.JBLabel
import com.intellij.util.ui.JBUI
import com.intellij.util.ui.components.BorderLayoutPanel
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.agent.AgentService
import ee.carlrobert.codegpt.agent.AgentToolOutputNotifier
import ee.carlrobert.codegpt.agent.MessageWithContext
import ee.carlrobert.codegpt.agent.ToolSpecs
import ee.carlrobert.codegpt.agent.ToolRunContext
import ee.carlrobert.codegpt.agent.*
import ee.carlrobert.codegpt.agent.ProxyAIAgent.loadProjectInstructions
import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService
import ee.carlrobert.codegpt.agent.history.AgentCheckpointTurnSequencer
import ee.carlrobert.codegpt.agent.history.CheckpointRef
import ee.carlrobert.codegpt.agent.rollback.RollbackService
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.conversations.message.Message
@ -42,9 +44,12 @@ import ee.carlrobert.codegpt.ui.queue.QueuedMessagePanel
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
import ee.carlrobert.codegpt.util.EditorUtil
import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks
import ee.carlrobert.codegpt.util.coroutines.CoroutineDispatchers
import kotlinx.coroutines.launch
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
import kotlinx.coroutines.*
import kotlinx.serialization.json.Json
import java.util.*
import javax.swing.Box
import javax.swing.BoxLayout
import javax.swing.JComponent
@ -54,6 +59,10 @@ class AgentToolWindowTabPanel(
private val project: Project,
private val agentSession: AgentSession
) : BorderLayoutPanel(), Disposable {
companion object {
private const val RECOVERED_CONVERSATION_RENDER_BATCH_SIZE = 6
}
private val replayJson = Json {
ignoreUnknownKeys = true
isLenient = true
@ -63,6 +72,7 @@ class AgentToolWindowTabPanel(
private val scrollablePanel = ChatToolWindowScrollablePanel()
private val tagManager = TagManager()
private val dispatchers = CoroutineDispatchers()
private val backgroundScope = DisposableCoroutineScope(dispatchers.io())
private val sessionId = agentSession.sessionId
private val conversation = agentSession.conversation
private val psiRepository = PsiStructureRepository(
@ -107,18 +117,43 @@ class AgentToolWindowTabPanel(
this,
FeatureType.AGENT,
tagManager,
onSubmit = { text -> handleSubmit(text) },
onStop = { handleCancel() },
onSubmit = ::handleSubmit,
onStop = ::handleCancel,
withRemovableSelectedEditorTag = true,
agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId),
sessionIdProvider = { sessionId },
conversationIdProvider = { conversation.id }
conversationIdProvider = { conversation.id },
onStartSessionTimeline = ::showSessionStartTimelineDialog
)
private var rollbackPanel: RollbackPanel
private val todoListPanel = TodoListPanel()
private val projectMessageBusConnection = project.messageBus.connect()
private val appMessageBusConnection = ApplicationManager.getApplication().messageBus.connect()
private val rollbackService = RollbackService.getInstance(project)
private val historyService = project.service<AgentCheckpointHistoryService>()
private data class RunCardState(
val runMessageId: UUID,
val rollbackRunId: String,
var responsePanel: ResponseMessagePanel,
var sourceMessage: Message? = null,
var completed: Boolean = false
)
// Insertion order matters: timeline run numbering depends on the message order.
private val runCardsByMessageId = linkedMapOf<UUID, RunCardState>()
private var activeRunMessageId: UUID? = null
private var recoveredConversationJob: Job? = null
private var activeLandingPanel: AgentToolWindowLandingPanel? = null
private val timelineController = AgentSessionTimelineController(
project = project,
agentSession = agentSession,
conversation = conversation,
runStateForRunIndex = ::runStateForRunIndex,
applySeededSessionState = ::applySeededSessionState,
onAfterRollbackRefresh = ::refreshViewAfterRollback
)
private val eventHandler = AgentEventHandler(
project = project,
@ -140,6 +175,12 @@ class AgentToolWindowTabPanel(
repaint()
rollbackPanel.refreshOperations()
},
onRunFinishedCallback = {
markActiveRunCompleted()
},
onRunCheckpointUpdatedCallback = { runMessageId, ref ->
updateRunCheckpoint(runMessageId, ref)
},
onQueuedMessagesResolved = { message ->
runInEdt {
clearQueuedMessagesAndCreateNewResponse(
@ -163,12 +204,15 @@ class AgentToolWindowTabPanel(
}
userInputPanel.setStopEnabled(false)
Disposer.register(this, rollbackPanel)
Disposer.register(this, eventHandler)
Disposer.register(this, timelineController)
Disposer.register(this, backgroundScope)
}
private fun setupMessageBusSubscriptions() {
project.service<AgentService>().queuedMessageProcessed.let { flow ->
kotlinx.coroutines.CoroutineScope(dispatchers.io()).launch {
backgroundScope.launch {
flow.collect { processedMessage ->
ApplicationManager.getApplication().invokeLater {
removeQueuedMessage(processedMessage)
@ -230,6 +274,7 @@ class AgentToolWindowTabPanel(
private fun handleSubmit(text: String) {
if (text.isBlank()) return
disposeLandingPanelIfPresent()
scrollablePanel.clearLandingViewIfVisible()
agentSession.serviceType =
ModelSelectionService.getInstance().getServiceForFeature(FeatureType.AGENT)
@ -255,7 +300,7 @@ class AgentToolWindowTabPanel(
.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.RUNNING)
}
rollbackService.startSession(sessionId)
val rollbackRunId = rollbackService.startSession(sessionId)
rollbackPanel.refreshOperations()
val message = MessageWithContext(text, userInputPanel.getSelectedTags())
@ -282,8 +327,16 @@ class AgentToolWindowTabPanel(
messagePanel.add(responsePanel)
scrollablePanel.update()
registerRunCard(
runMessageId = message.id,
rollbackRunId = rollbackRunId,
responsePanel = responsePanel,
prompt = text
)
eventHandler.resetForNewSubmission()
eventHandler.setCurrentResponseBody(responseBody)
eventHandler.setCurrentRollbackRunId(rollbackRunId)
loadingLabel.text = CodeGPTBundle.get("toolwindow.chat.loading")
loadingLabel.isVisible = true
@ -300,7 +353,15 @@ class AgentToolWindowTabPanel(
agentService.cancelCurrentRun(sessionId)
agentService.clearPendingMessages(sessionId)
rollbackService.finishSession(sessionId)
val activeRun = activeRunMessageId?.let { runCardsByMessageId[it] }
if (activeRun != null) {
rollbackService.finishRun(activeRun.rollbackRunId)
runCardsByMessageId.remove(activeRun.runMessageId)
activeRunMessageId = null
} else {
rollbackService.finishSession(sessionId)
}
eventHandler.setCurrentRollbackRunId(null)
rollbackPanel.refreshOperations()
approvalContainer.removeAll()
@ -316,37 +377,182 @@ class AgentToolWindowTabPanel(
}
private fun displayLandingView() {
scrollablePanel.displayLandingView(createLandingView())
disposeLandingPanelIfPresent()
val landingPanel = createLandingView()
activeLandingPanel = landingPanel
scrollablePanel.displayLandingView(landingPanel)
}
private fun displayRecoveredConversation() {
disposeLandingPanelIfPresent()
scrollablePanel.clearAll()
conversation.messages.forEach { message ->
val prompt = message.prompt.orEmpty()
val wrapper = scrollablePanel.addMessage(message.id)
val userPanel = UserMessagePanel(project, message, this)
userPanel.addCopyAction { CopyAction.copyToClipboard(prompt) }
wrapper.add(userPanel)
recoveredConversationJob?.cancel()
recoveredConversationJob = backgroundScope.launch {
val recoveredTurns =
runCatching { loadRecoveredTurnsFromResumeCheckpoint() }.getOrNull()
renderRecoveredConversation(recoveredTurns)
}
}
val responseBody = ChatMessageResponseBody(
project,
false,
false,
false,
false,
false,
this
)
addRecoveredToolCards(responseBody, message)
responseBody.withResponse(message.response.orEmpty())
val responsePanel = ResponseMessagePanel().apply {
setResponseContent(responseBody)
private suspend fun renderRecoveredConversation(
recoveredTurns: List<AgentCheckpointTurnSequencer.Turn>?
) {
withContext(Dispatchers.EDT) {
if (!isActive || project.isDisposed) return@withContext
val canRenderInOrder = recoveredTurns != null &&
recoveredTurns.size == conversation.messages.size &&
recoveredTurns.indices.all { index ->
recoveredTurns[index].prompt == conversation.messages[index].prompt.orEmpty()
.trim()
}
val messages = conversation.messages.toList()
var nextIndex = 0
while (nextIndex < messages.size) {
if (!isActive || project.isDisposed) return@withContext
val batchEnd = minOf(
nextIndex + RECOVERED_CONVERSATION_RENDER_BATCH_SIZE,
messages.size
)
while (nextIndex < batchEnd) {
if (!isActive || project.isDisposed) return@withContext
val index = nextIndex
val message = messages[index]
val prompt = message.prompt.orEmpty()
val wrapper = scrollablePanel.addMessage(message.id)
val userPanel = UserMessagePanel(project, message, this@AgentToolWindowTabPanel)
userPanel.addCopyAction { CopyAction.copyToClipboard(prompt) }
wrapper.add(userPanel)
val responseBody = ChatMessageResponseBody(
project,
false,
false,
false,
false,
false,
this@AgentToolWindowTabPanel
)
val renderedInOrder = if (canRenderInOrder) {
renderRecoveredTurnInOrder(responseBody, recoveredTurns[index].events)
} else {
false
}
if (!renderedInOrder) {
addRecoveredToolCards(responseBody, message)
responseBody.withResponse(message.response.orEmpty().stripThinkingBlocks())
}
val responsePanel = ResponseMessagePanel().apply {
setResponseContent(responseBody)
}
wrapper.add(responsePanel)
registerRecoveredRunCard(message, responsePanel)
nextIndex += 1
}
scrollablePanel.update()
if (nextIndex < messages.size) {
yield()
}
}
wrapper.add(responsePanel)
scrollablePanel.scrollToBottom()
}
}
private suspend fun loadRecoveredTurnsFromResumeCheckpoint(): List<AgentCheckpointTurnSequencer.Turn>? {
val resumeRef = agentSession.resumeCheckpointRef ?: return null
val checkpoint = historyService.loadCheckpoint(resumeRef)
?: historyService.loadResumeCheckpoint(resumeRef)
?: return null
val projectInstructions = loadProjectInstructions(project.basePath)
return AgentCheckpointTurnSequencer.toVisibleTurns(
history = checkpoint.messageHistory,
projectInstructions = projectInstructions,
preserveSyntheticContinuation = true
)
}
private fun renderRecoveredTurnInOrder(
responseBody: ChatMessageResponseBody,
events: List<AgentCheckpointTurnSequencer.TurnEvent>
): Boolean {
if (events.isEmpty()) {
return false
}
scrollablePanel.update()
scrollablePanel.scrollToBottom()
val pendingById = mutableMapOf<String, ToolCallCard>()
val pendingWithoutId = ArrayDeque<ToolCallCard>()
var rendered = false
events.forEach { event ->
when (event) {
is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> {
val text = event.content.stripThinkingBlocks()
if (text.isNotBlank()) {
responseBody.withResponse(text)
rendered = true
}
}
is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> {
val text = event.content.stripThinkingBlocks()
if (text.isNotBlank()) {
responseBody.withResponse(text)
rendered = true
}
}
is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> {
val toolName = event.tool.ifBlank { "Tool" }
if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) {
return@forEach
}
val rawArgs = event.content
val args = parseRecoveredToolArgs(toolName, rawArgs)
val card = createRecoveredToolCard(toolName, args, rawArgs)
responseBody.addToolStatusPanel(card)
val callId = event.id?.takeIf { it.isNotBlank() }
if (callId != null) {
pendingById[callId] = card
} else {
pendingWithoutId.addLast(card)
}
rendered = true
}
is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> {
val toolName = event.tool.ifBlank { "Tool" }
if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) {
return@forEach
}
val rawResult = event.content
val parsedResult = parseRecoveredToolResult(toolName, rawResult)
val success = inferRecoveredToolSuccess(parsedResult, rawResult)
val card = event.id
?.takeIf { it.isNotBlank() }
?.let { pendingById.remove(it) }
?: pendingWithoutId.pollFirst()
?: run {
val orphan = createRecoveredToolCard(toolName, "", "")
responseBody.addToolStatusPanel(orphan)
orphan
}
card.complete(success, parsedResult ?: rawResult)
rendered = true
}
}
}
return rendered
}
private fun addRecoveredToolCards(responseBody: ChatMessageResponseBody, message: Message) {
@ -355,7 +561,7 @@ class AgentToolWindowTabPanel(
toolCalls.forEach { toolCall ->
val toolName = toolCall.function.name ?: return@forEach
if (toolName == "TodoWrite") {
if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) {
return@forEach
}
@ -419,6 +625,11 @@ class AgentToolWindowTabPanel(
return AgentToolWindowLandingPanel(project)
}
private fun disposeLandingPanelIfPresent() {
activeLandingPanel?.let { Disposer.dispose(it) }
activeLandingPanel = null
}
fun getSessionId(): String = sessionId
fun getAgentSession(): AgentSession = agentSession
@ -483,8 +694,19 @@ class AgentToolWindowTabPanel(
responsePanel.setResponseContent(responseBody)
messagePanel.add(responsePanel)
val rollbackRunId = activeRunMessageId
?.let { runCardsByMessageId[it] }
?.rollbackRunId
eventHandler.resetForNewSubmission()
eventHandler.setCurrentResponseBody(responseBody)
eventHandler.setCurrentRollbackRunId(rollbackRunId)
activeRunMessageId?.let { runMessageId ->
runCardsByMessageId[runMessageId]?.let { state ->
state.responsePanel = responsePanel
}
}
scrollablePanel.update()
}
@ -518,8 +740,118 @@ class AgentToolWindowTabPanel(
}
}
private fun registerRunCard(
runMessageId: UUID,
rollbackRunId: String,
responsePanel: ResponseMessagePanel,
prompt: String
) {
val state = RunCardState(
runMessageId = runMessageId,
rollbackRunId = rollbackRunId,
responsePanel = responsePanel,
sourceMessage = Message(prompt)
)
runCardsByMessageId[runMessageId] = state
activeRunMessageId = runMessageId
}
private fun registerRecoveredRunCard(message: Message, responsePanel: ResponseMessagePanel) {
val copiedMessage = Message(
message.prompt.orEmpty(),
message.response.orEmpty().stripThinkingBlocks()
).apply {
message.toolCalls?.let { toolCalls = ArrayList(it) }
message.toolCallResults?.let { toolCallResults = LinkedHashMap(it) }
}
val state = RunCardState(
runMessageId = message.id,
rollbackRunId = "",
responsePanel = responsePanel,
sourceMessage = copiedMessage,
completed = true
)
runCardsByMessageId[message.id] = state
timelineController.invalidateTimelineCache()
}
private fun markActiveRunCompleted() {
val runMessageId = activeRunMessageId ?: return
val state = runCardsByMessageId[runMessageId] ?: return
state.completed = true
activeRunMessageId = null
}
private fun updateRunCheckpoint(runMessageId: UUID, ref: CheckpointRef?) {
val state = runCardsByMessageId[runMessageId] ?: return
timelineController.invalidateTimelineCache()
if (ref != null) {
state.sourceMessage = null
}
}
private fun showSessionStartTimelineDialog() {
timelineController.showSessionStartTimelineDialog()
}
private fun runStateForRunIndex(runIndex: Int): AgentTimelineRunState? {
if (runIndex <= 0) return null
val state = runCardsByMessageId.values.elementAtOrNull(runIndex - 1) ?: return null
return AgentTimelineRunState(
rollbackRunId = state.rollbackRunId,
sourceMessage = state.sourceMessage
)
}
private fun applySeededSessionState(seededConversation: Conversation, seedRef: CheckpointRef) {
conversation.messages = seededConversation.messages
runCardsByMessageId.clear()
activeRunMessageId = null
timelineController.invalidateTimelineCache()
agentSession.runtimeAgentId = seedRef.agentId
agentSession.resumeCheckpointRef = seedRef
val contentManager = project.service<AgentToolWindowContentManager>()
contentManager.setRuntimeAgentId(sessionId, seedRef.agentId)
contentManager.setResumeCheckpointRef(sessionId, seedRef)
project.service<AgentService>().clearPendingMessages(sessionId)
loadingLabel.isVisible = false
clearQueuedMessages()
approvalContainer.removeAll()
approvalContainer.isVisible = false
eventHandler.resetForNewSubmission()
eventHandler.setCurrentRollbackRunId(null)
userInputPanel.setStopEnabled(false)
if (conversation.messages.isEmpty()) {
displayLandingView()
} else {
displayRecoveredConversation()
}
refreshViewAfterRollback()
runCatching {
contentManager.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.STOPPED)
}
}
private fun refreshViewAfterRollback() {
timelineController.invalidateTimelineCache()
runCatching { VirtualFileManager.getInstance().asyncRefresh(null) }
rollbackPanel.refreshOperations()
scrollablePanel.update()
revalidate()
repaint()
}
override fun dispose() {
recoveredConversationJob?.cancel()
disposeLandingPanelIfPresent()
ToolRunContext.cleanupSession(sessionId)
runCardsByMessageId.clear()
activeRunMessageId = null
projectMessageBusConnection.disconnect()
appMessageBusConnection.disconnect()

View file

@ -0,0 +1,59 @@
package ee.carlrobert.codegpt.toolwindow.agent
import ee.carlrobert.codegpt.agent.ToolName
import ee.carlrobert.codegpt.agent.ToolSpecs
import ee.carlrobert.codegpt.agent.tools.EditTool
import ee.carlrobert.codegpt.agent.tools.ReadTool
import ee.carlrobert.codegpt.agent.tools.WriteTool
import kotlinx.serialization.json.Json
internal object HistoricalRollbackCompatibility {
private val supportedTools = setOf(ToolName.READ, ToolName.EDIT, ToolName.WRITE)
fun resolveSupportedTool(rawToolName: String): ToolName? {
val normalized = rawToolName.trim()
if (normalized.isEmpty()) return null
val resolved = ToolSpecs.find(normalized)?.name ?: return null
return resolved.takeIf { it in supportedTools }
}
fun isSuccessfulResult(toolName: ToolName, content: String, replayJson: Json): Boolean {
if (toolName !in supportedTools) return false
val normalized = content.trim()
val decoded = ToolSpecs.decodeResultOrNull(replayJson, toolName.id, normalized)
if (decoded != null) {
return when (toolName) {
ToolName.READ -> decoded is ReadTool.Result.Success
ToolName.EDIT -> decoded is EditTool.Result.Success
ToolName.WRITE -> decoded is WriteTool.Result.Success
else -> false
}
}
// TODO: Is there a better way to determine if the result is successful?
// This string comparison won't cut it
return when (toolName) {
ToolName.READ -> !normalized.startsWith(READ_ERROR_PREFIX, ignoreCase = true)
ToolName.EDIT ->
normalized.isNotBlank() &&
!normalized.startsWith(EDIT_ERROR_PREFIX, ignoreCase = true) &&
(normalized.contains(EDIT_SUCCESS_MARKER, ignoreCase = true) ||
normalized.contains(LEGACY_EDIT_SUCCESS_MARKER, ignoreCase = true))
ToolName.WRITE ->
normalized.isNotBlank() &&
normalized.contains(WRITE_SUCCESS_MARKER, ignoreCase = true) &&
!normalized.startsWith(WRITE_ERROR_PREFIX, ignoreCase = true)
else -> false
}
}
private const val READ_ERROR_PREFIX = "Error reading file"
private const val EDIT_ERROR_PREFIX = "Error editing file"
private const val WRITE_ERROR_PREFIX = "Error writing file"
private const val EDIT_SUCCESS_MARKER = "Successfully edited file"
private const val LEGACY_EDIT_SUCCESS_MARKER = "Successfully made"
private const val WRITE_SUCCESS_MARKER = "successfully"
}

View file

@ -1,6 +1,5 @@
package ee.carlrobert.codegpt.toolwindow.agent.ui
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.components.service
@ -21,27 +20,33 @@ import ee.carlrobert.codegpt.agent.history.AgentCheckpointConversationMapper
import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService
import ee.carlrobert.codegpt.agent.history.AgentHistoryThreadSummary
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
import ee.carlrobert.codegpt.settings.ProxyAISubagent
import ee.carlrobert.codegpt.settings.agents.SubagentsConfigurable
import ee.carlrobert.codegpt.settings.hooks.HookConfig
import ee.carlrobert.codegpt.settings.hooks.HookConfiguration
import ee.carlrobert.codegpt.settings.skills.SkillDescriptor
import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService
import ee.carlrobert.codegpt.tokens.TokenComputationService
import ee.carlrobert.codegpt.toolwindow.agent.AgentToolWindowContentManager
import ee.carlrobert.codegpt.toolwindow.agent.history.AgentHistoryListPanel
import ee.carlrobert.codegpt.toolwindow.ui.ResponseMessagePanel
import ee.carlrobert.codegpt.ui.UIUtil.createTextPane
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
import com.intellij.openapi.Disposable
import java.awt.BorderLayout
import java.awt.Color
import java.awt.Desktop
import java.net.URI
import java.nio.charset.StandardCharsets
import java.nio.file.Path
import java.time.Instant
import java.time.ZoneId
import java.time.format.DateTimeFormatter
import javax.swing.Box
import javax.swing.BoxLayout
import javax.swing.JPanel
class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessagePanel() {
class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessagePanel(), Disposable {
companion object {
private val logger = thisLogger()
@ -49,7 +54,10 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
private val historyService = project.service<AgentCheckpointHistoryService>()
private val historyListPanel = AgentHistoryListPanel(defaultLimit = 5)
private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO)
private var refreshHistory = true
@Volatile
private var disposed = false
private fun createLabel(text: String) = JBLabel(text)
private fun createLink(text: String, onClick: () -> Unit) = ActionLink(text) { onClick() }
@ -166,77 +174,39 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
tokensLabel.foreground = healthColor(tokens)
topRow.add(tokensLabel)
val bottomRow = JPanel()
bottomRow.layout = BoxLayout(bottomRow, BoxLayout.X_AXIS)
bottomRow.alignmentX = LEFT_ALIGNMENT
val settingsService = project.service<ProxyAISettingsService>()
val skills = project.service<SkillDiscoveryService>().listSkills()
val subagents = settingsService.getSubagents()
val hookEntries = collectHookEntries(settingsService.getHooks())
val enabledHooksCount = hookEntries.count { it.hook.enabled }
val ts = vf.timeStamp
val now = Instant.now()
val fileTime = Instant.ofEpochMilli(ts)
val zonedNow = now.atZone(ZoneId.systemDefault())
val zonedFileTime = fileTime.atZone(ZoneId.systemDefault())
val timeLabel = when {
zonedFileTime.toLocalDate() == zonedNow.toLocalDate() -> {
val timeFormat =
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
JBLabel("Modified today at ${timeFormat.format(fileTime)}")
}
zonedFileTime.toLocalDate() == zonedNow.minusDays(1).toLocalDate() -> {
val timeFormat =
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
JBLabel("Modified yesterday at ${timeFormat.format(fileTime)}")
}
fileTime.isAfter(now.minusSeconds(7 * 24 * 60 * 60)) -> {
val dayFormat =
DateTimeFormatter.ofPattern("EEEE").withZone(ZoneId.systemDefault())
val timeFormat =
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
JBLabel(
"Modified on ${dayFormat.format(fileTime)} at ${
timeFormat.format(
fileTime
)
}"
)
}
zonedFileTime.toLocalDate().year == zonedNow.toLocalDate().year -> {
val dateFormat =
DateTimeFormatter.ofPattern("MMM d").withZone(ZoneId.systemDefault())
val timeFormat =
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
JBLabel(
"Modified on ${dateFormat.format(fileTime)} at ${
timeFormat.format(
fileTime
)
}"
)
}
else -> {
val dateFormat =
DateTimeFormatter.ofPattern("MMM d, yyyy").withZone(ZoneId.systemDefault())
val timeFormat =
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
JBLabel(
"Modified on ${dateFormat.format(fileTime)} at ${
timeFormat.format(
fileTime
)
}"
)
}
}
timeLabel.foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor
bottomRow.add(timeLabel)
val detailsRow = JPanel()
detailsRow.layout = BoxLayout(detailsRow, BoxLayout.X_AXIS)
detailsRow.alignmentX = LEFT_ALIGNMENT
detailsRow.add(
createHoverDetailsLabel(
text = "Skills ${skills.size}",
tooltip = buildSkillsTooltip(skills)
)
)
detailsRow.add(createDetailsSeparator())
detailsRow.add(
createHoverDetailsLabel(
text = "Hooks $enabledHooksCount",
tooltip = buildHooksTooltip(hookEntries)
)
)
detailsRow.add(createDetailsSeparator())
detailsRow.add(
createHoverDetailsLabel(
text = "Subagents ${subagents.size}",
tooltip = buildSubagentsTooltip(subagents)
)
)
container.add(topRow)
container.add(Box.createVerticalStrut(2))
container.add(bottomRow)
container.add(detailsRow)
}
panel.add(container)
@ -299,22 +269,23 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
limit: Int,
onResult: (List<AgentHistoryThreadSummary>, Boolean, Int) -> Unit
) {
if (disposed || project.isDisposed) return
val shouldRefresh = offset == 0 && refreshHistory
ApplicationManager.getApplication().executeOnPooledThread {
backgroundScope.launch {
val page = runCatching {
runBlocking {
historyService.listThreadsPage(
query = query,
offset = offset,
limit = limit,
refresh = shouldRefresh
)
}
historyService.listThreadsPage(
query = query,
offset = offset,
limit = limit,
refresh = shouldRefresh
)
}
.onFailure { logger.warn("Failed to load checkpoint history", it) }
.getOrNull()
if (disposed || project.isDisposed) return@launch
runInEdt {
if (disposed || project.isDisposed) return@runInEdt
if (page != null) {
refreshHistory = false
}
@ -324,25 +295,37 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
}
private fun openCheckpointThread(thread: AgentHistoryThreadSummary) {
ApplicationManager.getApplication().executeOnPooledThread {
if (disposed || project.isDisposed) return
backgroundScope.launch {
val checkpoint = runCatching {
historyService.loadCheckpoint(thread.latest)
}.onFailure {
logger.warn("Failed to open checkpoint thread ${thread.agentId}", it)
}.getOrNull() ?: return@launch
val conversation = runCatching {
val checkpoint = runBlocking { historyService.loadCheckpoint(thread.latest) }
?: return@executeOnPooledThread
AgentCheckpointConversationMapper.toConversation(
checkpoint = checkpoint,
projectInstructions = loadProjectInstructions(project.basePath)
)
}.onFailure {
logger.warn("Failed to open checkpoint thread ${thread.agentId}", it)
}.getOrNull() ?: return@executeOnPooledThread
}.getOrNull() ?: return@launch
ApplicationManager.getApplication().invokeLater {
if (disposed || project.isDisposed) return@launch
runInEdt {
if (disposed || project.isDisposed) return@runInEdt
project.service<AgentToolWindowContentManager>()
.openCheckpointConversation(thread, conversation)
}
}
}
override fun dispose() {
disposed = true
backgroundScope.dispose()
}
private fun welcomeMessage(): String {
val name = GeneralSettings.getCurrentState().displayName
return """
@ -385,4 +368,93 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
val bl = (a.blue + (b.blue - a.blue) * t).toInt()
return Color(r, g, bl)
}
private fun createHoverDetailsLabel(text: String, tooltip: String): JBLabel {
return JBLabel(text).apply {
foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor
toolTipText = tooltip
}
}
private fun createDetailsSeparator(): JBLabel {
return JBLabel("").apply {
foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor
}
}
private fun collectHookEntries(configuration: HookConfiguration): List<HookEntry> = buildList {
addAll(configuration.beforeToolUse.map { HookEntry("Before tool", it) })
addAll(configuration.afterToolUse.map { HookEntry("After tool", it) })
addAll(configuration.subagentStart.map { HookEntry("Subagent start", it) })
addAll(configuration.subagentStop.map { HookEntry("Subagent stop", it) })
addAll(configuration.beforeShellExecution.map { HookEntry("Before shell", it) })
addAll(configuration.afterShellExecution.map { HookEntry("After shell", it) })
addAll(configuration.beforeReadFile.map { HookEntry("Before read", it) })
addAll(configuration.afterFileEdit.map { HookEntry("After edit", it) })
addAll(configuration.stop.map { HookEntry("Stop", it) })
}
private fun buildSkillsTooltip(skills: List<SkillDescriptor>): String {
if (skills.isEmpty()) {
return tooltipHtml("Skills", listOf("No skills discovered in .proxyai/skills"))
}
val lines = skills.take(5).map { "${it.name}${it.title}" }.toMutableList()
val remaining = skills.size - lines.size
if (remaining > 0) {
lines.add("+$remaining more")
}
return tooltipHtml("Skills (${skills.size})", lines)
}
private fun buildHooksTooltip(hookEntries: List<HookEntry>): String {
val enabled = hookEntries.filter { it.hook.enabled }
if (hookEntries.isEmpty()) {
return tooltipHtml("Hooks", listOf("No hooks configured"))
}
val lines = mutableListOf<String>()
enabled.take(4).forEach { entry ->
lines.add("${entry.event}: ${entry.hook.command}")
}
val remainingEnabled = enabled.size - 4
if (remainingEnabled > 0) {
lines.add("+$remainingEnabled more enabled")
}
return tooltipHtml("Hooks", lines)
}
private fun buildSubagentsTooltip(subagents: List<ProxyAISubagent>): String {
if (subagents.isEmpty()) {
return tooltipHtml("Subagents", listOf("No subagents configured"))
}
val lines = subagents.take(5).map { "${it.title}" }.toMutableList()
val remaining = subagents.size - lines.size
if (remaining > 0) {
lines.add("+$remaining more")
}
return tooltipHtml("Subagents (${subagents.size})", lines)
}
private fun tooltipHtml(title: String, lines: List<String>): String {
val escapedTitle = escapeHtml(title)
val escapedLines = lines.joinToString("<br>") { escapeHtml(it) }
return "<html><b>$escapedTitle</b><br>$escapedLines</html>"
}
private fun escapeHtml(value: String): String {
return value
.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
}
private data class HookEntry(
val event: String,
val hook: HookConfig
)
}

View file

@ -5,6 +5,7 @@ import com.intellij.diff.DiffManager
import com.intellij.diff.requests.SimpleDiffRequest
import com.intellij.icons.AllIcons
import com.intellij.ide.actions.OpenFileAction
import com.intellij.openapi.Disposable
import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.application.EDT
@ -23,6 +24,7 @@ import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.drawCenteredText
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.lineDiffStats
import ee.carlrobert.codegpt.ui.IconActionButton
import ee.carlrobert.codegpt.ui.components.LeftEllipsisLabel
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
import kotlinx.coroutines.*
import java.awt.BorderLayout
import java.awt.Dimension
@ -40,15 +42,11 @@ import java.time.format.DateTimeFormatter
import java.util.concurrent.ConcurrentHashMap
import javax.swing.*
/**
* Panel showing file operations performed by the agent with rollback controls.
*
*/
class RollbackPanel(
private val project: Project,
private val sessionId: String,
private val onRollbackComplete: () -> Unit
) : BorderLayoutPanel() {
) : BorderLayoutPanel(), Disposable {
private val rollbackService = RollbackService.getInstance(project)
private val titleLabel = JBLabel()
private val timeLabel = JBLabel()
@ -58,7 +56,8 @@ class RollbackPanel(
private val keepAllLink = createKeepAllLink()
private val diffStatsCache = ConcurrentHashMap<String, Triple<Int, Int, Int>>()
private val diffDataCache = ConcurrentHashMap<String, RollbackDiffData>()
private val backgroundScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private var lastSnapshotRunId: String? = null
private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO)
init {
setupUI()
@ -130,12 +129,26 @@ class RollbackPanel(
.sortedBy { it.path }
withContext(Dispatchers.EDT) {
refreshOperationsUI(changes, snapshot?.completedAt)
refreshOperationsUI(
changes = changes,
completedAt = snapshot?.completedAt,
snapshotRunId = snapshot?.runId
)
}
}
}
private fun refreshOperationsUI(changes: List<FileChange>, completedAt: Instant?) {
private fun refreshOperationsUI(
changes: List<FileChange>,
completedAt: Instant?,
snapshotRunId: String?
) {
if (snapshotRunId != lastSnapshotRunId) {
diffStatsCache.clear()
diffDataCache.clear()
lastSnapshotRunId = snapshotRunId
}
isVisible = changes.isNotEmpty()
if (changes.isEmpty()) {
titleLabel.text = "Changes"
@ -168,10 +181,10 @@ class RollbackPanel(
}
private fun handleRollback() {
CoroutineScope(Dispatchers.Main).launch {
backgroundScope.launch {
val result = rollbackService.rollbackSession(sessionId)
withContext(Dispatchers.Main) {
withContext(Dispatchers.EDT) {
when (result) {
is RollbackResult.Success -> {
refreshOperations()
@ -424,10 +437,10 @@ class RollbackPanel(
}
private fun rollbackFile(path: String) {
CoroutineScope(Dispatchers.Main).launch {
backgroundScope.launch {
val result = rollbackService.rollbackFile(sessionId, path)
withContext(Dispatchers.Main) {
withContext(Dispatchers.EDT) {
when (result) {
is RollbackResult.Success -> {
refreshOperations()
@ -466,7 +479,7 @@ class RollbackPanel(
backgroundScope.launch {
val diffData = rollbackService.getDiffData(sessionId, path) ?: return@launch
diffDataCache[path] = diffData
withContext(Dispatchers.Main) {
withContext(Dispatchers.EDT) {
openDiff(diffData)
}
}
@ -480,4 +493,8 @@ class RollbackPanel(
val request = SimpleDiffRequest(title, before, after, "Before", "After")
DiffManager.getInstance().showDiff(project, request)
}
}
override fun dispose() {
backgroundScope.dispose()
}
}

View file

@ -70,6 +70,10 @@ abstract class BaseMessagePanel : BorderLayoutPanel() {
)
}
fun addHeaderAction(action: AnAction, actionCode: String) {
addIconActionButton(IconActionButton(action, actionCode))
}
fun addContent(content: JComponent) {
body.addContent(content)
}

View file

@ -73,6 +73,7 @@ class UserInputPanel @JvmOverloads constructor(
private val agentTokenCounterPanel: JComponent? = null,
private val sessionIdProvider: (() -> String?)? = null,
private val conversationIdProvider: (() -> UUID?)? = null,
private val onStartSessionTimeline: (() -> Unit)? = null,
) : BorderLayoutPanel() {
constructor(
@ -98,6 +99,7 @@ class UserInputPanel @JvmOverloads constructor(
null,
withRemovableSelectedEditorTag,
null,
null,
null
)
@ -220,7 +222,26 @@ class UserInputPanel @JvmOverloads constructor(
} else {
null
}
private val promptEnhancerSeparator = if (featureType == FeatureType.AGENT) {
private val sessionTimelineButton = if (featureType == FeatureType.AGENT) {
IconActionButton(
object : AnAction(
"Timeline",
"Choose a timeline point from this session",
AllIcons.Vcs.History
) {
override fun actionPerformed(e: AnActionEvent) {
onStartSessionTimeline?.invoke()
}
},
"SESSION_TIMELINE"
).apply {
isVisible = onStartSessionTimeline != null
cursor = Cursor.getPredefinedCursor(Cursor.HAND_CURSOR)
}
} else {
null
}
private val sessionTimelineSeparator = if (featureType == FeatureType.AGENT) {
createActionSeparator()
} else {
null
@ -600,9 +621,12 @@ class UserInputPanel @JvmOverloads constructor(
panel {
row {
if (applyChip != null) cell(applyChip).gap(RightGap.SMALL)
if (promptEnhancerButton != null && promptEnhancerSeparator != null) {
if (promptEnhancerButton != null) {
cell(promptEnhancerButton).gap(RightGap.SMALL)
cell(promptEnhancerSeparator).gap(RightGap.SMALL)
}
if (sessionTimelineButton != null && sessionTimelineSeparator != null) {
cell(sessionTimelineButton).gap(RightGap.SMALL)
cell(sessionTimelineSeparator).gap(RightGap.SMALL)
}
cell(submitButton).gap(RightGap.SMALL)
cell(stopButton)

View file

@ -13,7 +13,7 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.FolderTagDetails
class FolderActionItem(
private val project: Project,
private val folder: VirtualFile
) : AbstractLookupActionItem() {
) : AbstractLookupActionItem(), InsertsDisplayNameLookupItem {
override val displayName = folder.name
override val icon = AllIcons.Nodes.Folder
@ -33,4 +33,4 @@ class FolderActionItem(
override fun execute(project: Project, userInputPanel: UserInputPanel) {
userInputPanel.addTag(FolderTagDetails(folder))
}
}
}

View file

@ -10,9 +10,10 @@ import com.intellij.openapi.vfs.VirtualFile
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
import ee.carlrobert.codegpt.ui.textarea.lookup.action.AbstractLookupActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.InsertsDisplayNameLookupItem
class FileActionItem(private val project: Project, val file: VirtualFile) :
AbstractLookupActionItem() {
AbstractLookupActionItem(), InsertsDisplayNameLookupItem {
override val displayName = file.name
override val icon = file.fileType.icon ?: AllIcons.FileTypes.Any_type
@ -32,4 +33,4 @@ class FileActionItem(private val project: Project, val file: VirtualFile) :
override fun execute(project: Project, userInputPanel: UserInputPanel) {
userInputPanel.addTag(EditorTagDetails(file))
}
}
}

View file

@ -4,6 +4,13 @@ import ai.grazie.nlp.utils.takeWhitespaces
object StringUtil {
private val thinkBlockRegex = Regex("(?s)<think>.*?</think>\\s*")
fun String.stripThinkingBlocks(): String {
if (this.isBlank()) return ""
return thinkBlockRegex.replace(this, "").trim()
}
fun adjustWhitespace(
completionLine: String,
editorLine: String
@ -33,12 +40,4 @@ object StringUtil {
val intersection = bigrams1.intersect(bigrams2).size
return (2.0 * intersection) / (bigrams1.size + bigrams2.size)
}
fun String.extractUntilNewline(): String {
val index = this.indexOf('\n')
if (index == -1) {
return this
}
return this.substring(0, index + 1)
}
}

View file

@ -69,26 +69,26 @@ class AgentServiceIntegrationTest : IntegrationTest() {
val secondMessage = MessageWithContext("Second request")
submitAndAwait(sessionId, firstMessage)
val firstSnapshot = snapshot(sessionId)
val firstActualSnapshot = snapshot(sessionId)
submitAndAwait(sessionId, secondMessage)
val secondSnapshot = snapshot(sessionId)
val secondActualSnapshot = snapshot(sessionId)
val runtimeAgentId = firstSnapshot.runtimeAgentId
val managedService = factory.createdServices.single()
assertThat(runtimeAgentId).isNotBlank()
val actualRuntimeAgentId = firstActualSnapshot.runtimeAgentId
val actualManagedService = factory.createdServices.single()
assertThat(actualRuntimeAgentId).isNotBlank()
assertThat(factory.createdServices).hasSize(1)
assertThat(listOf(firstSnapshot, secondSnapshot))
assertThat(listOf(firstActualSnapshot, secondActualSnapshot))
.extracting("runtimeAgentId", "resumeCheckpointRef.agentId")
.containsExactly(
tuple(runtimeAgentId, runtimeAgentId),
tuple(runtimeAgentId, runtimeAgentId)
tuple(actualRuntimeAgentId, actualRuntimeAgentId),
tuple(actualRuntimeAgentId, actualRuntimeAgentId)
)
assertThat(firstSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
assertThat(secondSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
assertThat(secondSnapshot.resumeCheckpointRef?.checkpointId)
.isNotEqualTo(firstSnapshot.resumeCheckpointRef?.checkpointId)
assertThat(managedService.createdAgentIds).hasSize(2)
assertThat(managedService.createdAgentIds.last()).isEqualTo(runtimeAgentId)
assertThat(firstActualSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
assertThat(secondActualSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
assertThat(secondActualSnapshot.resumeCheckpointRef?.checkpointId)
.isNotEqualTo(firstActualSnapshot.resumeCheckpointRef?.checkpointId)
assertThat(actualManagedService.createdAgentIds).hasSize(2)
assertThat(actualManagedService.createdAgentIds.last()).isEqualTo(actualRuntimeAgentId)
}
fun testSubmitMessageRebuildsRuntimeWhenMcpSelectionChanges() {
@ -102,12 +102,38 @@ class AgentServiceIntegrationTest : IntegrationTest() {
submitAndAwait(sessionId, withoutMcp)
submitAndAwait(sessionId, withMcp)
assertThat(factory.createdServices).hasSize(2)
assertThat(factory.createdServices)
val actualCreatedServices = factory.createdServices
assertThat(actualCreatedServices).hasSize(2)
assertThat(actualCreatedServices)
.extracting("closeAllCalls")
.containsExactly(1, 0)
}
fun testSubmitMessageEmitsRunCheckpointCallback() {
val sessionId = createSession("agent-runtime-callback")
val factory = RecordingRuntimeFactory(agentModel())
agentService.runtimeFactory = factory
val message = MessageWithContext("Callback request")
var callbackMessageId: UUID? = null
var callbackRef: CheckpointRef? = null
val events = object : AgentEvents {
override fun onQueuedMessagesResolved() = Unit
override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {
callbackMessageId = runMessageId
callbackRef = ref
}
}
agentService.submitMessage(message, events, sessionId)
awaitSessionToFinish(sessionId)
val actualSession = contentManager.getSession(sessionId)
assertThat(callbackMessageId).isEqualTo(message.id)
assertThat(callbackRef).isNotNull
assertThat(callbackRef?.agentId).isEqualTo(actualSession?.runtimeAgentId)
}
fun testGetCheckpointFallsBackToRuntimeAgentIdWhenResumeRefIsStale() {
val sessionId = createSession("agent-checkpoint-fallback")
val runtimeAgentId = "runtime-agent-${UUID.randomUUID()}"
@ -133,11 +159,11 @@ class AgentServiceIntegrationTest : IntegrationTest() {
contentManager.setRuntimeAgentId(sessionId, runtimeAgentId)
contentManager.setResumeCheckpointRef(sessionId, staleRef)
val checkpoint = runBlocking { agentService.getCheckpoint(sessionId) }
assertThat(checkpoint).isNotNull
assertThat(checkpoint!!.checkpointId).isEqualTo(checkpointId)
assertThat(contentManager.getSession(sessionId)!!.resumeCheckpointRef)
val actualCheckpoint = runBlocking { agentService.getCheckpoint(sessionId) }
val actualSession = contentManager.getSession(sessionId)
assertThat(actualCheckpoint).isNotNull
assertThat(actualCheckpoint!!.checkpointId).isEqualTo(checkpointId)
assertThat(actualSession!!.resumeCheckpointRef)
.isEqualTo(CheckpointRef(runtimeAgentId, checkpointId))
}

View file

@ -2,8 +2,9 @@ package ee.carlrobert.codegpt.agent
import com.intellij.openapi.vfs.LocalFileSystem
import ee.carlrobert.codegpt.agent.rollback.ChangeKind
import ee.carlrobert.codegpt.agent.rollback.RollbackResult
import ee.carlrobert.codegpt.agent.rollback.RollbackService
import ee.carlrobert.codegpt.agent.tools.WriteTool
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
import java.io.File
@ -31,10 +32,10 @@ class RollbackServiceTrackingTest : IntegrationTest() {
filePath = filePath,
originalContent = "before"
)
val snapshot = rollbackService.finishSession(sessionId)
val actualSnapshot = rollbackService.finishSession(sessionId)
assertThat(snapshot!!.changes).hasSize(1)
assertThat(snapshot.changes.single())
assertThat(actualSnapshot!!.changes).hasSize(1)
assertThat(actualSnapshot.changes.single())
.extracting("path", "kind")
.containsExactly(filePath, ChangeKind.MODIFIED)
}
@ -46,10 +47,10 @@ class RollbackServiceTrackingTest : IntegrationTest() {
val sessionId = "s2"
rollbackService.startSession(sessionId)
rollbackService.trackWrite(sessionId, filePath, WriteTool.Args(filePath, "hello"))
val snapshot = rollbackService.finishSession(sessionId)
rollbackService.trackWrite(sessionId, filePath)
val actualSnapshot = rollbackService.finishSession(sessionId)
assertThat(snapshot).isNotNull
assertThat(actualSnapshot).isNotNull
.extracting { it!!.changes.single().kind }
.isEqualTo(ChangeKind.ADDED)
}
@ -62,14 +63,81 @@ class RollbackServiceTrackingTest : IntegrationTest() {
val sessionId = "s3"
rollbackService.startSession(sessionId)
rollbackService.trackWrite(sessionId, filePath, WriteTool.Args(filePath, "after"))
rollbackService.trackWrite(sessionId, filePath)
file.writeText("after")
LocalFileSystem.getInstance().refreshAndFindFileByPath(filePath)
val snapshot = rollbackService.finishSession(sessionId)
val actualSnapshot = rollbackService.finishSession(sessionId)
val actualDiff = rollbackService.getDiffData(sessionId, filePath)
assertThat(snapshot!!.changes.single().kind).isEqualTo(ChangeKind.MODIFIED)
assertThat(rollbackService.getDiffData(sessionId, filePath))
assertThat(actualSnapshot!!.changes.single().kind).isEqualTo(ChangeKind.MODIFIED)
assertThat(actualDiff)
.extracting("beforeText", "afterText")
.containsExactly("before", "after")
}
}
fun testSnapshotsRemainAvailablePerRunWithinSameSession() {
val rollbackService = RollbackService(project)
val sessionId = "run-scoped-session"
val firstPath = getTestFilePath("run_scoped_a.txt")
File(firstPath).delete()
val firstRunId = rollbackService.startSession(sessionId)
rollbackService.trackWrite(sessionId, firstPath)
File(firstPath).writeText("run-a")
rollbackService.finishSession(sessionId)
val secondPath = getTestFilePath("run_scoped_b.txt")
File(secondPath).delete()
val secondRunId = rollbackService.startSession(sessionId)
rollbackService.trackWrite(sessionId, secondPath)
File(secondPath).writeText("run-b")
rollbackService.finishSession(sessionId)
val firstRunSnapshot = rollbackService.getRunSnapshot(firstRunId)
val secondRunSnapshot = rollbackService.getRunSnapshot(secondRunId)
val latestSessionSnapshot = rollbackService.getSnapshot(sessionId)
assertThat(firstRunSnapshot)
.extracting("runId")
.isEqualTo(firstRunId)
assertThat(secondRunSnapshot)
.extracting("runId")
.isEqualTo(secondRunId)
assertThat(latestSessionSnapshot)
.extracting("runId")
.isEqualTo(secondRunId)
}
fun testRollbackRunDoesNotClearOtherRunSnapshots() {
val rollbackService = RollbackService(project)
val sessionId = "run-rollback-session"
val firstPath = getTestFilePath("run_rollback_a.txt")
File(firstPath).delete()
val firstRunId = rollbackService.startSession(sessionId)
rollbackService.trackWrite(sessionId, firstPath)
File(firstPath).writeText("run-a")
rollbackService.finishSession(sessionId)
val secondPath = getTestFilePath("run_rollback_b.txt")
File(secondPath).delete()
val secondRunId = rollbackService.startSession(sessionId)
rollbackService.trackWrite(sessionId, secondPath)
File(secondPath).writeText("run-b")
rollbackService.finishSession(sessionId)
val actualRollbackResult = runBlocking {
rollbackService.rollbackRun(firstRunId)
}
val firstRunSnapshotAfterRollback = rollbackService.getRunSnapshot(firstRunId)
val secondRunSnapshotAfterRollback = rollbackService.getRunSnapshot(secondRunId)
val latestSessionSnapshot = rollbackService.getSnapshot(sessionId)
assertThat(actualRollbackResult).isInstanceOf(RollbackResult.Success::class.java)
assertThat(firstRunSnapshotAfterRollback).isNull()
assertThat(secondRunSnapshotAfterRollback).isNotNull()
assertThat(latestSessionSnapshot)
.extracting("runId")
.isEqualTo(secondRunId)
}
}

View file

@ -4,7 +4,6 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.RequestMetaInfo
import ai.koog.prompt.message.ResponseMetaInfo
import ee.carlrobert.codegpt.conversations.Conversation
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant
import kotlinx.serialization.json.JsonNull
@ -15,7 +14,6 @@ import org.assertj.core.groups.Tuple.tuple
import kotlin.test.Test
class AgentCheckpointConversationMapperTest {
private val actualUserTurn = "Actual user prompt" to "Actual response"
@Test
fun `maps user and assistant turns into conversation messages`() {
@ -29,9 +27,9 @@ class AgentCheckpointConversationMapperTest {
)
)
val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
assertThat(conversation.messages)
assertThat(actualConversation.messages)
.extracting("prompt", "response")
.containsExactly(
tuple("First prompt", "First response"),
@ -60,14 +58,14 @@ class AgentCheckpointConversationMapperTest {
)
)
val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
val mapped = conversation.messages.single()
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
val actualMessage = actualConversation.messages.single()
assertThat(mapped.response).isEqualTo("Working on it")
assertThat(mapped.toolCalls.orEmpty())
assertThat(actualMessage.response).isEqualTo("Working on it")
assertThat(actualMessage.toolCalls.orEmpty())
.extracting("function.name", "function.arguments")
.containsExactly(tuple("Bash", """{"command":"ls"}"""))
assertThat(mapped.toolCallResults).containsEntry("tool-1", "file1.txt")
assertThat(actualMessage.toolCallResults).containsEntry("tool-1", "file1.txt")
}
@Test
@ -81,12 +79,14 @@ class AgentCheckpointConversationMapperTest {
)
)
val conversation = AgentCheckpointConversationMapper.toConversation(
val actualConversation = AgentCheckpointConversationMapper.toConversation(
checkpoint = checkpoint,
projectInstructions = projectInstructions
)
assertSingleTurn(conversation, actualUserTurn)
assertThat(actualConversation.messages)
.extracting("prompt", "response")
.containsExactly(tuple("Actual user prompt", "Actual response"))
}
@Test
@ -102,12 +102,43 @@ class AgentCheckpointConversationMapperTest {
)
)
val conversation = AgentCheckpointConversationMapper.toConversation(
val actualConversation = AgentCheckpointConversationMapper.toConversation(
checkpoint = checkpoint,
projectInstructions = null
)
assertSingleTurn(conversation, actualUserTurn)
assertThat(actualConversation.messages)
.extracting("prompt", "response")
.containsExactly(tuple("Actual user prompt", "Actual response"))
}
@Test
fun `does not merge hidden user turn output into previous visible turn`() {
val checkpoint = checkpoint(
listOf(
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
Message.User(
content = "Hidden cacheable instruction",
metaInfo = cacheableMetaInfo()
),
Message.Assistant("Hidden response", ResponseMetaInfo.Empty),
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
Message.Assistant("Visible response 2", ResponseMetaInfo.Empty)
)
)
val actualConversation = AgentCheckpointConversationMapper.toConversation(
checkpoint = checkpoint,
projectInstructions = null
)
assertThat(actualConversation.messages)
.extracting("prompt", "response")
.containsExactly(
tuple("Visible prompt 1", "Visible response 1"),
tuple("Visible prompt 2", "Visible response 2")
)
}
@Test
@ -136,12 +167,80 @@ class AgentCheckpointConversationMapperTest {
)
)
val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
assertThat(conversation.messages.single().toolCallResults)
assertThat(actualConversation.messages.single().toolCallResults)
.containsEntry("tool-1", "file1.txt\n\nfile2.txt")
}
@Test
fun `filters synthetic timeline user turn and its output`() {
val checkpoint = checkpoint(
listOf(
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
Message.User("I haven't created a todo list yet.", RequestMetaInfo.Empty),
Message.Assistant("Synthetic response", ResponseMetaInfo.Empty),
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
Message.Assistant("Visible response 2", ResponseMetaInfo.Empty)
)
)
val actualConversation = AgentCheckpointConversationMapper.toConversation(
checkpoint = checkpoint,
projectInstructions = null
)
assertThat(actualConversation.messages)
.extracting("prompt", "response")
.containsExactly(
tuple("Visible prompt 1", "Visible response 1"),
tuple("Visible prompt 2", "Visible response 2")
)
}
@Test
fun `filters internal timeline tools from mapped message`() {
val checkpoint = checkpoint(
listOf(
Message.User("Visible prompt", RequestMetaInfo.Empty),
Message.Tool.Call(
id = "todo-1",
tool = "TodoWrite",
content = """{"todos":[{"content":"x","status":"pending"}]}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "todo-1",
tool = "TodoWrite",
content = "ok",
metaInfo = RequestMetaInfo.Empty
),
Message.Tool.Call(
id = "tool-1",
tool = "Bash",
content = """{"command":"ls"}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "tool-1",
tool = "Bash",
content = "file1.txt",
metaInfo = RequestMetaInfo.Empty
)
)
)
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
val actualMessage = actualConversation.messages.single()
assertThat(actualMessage.toolCalls.orEmpty())
.extracting("function.name")
.containsExactly("Bash")
assertThat(actualMessage.toolCallResults)
.containsOnlyKeys("tool-1")
}
private fun cacheableMetaInfo(): RequestMetaInfo {
return RequestMetaInfo(
timestamp = Clock.System.now(),
@ -149,15 +248,6 @@ class AgentCheckpointConversationMapperTest {
)
}
private fun assertSingleTurn(
conversation: Conversation,
expectedTurn: Pair<String, String>
) {
assertThat(conversation.messages)
.extracting("prompt", "response")
.containsExactly(tuple(expectedTurn.first, expectedTurn.second))
}
private fun checkpoint(history: List<Message>): AgentCheckpointData {
return AgentCheckpointData(
checkpointId = "checkpoint-1",

View file

@ -0,0 +1,248 @@
package ee.carlrobert.codegpt.agent.history
import ai.koog.agents.snapshot.feature.AgentCheckpointData
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.RequestMetaInfo
import ai.koog.prompt.message.ResponseMetaInfo
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import org.assertj.core.api.Assertions.assertThat
import kotlin.test.Test
class AgentCheckpointTurnSequencerTest {
@Test
fun `preserves assistant and tool event order within a turn`() {
val history = listOf(
Message.User("Prompt 1", RequestMetaInfo.Empty),
Message.Assistant("Assistant 1", ResponseMetaInfo.Empty),
Message.Tool.Call(
id = "tool-1",
tool = "Bash",
content = """{"command":"ls"}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Assistant("Assistant 2", ResponseMetaInfo.Empty),
Message.Tool.Result(
id = "tool-1",
tool = "Bash",
content = "file.txt",
metaInfo = RequestMetaInfo.Empty
),
Message.Reasoning(content = "Reasoning 1", metaInfo = ResponseMetaInfo.Empty),
Message.User("Prompt 2", RequestMetaInfo.Empty),
Message.Assistant("Assistant 3", ResponseMetaInfo.Empty)
)
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(
history = history,
projectInstructions = null
)
assertThat(actualTurns).hasSize(2)
assertThat(actualTurns[0].prompt).isEqualTo("Prompt 1")
assertThat(actualTurns[0].userNonSystemMessageCount).isEqualTo(1)
assertThat(actualTurns[0].events.map(::eventSignature))
.containsExactly(
"assistant:Assistant 1",
"tool-call:Bash:tool-1",
"assistant:Assistant 2",
"tool-result:Bash:tool-1",
"reasoning:Reasoning 1"
)
assertThat(actualTurns[0].events.map { it.nonSystemMessageCount })
.containsExactly(2, 3, 4, 5, 6)
assertThat(actualTurns[1].prompt).isEqualTo("Prompt 2")
assertThat(actualTurns[1].userNonSystemMessageCount).isEqualTo(7)
assertThat(actualTurns[1].events.map(::eventSignature))
.containsExactly("assistant:Assistant 3")
assertThat(actualTurns[1].events.map { it.nonSystemMessageCount })
.containsExactly(8)
}
@Test
fun `filters hidden and synthetic user turns and internal tools`() {
val history = listOf(
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
Message.User("Hidden cacheable prompt", cacheableMetaInfo()),
Message.Assistant("Hidden response", ResponseMetaInfo.Empty),
Message.User("I haven't created a todo list yet.", RequestMetaInfo.Empty),
Message.Assistant("Synthetic response", ResponseMetaInfo.Empty),
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
Message.Tool.Call(
id = "todo-1",
tool = "TodoWrite",
content = """{"todos":[{"content":"x","status":"pending"}]}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "todo-1",
tool = "TodoWrite",
content = "ok",
metaInfo = RequestMetaInfo.Empty
),
Message.Tool.Call(
id = "tool-1",
tool = "Bash",
content = """{"command":"ls"}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "tool-1",
tool = "Bash",
content = "file.txt",
metaInfo = RequestMetaInfo.Empty
),
Message.Assistant("Visible response 2", ResponseMetaInfo.Empty)
)
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(
history = history,
projectInstructions = null
)
assertThat(actualTurns).hasSize(2)
assertThat(actualTurns.map { it.prompt })
.containsExactly("Visible prompt 1", "Visible prompt 2")
assertThat(actualTurns.map { it.userNonSystemMessageCount })
.containsExactly(1, 7)
assertThat(actualTurns[0].events.map(::eventSignature))
.containsExactly("assistant:Visible response 1")
assertThat(actualTurns[0].events.map { it.nonSystemMessageCount })
.containsExactly(2)
assertThat(actualTurns[1].events.map(::eventSignature))
.containsExactly(
"tool-call:Bash:tool-1",
"tool-result:Bash:tool-1",
"assistant:Visible response 2"
)
assertThat(actualTurns[1].events.map { it.nonSystemMessageCount })
.containsExactly(10, 11, 12)
}
@Test
fun `mapper and sequencer keep same visible turn prompts`() {
val history = listOf(
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
Message.User("Hidden cacheable prompt", cacheableMetaInfo()),
Message.Assistant("Hidden response", ResponseMetaInfo.Empty),
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
Message.Tool.Call(
id = "tool-1",
tool = "Bash",
content = """{"command":"ls"}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "tool-1",
tool = "Bash",
content = "file.txt",
metaInfo = RequestMetaInfo.Empty
)
)
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(history, null)
val actualConversation = AgentCheckpointConversationMapper.toConversation(
checkpoint = AgentCheckpointData(
checkpointId = "checkpoint-1",
createdAt = Instant.parse("2026-02-04T00:00:00Z"),
nodePath = "agent/single_run/nodeExecuteTool",
lastInput = JsonNull,
messageHistory = history,
version = 0
),
projectInstructions = null
)
assertThat(actualConversation.messages.map { it.prompt })
.containsExactlyElementsOf(actualTurns.map { it.prompt })
}
@Test
fun `keeps events after synthetic todo user message within the active visible turn`() {
val history = listOf(
Message.User("Implement feature X", RequestMetaInfo.Empty),
Message.Assistant("Starting implementation", ResponseMetaInfo.Empty),
Message.Tool.Call(
id = "bash-1",
tool = "Bash",
content = """{"command":"ls"}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "bash-1",
tool = "Bash",
content = "file.txt",
metaInfo = RequestMetaInfo.Empty
),
Message.User(
"It seems that you haven't created a todo list yet. If the task on hand requires multiple steps then create a todo list to track your changes.",
RequestMetaInfo.Empty
),
Message.Tool.Call(
id = "todo-1",
tool = "TodoWrite",
content = """{"todos":[{"content":"x","status":"pending"}]}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "todo-1",
tool = "TodoWrite",
content = "ok",
metaInfo = RequestMetaInfo.Empty
),
Message.Tool.Call(
id = "read-1",
tool = "Read",
content = """{"path":"src/Main.kt"}""",
metaInfo = ResponseMetaInfo.Empty
),
Message.Tool.Result(
id = "read-1",
tool = "Read",
content = "content",
metaInfo = RequestMetaInfo.Empty
),
Message.Assistant("Done", ResponseMetaInfo.Empty)
)
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(
history = history,
projectInstructions = null,
preserveSyntheticContinuation = true
)
assertThat(actualTurns).hasSize(1)
assertThat(actualTurns.single().prompt).isEqualTo("Implement feature X")
assertThat(actualTurns.single().events.map(::eventSignature))
.containsExactly(
"assistant:Starting implementation",
"tool-call:Bash:bash-1",
"tool-result:Bash:bash-1",
"tool-call:Read:read-1",
"tool-result:Read:read-1",
"assistant:Done"
)
}
private fun eventSignature(event: AgentCheckpointTurnSequencer.TurnEvent): String {
return when (event) {
is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> "assistant:${event.content}"
is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> "reasoning:${event.content}"
is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> "tool-call:${event.tool}:${event.id}"
is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> "tool-result:${event.tool}:${event.id}"
}
}
private fun cacheableMetaInfo(): RequestMetaInfo {
return RequestMetaInfo(
timestamp = Clock.System.now(),
metadata = JsonObject(mapOf("cacheable" to JsonPrimitive(true)))
)
}
}

View file

@ -17,7 +17,6 @@ class ChatToolWindowTabbedPaneTest : BasePlatformTestCase() {
assertThat(tabbedPane.activeTabMapping).isEmpty()
}
fun testAddingNewTabs() {
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())