fix: rollback changes tracking

This commit is contained in:
Carl-Robert Linnupuu 2026-01-29 00:41:36 +00:00
parent ee7b662ec4
commit 6ae4187f6b
5 changed files with 111 additions and 21 deletions

View file

@ -3,16 +3,18 @@ package ee.carlrobert.codegpt.agent.rollback
import com.intellij.history.Label
import com.intellij.history.LocalHistory
import com.intellij.openapi.application.ModalityState
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.components.Service
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.fileTypes.FileTypeManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFileManager
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
import ee.carlrobert.codegpt.agent.tools.EditTool
import ee.carlrobert.codegpt.agent.tools.EditArgsSnapshot
import ee.carlrobert.codegpt.agent.tools.WriteTool
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
@ -51,9 +53,15 @@ class RollbackService(private val project: Project) {
* Track an EditTool operation directly from the agent.
* This captures the original file content before the edit is applied.
*/
fun trackEdit(sessionId: String, filePath: String, args: EditTool.Args, originalContent: String) {
fun trackEdit(
sessionId: String,
filePath: String,
args: EditArgsSnapshot,
originalContent: String
) {
val tracker = activeRuns[sessionId] ?: return
tracker.recordExplicitEdit(filePath, args, originalContent)
val normalizedPath = filePath.replace("\\", "/")
tracker.recordExplicitEdit(normalizedPath, args, originalContent)
}
/**
@ -62,7 +70,8 @@ class RollbackService(private val project: Project) {
*/
fun trackWrite(sessionId: String, filePath: String, args: WriteTool.Args) {
val tracker = activeRuns[sessionId] ?: return
tracker.recordExplicitWrite(filePath, args)
val normalizedPath = filePath.replace("\\", "/")
tracker.recordExplicitWrite(normalizedPath, args)
}
fun startSession(sessionId: String) {
@ -94,16 +103,23 @@ class RollbackService(private val project: Project) {
}
fun getDiffData(sessionId: String, path: String): RollbackDiffData? {
if (!isTrackable(path)) return null
val snapshot = snapshots[sessionId] ?: return null
val change = snapshot.changes[path] ?: return null
if (change.kind != ChangeKind.DELETED && !isTrackable(path)) return null
val beforeText = when (change.kind) {
ChangeKind.ADDED -> ""
else -> decodeLabelContent(
snapshot.labelRef,
change.originalPath ?: path,
change.originalContent
)
else -> {
val original = change.originalContent
if (original != null && original.isNotEmpty()) {
decodeContent(original)
} else {
decodeLabelContent(
snapshot.labelRef,
change.originalPath ?: path,
change.originalContent
)
}
}
}
val afterText = when (change.kind) {
ChangeKind.DELETED -> ""
@ -212,6 +228,10 @@ class RollbackService(private val project: Project) {
private fun readCurrentText(path: String): String {
val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return ""
val docText = runReadAction {
FileDocumentManager.getInstance().getDocument(vf)?.text
}
if (docText != null) return docText
return runCatching { VfsUtilCore.loadText(vf) }.getOrDefault("")
}
@ -348,7 +368,7 @@ class RollbackService(private val project: Project) {
val labelRef: Label,
val changes: MutableMap<String, TrackedChange> = ConcurrentHashMap()
) {
fun recordExplicitEdit(filePath: String, args: EditTool.Args, originalContent: String) {
fun recordExplicitEdit(filePath: String, args: EditArgsSnapshot, originalContent: String) {
val existing = changes[filePath]
if (existing?.kind == ChangeKind.ADDED) return
if (existing?.kind == ChangeKind.MOVED) return
@ -466,4 +486,4 @@ enum class ChangeKind {
sealed class RollbackResult {
data class Success(val message: String) : RollbackResult()
data class Failure(val message: String) : RollbackResult()
}
}

View file

@ -0,0 +1,38 @@
package ee.carlrobert.codegpt.agent.tools
data class EditArgsSnapshot(
val filePath: String,
val oldString: String,
val newString: String,
val replaceAll: Boolean,
val shortDescription: String
)
fun EditTool.Args.toSnapshot(): EditArgsSnapshot {
return EditArgsSnapshot(
filePath = filePath,
oldString = oldString,
newString = newString,
replaceAll = replaceAll,
shortDescription = shortDescription
)
}
fun ProxyAIEditTool.Args.toSnapshot(): EditArgsSnapshot {
return EditArgsSnapshot(
filePath = filePath,
oldString = "",
newString = updateSnippet,
replaceAll = false,
shortDescription = shortDescription
)
}
fun snapshotFromEditArgs(args: Any?): EditArgsSnapshot? {
return when (args) {
is EditTool.Args -> args.toSnapshot()
is ProxyAIEditTool.Args -> args.toSnapshot()
is EditArgsSnapshot -> args
else -> null
}
}

View file

@ -210,4 +210,18 @@ object EditorUtil {
}
}
}
}
@JvmStatic
fun writeDocumentContent(project: Project, virtualFile: VirtualFile, content: String): Boolean {
val document = runReadAction {
FileDocumentManager.getInstance().getDocument(virtualFile)
} ?: return false
WriteCommandAction.runWriteCommandAction(project) {
document.setText(content)
FileDocumentManager.getInstance().saveDocument(document)
}
return true
}
}

View file

@ -9,6 +9,7 @@ import com.intellij.openapi.util.io.FileUtil.createDirectory
import com.intellij.openapi.vfs.JarFileSystem
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.openapi.vfs.VirtualFile
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings.getLlamaModelsPath
import java.io.File
@ -42,7 +43,6 @@ object FileUtil {
}
}
@JvmStatic
fun readContent(virtualFile: VirtualFile): String {
try {
return VfsUtilCore.loadText(virtualFile)
@ -113,7 +113,6 @@ object FileUtil {
}.takeIf { it } ?: throw RuntimeException("Failed to create directory: $directoryPath")
}
@JvmStatic
fun getFileExtension(filename: String?): String {
val pattern = Pattern.compile("[^.]+$")
val matcher = filename?.let { pattern.matcher(it) }
@ -124,7 +123,6 @@ object FileUtil {
return ""
}
@JvmStatic
fun findLanguageExtensionMapping(language: String? = ""): Map.Entry<String, String> {
val defaultValue = mapOf("Text" to ".txt").entries.first()
val mapper = ObjectMapper()
@ -169,7 +167,6 @@ object FileUtil {
}
}
@JvmStatic
fun getImageMediaType(fileName: String?): String {
return when (val fileExtension = getFileExtension(fileName)) {
"png" -> "image/png"
@ -178,7 +175,6 @@ object FileUtil {
}
}
@JvmStatic
fun getResourceContent(filePath: String?): String {
try {
Objects.requireNonNull(filePath?.let { FileUtil::class.java.getResourceAsStream(it) })
@ -216,7 +212,6 @@ object FileUtil {
return value.toString()
}
@JvmStatic
fun findFirstExtension(
languageFileExtensionMappings: List<LanguageFileExtensionDetails>,
language: String? = ""
@ -254,4 +249,27 @@ object FileUtil {
null
}
}
fun findVirtualFile(normalizedPath: String): VirtualFile? {
return VirtualFileManager.getInstance().refreshAndFindFileByUrl("file://$normalizedPath")
?: LocalFileSystem.getInstance().findFileByIoFile(File(normalizedPath))
}
fun validateFileForEdit(filePath: String): Result<File> {
val normalizedPath = filePath.replace("\\", "/")
val file = File(normalizedPath)
return when {
!file.exists() -> Result.failure(
IllegalArgumentException("File not found: $filePath (File does not exist on filesystem)")
)
!file.isFile -> Result.failure(
IllegalArgumentException("Path is not a file: $filePath")
)
!file.canWrite() -> Result.failure(
IllegalArgumentException("File is not writable: $filePath")
)
else -> Result.success(file)
}
}
}

View file

@ -1,7 +1,7 @@
package ee.carlrobert.codegpt.agent.rollback
import com.intellij.openapi.vfs.LocalFileSystem
import ee.carlrobert.codegpt.agent.tools.EditTool
import ee.carlrobert.codegpt.agent.tools.EditArgsSnapshot
import ee.carlrobert.codegpt.agent.tools.WriteTool
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
@ -28,7 +28,7 @@ class RollbackServiceTrackingTest : IntegrationTest() {
rollbackService.trackEdit(
sessionId = sessionId,
filePath = filePath,
args = EditTool.Args(filePath, "before", "after", "test", false),
args = EditArgsSnapshot(filePath, "before", "after", false, "test"),
originalContent = "before"
)
val snapshot = rollbackService.finishSession(sessionId)