mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 07:54:46 +00:00
fix: rollback changes tracking
This commit is contained in:
parent
ee7b662ec4
commit
6ae4187f6b
5 changed files with 111 additions and 21 deletions
|
|
@ -3,16 +3,18 @@ package ee.carlrobert.codegpt.agent.rollback
|
|||
import com.intellij.history.Label
|
||||
import com.intellij.history.LocalHistory
|
||||
import com.intellij.openapi.application.ModalityState
|
||||
import com.intellij.openapi.application.runReadAction
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
import com.intellij.openapi.application.runWriteAction
|
||||
import com.intellij.openapi.components.Service
|
||||
import com.intellij.openapi.fileEditor.FileDocumentManager
|
||||
import com.intellij.openapi.fileTypes.FileTypeManager
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.openapi.vfs.VfsUtilCore
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import ee.carlrobert.codegpt.agent.tools.EditTool
|
||||
import ee.carlrobert.codegpt.agent.tools.EditArgsSnapshot
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
|
|
@ -51,9 +53,15 @@ class RollbackService(private val project: Project) {
|
|||
* Track an EditTool operation directly from the agent.
|
||||
* This captures the original file content before the edit is applied.
|
||||
*/
|
||||
fun trackEdit(sessionId: String, filePath: String, args: EditTool.Args, originalContent: String) {
|
||||
fun trackEdit(
|
||||
sessionId: String,
|
||||
filePath: String,
|
||||
args: EditArgsSnapshot,
|
||||
originalContent: String
|
||||
) {
|
||||
val tracker = activeRuns[sessionId] ?: return
|
||||
tracker.recordExplicitEdit(filePath, args, originalContent)
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
tracker.recordExplicitEdit(normalizedPath, args, originalContent)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -62,7 +70,8 @@ class RollbackService(private val project: Project) {
|
|||
*/
|
||||
fun trackWrite(sessionId: String, filePath: String, args: WriteTool.Args) {
|
||||
val tracker = activeRuns[sessionId] ?: return
|
||||
tracker.recordExplicitWrite(filePath, args)
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
tracker.recordExplicitWrite(normalizedPath, args)
|
||||
}
|
||||
|
||||
fun startSession(sessionId: String) {
|
||||
|
|
@ -94,16 +103,23 @@ class RollbackService(private val project: Project) {
|
|||
}
|
||||
|
||||
fun getDiffData(sessionId: String, path: String): RollbackDiffData? {
|
||||
if (!isTrackable(path)) return null
|
||||
val snapshot = snapshots[sessionId] ?: return null
|
||||
val change = snapshot.changes[path] ?: return null
|
||||
if (change.kind != ChangeKind.DELETED && !isTrackable(path)) return null
|
||||
val beforeText = when (change.kind) {
|
||||
ChangeKind.ADDED -> ""
|
||||
else -> decodeLabelContent(
|
||||
snapshot.labelRef,
|
||||
change.originalPath ?: path,
|
||||
change.originalContent
|
||||
)
|
||||
else -> {
|
||||
val original = change.originalContent
|
||||
if (original != null && original.isNotEmpty()) {
|
||||
decodeContent(original)
|
||||
} else {
|
||||
decodeLabelContent(
|
||||
snapshot.labelRef,
|
||||
change.originalPath ?: path,
|
||||
change.originalContent
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
val afterText = when (change.kind) {
|
||||
ChangeKind.DELETED -> ""
|
||||
|
|
@ -212,6 +228,10 @@ class RollbackService(private val project: Project) {
|
|||
|
||||
private fun readCurrentText(path: String): String {
|
||||
val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return ""
|
||||
val docText = runReadAction {
|
||||
FileDocumentManager.getInstance().getDocument(vf)?.text
|
||||
}
|
||||
if (docText != null) return docText
|
||||
return runCatching { VfsUtilCore.loadText(vf) }.getOrDefault("")
|
||||
}
|
||||
|
||||
|
|
@ -348,7 +368,7 @@ class RollbackService(private val project: Project) {
|
|||
val labelRef: Label,
|
||||
val changes: MutableMap<String, TrackedChange> = ConcurrentHashMap()
|
||||
) {
|
||||
fun recordExplicitEdit(filePath: String, args: EditTool.Args, originalContent: String) {
|
||||
fun recordExplicitEdit(filePath: String, args: EditArgsSnapshot, originalContent: String) {
|
||||
val existing = changes[filePath]
|
||||
if (existing?.kind == ChangeKind.ADDED) return
|
||||
if (existing?.kind == ChangeKind.MOVED) return
|
||||
|
|
@ -466,4 +486,4 @@ enum class ChangeKind {
|
|||
sealed class RollbackResult {
|
||||
data class Success(val message: String) : RollbackResult()
|
||||
data class Failure(val message: String) : RollbackResult()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
package ee.carlrobert.codegpt.agent.tools
|
||||
|
||||
data class EditArgsSnapshot(
|
||||
val filePath: String,
|
||||
val oldString: String,
|
||||
val newString: String,
|
||||
val replaceAll: Boolean,
|
||||
val shortDescription: String
|
||||
)
|
||||
|
||||
fun EditTool.Args.toSnapshot(): EditArgsSnapshot {
|
||||
return EditArgsSnapshot(
|
||||
filePath = filePath,
|
||||
oldString = oldString,
|
||||
newString = newString,
|
||||
replaceAll = replaceAll,
|
||||
shortDescription = shortDescription
|
||||
)
|
||||
}
|
||||
|
||||
fun ProxyAIEditTool.Args.toSnapshot(): EditArgsSnapshot {
|
||||
return EditArgsSnapshot(
|
||||
filePath = filePath,
|
||||
oldString = "",
|
||||
newString = updateSnippet,
|
||||
replaceAll = false,
|
||||
shortDescription = shortDescription
|
||||
)
|
||||
}
|
||||
|
||||
fun snapshotFromEditArgs(args: Any?): EditArgsSnapshot? {
|
||||
return when (args) {
|
||||
is EditTool.Args -> args.toSnapshot()
|
||||
is ProxyAIEditTool.Args -> args.toSnapshot()
|
||||
is EditArgsSnapshot -> args
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
|
@ -210,4 +210,18 @@ object EditorUtil {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun writeDocumentContent(project: Project, virtualFile: VirtualFile, content: String): Boolean {
|
||||
val document = runReadAction {
|
||||
FileDocumentManager.getInstance().getDocument(virtualFile)
|
||||
} ?: return false
|
||||
|
||||
WriteCommandAction.runWriteCommandAction(project) {
|
||||
document.setText(content)
|
||||
FileDocumentManager.getInstance().saveDocument(document)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import com.intellij.openapi.util.io.FileUtil.createDirectory
|
|||
import com.intellij.openapi.vfs.JarFileSystem
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.openapi.vfs.VfsUtilCore
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings.getLlamaModelsPath
|
||||
import java.io.File
|
||||
|
|
@ -42,7 +43,6 @@ object FileUtil {
|
|||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun readContent(virtualFile: VirtualFile): String {
|
||||
try {
|
||||
return VfsUtilCore.loadText(virtualFile)
|
||||
|
|
@ -113,7 +113,6 @@ object FileUtil {
|
|||
}.takeIf { it } ?: throw RuntimeException("Failed to create directory: $directoryPath")
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun getFileExtension(filename: String?): String {
|
||||
val pattern = Pattern.compile("[^.]+$")
|
||||
val matcher = filename?.let { pattern.matcher(it) }
|
||||
|
|
@ -124,7 +123,6 @@ object FileUtil {
|
|||
return ""
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun findLanguageExtensionMapping(language: String? = ""): Map.Entry<String, String> {
|
||||
val defaultValue = mapOf("Text" to ".txt").entries.first()
|
||||
val mapper = ObjectMapper()
|
||||
|
|
@ -169,7 +167,6 @@ object FileUtil {
|
|||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun getImageMediaType(fileName: String?): String {
|
||||
return when (val fileExtension = getFileExtension(fileName)) {
|
||||
"png" -> "image/png"
|
||||
|
|
@ -178,7 +175,6 @@ object FileUtil {
|
|||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun getResourceContent(filePath: String?): String {
|
||||
try {
|
||||
Objects.requireNonNull(filePath?.let { FileUtil::class.java.getResourceAsStream(it) })
|
||||
|
|
@ -216,7 +212,6 @@ object FileUtil {
|
|||
return value.toString()
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun findFirstExtension(
|
||||
languageFileExtensionMappings: List<LanguageFileExtensionDetails>,
|
||||
language: String? = ""
|
||||
|
|
@ -254,4 +249,27 @@ object FileUtil {
|
|||
null
|
||||
}
|
||||
}
|
||||
|
||||
fun findVirtualFile(normalizedPath: String): VirtualFile? {
|
||||
return VirtualFileManager.getInstance().refreshAndFindFileByUrl("file://$normalizedPath")
|
||||
?: LocalFileSystem.getInstance().findFileByIoFile(File(normalizedPath))
|
||||
}
|
||||
|
||||
fun validateFileForEdit(filePath: String): Result<File> {
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
val file = File(normalizedPath)
|
||||
|
||||
return when {
|
||||
!file.exists() -> Result.failure(
|
||||
IllegalArgumentException("File not found: $filePath (File does not exist on filesystem)")
|
||||
)
|
||||
!file.isFile -> Result.failure(
|
||||
IllegalArgumentException("Path is not a file: $filePath")
|
||||
)
|
||||
!file.canWrite() -> Result.failure(
|
||||
IllegalArgumentException("File is not writable: $filePath")
|
||||
)
|
||||
else -> Result.success(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
package ee.carlrobert.codegpt.agent.rollback
|
||||
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import ee.carlrobert.codegpt.agent.tools.EditTool
|
||||
import ee.carlrobert.codegpt.agent.tools.EditArgsSnapshot
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
|
|
@ -28,7 +28,7 @@ class RollbackServiceTrackingTest : IntegrationTest() {
|
|||
rollbackService.trackEdit(
|
||||
sessionId = sessionId,
|
||||
filePath = filePath,
|
||||
args = EditTool.Args(filePath, "before", "after", "test", false),
|
||||
args = EditArgsSnapshot(filePath, "before", "after", false, "test"),
|
||||
originalContent = "before"
|
||||
)
|
||||
val snapshot = rollbackService.finishSession(sessionId)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue