diff --git a/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt b/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt new file mode 100644 index 00000000..7c2aed0a --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/agent/rollback/RollbackService.kt @@ -0,0 +1,560 @@ +package ee.carlrobert.codegpt.agent.rollback + +import com.intellij.history.Label +import com.intellij.history.LocalHistory +import com.intellij.openapi.application.runInEdt +import com.intellij.openapi.application.runReadAction +import com.intellij.openapi.application.runWriteAction +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.components.Service +import com.intellij.openapi.fileTypes.FileTypeManager +import com.intellij.openapi.project.Project +import com.intellij.openapi.roots.ProjectFileIndex +import com.intellij.openapi.util.io.FileUtil +import com.intellij.openapi.vfs.* +import com.intellij.openapi.vfs.newvfs.BulkFileListener +import com.intellij.openapi.vfs.newvfs.events.* +import ee.carlrobert.codegpt.settings.ProxyAISettingsService +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.suspendCancellableCoroutine +import kotlinx.coroutines.withContext +import java.io.File +import java.nio.file.Paths +import java.time.Instant +import java.time.format.DateTimeFormatter +import java.util.concurrent.ConcurrentHashMap +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException + +/** + * Tracks file changes during agent runs and provides rollback to a checkpoint. + * + * Uses IntelliJ's LocalHistory labels for persistence and a lightweight in-memory + * snapshot of modified files for fast restore. + */ +@Service(Service.Level.PROJECT) +class RollbackService(private val project: Project) { + + private val activeRuns = ConcurrentHashMap() + private val snapshots = ConcurrentHashMap() + private val settingsService = project.getService(ProxyAISettingsService::class.java) + + @Volatile + private var isApplyingRollback = false + + init { + val connection = project.messageBus.connect() + connection.subscribe(VirtualFileManager.VFS_CHANGES, object : BulkFileListener { + override fun before(events: List) { + if (isApplyingRollback || activeRuns.isEmpty()) return + events.forEach { event -> + when (event) { + is VFileContentChangeEvent -> recordModified(event.file) + is VFileDeleteEvent -> recordDeleted(event.file) + is VFileMoveEvent -> { + val oldPath = "${event.oldParent.path}/${event.file.name}" + val newPath = "${event.newParent.path}/${event.file.name}" + recordMoved(event.file, oldPath, newPath) + } + + is VFilePropertyChangeEvent -> { + if (event.propertyName == VirtualFile.PROP_NAME) { + val oldName = event.oldValue as? String ?: return@forEach + val newName = event.newValue as? String ?: return@forEach + val parentPath = event.file.parent?.path ?: return@forEach + recordMoved( + event.file, + "$parentPath/$oldName", + "$parentPath/$newName" + ) + } + } + } + } + } + + override fun after(events: List) { + if (isApplyingRollback || activeRuns.isEmpty()) return + events.forEach { event -> + if (event is VFileCreateEvent) { + recordCreated(event.path, event.isDirectory) + } + } + } + }) + } + + fun startSession(sessionId: String) { + val labelText = "ProxyAI: Agent run ${DateTimeFormatter.ISO_INSTANT.format(Instant.now())}" + val label = LocalHistory.getInstance().putSystemLabel(project, labelText) + activeRuns[sessionId] = RunTracker(sessionId, Instant.now(), labelText, label) + snapshots.remove(sessionId) + } + + fun finishSession(sessionId: String): RollbackSnapshot? { + val tracker = activeRuns.remove(sessionId) ?: return getSnapshot(sessionId) + val snapshot = SnapshotState( + sessionId = sessionId, + label = tracker.label, + labelRef = tracker.labelRef, + startedAt = tracker.startedAt, + completedAt = Instant.now(), + changes = tracker.changes.toMap() + ) + snapshots[sessionId] = snapshot + return snapshot.toSnapshot() + } + + fun getSnapshot(sessionId: String): RollbackSnapshot? = + snapshots[sessionId]?.toSnapshot() + + fun clearSnapshot(sessionId: String) { + snapshots.remove(sessionId) + } + + fun getDiffData(sessionId: String, path: String): RollbackDiffData? { + if (!isTrackable(path)) return null + val snapshot = snapshots[sessionId] ?: return null + val change = snapshot.changes[path] ?: return null + val beforeText = when (change.kind) { + ChangeKind.ADDED -> "" + else -> decodeLabelContent( + snapshot.labelRef, + change.originalPath ?: path, + change.originalContent + ) + } + val afterText = when (change.kind) { + ChangeKind.DELETED -> "" + else -> readCurrentText(path) + } + return RollbackDiffData( + path = path, + beforeText = beforeText, + afterText = if (change.kind == ChangeKind.DELETED) "" else afterText + ) + } + + fun isRollbackAvailable(sessionId: String): Boolean = + snapshots[sessionId]?.changes?.isNotEmpty() == true && !activeRuns.containsKey(sessionId) + + fun isDisplayable(path: String): Boolean = isTrackable(path) + + suspend fun rollbackFile(sessionId: String, path: String): RollbackResult = + withContext(Dispatchers.Main) { + val snapshot = snapshots[sessionId] + ?: return@withContext RollbackResult.Failure("No rollback snapshot available") + val change = snapshot.changes[path] + ?: return@withContext RollbackResult.Failure("No change tracked for $path") + + val errors = mutableListOf() + isApplyingRollback = true + try { + runWriteSafe("ProxyAI Rollback") { + applyChangeWithLabel(snapshot.labelRef, path, change, errors) + } + } finally { + isApplyingRollback = false + } + + if (errors.isNotEmpty()) { + RollbackResult.Failure(errors.joinToString("\n")) + } else { + val updated = snapshot.changes.toMutableMap() + updated.remove(path) + if (updated.isEmpty()) { + snapshots.remove(sessionId) + } else { + snapshots[sessionId] = snapshot.copy(changes = updated.toMap()) + } + RollbackResult.Success("Rollback completed") + } + } + + suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.Main) { + val snapshot = snapshots[sessionId] + ?: return@withContext RollbackResult.Failure("No rollback snapshot available") + + val errors = mutableListOf() + isApplyingRollback = true + try { + runWriteSafe("ProxyAI Rollback") { + snapshot.changes.forEach { (path, change) -> + applyChangeWithLabel(snapshot.labelRef, path, change, errors) + } + } + } finally { + isApplyingRollback = false + } + + if (errors.isNotEmpty()) { + RollbackResult.Failure(errors.joinToString("\n")) + } else { + snapshots.remove(sessionId) + RollbackResult.Success("Rollback completed") + } + } + + private fun recordModified(file: VirtualFile) { + if (!isTrackable(file)) return + activeRuns.values.forEach { it.recordModified(file) } + } + + private fun recordDeleted(file: VirtualFile) { + if (!isTrackable(file)) return + activeRuns.values.forEach { it.recordDeleted(file) } + } + + private fun recordCreated(path: String, isDirectory: Boolean) { + if (isDirectory || !isTrackable(path)) return + activeRuns.values.forEach { it.recordCreated(path) } + } + + private fun recordMoved(file: VirtualFile, oldPath: String, newPath: String) { + if (!isTrackable(file)) return + val oldNormalized = oldPath.replace("\\", "/") + val newNormalized = newPath.replace("\\", "/") + if (!isTrackable(oldNormalized) && !isTrackable(newNormalized)) return + activeRuns.values.forEach { it.recordMoved(file, oldNormalized, newNormalized) } + } + + private fun isTrackable(file: VirtualFile): Boolean { + if (file.isDirectory || !file.isValid) return false + if (FileTypeManager.getInstance().isFileIgnored(file)) return false + if (runReadAction { !ProjectFileIndex.getInstance(project).isInContent(file) }) return false + if (settingsService.isPathIgnored(file.path)) return false + return true + } + + private fun isTrackable(path: String): Boolean { + val file = LocalFileSystem.getInstance().findFileByPath(path) + if (file != null) return isTrackable(file) + + val basePath = project.basePath ?: return false + val normalized = FileUtil.toSystemIndependentName(path) + val normalizedBase = FileUtil.toSystemIndependentName(basePath) + if (!FileUtil.isAncestor( + normalizedBase, + normalized, + false + ) && normalized != normalizedBase + ) { + return false + } + val fileName = runCatching { Paths.get(normalized).fileName?.toString() } + .getOrNull() + ?: normalized.substringAfterLast('/') + if (FileTypeManager.getInstance().isFileIgnored(fileName)) { + return false + } + if (settingsService.isPathIgnored(normalized)) return false + return true + } + + private fun readContentSafe(file: VirtualFile): ByteArray? = + runCatching { file.contentsToByteArray() }.getOrNull() + + private fun decodeContent(content: ByteArray?): String { + if (content == null) return "" + return runCatching { String(content, Charsets.UTF_8) }.getOrDefault("") + } + + private fun decodeLabelContent(label: Label, path: String, fallback: ByteArray?): String { + val bytes = resolveLabelContent(label, path, fallback) ?: return "" + return decodeContent(bytes) + } + + private fun readCurrentText(path: String): String { + val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return "" + return runCatching { VfsUtilCore.loadText(vf) }.getOrDefault("") + } + + private fun deleteFile(path: String, errors: MutableList) { + val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return + if (!vf.exists()) return + runCatching { + runInEdt { + runWriteAction { + vf.delete(this) + } + } + }.onFailure { + errors.add("Failed to delete $path: ${it.message}") + } + } + + private fun restoreFile(path: String?, content: ByteArray?, errors: MutableList) { + if (path.isNullOrBlank()) return + if (content == null) { + errors.add("Missing original content for $path") + return + } + + val ioFile = File(path) + val parent = ioFile.parentFile + if (parent != null && !parent.exists()) { + runCatching { parent.mkdirs() }.onFailure { + errors.add("Failed to create parent directories for $path: ${it.message}") + return + } + } + + val file = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) + ?: runCatching { + val parentPath = parent?.path ?: run { + errors.add("Missing parent directory for $path") + return@runCatching null + } + val parentVf = VfsUtil.createDirectories(parentPath) + parentVf.createChildData(this, ioFile.name) + }.getOrNull() + + if (file == null) { + errors.add("Failed to recreate file $path") + return + } + + runCatching { + runInEdt { + runWriteAction { + file.setBinaryContent(content) + } + } + }.onFailure { + errors.add("Failed to restore $path: ${it.message}") + } + } + + private fun applyChangeWithLabel( + label: Label, + path: String, + change: TrackedChange, + errors: MutableList + ) { + val file = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) + when (change.kind) { + ChangeKind.ADDED -> { + if (file != null) { + val reverted = runCatching { label.revert(project, file) }.isSuccess + if (!reverted) deleteFile(path, errors) + } else { + deleteFile(path, errors) + } + } + + ChangeKind.MODIFIED -> { + if (file != null) { + val reverted = runCatching { label.revert(project, file) }.isSuccess + if (!reverted) { + val before = resolveLabelContent(label, path, change.originalContent) + restoreFile(path, before, errors) + } + } else { + val before = resolveLabelContent(label, path, change.originalContent) + restoreFile(path, before, errors) + } + } + + ChangeKind.DELETED -> { + val before = resolveLabelContent(label, path, change.originalContent) + restoreFile(path, before, errors) + } + + ChangeKind.MOVED -> { + deleteFile(path, errors) + val originalPath = change.originalPath + if (originalPath != null) { + val before = resolveLabelContent(label, originalPath, change.originalContent) + restoreFile(originalPath, before, errors) + } + } + } + } + + private fun resolveLabelContent( + label: Label, + path: String, + fallback: ByteArray? + ): ByteArray? { + return runCatching { label.getByteContent(path) } + .getOrNull()?.bytes ?: fallback + } + + private suspend fun runWriteSafe(actionName: String, block: () -> Unit) { + suspendCancellableCoroutine { cont -> + runInEdt { + try { + WriteCommandAction.runWriteCommandAction( + project, + actionName, + null, + { block() } + ) + if (cont.isActive) cont.resume(Unit) + } catch (e: Throwable) { + if (cont.isActive) cont.resumeWithException(e) + } + } + } + } + + companion object { + fun getInstance(project: Project): RollbackService { + return project.getService(RollbackService::class.java) + } + } + + private inner class RunTracker( + val sessionId: String, + val startedAt: Instant, + val label: String, + val labelRef: Label, + val changes: MutableMap = ConcurrentHashMap() + ) { + fun recordModified(file: VirtualFile) { + val path = file.path + val existing = changes[path] + if (existing?.kind == ChangeKind.ADDED) return + if (existing?.kind == ChangeKind.MOVED) return + if (existing?.kind == ChangeKind.MODIFIED) return + changes[path] = TrackedChange( + kind = ChangeKind.MODIFIED, + originalPath = null, + originalContent = readContentSafe(file) + ) + } + + fun recordDeleted(file: VirtualFile) { + val path = file.path + val existing = changes[path] + if (existing?.kind == ChangeKind.ADDED) { + changes.remove(path) + return + } + if (existing?.kind == ChangeKind.DELETED) return + + val original = existing?.originalContent ?: readContentSafe(file) + changes[path] = TrackedChange( + kind = ChangeKind.DELETED, + originalPath = null, + originalContent = original + ) + } + + fun recordCreated(path: String) { + val existing = changes[path] + if (existing?.kind == ChangeKind.DELETED) { + changes[path] = existing.copy(kind = ChangeKind.MODIFIED) + return + } + if (existing == null) { + changes[path] = TrackedChange( + kind = ChangeKind.ADDED, + originalPath = null, + originalContent = null + ) + } + } + + fun recordMoved(file: VirtualFile, oldPath: String, newPath: String) { + if (oldPath == newPath) return + val existing = changes.remove(oldPath) + + if (existing?.kind == ChangeKind.ADDED) { + changes[newPath] = existing.copy(kind = ChangeKind.ADDED) + return + } + + val content = existing?.originalContent ?: readContentSafe(file) + changes[newPath] = TrackedChange( + kind = ChangeKind.MOVED, + originalPath = oldPath, + originalContent = content + ) + } + } + + private data class TrackedChange( + val kind: ChangeKind, + val originalPath: String?, + val originalContent: ByteArray? + ) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as TrackedChange + + if (kind != other.kind) return false + if (originalPath != other.originalPath) return false + if (!originalContent.contentEquals(other.originalContent)) return false + + return true + } + + override fun hashCode(): Int { + var result = kind.hashCode() + result = 31 * result + (originalPath?.hashCode() ?: 0) + result = 31 * result + (originalContent?.contentHashCode() ?: 0) + return result + } + } + + private data class SnapshotState( + val sessionId: String, + val label: String, + val labelRef: Label, + val startedAt: Instant, + val completedAt: Instant, + val changes: Map + ) { + fun toSnapshot(): RollbackSnapshot? { + if (changes.isEmpty()) return null + return RollbackSnapshot( + sessionId = sessionId, + label = label, + startedAt = startedAt, + completedAt = completedAt, + changes = changes.map { (path, change) -> + FileChange( + path = path, + kind = change.kind, + originalPath = change.originalPath + ) + }.sortedBy { it.path } + ) + } + } +} + +data class RollbackSnapshot( + val sessionId: String, + val label: String, + val startedAt: Instant, + val completedAt: Instant, + val changes: List +) + +data class FileChange( + val path: String, + val kind: ChangeKind, + val originalPath: String? +) + +data class RollbackDiffData( + val path: String, + val beforeText: String, + val afterText: String +) + +enum class ChangeKind { + ADDED, + MODIFIED, + DELETED, + MOVED +} + +sealed class RollbackResult { + data class Success(val message: String) : RollbackResult() + data class Failure(val message: String) : RollbackResult() +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt index a9585442..cc0111d5 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentEventHandler.kt @@ -11,6 +11,7 @@ import com.intellij.openapi.diagnostic.thisLogger import com.intellij.openapi.project.Project import ee.carlrobert.codegpt.CodeGPTBundle import ee.carlrobert.codegpt.agent.* +import ee.carlrobert.codegpt.agent.rollback.RollbackService import ee.carlrobert.codegpt.agent.tools.* import ee.carlrobert.codegpt.conversations.message.TokenUsage import ee.carlrobert.codegpt.settings.service.ServiceType @@ -209,6 +210,7 @@ class AgentEventHandler( private fun handleDone() { runInEdt { + project.service().finishSession(sessionId) currentResponseBody?.finishThinking() onHideLoading() userInputPanel.setStopEnabled(false) @@ -416,16 +418,6 @@ class AgentEventHandler( toolName: String, result: Any? ) { - when (result) { - is TaskTool.InternalResult -> { - result.tokenUsage?.let { usage -> onTokenUsageUpdated(usage) } - } - - is TokenUsage -> { - onTokenUsageUpdated(result) - } - } - runInEdt { if (childId != null) { val holder = subagentViewHolders.values.find { viewHolder -> diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt index 75879513..e69990c5 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/AgentToolWindowTabPanel.kt @@ -16,6 +16,7 @@ import ee.carlrobert.codegpt.agent.AgentService import ee.carlrobert.codegpt.agent.AgentToolOutputNotifier import ee.carlrobert.codegpt.agent.MessageWithContext import ee.carlrobert.codegpt.agent.ToolRunContext +import ee.carlrobert.codegpt.agent.rollback.RollbackService import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.message.Message import ee.carlrobert.codegpt.conversations.message.QueuedMessage @@ -23,6 +24,7 @@ import ee.carlrobert.codegpt.psistructure.PsiStructureProvider import ee.carlrobert.codegpt.settings.service.FeatureType import ee.carlrobert.codegpt.settings.service.ModelSelectionService import ee.carlrobert.codegpt.toolwindow.agent.ui.AgentToolWindowLandingPanel +import ee.carlrobert.codegpt.toolwindow.agent.ui.RollbackPanel import ee.carlrobert.codegpt.toolwindow.agent.ui.TodoListPanel import ee.carlrobert.codegpt.toolwindow.chat.MessageBuilder import ee.carlrobert.codegpt.toolwindow.chat.editor.actions.CopyAction @@ -105,9 +107,11 @@ class AgentToolWindowTabPanel( agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId), sessionIdProvider = { sessionId } ) + private lateinit var rollbackPanel: RollbackPanel private val todoListPanel = TodoListPanel() private val projectMessageBusConnection = project.messageBus.connect() private val appMessageBusConnection = ApplicationManager.getApplication().messageBus.connect() + private val rollbackService = RollbackService.getInstance(project) private val agentEventHandler = AgentEventHandler( project = project, @@ -127,6 +131,7 @@ class AgentToolWindowTabPanel( loadingLabel.isVisible = false revalidate() repaint() + rollbackPanel.refreshOperations() }, onQueuedMessagesResolved = { message -> runInEdt { @@ -139,6 +144,9 @@ class AgentToolWindowTabPanel( init { setupMessageBusSubscriptions() + rollbackPanel = RollbackPanel(project, sessionId) { + rollbackPanel.refreshOperations() + } setupUI() if (conversation.messages.isEmpty()) { @@ -182,6 +190,9 @@ class AgentToolWindowTabPanel( isOpaque = false } + rollbackPanel.alignmentX = LEFT_ALIGNMENT + topContainer.add(rollbackPanel) + todoListPanel.alignmentX = LEFT_ALIGNMENT topContainer.add(todoListPanel) @@ -233,6 +244,9 @@ class AgentToolWindowTabPanel( .setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.RUNNING) } + rollbackService.startSession(sessionId) + rollbackPanel.refreshOperations() + val message = MessageWithContext(text, userInputPanel.getSelectedTags()) val messagePanel = scrollablePanel.addMessage(message.id) val userPanel = UserMessagePanel( @@ -275,6 +289,9 @@ class AgentToolWindowTabPanel( agentService.cancelCurrentRun(sessionId) agentService.clearPendingMessages(sessionId) + rollbackService.finishSession(sessionId) + rollbackPanel.refreshOperations() + approvalContainer.removeAll() clearQueuedMessages() approvalContainer.isVisible = false diff --git a/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt new file mode 100644 index 00000000..d2274437 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/toolwindow/agent/ui/RollbackPanel.kt @@ -0,0 +1,391 @@ +package ee.carlrobert.codegpt.toolwindow.agent.ui + +import com.intellij.diff.DiffContentFactory +import com.intellij.diff.DiffManager +import com.intellij.diff.requests.SimpleDiffRequest +import com.intellij.icons.AllIcons +import com.intellij.ide.actions.OpenFileAction +import com.intellij.openapi.actionSystem.AnAction +import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.application.EDT +import com.intellij.openapi.project.Project +import com.intellij.openapi.ui.Messages +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.ui.JBColor +import com.intellij.ui.components.ActionLink +import com.intellij.ui.components.JBLabel +import com.intellij.util.ui.JBUI +import com.intellij.util.ui.components.BorderLayoutPanel +import ee.carlrobert.codegpt.agent.rollback.* +import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.ChangeColors +import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.lineDiffStats +import ee.carlrobert.codegpt.ui.IconActionButton +import kotlinx.coroutines.* +import java.awt.FlowLayout +import java.awt.GridLayout +import java.time.Instant +import java.time.ZoneId +import java.time.format.DateTimeFormatter +import java.util.concurrent.ConcurrentHashMap +import javax.swing.* + +/** + * Panel showing file operations performed by the agent with rollback controls. + * + */ +class RollbackPanel( + private val project: Project, + private val sessionId: String, + private val onRollbackComplete: () -> Unit +) : BorderLayoutPanel() { + + private val rollbackService = RollbackService.getInstance(project) + private val titleLabel = JBLabel() + private val timeLabel = JBLabel() + private val summaryPanel = JPanel(FlowLayout(FlowLayout.RIGHT, 6, 2)) + private val changesPanel = JPanel() + private val footerPanel = JPanel(GridLayout(1, 2, 8, 0)) + private val diffStatsCache = ConcurrentHashMap>() + private val diffDataCache = ConcurrentHashMap() + private val backgroundScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + + init { + setupUI() + } + + private fun setupUI() { + val headerPanel = JPanel(FlowLayout(FlowLayout.LEFT, 6, 2)).apply { + isOpaque = false + add(JBLabel(AllIcons.Actions.Diff)) + add(JBLabel(AllIcons.General.Add)) + add(titleLabel.apply { + font = font.deriveFont(java.awt.Font.BOLD) + }) + add(timeLabel.apply { + foreground = JBUI.CurrentTheme.Label.disabledForeground() + }) + } + + summaryPanel.apply { + isOpaque = false + } + + val topPanel = BorderLayoutPanel().apply { + addToLeft(headerPanel) + addToRight(summaryPanel) + border = JBUI.Borders.empty(8) + } + + changesPanel.apply { + layout = BoxLayout(this, BoxLayout.Y_AXIS) + isOpaque = false + } + + footerPanel.apply { + isOpaque = true + border = JBUI.Borders.customLine(JBColor.border(), 1, 0, 0, 0) + add(createRollbackAllButton()) + add(createKeepButton()) + } + + addToTop(topPanel) + addToCenter( + BorderLayoutPanel().apply { + addToCenter(changesPanel) + addToBottom(footerPanel) + } + ) + border = JBUI.Borders.compound( + JBUI.Borders.customLine(JBColor.border(), 0, 0, 1, 0), + JBUI.Borders.empty(0, 0, 8, 0) + ) + + refreshOperations() + } + + fun refreshOperations() { + val snapshot = rollbackService.getSnapshot(sessionId) + val changes = snapshot?.changes.orEmpty() + .filter { rollbackService.isDisplayable(it.path) } + .sortedBy { it.path } + isVisible = changes.isNotEmpty() + if (changes.isEmpty()) { + titleLabel.text = "Agent changes" + timeLabel.text = "" + changesPanel.removeAll() + summaryPanel.removeAll() + footerPanel.isVisible = false + revalidate() + repaint() + return + } + + val timeText = snapshot?.completedAt?.let { formatTime(it) } ?: "" + titleLabel.text = "Agent changes (${changes.size})" + timeLabel.text = if (timeText.isNotBlank()) "• $timeText" else "" + + updateSummary(changes) + preloadDiffStats(changes) + + changesPanel.removeAll() + changes.forEachIndexed { index, change -> + if (index > 0) changesPanel.add(Box.createVerticalStrut(4)) + changesPanel.add(createChangeRow(change)) + } + footerPanel.isVisible = true + + revalidate() + repaint() + } + + private fun handleRollback() { + CoroutineScope(Dispatchers.Main).launch { + val result = rollbackService.rollbackSession(sessionId) + + withContext(Dispatchers.Main) { + when (result) { + is RollbackResult.Success -> { + refreshOperations() + onRollbackComplete() + } + + is RollbackResult.Failure -> { + Messages.showErrorDialog(project, result.message, "Rollback Failed") + } + } + } + } + } + + private fun createChangeRow(change: FileChange): JComponent { + val row = BorderLayoutPanel().apply { + isOpaque = true + background = JBUI.CurrentTheme.List.background(false, false) + border = JBUI.Borders.compound( + JBUI.Borders.customLine(JBColor.border(), 1), + JBUI.Borders.empty(4, 8) + ) + } + + val left = JPanel(FlowLayout(FlowLayout.LEFT, 6, 2)).apply { + isOpaque = false + } + left.add(changeLabel(change)) + left.add(fileNameComponent(change)) + left.add(filePathLabel(change)) + addDiffStats(change, left) + + val actions = JPanel(FlowLayout(FlowLayout.RIGHT, 6, 0)).apply { + isOpaque = false + } + openDiffAction(change).let { actions.add(it) } + actions.add(rollbackAction(change)) + + row.addToLeft(left) + row.addToRight(actions) + return row + } + + private fun changeLabel(change: FileChange): JBLabel { + val (text, color) = when (change.kind) { + ChangeKind.ADDED -> "+" to ChangeColors.inserted + ChangeKind.DELETED -> "-" to ChangeColors.deleted + ChangeKind.MODIFIED -> "~" to ChangeColors.modified + ChangeKind.MOVED -> "~" to ChangeColors.modified + } + return JBLabel(text).apply { + foreground = color + font = JBUI.Fonts.smallFont() + } + } + + private fun fileNameComponent(change: FileChange): JComponent { + val display = displayFileName(change.path) + val file = LocalFileSystem.getInstance().findFileByPath(change.path) + return if (file != null && change.kind != ChangeKind.DELETED) { + ActionLink(display) { + OpenFileAction.openFile(file, project) + }.apply { + font = JBUI.Fonts.label().asBold() + } + } else { + JBLabel(display).apply { + font = JBUI.Fonts.label().asBold() + } + } + } + + private fun filePathLabel(change: FileChange): JBLabel { + val display = displayPath(change.path, change) + return JBLabel(display).apply { + foreground = JBUI.CurrentTheme.Label.disabledForeground() + font = JBUI.Fonts.smallFont() + } + } + + private fun openDiffAction(change: FileChange): JComponent { + return IconActionButton( + object : AnAction("Open Diff", "Open diff view", AllIcons.Actions.Diff) { + override fun actionPerformed(e: AnActionEvent) { + openDiffForPath(change.path) + } + }, + "OPEN_DIFF" + ) + } + + private fun rollbackAction(change: FileChange): JComponent { + return IconActionButton( + object : AnAction("Rollback File", "Rollback this file", AllIcons.Actions.Undo) { + override fun actionPerformed(e: AnActionEvent) { + rollbackFile(change.path) + } + }, + "ROLLBACK_FILE" + ) + } + + private fun displayPath(path: String, change: FileChange): String { + val base = project.basePath?.replace("\\", "/") + val normalized = path.replace("\\", "/") + val relative = if (base != null && normalized.startsWith(base)) { + normalized.removePrefix(base).trimStart('/') + } else normalized + val original = change.originalPath?.replace("\\", "/") + return if (change.kind == ChangeKind.MOVED && original != null) { + val originalDisplay = if (base != null && original.startsWith(base)) { + original.removePrefix(base).trimStart('/') + } else original + "$relative (renamed from $originalDisplay)" + } else { + relative + } + } + + private fun displayFileName(path: String): String { + val normalized = path.replace("\\", "/") + return normalized.substringAfterLast('/') + } + + private fun addDiffStats(change: FileChange, container: JPanel) { + val stats = diffStatsCache[change.path] ?: return + val (ins, del, mod) = stats + if (ins + del + mod == 0) return + container.add(colorLabel("+$ins", ChangeColors.inserted)) + container.add(colorLabel("-$del", ChangeColors.deleted)) + container.add(colorLabel("~$mod", ChangeColors.modified)) + } + + private fun colorLabel(text: String, color: JBColor): JBLabel = + JBLabel(text).apply { + foreground = color + font = JBUI.Fonts.smallFont() + } + + private fun formatTime(instant: Instant): String { + val formatter = DateTimeFormatter.ofPattern("HH:mm").withZone(ZoneId.systemDefault()) + return "Last run at ${formatter.format(instant)}" + } + + private fun updateSummary(changes: List) { + val added = changes.count { it.kind == ChangeKind.ADDED } + val deleted = changes.count { it.kind == ChangeKind.DELETED } + val modified = + changes.count { it.kind == ChangeKind.MODIFIED || it.kind == ChangeKind.MOVED } + summaryPanel.removeAll() + if (added > 0) summaryPanel.add(colorChip("+$added", ChangeColors.inserted)) + if (deleted > 0) summaryPanel.add(colorChip("-$deleted", ChangeColors.deleted)) + if (modified > 0) summaryPanel.add(colorChip("~$modified", ChangeColors.modified)) + summaryPanel.revalidate() + summaryPanel.repaint() + } + + private fun colorChip(text: String, color: JBColor): JBLabel = + JBLabel(text).apply { + foreground = color + font = JBUI.Fonts.smallFont() + } + + private fun createRollbackAllButton(): JComponent { + return JButton("Rollback all").apply { + isFocusable = false + isContentAreaFilled = true + isOpaque = true + putClientProperty("JButton.buttonType", "roundRect") + margin = JBUI.insets(6, 12) + addActionListener { handleRollback() } + } + } + + private fun createKeepButton(): JComponent { + return JButton("Keep changes").apply { + isFocusable = false + isContentAreaFilled = true + isOpaque = true + putClientProperty("JButton.buttonType", "roundRect") + margin = JBUI.insets(6, 12) + addActionListener { + rollbackService.clearSnapshot(sessionId) + refreshOperations() + } + } + } + + private fun rollbackFile(path: String) { + CoroutineScope(Dispatchers.Main).launch { + val result = rollbackService.rollbackFile(sessionId, path) + + withContext(Dispatchers.Main) { + when (result) { + is RollbackResult.Success -> { + refreshOperations() + onRollbackComplete() + } + + is RollbackResult.Failure -> { + Messages.showErrorDialog(project, result.message, "Rollback Failed") + } + } + } + } + } + + private fun preloadDiffStats(changes: List) { + changes.forEach { change -> + if (diffStatsCache.containsKey(change.path)) return@forEach + backgroundScope.launch { + val diffData = rollbackService.getDiffData(sessionId, change.path) ?: return@launch + diffDataCache[change.path] = diffData + diffStatsCache[change.path] = + lineDiffStats(diffData.beforeText, diffData.afterText) + withContext(Dispatchers.EDT) { + refreshOperations() + } + } + } + } + + private fun openDiffForPath(path: String) { + val cached = diffDataCache[path] + if (cached != null) { + openDiff(cached) + return + } + backgroundScope.launch { + val diffData = rollbackService.getDiffData(sessionId, path) ?: return@launch + diffDataCache[path] = diffData + withContext(Dispatchers.Main) { + openDiff(diffData) + } + } + } + + private fun openDiff(diffData: RollbackDiffData) { + val contentFactory = DiffContentFactory.getInstance() + val before = contentFactory.create(diffData.beforeText) + val after = contentFactory.create(diffData.afterText) + val title = "Agent change: ${displayFileName(diffData.path)}" + val request = SimpleDiffRequest(title, before, after, "Before", "After") + DiffManager.getInstance().showDiff(project, request) + } +}