mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 07:54:46 +00:00
feat: add rollback tracking and UI for agent runs
This commit is contained in:
parent
88468e1c98
commit
e558a27ada
4 changed files with 970 additions and 10 deletions
|
|
@ -0,0 +1,560 @@
|
|||
package ee.carlrobert.codegpt.agent.rollback
|
||||
|
||||
import com.intellij.history.Label
|
||||
import com.intellij.history.LocalHistory
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
import com.intellij.openapi.application.runReadAction
|
||||
import com.intellij.openapi.application.runWriteAction
|
||||
import com.intellij.openapi.command.WriteCommandAction
|
||||
import com.intellij.openapi.components.Service
|
||||
import com.intellij.openapi.fileTypes.FileTypeManager
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.roots.ProjectFileIndex
|
||||
import com.intellij.openapi.util.io.FileUtil
|
||||
import com.intellij.openapi.vfs.*
|
||||
import com.intellij.openapi.vfs.newvfs.BulkFileListener
|
||||
import com.intellij.openapi.vfs.newvfs.events.*
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.suspendCancellableCoroutine
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.nio.file.Paths
|
||||
import java.time.Instant
|
||||
import java.time.format.DateTimeFormatter
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import kotlin.coroutines.resume
|
||||
import kotlin.coroutines.resumeWithException
|
||||
|
||||
/**
|
||||
* Tracks file changes during agent runs and provides rollback to a checkpoint.
|
||||
*
|
||||
* Uses IntelliJ's LocalHistory labels for persistence and a lightweight in-memory
|
||||
* snapshot of modified files for fast restore.
|
||||
*/
|
||||
@Service(Service.Level.PROJECT)
|
||||
class RollbackService(private val project: Project) {
|
||||
|
||||
private val activeRuns = ConcurrentHashMap<String, RunTracker>()
|
||||
private val snapshots = ConcurrentHashMap<String, SnapshotState>()
|
||||
private val settingsService = project.getService(ProxyAISettingsService::class.java)
|
||||
|
||||
@Volatile
|
||||
private var isApplyingRollback = false
|
||||
|
||||
init {
|
||||
val connection = project.messageBus.connect()
|
||||
connection.subscribe(VirtualFileManager.VFS_CHANGES, object : BulkFileListener {
|
||||
override fun before(events: List<VFileEvent>) {
|
||||
if (isApplyingRollback || activeRuns.isEmpty()) return
|
||||
events.forEach { event ->
|
||||
when (event) {
|
||||
is VFileContentChangeEvent -> recordModified(event.file)
|
||||
is VFileDeleteEvent -> recordDeleted(event.file)
|
||||
is VFileMoveEvent -> {
|
||||
val oldPath = "${event.oldParent.path}/${event.file.name}"
|
||||
val newPath = "${event.newParent.path}/${event.file.name}"
|
||||
recordMoved(event.file, oldPath, newPath)
|
||||
}
|
||||
|
||||
is VFilePropertyChangeEvent -> {
|
||||
if (event.propertyName == VirtualFile.PROP_NAME) {
|
||||
val oldName = event.oldValue as? String ?: return@forEach
|
||||
val newName = event.newValue as? String ?: return@forEach
|
||||
val parentPath = event.file.parent?.path ?: return@forEach
|
||||
recordMoved(
|
||||
event.file,
|
||||
"$parentPath/$oldName",
|
||||
"$parentPath/$newName"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun after(events: List<VFileEvent>) {
|
||||
if (isApplyingRollback || activeRuns.isEmpty()) return
|
||||
events.forEach { event ->
|
||||
if (event is VFileCreateEvent) {
|
||||
recordCreated(event.path, event.isDirectory)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fun startSession(sessionId: String) {
|
||||
val labelText = "ProxyAI: Agent run ${DateTimeFormatter.ISO_INSTANT.format(Instant.now())}"
|
||||
val label = LocalHistory.getInstance().putSystemLabel(project, labelText)
|
||||
activeRuns[sessionId] = RunTracker(sessionId, Instant.now(), labelText, label)
|
||||
snapshots.remove(sessionId)
|
||||
}
|
||||
|
||||
fun finishSession(sessionId: String): RollbackSnapshot? {
|
||||
val tracker = activeRuns.remove(sessionId) ?: return getSnapshot(sessionId)
|
||||
val snapshot = SnapshotState(
|
||||
sessionId = sessionId,
|
||||
label = tracker.label,
|
||||
labelRef = tracker.labelRef,
|
||||
startedAt = tracker.startedAt,
|
||||
completedAt = Instant.now(),
|
||||
changes = tracker.changes.toMap()
|
||||
)
|
||||
snapshots[sessionId] = snapshot
|
||||
return snapshot.toSnapshot()
|
||||
}
|
||||
|
||||
fun getSnapshot(sessionId: String): RollbackSnapshot? =
|
||||
snapshots[sessionId]?.toSnapshot()
|
||||
|
||||
fun clearSnapshot(sessionId: String) {
|
||||
snapshots.remove(sessionId)
|
||||
}
|
||||
|
||||
fun getDiffData(sessionId: String, path: String): RollbackDiffData? {
|
||||
if (!isTrackable(path)) return null
|
||||
val snapshot = snapshots[sessionId] ?: return null
|
||||
val change = snapshot.changes[path] ?: return null
|
||||
val beforeText = when (change.kind) {
|
||||
ChangeKind.ADDED -> ""
|
||||
else -> decodeLabelContent(
|
||||
snapshot.labelRef,
|
||||
change.originalPath ?: path,
|
||||
change.originalContent
|
||||
)
|
||||
}
|
||||
val afterText = when (change.kind) {
|
||||
ChangeKind.DELETED -> ""
|
||||
else -> readCurrentText(path)
|
||||
}
|
||||
return RollbackDiffData(
|
||||
path = path,
|
||||
beforeText = beforeText,
|
||||
afterText = if (change.kind == ChangeKind.DELETED) "" else afterText
|
||||
)
|
||||
}
|
||||
|
||||
fun isRollbackAvailable(sessionId: String): Boolean =
|
||||
snapshots[sessionId]?.changes?.isNotEmpty() == true && !activeRuns.containsKey(sessionId)
|
||||
|
||||
fun isDisplayable(path: String): Boolean = isTrackable(path)
|
||||
|
||||
suspend fun rollbackFile(sessionId: String, path: String): RollbackResult =
|
||||
withContext(Dispatchers.Main) {
|
||||
val snapshot = snapshots[sessionId]
|
||||
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
|
||||
val change = snapshot.changes[path]
|
||||
?: return@withContext RollbackResult.Failure("No change tracked for $path")
|
||||
|
||||
val errors = mutableListOf<String>()
|
||||
isApplyingRollback = true
|
||||
try {
|
||||
runWriteSafe("ProxyAI Rollback") {
|
||||
applyChangeWithLabel(snapshot.labelRef, path, change, errors)
|
||||
}
|
||||
} finally {
|
||||
isApplyingRollback = false
|
||||
}
|
||||
|
||||
if (errors.isNotEmpty()) {
|
||||
RollbackResult.Failure(errors.joinToString("\n"))
|
||||
} else {
|
||||
val updated = snapshot.changes.toMutableMap()
|
||||
updated.remove(path)
|
||||
if (updated.isEmpty()) {
|
||||
snapshots.remove(sessionId)
|
||||
} else {
|
||||
snapshots[sessionId] = snapshot.copy(changes = updated.toMap())
|
||||
}
|
||||
RollbackResult.Success("Rollback completed")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.Main) {
|
||||
val snapshot = snapshots[sessionId]
|
||||
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
|
||||
|
||||
val errors = mutableListOf<String>()
|
||||
isApplyingRollback = true
|
||||
try {
|
||||
runWriteSafe("ProxyAI Rollback") {
|
||||
snapshot.changes.forEach { (path, change) ->
|
||||
applyChangeWithLabel(snapshot.labelRef, path, change, errors)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
isApplyingRollback = false
|
||||
}
|
||||
|
||||
if (errors.isNotEmpty()) {
|
||||
RollbackResult.Failure(errors.joinToString("\n"))
|
||||
} else {
|
||||
snapshots.remove(sessionId)
|
||||
RollbackResult.Success("Rollback completed")
|
||||
}
|
||||
}
|
||||
|
||||
private fun recordModified(file: VirtualFile) {
|
||||
if (!isTrackable(file)) return
|
||||
activeRuns.values.forEach { it.recordModified(file) }
|
||||
}
|
||||
|
||||
private fun recordDeleted(file: VirtualFile) {
|
||||
if (!isTrackable(file)) return
|
||||
activeRuns.values.forEach { it.recordDeleted(file) }
|
||||
}
|
||||
|
||||
private fun recordCreated(path: String, isDirectory: Boolean) {
|
||||
if (isDirectory || !isTrackable(path)) return
|
||||
activeRuns.values.forEach { it.recordCreated(path) }
|
||||
}
|
||||
|
||||
private fun recordMoved(file: VirtualFile, oldPath: String, newPath: String) {
|
||||
if (!isTrackable(file)) return
|
||||
val oldNormalized = oldPath.replace("\\", "/")
|
||||
val newNormalized = newPath.replace("\\", "/")
|
||||
if (!isTrackable(oldNormalized) && !isTrackable(newNormalized)) return
|
||||
activeRuns.values.forEach { it.recordMoved(file, oldNormalized, newNormalized) }
|
||||
}
|
||||
|
||||
private fun isTrackable(file: VirtualFile): Boolean {
|
||||
if (file.isDirectory || !file.isValid) return false
|
||||
if (FileTypeManager.getInstance().isFileIgnored(file)) return false
|
||||
if (runReadAction { !ProjectFileIndex.getInstance(project).isInContent(file) }) return false
|
||||
if (settingsService.isPathIgnored(file.path)) return false
|
||||
return true
|
||||
}
|
||||
|
||||
private fun isTrackable(path: String): Boolean {
|
||||
val file = LocalFileSystem.getInstance().findFileByPath(path)
|
||||
if (file != null) return isTrackable(file)
|
||||
|
||||
val basePath = project.basePath ?: return false
|
||||
val normalized = FileUtil.toSystemIndependentName(path)
|
||||
val normalizedBase = FileUtil.toSystemIndependentName(basePath)
|
||||
if (!FileUtil.isAncestor(
|
||||
normalizedBase,
|
||||
normalized,
|
||||
false
|
||||
) && normalized != normalizedBase
|
||||
) {
|
||||
return false
|
||||
}
|
||||
val fileName = runCatching { Paths.get(normalized).fileName?.toString() }
|
||||
.getOrNull()
|
||||
?: normalized.substringAfterLast('/')
|
||||
if (FileTypeManager.getInstance().isFileIgnored(fileName)) {
|
||||
return false
|
||||
}
|
||||
if (settingsService.isPathIgnored(normalized)) return false
|
||||
return true
|
||||
}
|
||||
|
||||
private fun readContentSafe(file: VirtualFile): ByteArray? =
|
||||
runCatching { file.contentsToByteArray() }.getOrNull()
|
||||
|
||||
private fun decodeContent(content: ByteArray?): String {
|
||||
if (content == null) return ""
|
||||
return runCatching { String(content, Charsets.UTF_8) }.getOrDefault("")
|
||||
}
|
||||
|
||||
private fun decodeLabelContent(label: Label, path: String, fallback: ByteArray?): String {
|
||||
val bytes = resolveLabelContent(label, path, fallback) ?: return ""
|
||||
return decodeContent(bytes)
|
||||
}
|
||||
|
||||
private fun readCurrentText(path: String): String {
|
||||
val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return ""
|
||||
return runCatching { VfsUtilCore.loadText(vf) }.getOrDefault("")
|
||||
}
|
||||
|
||||
private fun deleteFile(path: String, errors: MutableList<String>) {
|
||||
val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return
|
||||
if (!vf.exists()) return
|
||||
runCatching {
|
||||
runInEdt {
|
||||
runWriteAction {
|
||||
vf.delete(this)
|
||||
}
|
||||
}
|
||||
}.onFailure {
|
||||
errors.add("Failed to delete $path: ${it.message}")
|
||||
}
|
||||
}
|
||||
|
||||
private fun restoreFile(path: String?, content: ByteArray?, errors: MutableList<String>) {
|
||||
if (path.isNullOrBlank()) return
|
||||
if (content == null) {
|
||||
errors.add("Missing original content for $path")
|
||||
return
|
||||
}
|
||||
|
||||
val ioFile = File(path)
|
||||
val parent = ioFile.parentFile
|
||||
if (parent != null && !parent.exists()) {
|
||||
runCatching { parent.mkdirs() }.onFailure {
|
||||
errors.add("Failed to create parent directories for $path: ${it.message}")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
val file = LocalFileSystem.getInstance().refreshAndFindFileByPath(path)
|
||||
?: runCatching {
|
||||
val parentPath = parent?.path ?: run {
|
||||
errors.add("Missing parent directory for $path")
|
||||
return@runCatching null
|
||||
}
|
||||
val parentVf = VfsUtil.createDirectories(parentPath)
|
||||
parentVf.createChildData(this, ioFile.name)
|
||||
}.getOrNull()
|
||||
|
||||
if (file == null) {
|
||||
errors.add("Failed to recreate file $path")
|
||||
return
|
||||
}
|
||||
|
||||
runCatching {
|
||||
runInEdt {
|
||||
runWriteAction {
|
||||
file.setBinaryContent(content)
|
||||
}
|
||||
}
|
||||
}.onFailure {
|
||||
errors.add("Failed to restore $path: ${it.message}")
|
||||
}
|
||||
}
|
||||
|
||||
private fun applyChangeWithLabel(
|
||||
label: Label,
|
||||
path: String,
|
||||
change: TrackedChange,
|
||||
errors: MutableList<String>
|
||||
) {
|
||||
val file = LocalFileSystem.getInstance().refreshAndFindFileByPath(path)
|
||||
when (change.kind) {
|
||||
ChangeKind.ADDED -> {
|
||||
if (file != null) {
|
||||
val reverted = runCatching { label.revert(project, file) }.isSuccess
|
||||
if (!reverted) deleteFile(path, errors)
|
||||
} else {
|
||||
deleteFile(path, errors)
|
||||
}
|
||||
}
|
||||
|
||||
ChangeKind.MODIFIED -> {
|
||||
if (file != null) {
|
||||
val reverted = runCatching { label.revert(project, file) }.isSuccess
|
||||
if (!reverted) {
|
||||
val before = resolveLabelContent(label, path, change.originalContent)
|
||||
restoreFile(path, before, errors)
|
||||
}
|
||||
} else {
|
||||
val before = resolveLabelContent(label, path, change.originalContent)
|
||||
restoreFile(path, before, errors)
|
||||
}
|
||||
}
|
||||
|
||||
ChangeKind.DELETED -> {
|
||||
val before = resolveLabelContent(label, path, change.originalContent)
|
||||
restoreFile(path, before, errors)
|
||||
}
|
||||
|
||||
ChangeKind.MOVED -> {
|
||||
deleteFile(path, errors)
|
||||
val originalPath = change.originalPath
|
||||
if (originalPath != null) {
|
||||
val before = resolveLabelContent(label, originalPath, change.originalContent)
|
||||
restoreFile(originalPath, before, errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun resolveLabelContent(
|
||||
label: Label,
|
||||
path: String,
|
||||
fallback: ByteArray?
|
||||
): ByteArray? {
|
||||
return runCatching { label.getByteContent(path) }
|
||||
.getOrNull()?.bytes ?: fallback
|
||||
}
|
||||
|
||||
private suspend fun runWriteSafe(actionName: String, block: () -> Unit) {
|
||||
suspendCancellableCoroutine { cont ->
|
||||
runInEdt {
|
||||
try {
|
||||
WriteCommandAction.runWriteCommandAction(
|
||||
project,
|
||||
actionName,
|
||||
null,
|
||||
{ block() }
|
||||
)
|
||||
if (cont.isActive) cont.resume(Unit)
|
||||
} catch (e: Throwable) {
|
||||
if (cont.isActive) cont.resumeWithException(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun getInstance(project: Project): RollbackService {
|
||||
return project.getService(RollbackService::class.java)
|
||||
}
|
||||
}
|
||||
|
||||
private inner class RunTracker(
|
||||
val sessionId: String,
|
||||
val startedAt: Instant,
|
||||
val label: String,
|
||||
val labelRef: Label,
|
||||
val changes: MutableMap<String, TrackedChange> = ConcurrentHashMap()
|
||||
) {
|
||||
fun recordModified(file: VirtualFile) {
|
||||
val path = file.path
|
||||
val existing = changes[path]
|
||||
if (existing?.kind == ChangeKind.ADDED) return
|
||||
if (existing?.kind == ChangeKind.MOVED) return
|
||||
if (existing?.kind == ChangeKind.MODIFIED) return
|
||||
changes[path] = TrackedChange(
|
||||
kind = ChangeKind.MODIFIED,
|
||||
originalPath = null,
|
||||
originalContent = readContentSafe(file)
|
||||
)
|
||||
}
|
||||
|
||||
fun recordDeleted(file: VirtualFile) {
|
||||
val path = file.path
|
||||
val existing = changes[path]
|
||||
if (existing?.kind == ChangeKind.ADDED) {
|
||||
changes.remove(path)
|
||||
return
|
||||
}
|
||||
if (existing?.kind == ChangeKind.DELETED) return
|
||||
|
||||
val original = existing?.originalContent ?: readContentSafe(file)
|
||||
changes[path] = TrackedChange(
|
||||
kind = ChangeKind.DELETED,
|
||||
originalPath = null,
|
||||
originalContent = original
|
||||
)
|
||||
}
|
||||
|
||||
fun recordCreated(path: String) {
|
||||
val existing = changes[path]
|
||||
if (existing?.kind == ChangeKind.DELETED) {
|
||||
changes[path] = existing.copy(kind = ChangeKind.MODIFIED)
|
||||
return
|
||||
}
|
||||
if (existing == null) {
|
||||
changes[path] = TrackedChange(
|
||||
kind = ChangeKind.ADDED,
|
||||
originalPath = null,
|
||||
originalContent = null
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun recordMoved(file: VirtualFile, oldPath: String, newPath: String) {
|
||||
if (oldPath == newPath) return
|
||||
val existing = changes.remove(oldPath)
|
||||
|
||||
if (existing?.kind == ChangeKind.ADDED) {
|
||||
changes[newPath] = existing.copy(kind = ChangeKind.ADDED)
|
||||
return
|
||||
}
|
||||
|
||||
val content = existing?.originalContent ?: readContentSafe(file)
|
||||
changes[newPath] = TrackedChange(
|
||||
kind = ChangeKind.MOVED,
|
||||
originalPath = oldPath,
|
||||
originalContent = content
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private data class TrackedChange(
|
||||
val kind: ChangeKind,
|
||||
val originalPath: String?,
|
||||
val originalContent: ByteArray?
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (javaClass != other?.javaClass) return false
|
||||
|
||||
other as TrackedChange
|
||||
|
||||
if (kind != other.kind) return false
|
||||
if (originalPath != other.originalPath) return false
|
||||
if (!originalContent.contentEquals(other.originalContent)) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = kind.hashCode()
|
||||
result = 31 * result + (originalPath?.hashCode() ?: 0)
|
||||
result = 31 * result + (originalContent?.contentHashCode() ?: 0)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
private data class SnapshotState(
|
||||
val sessionId: String,
|
||||
val label: String,
|
||||
val labelRef: Label,
|
||||
val startedAt: Instant,
|
||||
val completedAt: Instant,
|
||||
val changes: Map<String, TrackedChange>
|
||||
) {
|
||||
fun toSnapshot(): RollbackSnapshot? {
|
||||
if (changes.isEmpty()) return null
|
||||
return RollbackSnapshot(
|
||||
sessionId = sessionId,
|
||||
label = label,
|
||||
startedAt = startedAt,
|
||||
completedAt = completedAt,
|
||||
changes = changes.map { (path, change) ->
|
||||
FileChange(
|
||||
path = path,
|
||||
kind = change.kind,
|
||||
originalPath = change.originalPath
|
||||
)
|
||||
}.sortedBy { it.path }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class RollbackSnapshot(
|
||||
val sessionId: String,
|
||||
val label: String,
|
||||
val startedAt: Instant,
|
||||
val completedAt: Instant,
|
||||
val changes: List<FileChange>
|
||||
)
|
||||
|
||||
data class FileChange(
|
||||
val path: String,
|
||||
val kind: ChangeKind,
|
||||
val originalPath: String?
|
||||
)
|
||||
|
||||
data class RollbackDiffData(
|
||||
val path: String,
|
||||
val beforeText: String,
|
||||
val afterText: String
|
||||
)
|
||||
|
||||
enum class ChangeKind {
|
||||
ADDED,
|
||||
MODIFIED,
|
||||
DELETED,
|
||||
MOVED
|
||||
}
|
||||
|
||||
sealed class RollbackResult {
|
||||
data class Success(val message: String) : RollbackResult()
|
||||
data class Failure(val message: String) : RollbackResult()
|
||||
}
|
||||
|
|
@ -11,6 +11,7 @@ import com.intellij.openapi.diagnostic.thisLogger
|
|||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.CodeGPTBundle
|
||||
import ee.carlrobert.codegpt.agent.*
|
||||
import ee.carlrobert.codegpt.agent.rollback.RollbackService
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import ee.carlrobert.codegpt.conversations.message.TokenUsage
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
|
|
@ -209,6 +210,7 @@ class AgentEventHandler(
|
|||
|
||||
private fun handleDone() {
|
||||
runInEdt {
|
||||
project.service<RollbackService>().finishSession(sessionId)
|
||||
currentResponseBody?.finishThinking()
|
||||
onHideLoading()
|
||||
userInputPanel.setStopEnabled(false)
|
||||
|
|
@ -416,16 +418,6 @@ class AgentEventHandler(
|
|||
toolName: String,
|
||||
result: Any?
|
||||
) {
|
||||
when (result) {
|
||||
is TaskTool.InternalResult -> {
|
||||
result.tokenUsage?.let { usage -> onTokenUsageUpdated(usage) }
|
||||
}
|
||||
|
||||
is TokenUsage -> {
|
||||
onTokenUsageUpdated(result)
|
||||
}
|
||||
}
|
||||
|
||||
runInEdt {
|
||||
if (childId != null) {
|
||||
val holder = subagentViewHolders.values.find { viewHolder ->
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import ee.carlrobert.codegpt.agent.AgentService
|
|||
import ee.carlrobert.codegpt.agent.AgentToolOutputNotifier
|
||||
import ee.carlrobert.codegpt.agent.MessageWithContext
|
||||
import ee.carlrobert.codegpt.agent.ToolRunContext
|
||||
import ee.carlrobert.codegpt.agent.rollback.RollbackService
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.conversations.message.QueuedMessage
|
||||
|
|
@ -23,6 +24,7 @@ import ee.carlrobert.codegpt.psistructure.PsiStructureProvider
|
|||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.AgentToolWindowLandingPanel
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.RollbackPanel
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.TodoListPanel
|
||||
import ee.carlrobert.codegpt.toolwindow.chat.MessageBuilder
|
||||
import ee.carlrobert.codegpt.toolwindow.chat.editor.actions.CopyAction
|
||||
|
|
@ -105,9 +107,11 @@ class AgentToolWindowTabPanel(
|
|||
agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId),
|
||||
sessionIdProvider = { sessionId }
|
||||
)
|
||||
private lateinit var rollbackPanel: RollbackPanel
|
||||
private val todoListPanel = TodoListPanel()
|
||||
private val projectMessageBusConnection = project.messageBus.connect()
|
||||
private val appMessageBusConnection = ApplicationManager.getApplication().messageBus.connect()
|
||||
private val rollbackService = RollbackService.getInstance(project)
|
||||
|
||||
private val agentEventHandler = AgentEventHandler(
|
||||
project = project,
|
||||
|
|
@ -127,6 +131,7 @@ class AgentToolWindowTabPanel(
|
|||
loadingLabel.isVisible = false
|
||||
revalidate()
|
||||
repaint()
|
||||
rollbackPanel.refreshOperations()
|
||||
},
|
||||
onQueuedMessagesResolved = { message ->
|
||||
runInEdt {
|
||||
|
|
@ -139,6 +144,9 @@ class AgentToolWindowTabPanel(
|
|||
|
||||
init {
|
||||
setupMessageBusSubscriptions()
|
||||
rollbackPanel = RollbackPanel(project, sessionId) {
|
||||
rollbackPanel.refreshOperations()
|
||||
}
|
||||
setupUI()
|
||||
|
||||
if (conversation.messages.isEmpty()) {
|
||||
|
|
@ -182,6 +190,9 @@ class AgentToolWindowTabPanel(
|
|||
isOpaque = false
|
||||
}
|
||||
|
||||
rollbackPanel.alignmentX = LEFT_ALIGNMENT
|
||||
topContainer.add(rollbackPanel)
|
||||
|
||||
todoListPanel.alignmentX = LEFT_ALIGNMENT
|
||||
topContainer.add(todoListPanel)
|
||||
|
||||
|
|
@ -233,6 +244,9 @@ class AgentToolWindowTabPanel(
|
|||
.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.RUNNING)
|
||||
}
|
||||
|
||||
rollbackService.startSession(sessionId)
|
||||
rollbackPanel.refreshOperations()
|
||||
|
||||
val message = MessageWithContext(text, userInputPanel.getSelectedTags())
|
||||
val messagePanel = scrollablePanel.addMessage(message.id)
|
||||
val userPanel = UserMessagePanel(
|
||||
|
|
@ -275,6 +289,9 @@ class AgentToolWindowTabPanel(
|
|||
agentService.cancelCurrentRun(sessionId)
|
||||
agentService.clearPendingMessages(sessionId)
|
||||
|
||||
rollbackService.finishSession(sessionId)
|
||||
rollbackPanel.refreshOperations()
|
||||
|
||||
approvalContainer.removeAll()
|
||||
clearQueuedMessages()
|
||||
approvalContainer.isVisible = false
|
||||
|
|
|
|||
|
|
@ -0,0 +1,391 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent.ui
|
||||
|
||||
import com.intellij.diff.DiffContentFactory
|
||||
import com.intellij.diff.DiffManager
|
||||
import com.intellij.diff.requests.SimpleDiffRequest
|
||||
import com.intellij.icons.AllIcons
|
||||
import com.intellij.ide.actions.OpenFileAction
|
||||
import com.intellij.openapi.actionSystem.AnAction
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
import com.intellij.openapi.application.EDT
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.ui.Messages
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.ui.JBColor
|
||||
import com.intellij.ui.components.ActionLink
|
||||
import com.intellij.ui.components.JBLabel
|
||||
import com.intellij.util.ui.JBUI
|
||||
import com.intellij.util.ui.components.BorderLayoutPanel
|
||||
import ee.carlrobert.codegpt.agent.rollback.*
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.ChangeColors
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.lineDiffStats
|
||||
import ee.carlrobert.codegpt.ui.IconActionButton
|
||||
import kotlinx.coroutines.*
|
||||
import java.awt.FlowLayout
|
||||
import java.awt.GridLayout
|
||||
import java.time.Instant
|
||||
import java.time.ZoneId
|
||||
import java.time.format.DateTimeFormatter
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import javax.swing.*
|
||||
|
||||
/**
|
||||
* Panel showing file operations performed by the agent with rollback controls.
|
||||
*
|
||||
*/
|
||||
class RollbackPanel(
|
||||
private val project: Project,
|
||||
private val sessionId: String,
|
||||
private val onRollbackComplete: () -> Unit
|
||||
) : BorderLayoutPanel() {
|
||||
|
||||
private val rollbackService = RollbackService.getInstance(project)
|
||||
private val titleLabel = JBLabel()
|
||||
private val timeLabel = JBLabel()
|
||||
private val summaryPanel = JPanel(FlowLayout(FlowLayout.RIGHT, 6, 2))
|
||||
private val changesPanel = JPanel()
|
||||
private val footerPanel = JPanel(GridLayout(1, 2, 8, 0))
|
||||
private val diffStatsCache = ConcurrentHashMap<String, Triple<Int, Int, Int>>()
|
||||
private val diffDataCache = ConcurrentHashMap<String, RollbackDiffData>()
|
||||
private val backgroundScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
|
||||
init {
|
||||
setupUI()
|
||||
}
|
||||
|
||||
private fun setupUI() {
|
||||
val headerPanel = JPanel(FlowLayout(FlowLayout.LEFT, 6, 2)).apply {
|
||||
isOpaque = false
|
||||
add(JBLabel(AllIcons.Actions.Diff))
|
||||
add(JBLabel(AllIcons.General.Add))
|
||||
add(titleLabel.apply {
|
||||
font = font.deriveFont(java.awt.Font.BOLD)
|
||||
})
|
||||
add(timeLabel.apply {
|
||||
foreground = JBUI.CurrentTheme.Label.disabledForeground()
|
||||
})
|
||||
}
|
||||
|
||||
summaryPanel.apply {
|
||||
isOpaque = false
|
||||
}
|
||||
|
||||
val topPanel = BorderLayoutPanel().apply {
|
||||
addToLeft(headerPanel)
|
||||
addToRight(summaryPanel)
|
||||
border = JBUI.Borders.empty(8)
|
||||
}
|
||||
|
||||
changesPanel.apply {
|
||||
layout = BoxLayout(this, BoxLayout.Y_AXIS)
|
||||
isOpaque = false
|
||||
}
|
||||
|
||||
footerPanel.apply {
|
||||
isOpaque = true
|
||||
border = JBUI.Borders.customLine(JBColor.border(), 1, 0, 0, 0)
|
||||
add(createRollbackAllButton())
|
||||
add(createKeepButton())
|
||||
}
|
||||
|
||||
addToTop(topPanel)
|
||||
addToCenter(
|
||||
BorderLayoutPanel().apply {
|
||||
addToCenter(changesPanel)
|
||||
addToBottom(footerPanel)
|
||||
}
|
||||
)
|
||||
border = JBUI.Borders.compound(
|
||||
JBUI.Borders.customLine(JBColor.border(), 0, 0, 1, 0),
|
||||
JBUI.Borders.empty(0, 0, 8, 0)
|
||||
)
|
||||
|
||||
refreshOperations()
|
||||
}
|
||||
|
||||
fun refreshOperations() {
|
||||
val snapshot = rollbackService.getSnapshot(sessionId)
|
||||
val changes = snapshot?.changes.orEmpty()
|
||||
.filter { rollbackService.isDisplayable(it.path) }
|
||||
.sortedBy { it.path }
|
||||
isVisible = changes.isNotEmpty()
|
||||
if (changes.isEmpty()) {
|
||||
titleLabel.text = "Agent changes"
|
||||
timeLabel.text = ""
|
||||
changesPanel.removeAll()
|
||||
summaryPanel.removeAll()
|
||||
footerPanel.isVisible = false
|
||||
revalidate()
|
||||
repaint()
|
||||
return
|
||||
}
|
||||
|
||||
val timeText = snapshot?.completedAt?.let { formatTime(it) } ?: ""
|
||||
titleLabel.text = "Agent changes (${changes.size})"
|
||||
timeLabel.text = if (timeText.isNotBlank()) "• $timeText" else ""
|
||||
|
||||
updateSummary(changes)
|
||||
preloadDiffStats(changes)
|
||||
|
||||
changesPanel.removeAll()
|
||||
changes.forEachIndexed { index, change ->
|
||||
if (index > 0) changesPanel.add(Box.createVerticalStrut(4))
|
||||
changesPanel.add(createChangeRow(change))
|
||||
}
|
||||
footerPanel.isVisible = true
|
||||
|
||||
revalidate()
|
||||
repaint()
|
||||
}
|
||||
|
||||
private fun handleRollback() {
|
||||
CoroutineScope(Dispatchers.Main).launch {
|
||||
val result = rollbackService.rollbackSession(sessionId)
|
||||
|
||||
withContext(Dispatchers.Main) {
|
||||
when (result) {
|
||||
is RollbackResult.Success -> {
|
||||
refreshOperations()
|
||||
onRollbackComplete()
|
||||
}
|
||||
|
||||
is RollbackResult.Failure -> {
|
||||
Messages.showErrorDialog(project, result.message, "Rollback Failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createChangeRow(change: FileChange): JComponent {
|
||||
val row = BorderLayoutPanel().apply {
|
||||
isOpaque = true
|
||||
background = JBUI.CurrentTheme.List.background(false, false)
|
||||
border = JBUI.Borders.compound(
|
||||
JBUI.Borders.customLine(JBColor.border(), 1),
|
||||
JBUI.Borders.empty(4, 8)
|
||||
)
|
||||
}
|
||||
|
||||
val left = JPanel(FlowLayout(FlowLayout.LEFT, 6, 2)).apply {
|
||||
isOpaque = false
|
||||
}
|
||||
left.add(changeLabel(change))
|
||||
left.add(fileNameComponent(change))
|
||||
left.add(filePathLabel(change))
|
||||
addDiffStats(change, left)
|
||||
|
||||
val actions = JPanel(FlowLayout(FlowLayout.RIGHT, 6, 0)).apply {
|
||||
isOpaque = false
|
||||
}
|
||||
openDiffAction(change).let { actions.add(it) }
|
||||
actions.add(rollbackAction(change))
|
||||
|
||||
row.addToLeft(left)
|
||||
row.addToRight(actions)
|
||||
return row
|
||||
}
|
||||
|
||||
private fun changeLabel(change: FileChange): JBLabel {
|
||||
val (text, color) = when (change.kind) {
|
||||
ChangeKind.ADDED -> "+" to ChangeColors.inserted
|
||||
ChangeKind.DELETED -> "-" to ChangeColors.deleted
|
||||
ChangeKind.MODIFIED -> "~" to ChangeColors.modified
|
||||
ChangeKind.MOVED -> "~" to ChangeColors.modified
|
||||
}
|
||||
return JBLabel(text).apply {
|
||||
foreground = color
|
||||
font = JBUI.Fonts.smallFont()
|
||||
}
|
||||
}
|
||||
|
||||
private fun fileNameComponent(change: FileChange): JComponent {
|
||||
val display = displayFileName(change.path)
|
||||
val file = LocalFileSystem.getInstance().findFileByPath(change.path)
|
||||
return if (file != null && change.kind != ChangeKind.DELETED) {
|
||||
ActionLink(display) {
|
||||
OpenFileAction.openFile(file, project)
|
||||
}.apply {
|
||||
font = JBUI.Fonts.label().asBold()
|
||||
}
|
||||
} else {
|
||||
JBLabel(display).apply {
|
||||
font = JBUI.Fonts.label().asBold()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun filePathLabel(change: FileChange): JBLabel {
|
||||
val display = displayPath(change.path, change)
|
||||
return JBLabel(display).apply {
|
||||
foreground = JBUI.CurrentTheme.Label.disabledForeground()
|
||||
font = JBUI.Fonts.smallFont()
|
||||
}
|
||||
}
|
||||
|
||||
private fun openDiffAction(change: FileChange): JComponent {
|
||||
return IconActionButton(
|
||||
object : AnAction("Open Diff", "Open diff view", AllIcons.Actions.Diff) {
|
||||
override fun actionPerformed(e: AnActionEvent) {
|
||||
openDiffForPath(change.path)
|
||||
}
|
||||
},
|
||||
"OPEN_DIFF"
|
||||
)
|
||||
}
|
||||
|
||||
private fun rollbackAction(change: FileChange): JComponent {
|
||||
return IconActionButton(
|
||||
object : AnAction("Rollback File", "Rollback this file", AllIcons.Actions.Undo) {
|
||||
override fun actionPerformed(e: AnActionEvent) {
|
||||
rollbackFile(change.path)
|
||||
}
|
||||
},
|
||||
"ROLLBACK_FILE"
|
||||
)
|
||||
}
|
||||
|
||||
private fun displayPath(path: String, change: FileChange): String {
|
||||
val base = project.basePath?.replace("\\", "/")
|
||||
val normalized = path.replace("\\", "/")
|
||||
val relative = if (base != null && normalized.startsWith(base)) {
|
||||
normalized.removePrefix(base).trimStart('/')
|
||||
} else normalized
|
||||
val original = change.originalPath?.replace("\\", "/")
|
||||
return if (change.kind == ChangeKind.MOVED && original != null) {
|
||||
val originalDisplay = if (base != null && original.startsWith(base)) {
|
||||
original.removePrefix(base).trimStart('/')
|
||||
} else original
|
||||
"$relative (renamed from $originalDisplay)"
|
||||
} else {
|
||||
relative
|
||||
}
|
||||
}
|
||||
|
||||
private fun displayFileName(path: String): String {
|
||||
val normalized = path.replace("\\", "/")
|
||||
return normalized.substringAfterLast('/')
|
||||
}
|
||||
|
||||
private fun addDiffStats(change: FileChange, container: JPanel) {
|
||||
val stats = diffStatsCache[change.path] ?: return
|
||||
val (ins, del, mod) = stats
|
||||
if (ins + del + mod == 0) return
|
||||
container.add(colorLabel("+$ins", ChangeColors.inserted))
|
||||
container.add(colorLabel("-$del", ChangeColors.deleted))
|
||||
container.add(colorLabel("~$mod", ChangeColors.modified))
|
||||
}
|
||||
|
||||
private fun colorLabel(text: String, color: JBColor): JBLabel =
|
||||
JBLabel(text).apply {
|
||||
foreground = color
|
||||
font = JBUI.Fonts.smallFont()
|
||||
}
|
||||
|
||||
private fun formatTime(instant: Instant): String {
|
||||
val formatter = DateTimeFormatter.ofPattern("HH:mm").withZone(ZoneId.systemDefault())
|
||||
return "Last run at ${formatter.format(instant)}"
|
||||
}
|
||||
|
||||
private fun updateSummary(changes: List<FileChange>) {
|
||||
val added = changes.count { it.kind == ChangeKind.ADDED }
|
||||
val deleted = changes.count { it.kind == ChangeKind.DELETED }
|
||||
val modified =
|
||||
changes.count { it.kind == ChangeKind.MODIFIED || it.kind == ChangeKind.MOVED }
|
||||
summaryPanel.removeAll()
|
||||
if (added > 0) summaryPanel.add(colorChip("+$added", ChangeColors.inserted))
|
||||
if (deleted > 0) summaryPanel.add(colorChip("-$deleted", ChangeColors.deleted))
|
||||
if (modified > 0) summaryPanel.add(colorChip("~$modified", ChangeColors.modified))
|
||||
summaryPanel.revalidate()
|
||||
summaryPanel.repaint()
|
||||
}
|
||||
|
||||
private fun colorChip(text: String, color: JBColor): JBLabel =
|
||||
JBLabel(text).apply {
|
||||
foreground = color
|
||||
font = JBUI.Fonts.smallFont()
|
||||
}
|
||||
|
||||
private fun createRollbackAllButton(): JComponent {
|
||||
return JButton("Rollback all").apply {
|
||||
isFocusable = false
|
||||
isContentAreaFilled = true
|
||||
isOpaque = true
|
||||
putClientProperty("JButton.buttonType", "roundRect")
|
||||
margin = JBUI.insets(6, 12)
|
||||
addActionListener { handleRollback() }
|
||||
}
|
||||
}
|
||||
|
||||
private fun createKeepButton(): JComponent {
|
||||
return JButton("Keep changes").apply {
|
||||
isFocusable = false
|
||||
isContentAreaFilled = true
|
||||
isOpaque = true
|
||||
putClientProperty("JButton.buttonType", "roundRect")
|
||||
margin = JBUI.insets(6, 12)
|
||||
addActionListener {
|
||||
rollbackService.clearSnapshot(sessionId)
|
||||
refreshOperations()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun rollbackFile(path: String) {
|
||||
CoroutineScope(Dispatchers.Main).launch {
|
||||
val result = rollbackService.rollbackFile(sessionId, path)
|
||||
|
||||
withContext(Dispatchers.Main) {
|
||||
when (result) {
|
||||
is RollbackResult.Success -> {
|
||||
refreshOperations()
|
||||
onRollbackComplete()
|
||||
}
|
||||
|
||||
is RollbackResult.Failure -> {
|
||||
Messages.showErrorDialog(project, result.message, "Rollback Failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun preloadDiffStats(changes: List<FileChange>) {
|
||||
changes.forEach { change ->
|
||||
if (diffStatsCache.containsKey(change.path)) return@forEach
|
||||
backgroundScope.launch {
|
||||
val diffData = rollbackService.getDiffData(sessionId, change.path) ?: return@launch
|
||||
diffDataCache[change.path] = diffData
|
||||
diffStatsCache[change.path] =
|
||||
lineDiffStats(diffData.beforeText, diffData.afterText)
|
||||
withContext(Dispatchers.EDT) {
|
||||
refreshOperations()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun openDiffForPath(path: String) {
|
||||
val cached = diffDataCache[path]
|
||||
if (cached != null) {
|
||||
openDiff(cached)
|
||||
return
|
||||
}
|
||||
backgroundScope.launch {
|
||||
val diffData = rollbackService.getDiffData(sessionId, path) ?: return@launch
|
||||
diffDataCache[path] = diffData
|
||||
withContext(Dispatchers.Main) {
|
||||
openDiff(diffData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun openDiff(diffData: RollbackDiffData) {
|
||||
val contentFactory = DiffContentFactory.getInstance()
|
||||
val before = contentFactory.create(diffData.beforeText)
|
||||
val after = contentFactory.create(diffData.afterText)
|
||||
val title = "Agent change: ${displayFileName(diffData.path)}"
|
||||
val request = SimpleDiffRequest(title, before, after, "Before", "After")
|
||||
DiffManager.getInstance().showDiff(project, request)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue