mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 16:28:46 +00:00
feat: agent timeline
This commit is contained in:
parent
24e2a5fbad
commit
3f32f19f72
31 changed files with 3779 additions and 450 deletions
|
|
@ -1,11 +1,13 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import ai.koog.prompt.executor.clients.LLMClientException
|
||||
import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
||||
import ee.carlrobert.codegpt.agent.tools.AskUserQuestionTool
|
||||
import ee.carlrobert.codegpt.conversations.message.TokenUsage
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest
|
||||
import java.util.UUID
|
||||
|
||||
interface AgentEvents {
|
||||
fun onTextReceived(text: String) {}
|
||||
|
|
@ -22,6 +24,7 @@ interface AgentEvents {
|
|||
}
|
||||
|
||||
fun onRetry(attempt: Int, maxAttempts: Int, reason: String? = null) {}
|
||||
fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {}
|
||||
fun onQueuedMessagesResolved()
|
||||
fun onTokenUsageAvailable(tokenUsage: Long) {}
|
||||
fun onCreditsAvailable(event: AgentCreditsEvent) {}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,12 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.McpTagDetails
|
|||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.asSharedFlow
|
||||
import kotlinx.datetime.Clock
|
||||
import kotlinx.serialization.json.JsonNull
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import kotlin.io.path.Path
|
||||
import ai.koog.prompt.message.Message as PromptMessage
|
||||
|
||||
internal fun interface AgentRuntimeFactory {
|
||||
fun create(
|
||||
|
|
@ -53,16 +56,17 @@ class AgentService(private val project: Project) {
|
|||
private val pendingMessages = ConcurrentHashMap<String, ArrayDeque<MessageWithContext>>()
|
||||
private val sessionAgents = ConcurrentHashMap<String, AIAgent<MessageWithContext, String>>()
|
||||
private val sessionRuntimes = ConcurrentHashMap<String, SessionRuntime>()
|
||||
internal var runtimeFactory: AgentRuntimeFactory = AgentRuntimeFactory { p, storage, provider, events, sid, pending ->
|
||||
ProxyAIAgent.createService(
|
||||
project = p,
|
||||
checkpointStorage = storage,
|
||||
provider = provider,
|
||||
events = events,
|
||||
sessionId = sid,
|
||||
pendingMessages = pending
|
||||
)
|
||||
}
|
||||
internal var runtimeFactory: AgentRuntimeFactory =
|
||||
AgentRuntimeFactory { p, storage, provider, events, sid, pending ->
|
||||
ProxyAIAgent.createService(
|
||||
project = p,
|
||||
checkpointStorage = storage,
|
||||
provider = provider,
|
||||
events = events,
|
||||
sessionId = sid,
|
||||
pendingMessages = pending
|
||||
)
|
||||
}
|
||||
private val checkpointStorage =
|
||||
JVMFilePersistenceStorageProvider(Path(project.basePath ?: "", ".proxyai"))
|
||||
private val historyService = project.service<AgentCheckpointHistoryService>()
|
||||
|
|
@ -132,7 +136,8 @@ class AgentService(private val project: Project) {
|
|||
} catch (ex: Exception) {
|
||||
logger.error(ex)
|
||||
} finally {
|
||||
refreshSessionResumeCheckpoint(sessionId, agent.id)
|
||||
val ref = refreshSessionResumeCheckpoint(sessionId, agent.id)
|
||||
events.onRunCheckpointUpdated(message.id, ref)
|
||||
sessionAgents.remove(sessionId, agent)
|
||||
runCatching { runtime.service.removeAgentWithId(agent.id) }
|
||||
.onFailure { ex ->
|
||||
|
|
@ -201,7 +206,39 @@ class AgentService(private val project: Project) {
|
|||
.update(sessionId, conversationId, selectedServerIds)
|
||||
}
|
||||
|
||||
private suspend fun refreshSessionResumeCheckpoint(sessionId: String, agentId: String) {
|
||||
suspend fun createSeedCheckpointFromHistory(history: List<PromptMessage>): CheckpointRef? =
|
||||
withContext(Dispatchers.IO) {
|
||||
if (history.isEmpty()) {
|
||||
return@withContext null
|
||||
}
|
||||
|
||||
val agentId = UUID.randomUUID().toString()
|
||||
val checkpointId = UUID.randomUUID().toString()
|
||||
val checkpoint = AgentCheckpointData(
|
||||
checkpointId = checkpointId,
|
||||
createdAt = Clock.System.now(),
|
||||
nodePath = "$agentId/single_run/nodeExecuteTool",
|
||||
lastInput = JsonNull,
|
||||
messageHistory = history,
|
||||
version = 0
|
||||
)
|
||||
|
||||
runCatching {
|
||||
checkpointStorage.saveCheckpoint(agentId, checkpoint)
|
||||
CheckpointRef(agentId, checkpointId)
|
||||
}.onFailure { ex ->
|
||||
logger.warn(
|
||||
"Agent checkpoints: failed to create seed checkpoint from history " +
|
||||
"agentId=$agentId error=${ex.message}",
|
||||
ex
|
||||
)
|
||||
}.getOrNull()
|
||||
}
|
||||
|
||||
private suspend fun refreshSessionResumeCheckpoint(
|
||||
sessionId: String,
|
||||
agentId: String
|
||||
): CheckpointRef? {
|
||||
val ref = runCatching {
|
||||
historyService.loadLatestResumeCheckpoint(agentId)
|
||||
?.let { CheckpointRef(agentId, it.checkpointId) }
|
||||
|
|
@ -216,6 +253,7 @@ class AgentService(private val project: Project) {
|
|||
if (ref != null) {
|
||||
project.service<AgentToolWindowContentManager>().setResumeCheckpointRef(sessionId, ref)
|
||||
}
|
||||
return ref
|
||||
}
|
||||
|
||||
private fun ensureSessionRuntime(
|
||||
|
|
|
|||
|
|
@ -171,6 +171,8 @@ object ProxyAIAgent {
|
|||
}
|
||||
|
||||
onNodeExecutionCompleted { ctx ->
|
||||
if (stream) return@onNodeExecutionCompleted
|
||||
|
||||
(ctx.output as? List<*>)?.forEach { msg ->
|
||||
(msg as? Message.Assistant)?.let {
|
||||
events.onTextReceived(it.content)
|
||||
|
|
|
|||
|
|
@ -169,25 +169,20 @@ object ToolSpecs {
|
|||
|
||||
fun find(toolName: String): ToolSpec<*, *>? = specsByName[toolName.lowercase()]
|
||||
|
||||
fun approvalTypeFor(toolName: String): ToolApprovalType {
|
||||
return find(toolName)?.approvalType ?: ToolApprovalType.GENERIC
|
||||
}
|
||||
fun approvalTypeFor(toolName: String): ToolApprovalType =
|
||||
find(toolName)?.approvalType ?: ToolApprovalType.GENERIC
|
||||
|
||||
fun decodeArgsOrNull(
|
||||
json: Json,
|
||||
toolName: String,
|
||||
payload: String
|
||||
): Any? {
|
||||
return decodeOrNull(json, find(toolName)?.argsSerializer, payload)
|
||||
}
|
||||
) = decodeOrNull(json, find(toolName)?.argsSerializer, payload)
|
||||
|
||||
fun decodeResultOrNull(
|
||||
json: Json,
|
||||
toolName: String,
|
||||
payload: String
|
||||
): Any? {
|
||||
return decodeOrNull(json, find(toolName)?.resultSerializer, payload)
|
||||
}
|
||||
) = decodeOrNull(json, find(toolName)?.resultSerializer, payload)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun decodeOrNull(
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ package ee.carlrobert.codegpt.agent.history
|
|||
import ai.koog.agents.snapshot.feature.AgentCheckpointData
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks
|
||||
import ee.carlrobert.llm.client.openai.completion.response.ToolCall
|
||||
import ee.carlrobert.llm.client.openai.completion.response.ToolFunctionResponse
|
||||
import ai.koog.prompt.message.Message as PromptMessage
|
||||
|
||||
object AgentCheckpointConversationMapper {
|
||||
|
||||
|
|
@ -14,21 +14,54 @@ object AgentCheckpointConversationMapper {
|
|||
projectInstructions: String?
|
||||
): Conversation {
|
||||
val conversation = Conversation()
|
||||
val history = checkpoint.messageHistory.filterNot { it is PromptMessage.System }
|
||||
val turns = AgentCheckpointTurnSequencer.toVisibleTurns(
|
||||
history = checkpoint.messageHistory,
|
||||
projectInstructions = projectInstructions
|
||||
)
|
||||
|
||||
var currentPrompt: String? = null
|
||||
val response = StringBuilder()
|
||||
val toolCalls = mutableListOf<ToolCall>()
|
||||
val toolResults = LinkedHashMap<String, String>()
|
||||
var syntheticToolIdIndex = 0
|
||||
turns.forEach { turn ->
|
||||
val response = StringBuilder()
|
||||
val toolCalls = mutableListOf<ToolCall>()
|
||||
val toolResults = LinkedHashMap<String, String>()
|
||||
var syntheticToolIdIndex = 0
|
||||
|
||||
fun flushTurn() {
|
||||
val prompt = currentPrompt?.trim().orEmpty()
|
||||
if (prompt.isBlank()) {
|
||||
return
|
||||
turn.events.forEach { event ->
|
||||
when (event) {
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> {
|
||||
appendAssistant(response, event.content)
|
||||
}
|
||||
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> {
|
||||
appendAssistant(response, event.content)
|
||||
}
|
||||
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> {
|
||||
val callId = event.id?.takeIf { it.isNotBlank() }
|
||||
?: "tool-call-${++syntheticToolIdIndex}"
|
||||
toolCalls.add(
|
||||
ToolCall(
|
||||
null,
|
||||
callId,
|
||||
"function",
|
||||
ToolFunctionResponse(event.tool, event.content.trim())
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> {
|
||||
val callId = event.id?.takeIf { it.isNotBlank() }
|
||||
?: toolCalls.lastOrNull()?.id
|
||||
?: "tool-call-${++syntheticToolIdIndex}"
|
||||
val prior = toolResults[callId]
|
||||
val merged = if (prior.isNullOrBlank()) event.content.trim() else {
|
||||
"$prior\n\n${event.content.trim()}"
|
||||
}
|
||||
toolResults[callId] = merged
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val uiMessage = Message(prompt)
|
||||
val uiMessage = Message(turn.prompt)
|
||||
uiMessage.response = response.toString().trim()
|
||||
if (toolCalls.isNotEmpty()) {
|
||||
uiMessage.toolCalls = ArrayList(toolCalls)
|
||||
|
|
@ -37,64 +70,13 @@ object AgentCheckpointConversationMapper {
|
|||
uiMessage.toolCallResults = LinkedHashMap(toolResults)
|
||||
}
|
||||
conversation.addMessage(uiMessage)
|
||||
currentPrompt = null
|
||||
response.setLength(0)
|
||||
toolCalls.clear()
|
||||
toolResults.clear()
|
||||
}
|
||||
|
||||
history.forEach { msg ->
|
||||
when (msg) {
|
||||
is PromptMessage.User -> {
|
||||
val text = msg.content.trim()
|
||||
if (shouldHideInAgentToolWindow(msg, projectInstructions)) {
|
||||
return@forEach
|
||||
}
|
||||
flushTurn()
|
||||
currentPrompt = text
|
||||
}
|
||||
|
||||
is PromptMessage.Assistant -> appendAssistant(response, msg.content)
|
||||
is PromptMessage.Reasoning -> appendAssistant(response, msg.content)
|
||||
is PromptMessage.Tool.Call -> {
|
||||
if (currentPrompt != null) {
|
||||
val callId =
|
||||
msg.id?.takeIf { it.isNotBlank() }
|
||||
?: "tool-call-${++syntheticToolIdIndex}"
|
||||
toolCalls.add(
|
||||
ToolCall(
|
||||
null,
|
||||
callId,
|
||||
"function",
|
||||
ToolFunctionResponse(msg.tool, msg.content.trim())
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
is PromptMessage.Tool.Result -> {
|
||||
if (currentPrompt != null) {
|
||||
val callId = msg.id?.takeIf { it.isNotBlank() }
|
||||
?: toolCalls.lastOrNull()?.id
|
||||
?: "tool-call-${++syntheticToolIdIndex}"
|
||||
val prior = toolResults[callId]
|
||||
val merged = if (prior.isNullOrBlank()) msg.content.trim() else {
|
||||
"$prior\n\n${msg.content.trim()}"
|
||||
}
|
||||
toolResults[callId] = merged
|
||||
}
|
||||
}
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
flushTurn()
|
||||
return conversation
|
||||
}
|
||||
|
||||
private fun appendAssistant(sb: StringBuilder, content: String) {
|
||||
val text = content.trim()
|
||||
val text = content.stripThinkingBlocks()
|
||||
if (text.isBlank()) {
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,6 +96,13 @@ class AgentCheckpointHistoryService(project: Project) {
|
|||
?: checkpoints.maxByOrNull { it.createdAt }
|
||||
}
|
||||
|
||||
suspend fun listCheckpoints(agentId: String): List<AgentCheckpointData> =
|
||||
withContext(Dispatchers.IO) {
|
||||
storage.getCheckpoints(agentId)
|
||||
.filterNot { it.isTombstone() }
|
||||
.sortedByDescending { it.createdAt }
|
||||
}
|
||||
|
||||
private suspend fun buildSummary(agentId: String): AgentHistoryThreadSummary? {
|
||||
val checkpoints = storage.getCheckpoints(agentId)
|
||||
.filterNot { it.isTombstone() }
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
package ee.carlrobert.codegpt.agent.history
|
||||
|
||||
import kotlinx.serialization.json.booleanOrNull
|
||||
import kotlinx.serialization.json.contentOrNull
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import ai.koog.prompt.message.Message as PromptMessage
|
||||
|
||||
object AgentCheckpointTurnSequencer {
|
||||
|
||||
data class Turn(
|
||||
val prompt: String,
|
||||
val userNonSystemMessageCount: Int,
|
||||
val events: List<TurnEvent>
|
||||
)
|
||||
|
||||
sealed interface TurnEvent {
|
||||
val nonSystemMessageCount: Int
|
||||
|
||||
data class Assistant(
|
||||
val content: String,
|
||||
override val nonSystemMessageCount: Int
|
||||
) : TurnEvent
|
||||
|
||||
data class Reasoning(
|
||||
val content: String,
|
||||
override val nonSystemMessageCount: Int
|
||||
) : TurnEvent
|
||||
|
||||
data class ToolCall(
|
||||
val id: String?,
|
||||
val tool: String,
|
||||
val content: String,
|
||||
override val nonSystemMessageCount: Int
|
||||
) : TurnEvent
|
||||
|
||||
data class ToolResult(
|
||||
val id: String?,
|
||||
val tool: String,
|
||||
val content: String,
|
||||
override val nonSystemMessageCount: Int
|
||||
) : TurnEvent
|
||||
}
|
||||
|
||||
fun toVisibleTurns(
|
||||
history: List<PromptMessage>,
|
||||
projectInstructions: String?,
|
||||
preserveSyntheticContinuation: Boolean = false
|
||||
): List<Turn> {
|
||||
val turns = mutableListOf<Turn>()
|
||||
var currentPrompt: String? = null
|
||||
var currentUserNonSystemMessageCount = 0
|
||||
val currentEvents = mutableListOf<TurnEvent>()
|
||||
|
||||
fun flushTurn() {
|
||||
val prompt = currentPrompt?.trim().orEmpty()
|
||||
if (prompt.isBlank() || currentUserNonSystemMessageCount <= 0) {
|
||||
return
|
||||
}
|
||||
turns.add(
|
||||
Turn(
|
||||
prompt = prompt,
|
||||
userNonSystemMessageCount = currentUserNonSystemMessageCount,
|
||||
events = currentEvents.toList()
|
||||
)
|
||||
)
|
||||
currentPrompt = null
|
||||
currentUserNonSystemMessageCount = 0
|
||||
currentEvents.clear()
|
||||
}
|
||||
|
||||
history
|
||||
.filterNot { it is PromptMessage.System }
|
||||
.forEachIndexed { index, message ->
|
||||
val nonSystemMessageCount = index + 1
|
||||
when (message) {
|
||||
is PromptMessage.User -> {
|
||||
if (isSyntheticTimelineUserMessage(message)) {
|
||||
if (preserveSyntheticContinuation && currentPrompt != null) {
|
||||
// Keep collecting events for the currently active visible turn.
|
||||
// Synthetic todo prompts are injected internally mid-run.
|
||||
return@forEachIndexed
|
||||
}
|
||||
flushTurn()
|
||||
currentPrompt = null
|
||||
currentUserNonSystemMessageCount = 0
|
||||
return@forEachIndexed
|
||||
}
|
||||
|
||||
flushTurn()
|
||||
if (isHiddenUserMessage(message, projectInstructions)) {
|
||||
currentPrompt = null
|
||||
currentUserNonSystemMessageCount = 0
|
||||
return@forEachIndexed
|
||||
}
|
||||
currentPrompt = message.content.trim()
|
||||
currentUserNonSystemMessageCount = nonSystemMessageCount
|
||||
}
|
||||
|
||||
is PromptMessage.Assistant -> {
|
||||
if (currentPrompt != null) {
|
||||
currentEvents.add(
|
||||
TurnEvent.Assistant(
|
||||
content = message.content,
|
||||
nonSystemMessageCount = nonSystemMessageCount
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
is PromptMessage.Reasoning -> {
|
||||
if (currentPrompt != null) {
|
||||
currentEvents.add(
|
||||
TurnEvent.Reasoning(
|
||||
content = message.content,
|
||||
nonSystemMessageCount = nonSystemMessageCount
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
is PromptMessage.Tool.Call -> {
|
||||
if (currentPrompt != null && !isTodoWriteTool(message.tool)) {
|
||||
currentEvents.add(
|
||||
TurnEvent.ToolCall(
|
||||
id = message.id,
|
||||
tool = message.tool,
|
||||
content = message.content,
|
||||
nonSystemMessageCount = nonSystemMessageCount
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
is PromptMessage.Tool.Result -> {
|
||||
if (currentPrompt != null && !isTodoWriteTool(message.tool)) {
|
||||
currentEvents.add(
|
||||
TurnEvent.ToolResult(
|
||||
id = message.id,
|
||||
tool = message.tool,
|
||||
content = message.content,
|
||||
nonSystemMessageCount = nonSystemMessageCount
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
flushTurn()
|
||||
return turns
|
||||
}
|
||||
|
||||
fun isTodoWriteTool(toolName: String): Boolean {
|
||||
return toolName.equals("TodoWrite", ignoreCase = true)
|
||||
}
|
||||
|
||||
fun isSyntheticTimelineUserMessage(message: PromptMessage.User): Boolean {
|
||||
val normalized = message.content.lowercase()
|
||||
return normalized.contains("haven't created a todo list yet")
|
||||
}
|
||||
|
||||
private fun isHiddenUserMessage(
|
||||
message: PromptMessage.User,
|
||||
projectInstructions: String?
|
||||
): Boolean {
|
||||
return isCacheableInstructionMessage(message) ||
|
||||
isProjectInstructionsMessage(message.content, projectInstructions)
|
||||
}
|
||||
|
||||
private fun isCacheableInstructionMessage(message: PromptMessage.User): Boolean {
|
||||
val cacheable = message.metaInfo.metadata
|
||||
?.get("cacheable")
|
||||
?.jsonPrimitive
|
||||
?: return false
|
||||
return cacheable.booleanOrNull ?: (cacheable.contentOrNull?.equals(
|
||||
"true",
|
||||
ignoreCase = true
|
||||
) == true)
|
||||
}
|
||||
|
||||
private fun isProjectInstructionsMessage(text: String, projectInstructions: String?): Boolean {
|
||||
if (projectInstructions.isNullOrBlank()) {
|
||||
return false
|
||||
}
|
||||
return normalize(text) == normalize(projectInstructions)
|
||||
}
|
||||
|
||||
private fun normalize(value: String): String {
|
||||
return value.replace("\\s+".toRegex(), " ").trim()
|
||||
}
|
||||
}
|
||||
|
|
@ -13,9 +13,7 @@ import com.intellij.openapi.fileTypes.FileTypeManager
|
|||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.openapi.vfs.VfsUtilCore
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
|
|
@ -23,85 +21,100 @@ import java.nio.file.Paths
|
|||
import java.time.Instant
|
||||
import java.time.format.DateTimeFormatter
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* Tracks file changes during agent runs and provides rollback to a checkpoint.
|
||||
*
|
||||
* Uses IntelliJ's LocalHistory labels for persistence and a lightweight in-memory
|
||||
* snapshot of modified files for fast restore.
|
||||
*
|
||||
* Now uses direct tracking via trackEdit() and trackWrite() methods instead of
|
||||
* VFS event monitoring to avoid capturing build artifacts and temporary files.
|
||||
*/
|
||||
@Service(Service.Level.PROJECT)
|
||||
class RollbackService(private val project: Project) {
|
||||
|
||||
private val activeRuns = ConcurrentHashMap<String, RunTracker>()
|
||||
private val snapshots = ConcurrentHashMap<String, SnapshotState>()
|
||||
private val activeRunsBySession = ConcurrentHashMap<String, String>()
|
||||
private val snapshotsByRunId = ConcurrentHashMap<String, SnapshotState>()
|
||||
private val latestSnapshotRunIdBySession = ConcurrentHashMap<String, String>()
|
||||
|
||||
@Volatile
|
||||
private var isApplyingRollback = false
|
||||
private val MAX_TRACKABLE_BYTES = 10 * 1024 * 1024
|
||||
|
||||
private val MAX_TRACKABLE_BYTES = 10 * 1024 * 1024 // 10MB
|
||||
|
||||
init {
|
||||
// No VFS listener - using direct tracking via trackEdit() and trackWrite()
|
||||
}
|
||||
|
||||
/**
|
||||
* Track an EditTool operation directly from the agent.
|
||||
* This captures the original file content before the edit is applied.
|
||||
*/
|
||||
fun trackEdit(
|
||||
sessionId: String,
|
||||
filePath: String,
|
||||
originalContent: String
|
||||
) {
|
||||
val tracker = activeRuns[sessionId] ?: return
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
tracker.recordExplicitEdit(normalizedPath, originalContent)
|
||||
val runId = activeRunsBySession[sessionId] ?: return
|
||||
trackEditForRun(runId, filePath, originalContent)
|
||||
}
|
||||
|
||||
/**
|
||||
* Track a WriteTool operation directly from the agent.
|
||||
* This determines if the file was new or existing before the write.
|
||||
*/
|
||||
fun trackWrite(sessionId: String, filePath: String, args: WriteTool.Args) {
|
||||
val tracker = activeRuns[sessionId] ?: return
|
||||
fun trackEditForRun(
|
||||
runId: String,
|
||||
filePath: String,
|
||||
originalContent: String
|
||||
) {
|
||||
val tracker = activeRuns[runId] ?: return
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
tracker.recordExplicitWrite(normalizedPath, args)
|
||||
tracker.recordExplicitEdit(normalizedPath, originalContent)
|
||||
}
|
||||
|
||||
fun startSession(sessionId: String) {
|
||||
fun trackWrite(sessionId: String, filePath: String) {
|
||||
val runId = activeRunsBySession[sessionId] ?: return
|
||||
trackWriteForRun(runId, filePath)
|
||||
}
|
||||
|
||||
fun trackWriteForRun(runId: String, filePath: String) {
|
||||
val tracker = activeRuns[runId] ?: return
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
tracker.recordExplicitWrite(normalizedPath)
|
||||
}
|
||||
|
||||
fun startSession(sessionId: String): String {
|
||||
val labelText = "ProxyAI: Agent run ${DateTimeFormatter.ISO_INSTANT.format(Instant.now())}"
|
||||
val label = LocalHistory.getInstance().putSystemLabel(project, labelText)
|
||||
activeRuns[sessionId] = RunTracker(sessionId, Instant.now(), labelText, label)
|
||||
snapshots.remove(sessionId)
|
||||
val runId = UUID.randomUUID().toString()
|
||||
val tracker = RunTracker(runId, sessionId, Instant.now(), labelText, label)
|
||||
activeRuns[runId] = tracker
|
||||
activeRunsBySession[sessionId] = runId
|
||||
latestSnapshotRunIdBySession.remove(sessionId)
|
||||
return runId
|
||||
}
|
||||
|
||||
fun finishSession(sessionId: String): RollbackSnapshot? {
|
||||
val tracker = activeRuns.remove(sessionId) ?: return getSnapshot(sessionId)
|
||||
val runId = activeRunsBySession.remove(sessionId) ?: return getSnapshot(sessionId)
|
||||
return finishRun(runId)
|
||||
}
|
||||
|
||||
fun finishRun(runId: String): RollbackSnapshot? {
|
||||
val tracker = activeRuns.remove(runId) ?: return getRunSnapshot(runId)
|
||||
val snapshot = SnapshotState(
|
||||
sessionId = sessionId,
|
||||
runId = runId,
|
||||
sessionId = tracker.sessionId,
|
||||
label = tracker.label,
|
||||
labelRef = tracker.labelRef,
|
||||
startedAt = tracker.startedAt,
|
||||
completedAt = Instant.now(),
|
||||
changes = tracker.changes.toMap()
|
||||
)
|
||||
snapshots[sessionId] = snapshot
|
||||
snapshotsByRunId[runId] = snapshot
|
||||
latestSnapshotRunIdBySession[tracker.sessionId] = runId
|
||||
return snapshot.toSnapshot()
|
||||
}
|
||||
|
||||
fun getSnapshot(sessionId: String): RollbackSnapshot? =
|
||||
snapshots[sessionId]?.toSnapshot()
|
||||
fun getSnapshot(sessionId: String): RollbackSnapshot? {
|
||||
val runId = latestSnapshotRunIdBySession[sessionId] ?: return null
|
||||
return snapshotsByRunId[runId]?.toSnapshot()
|
||||
}
|
||||
|
||||
fun getRunSnapshot(runId: String): RollbackSnapshot? =
|
||||
snapshotsByRunId[runId]?.toSnapshot()
|
||||
|
||||
fun clearSnapshot(sessionId: String) {
|
||||
snapshots.remove(sessionId)
|
||||
val runId = latestSnapshotRunIdBySession.remove(sessionId) ?: return
|
||||
snapshotsByRunId.remove(runId)
|
||||
}
|
||||
|
||||
fun clearRunSnapshot(runId: String) {
|
||||
val snapshot = snapshotsByRunId.remove(runId) ?: return
|
||||
latestSnapshotRunIdBySession.remove(snapshot.sessionId, runId)
|
||||
}
|
||||
|
||||
fun getDiffData(sessionId: String, path: String): RollbackDiffData? {
|
||||
val snapshot = snapshots[sessionId] ?: return null
|
||||
val snapshot = snapshotForSession(sessionId) ?: return null
|
||||
val change = snapshot.changes[path] ?: return null
|
||||
if (change.kind != ChangeKind.DELETED && !isTrackable(path)) return null
|
||||
val beforeText = when (change.kind) {
|
||||
|
|
@ -131,28 +144,19 @@ class RollbackService(private val project: Project) {
|
|||
}
|
||||
|
||||
fun isRollbackAvailable(sessionId: String): Boolean =
|
||||
snapshots[sessionId]?.changes?.isNotEmpty() == true && !activeRuns.containsKey(sessionId)
|
||||
snapshotForSession(sessionId)?.changes?.isNotEmpty() == true &&
|
||||
!activeRunsBySession.containsKey(sessionId)
|
||||
|
||||
fun isDisplayable(path: String): Boolean = isTrackable(path)
|
||||
|
||||
suspend fun rollbackFile(sessionId: String, path: String): RollbackResult =
|
||||
withContext(Dispatchers.IO) {
|
||||
val snapshot = snapshots[sessionId]
|
||||
val snapshot = snapshotForSession(sessionId)
|
||||
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
|
||||
val change = snapshot.changes[path]
|
||||
?: return@withContext RollbackResult.Failure("No change tracked for $path")
|
||||
|
||||
val errors = mutableListOf<String>()
|
||||
runInEdt(ModalityState.defaultModalityState()) {
|
||||
runWriteAction {
|
||||
try {
|
||||
isApplyingRollback = true
|
||||
applyChangeWithLabel(snapshot.labelRef, path, change, errors)
|
||||
} finally {
|
||||
isApplyingRollback = false
|
||||
}
|
||||
}
|
||||
}
|
||||
val errors = applyChanges(snapshot.labelRef, mapOf(path to change))
|
||||
|
||||
if (errors.isNotEmpty()) {
|
||||
RollbackResult.Failure(errors.joinToString("\n"))
|
||||
|
|
@ -160,36 +164,38 @@ class RollbackService(private val project: Project) {
|
|||
val updated = snapshot.changes.toMutableMap()
|
||||
updated.remove(path)
|
||||
if (updated.isEmpty()) {
|
||||
snapshots.remove(sessionId)
|
||||
clearRunSnapshot(snapshot.runId)
|
||||
} else {
|
||||
snapshots[sessionId] = snapshot.copy(changes = updated.toMap())
|
||||
snapshotsByRunId[snapshot.runId] = snapshot.copy(changes = updated.toMap())
|
||||
}
|
||||
RollbackResult.Success("Rollback completed")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.Main) {
|
||||
val snapshot = snapshots[sessionId]
|
||||
suspend fun rollbackSession(sessionId: String): RollbackResult = withContext(Dispatchers.IO) {
|
||||
val snapshot = snapshotForSession(sessionId)
|
||||
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
|
||||
|
||||
val errors = mutableListOf<String>()
|
||||
runInEdt(ModalityState.defaultModalityState()) {
|
||||
runWriteAction {
|
||||
try {
|
||||
isApplyingRollback = true
|
||||
snapshot.changes.forEach { (path, change) ->
|
||||
applyChangeWithLabel(snapshot.labelRef, path, change, errors)
|
||||
}
|
||||
} finally {
|
||||
isApplyingRollback = false
|
||||
}
|
||||
}
|
||||
}
|
||||
val errors = applyChanges(snapshot.labelRef, snapshot.changes)
|
||||
|
||||
if (errors.isNotEmpty()) {
|
||||
RollbackResult.Failure(errors.joinToString("\n"))
|
||||
} else {
|
||||
snapshots.remove(sessionId)
|
||||
clearRunSnapshot(snapshot.runId)
|
||||
RollbackResult.Success("Rollback completed")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun rollbackRun(runId: String): RollbackResult = withContext(Dispatchers.IO) {
|
||||
val snapshot = snapshotsByRunId[runId]
|
||||
?: return@withContext RollbackResult.Failure("No rollback snapshot available")
|
||||
|
||||
val errors = applyChanges(snapshot.labelRef, snapshot.changes)
|
||||
|
||||
if (errors.isNotEmpty()) {
|
||||
RollbackResult.Failure(errors.joinToString("\n"))
|
||||
} else {
|
||||
clearRunSnapshot(runId)
|
||||
RollbackResult.Success("Rollback completed")
|
||||
}
|
||||
}
|
||||
|
|
@ -269,8 +275,8 @@ class RollbackService(private val project: Project) {
|
|||
errors.add("Missing parent directory for $path")
|
||||
return@runCatching null
|
||||
}
|
||||
val parentVf = VirtualFileManager.getInstance()
|
||||
.findFileByUrl("file://$parentPath") ?: return@runCatching null
|
||||
val parentVf = LocalFileSystem.getInstance()
|
||||
.refreshAndFindFileByPath(parentPath) ?: return@runCatching null
|
||||
parentVf.createChildData(this, ioFile.name)
|
||||
}.getOrNull()
|
||||
|
||||
|
|
@ -345,7 +351,6 @@ class RollbackService(private val project: Project) {
|
|||
): ByteArray? {
|
||||
val labelContent = runCatching { label.getByteContent(path) }
|
||||
.getOrNull()?.bytes
|
||||
// If label content is null or empty, use fallback (captured at track time)
|
||||
return if (labelContent == null || labelContent.isEmpty()) {
|
||||
fallback
|
||||
} else {
|
||||
|
|
@ -353,6 +358,26 @@ class RollbackService(private val project: Project) {
|
|||
}
|
||||
}
|
||||
|
||||
private fun applyChanges(
|
||||
label: Label,
|
||||
changes: Map<String, TrackedChange>
|
||||
): List<String> {
|
||||
val errors = mutableListOf<String>()
|
||||
runInEdt(ModalityState.defaultModalityState()) {
|
||||
runWriteAction {
|
||||
changes.forEach { (path, change) ->
|
||||
applyChangeWithLabel(label, path, change, errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
private fun snapshotForSession(sessionId: String): SnapshotState? {
|
||||
val runId = latestSnapshotRunIdBySession[sessionId] ?: return null
|
||||
return snapshotsByRunId[runId]
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun getInstance(project: Project): RollbackService {
|
||||
return project.getService(RollbackService::class.java)
|
||||
|
|
@ -360,6 +385,7 @@ class RollbackService(private val project: Project) {
|
|||
}
|
||||
|
||||
private class RunTracker(
|
||||
val runId: String,
|
||||
val sessionId: String,
|
||||
val startedAt: Instant,
|
||||
val label: String,
|
||||
|
|
@ -378,24 +404,23 @@ class RollbackService(private val project: Project) {
|
|||
)
|
||||
}
|
||||
|
||||
fun recordExplicitWrite(filePath: String, args: WriteTool.Args) {
|
||||
val path = filePath.replace("\\", "/")
|
||||
val existing = changes[path]
|
||||
val file = File(path)
|
||||
|
||||
fun recordExplicitWrite(filePath: String) {
|
||||
val existing = changes[filePath]
|
||||
val file = File(filePath)
|
||||
|
||||
if (existing?.kind == ChangeKind.DELETED) {
|
||||
changes[path] = existing.copy(kind = ChangeKind.MODIFIED)
|
||||
changes[filePath] = existing.copy(kind = ChangeKind.MODIFIED)
|
||||
return
|
||||
}
|
||||
|
||||
if (existing == null) {
|
||||
val originalContent = if (file.exists()) {
|
||||
runCatching { file.readText() }.getOrNull()
|
||||
runCatching { file.readText(Charsets.UTF_8) }.getOrNull()
|
||||
} else null
|
||||
changes[path] = TrackedChange(
|
||||
changes[filePath] = TrackedChange(
|
||||
kind = if (file.exists()) ChangeKind.MODIFIED else ChangeKind.ADDED,
|
||||
originalPath = null,
|
||||
originalContent = originalContent?.toByteArray()
|
||||
originalContent = originalContent?.toByteArray(Charsets.UTF_8)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -428,6 +453,7 @@ class RollbackService(private val project: Project) {
|
|||
}
|
||||
|
||||
private data class SnapshotState(
|
||||
val runId: String,
|
||||
val sessionId: String,
|
||||
val label: String,
|
||||
val labelRef: Label,
|
||||
|
|
@ -438,6 +464,7 @@ class RollbackService(private val project: Project) {
|
|||
fun toSnapshot(): RollbackSnapshot? {
|
||||
if (changes.isEmpty()) return null
|
||||
return RollbackSnapshot(
|
||||
runId = runId,
|
||||
sessionId = sessionId,
|
||||
label = label,
|
||||
startedAt = startedAt,
|
||||
|
|
@ -455,6 +482,7 @@ class RollbackService(private val project: Project) {
|
|||
}
|
||||
|
||||
data class RollbackSnapshot(
|
||||
val runId: String,
|
||||
val sessionId: String,
|
||||
val label: String,
|
||||
val startedAt: Instant,
|
||||
|
|
|
|||
|
|
@ -84,12 +84,12 @@ class BashTool(
|
|||
- If the commands depend on each other and must run sequentially, use a single Bash call with '&&' to chain them together (e.g., `git add . && git commit -m "message" && git push`). For instance, if one operation must complete before another starts (like mkdir before cp, Write before Bash for git operations, or git add before git commit), run these operations sequentially instead.
|
||||
- Use ';' only when you need to run commands sequentially but don't care if earlier commands fail
|
||||
- DO NOT use newlines to separate commands (newlines are ok in quoted strings)
|
||||
- Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of `cd`. You may use `cd` if the User explicitly requests it.
|
||||
- Commands run with the working directory already set to the project root. Prefer relative project paths (e.g., `src/...` or `./src/...`) instead of root-absolute paths like `/src` or hardcoded machine-specific paths like `/Users/.../project/...`. You may use `cd` only if the user explicitly requests it.
|
||||
<good-example>
|
||||
pytest /foo/bar/tests
|
||||
find src -type f -name "*.kt"
|
||||
</good-example>
|
||||
<bad-example>
|
||||
cd /foo/bar && pytest tests
|
||||
find /src -type f -name "*.kt"
|
||||
</bad-example>
|
||||
|
||||
# Committing changes with git
|
||||
|
|
@ -261,11 +261,10 @@ class BashTool(
|
|||
)
|
||||
}
|
||||
|
||||
val isAskPattern = shouldAskForConfirmation(args.command)
|
||||
val resolvedToolId = toolId ?: throw IllegalArgumentException("Tool ID is missing")
|
||||
|
||||
val confirmation = when {
|
||||
isWhiteListed(args) && !isAskPattern -> ShellCommandConfirmation.Approved
|
||||
isWhiteListed(args) -> ShellCommandConfirmation.Approved
|
||||
else -> confirmationHandler.requestConfirmation(args)
|
||||
}
|
||||
return when (confirmation) {
|
||||
|
|
@ -377,7 +376,6 @@ class BashTool(
|
|||
}
|
||||
}
|
||||
} catch (_: IOException) {
|
||||
// Ignore IO exception if the stream is closed
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -391,7 +389,6 @@ class BashTool(
|
|||
}
|
||||
}
|
||||
} catch (_: IOException) {
|
||||
// Ignore IO exception if the stream is closed
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -463,8 +460,7 @@ class BashTool(
|
|||
}
|
||||
when (event) {
|
||||
WaitEvent.EXIT -> done = true
|
||||
WaitEvent.ACTIVITY -> { /* reset idle timer */
|
||||
}
|
||||
WaitEvent.ACTIVITY -> Unit
|
||||
|
||||
null -> {
|
||||
timedOut.set(true)
|
||||
|
|
@ -541,28 +537,6 @@ class BashTool(
|
|||
}
|
||||
}
|
||||
|
||||
private fun shouldAskForConfirmation(command: String): Boolean {
|
||||
val askPatterns = listOf(
|
||||
"find * -delete*",
|
||||
"find * -exec*",
|
||||
"find * -fprint*",
|
||||
"find * -fls*",
|
||||
"find * -fprintf*",
|
||||
"find * -ok*",
|
||||
"sort --output=*",
|
||||
"sort -o *",
|
||||
"tree -o *"
|
||||
)
|
||||
|
||||
return askPatterns.any { pattern ->
|
||||
if (pattern.contains("*")) {
|
||||
command.startsWith(pattern.removeSuffix("*"))
|
||||
} else {
|
||||
command == pattern
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun shouldBlockByIgnore(command: String): Boolean {
|
||||
val readers = setOf(
|
||||
"cat",
|
||||
|
|
@ -607,8 +581,6 @@ class BashTool(
|
|||
lastWasReader = tokenIsReader
|
||||
}
|
||||
val settingsService = project.service<ProxyAISettingsService>()
|
||||
// TODO(PROXYAI-IGNORE): Replace deny-style bash path checks with visibility filtering.
|
||||
// Bash output and directory listings should hide ignored paths instead of returning policy-denied.
|
||||
return paths.any { candidate ->
|
||||
settingsService.isPathIgnored(toAbsolute(candidate))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,11 +8,15 @@ import com.intellij.openapi.components.service
|
|||
import com.intellij.openapi.fileEditor.FileDocumentManager
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.text.StringUtil
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import com.intellij.openapi.vfs.VfsUtil
|
||||
import com.intellij.psi.PsiDocumentManager
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookManager
|
||||
import ee.carlrobert.codegpt.tokens.truncateToolResult
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
import java.io.File
|
||||
|
|
@ -122,43 +126,59 @@ class WriteTool(
|
|||
)
|
||||
}
|
||||
|
||||
val fileUrl = file.toURI().toString()
|
||||
val virtualFile =
|
||||
VirtualFileManager.getInstance().findFileByUrl(fileUrl)
|
||||
?: if (isNewFile) {
|
||||
VirtualFileManager.getInstance().refreshAndFindFileByUrl(fileUrl)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
||||
val bytesWritten = args.content.toByteArray(StandardCharsets.UTF_8).size
|
||||
val projectToUse = project
|
||||
|
||||
val normalizedPath = filePath.replace("\\", "/")
|
||||
val fileUrl = "file://$normalizedPath"
|
||||
|
||||
// Never do filesystem IO or synchronous VFS refreshes on the EDT.
|
||||
// If we can resolve a VirtualFile+Document, use the Document API for proper IDE integration.
|
||||
val vfm = VirtualFileManager.getInstance()
|
||||
val virtualFile = if (!isNewFile) {
|
||||
vfm.findFileByUrl(fileUrl) ?: vfm.refreshAndFindFileByUrl(fileUrl)
|
||||
} else {
|
||||
vfm.findFileByUrl(fileUrl)
|
||||
}
|
||||
|
||||
if (virtualFile != null) {
|
||||
val document =
|
||||
val document = withContext(Dispatchers.Default) {
|
||||
runReadAction { FileDocumentManager.getInstance().getDocument(virtualFile) }
|
||||
runInEdt {
|
||||
if (document != null) {
|
||||
}
|
||||
if (document != null) {
|
||||
runInEdt {
|
||||
runWriteAction {
|
||||
PsiDocumentManager.getInstance(projectToUse).commitDocument(document)
|
||||
document.setText(StringUtil.convertLineSeparators(args.content))
|
||||
PsiDocumentManager.getInstance(projectToUse).commitDocument(document)
|
||||
FileDocumentManager.getInstance().saveDocument(document)
|
||||
}
|
||||
} else {
|
||||
runWriteAction {
|
||||
virtualFile.setBinaryContent(args.content.toByteArray(StandardCharsets.UTF_8))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-document-backed files, fall back to plain IO and refresh the specific VirtualFile.
|
||||
withContext(Dispatchers.IO) {
|
||||
file.writeBytes(args.content.toByteArray(StandardCharsets.UTF_8))
|
||||
}
|
||||
VfsUtil.markDirtyAndRefresh(true, false, false, virtualFile)
|
||||
}
|
||||
} else if (isNewFile) {
|
||||
runInEdt {
|
||||
runWriteAction {
|
||||
file.writeText(
|
||||
StringUtil.convertLineSeparators(args.content),
|
||||
StandardCharsets.UTF_8
|
||||
)
|
||||
VirtualFileManager.getInstance().syncRefresh()
|
||||
}
|
||||
} else {
|
||||
// Fallback for files not present in VFS (e.g., newly created or external paths).
|
||||
withContext(Dispatchers.IO) {
|
||||
file.writeText(
|
||||
StringUtil.convertLineSeparators(args.content),
|
||||
StandardCharsets.UTF_8
|
||||
)
|
||||
}
|
||||
|
||||
val lfs = LocalFileSystem.getInstance()
|
||||
val parentVf = file.parentFile?.let { parent ->
|
||||
lfs.findFileByIoFile(parent) ?: lfs.refreshAndFindFileByIoFile(parent)
|
||||
}
|
||||
if (parentVf != null) {
|
||||
// Reload the directory children so the newly created file shows up.
|
||||
VfsUtil.markDirtyAndRefresh(true, false, true, parentVf)
|
||||
} else {
|
||||
// As a last resort, refresh just the file path.
|
||||
lfs.refreshAndFindFileByIoFile(file)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import com.intellij.diff.requests.SimpleDiffRequest
|
|||
import com.intellij.diff.util.DiffUserDataKeys
|
||||
import com.intellij.diff.util.DiffUserDataKeysEx
|
||||
import com.intellij.icons.AllIcons
|
||||
import com.intellij.openapi.Disposable
|
||||
import com.intellij.openapi.actionSystem.AnAction
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
import com.intellij.openapi.actionSystem.Presentation
|
||||
|
|
@ -34,7 +35,9 @@ import ee.carlrobert.codegpt.agent.tools.EditTool
|
|||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.applyStringReplacement
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.getFileContentWithFallback
|
||||
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import java.awt.BorderLayout
|
||||
import java.awt.FlowLayout
|
||||
import java.awt.Insets
|
||||
|
|
@ -54,8 +57,9 @@ import javax.swing.JPanel
|
|||
*/
|
||||
class AgentApprovalManager(
|
||||
private val project: Project
|
||||
) {
|
||||
) : Disposable {
|
||||
private val approvalPopups = ConcurrentHashMap<String, JBPopup>()
|
||||
private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO)
|
||||
|
||||
companion object {
|
||||
private val AGENT_DIFF_REQUEST_KEY: Key<String> = Key.create("agent.approval.diffRequest")
|
||||
|
|
@ -96,7 +100,7 @@ class AgentApprovalManager(
|
|||
args.filePath
|
||||
}
|
||||
|
||||
ApplicationManager.getApplication().executeOnPooledThread {
|
||||
backgroundScope.launch {
|
||||
val vf = LocalFileSystem.getInstance().refreshAndFindFileByPath(path)
|
||||
val factory = DiffContentFactory.getInstance()
|
||||
|
||||
|
|
@ -340,4 +344,11 @@ class AgentApprovalManager(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun dispose() {
|
||||
backgroundScope.dispose()
|
||||
runInEdt {
|
||||
approvalPopups.keys.toList().forEach(::closeDiffById)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import com.intellij.openapi.project.Project
|
|||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import ee.carlrobert.codegpt.CodeGPTBundle
|
||||
import ee.carlrobert.codegpt.agent.*
|
||||
import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
||||
import ee.carlrobert.codegpt.agent.rollback.RollbackService
|
||||
import ee.carlrobert.codegpt.agent.tools.*
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
|
|
@ -24,6 +25,7 @@ import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.*
|
|||
import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatMessageResponseBody
|
||||
import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatToolWindowScrollablePanel
|
||||
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
|
||||
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
|
||||
import kotlinx.coroutines.*
|
||||
import java.awt.Component
|
||||
import java.util.*
|
||||
|
|
@ -40,6 +42,8 @@ class AgentEventHandler(
|
|||
private val userInputPanel: UserInputPanel,
|
||||
private val onShowLoading: (String) -> Unit,
|
||||
private val onHideLoading: () -> Unit,
|
||||
private val onRunFinishedCallback: () -> Unit = {},
|
||||
private val onRunCheckpointUpdatedCallback: (UUID, CheckpointRef?) -> Unit = { _, _ -> },
|
||||
private val onQueuedMessagesResolved: (MessageWithContext) -> Unit = {}
|
||||
) : AgentEvents, Disposable {
|
||||
|
||||
|
|
@ -67,6 +71,9 @@ class AgentEventHandler(
|
|||
@Volatile
|
||||
private var lastEditArgs: EditTool.Args? = null
|
||||
|
||||
@Volatile
|
||||
private var currentRollbackRunId: String? = null
|
||||
|
||||
private val approvalQueue: ArrayDeque<ApprovalRequest> = ArrayDeque()
|
||||
|
||||
@Volatile
|
||||
|
|
@ -84,7 +91,7 @@ class AgentEventHandler(
|
|||
private var runViewHolder: RunViewHolder? = null
|
||||
|
||||
private val subagentViewHolders = ConcurrentHashMap<String, RunViewHolder>()
|
||||
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||
private val serviceScope = DisposableCoroutineScope(Dispatchers.Default)
|
||||
|
||||
data class ApprovalRequest(
|
||||
var model: ToolApprovalRequest,
|
||||
|
|
@ -105,6 +112,7 @@ class AgentEventHandler(
|
|||
runViewHolder = null
|
||||
subagentViewHolders.clear()
|
||||
lastReportedPromptTokens = 0
|
||||
currentRollbackRunId = null
|
||||
}
|
||||
|
||||
private fun clearApprovalContainer() {
|
||||
|
|
@ -162,6 +170,10 @@ class AgentEventHandler(
|
|||
currentResponseBody = responseBody
|
||||
}
|
||||
|
||||
fun setCurrentRollbackRunId(runId: String?) {
|
||||
currentRollbackRunId = runId
|
||||
}
|
||||
|
||||
override fun onAgentCompleted(agentId: String) {
|
||||
runCatching {
|
||||
project.service<AgentToolWindowContentManager>().getTabbedPane()
|
||||
|
|
@ -226,8 +238,12 @@ class AgentEventHandler(
|
|||
|
||||
private fun handleDone() {
|
||||
runInEdt {
|
||||
project.service<RollbackService>().finishSession(sessionId)
|
||||
currentRollbackRunId?.let { runId ->
|
||||
project.service<RollbackService>().finishRun(runId)
|
||||
} ?: project.service<RollbackService>().finishSession(sessionId)
|
||||
currentRollbackRunId = null
|
||||
currentResponseBody?.finishThinking()
|
||||
onRunFinishedCallback()
|
||||
onHideLoading()
|
||||
userInputPanel.setStopEnabled(false)
|
||||
scrollablePanel.update()
|
||||
|
|
@ -483,6 +499,12 @@ class AgentEventHandler(
|
|||
onQueuedMessagesResolved(pendingMessage)
|
||||
}
|
||||
|
||||
override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {
|
||||
runInEdt {
|
||||
onRunCheckpointUpdatedCallback(runMessageId, ref)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onTokenUsageAvailable(tokenUsage: Long) {
|
||||
lastReportedPromptTokens = tokenUsage
|
||||
|
||||
|
|
@ -739,6 +761,8 @@ class AgentEventHandler(
|
|||
}
|
||||
|
||||
override fun dispose() {
|
||||
serviceScope.dispose()
|
||||
agentApprovalManager.dispose()
|
||||
mainToolCards.clear()
|
||||
approvalQueue.clear()
|
||||
subagentViewHolders.clear()
|
||||
|
|
@ -752,16 +776,26 @@ class AgentEventHandler(
|
|||
val documentText = vf?.let { file ->
|
||||
runReadAction { FileDocumentManager.getInstance().getDocument(file)?.text }
|
||||
}
|
||||
documentText ?: java.io.File(normalizedPath).readText()
|
||||
documentText ?: java.io.File(normalizedPath).readText(Charsets.UTF_8)
|
||||
}.getOrNull() ?: ""
|
||||
project.service<RollbackService>()
|
||||
.trackEdit(sessionId, normalizedPath, originalContent)
|
||||
val rollbackService = project.service<RollbackService>()
|
||||
val runId = currentRollbackRunId
|
||||
if (runId != null) {
|
||||
rollbackService.trackEditForRun(runId, normalizedPath, originalContent)
|
||||
} else {
|
||||
rollbackService.trackEdit(sessionId, normalizedPath, originalContent)
|
||||
}
|
||||
}
|
||||
|
||||
private fun trackWriteOperation(args: WriteTool.Args) {
|
||||
lastWriteArgs = args
|
||||
val normalizedPath = args.filePath.replace("\\", "/")
|
||||
project.service<RollbackService>()
|
||||
.trackWrite(sessionId, normalizedPath, args)
|
||||
val rollbackService = project.service<RollbackService>()
|
||||
val runId = currentRollbackRunId
|
||||
if (runId != null) {
|
||||
rollbackService.trackWriteForRun(runId, normalizedPath)
|
||||
} else {
|
||||
rollbackService.trackWrite(sessionId, normalizedPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
internal object AgentMessageText {
|
||||
fun abbreviate(text: String, maxLength: Int): String {
|
||||
val normalized = text.replace("\\s+".toRegex(), " ").trim()
|
||||
if (normalized.length <= maxLength) return normalized
|
||||
return normalized.take(maxLength - 1).trimEnd() + "…"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,479 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import com.intellij.icons.AllIcons
|
||||
import com.intellij.openapi.actionSystem.ActionManager
|
||||
import com.intellij.openapi.actionSystem.ActionUpdateThread
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
import com.intellij.openapi.actionSystem.DefaultActionGroup
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
import com.intellij.openapi.project.DumbAwareAction
|
||||
import com.intellij.ui.CheckboxTree
|
||||
import com.intellij.ui.CheckedTreeNode
|
||||
import com.intellij.ui.SimpleTextAttributes
|
||||
import com.intellij.ui.components.JBLabel
|
||||
import com.intellij.ui.components.JBScrollPane
|
||||
import com.intellij.util.ui.JBUI
|
||||
import java.awt.BorderLayout
|
||||
import java.awt.Dimension
|
||||
import java.awt.event.ActionEvent
|
||||
import java.awt.event.KeyAdapter
|
||||
import java.awt.event.KeyEvent
|
||||
import java.awt.event.MouseAdapter
|
||||
import java.awt.event.MouseEvent
|
||||
import javax.swing.AbstractAction
|
||||
import javax.swing.KeyStroke
|
||||
import javax.swing.JPanel
|
||||
import javax.swing.JTree
|
||||
import javax.swing.ScrollPaneConstants
|
||||
import javax.swing.SwingUtilities
|
||||
import javax.swing.tree.DefaultMutableTreeNode
|
||||
import javax.swing.tree.DefaultTreeModel
|
||||
|
||||
internal object AgentSessionTimelineDialogBuilder {
|
||||
|
||||
fun build(
|
||||
points: List<RunTimelinePoint>,
|
||||
onSelect: (RunTimelinePoint) -> Unit,
|
||||
reloadPoints: (suspend () -> List<RunTimelinePoint>)?,
|
||||
contextSelectionEnabled: Boolean,
|
||||
onNonCopyMenuAction: (() -> Unit)?,
|
||||
onSelectionStateChanged: (() -> Unit)?,
|
||||
timelinePointDisplayLabel: (RunTimelinePoint) -> String,
|
||||
canRollbackTimelinePoint: (RunTimelinePoint) -> Boolean,
|
||||
rollbackToTimelinePoint: (RunTimelinePoint, String) -> Unit,
|
||||
canCopyTimelinePointOutput: (RunTimelinePoint) -> Boolean,
|
||||
copyTimelinePointOutput: (RunTimelinePoint, String) -> Unit,
|
||||
launchBackground: (suspend () -> Unit) -> Unit
|
||||
): TimelineDialogUi {
|
||||
fun runNodeKey(runNumber: Int): String = "run:$runNumber"
|
||||
fun checkpointNodeKey(point: RunTimelinePoint): String = "cp:${point.cacheKey}"
|
||||
|
||||
fun buildTreeRoot(
|
||||
dataPoints: List<RunTimelinePoint>,
|
||||
checkedNodeState: Map<String, Boolean>,
|
||||
withCheckboxes: Boolean
|
||||
): CheckedTreeNode {
|
||||
val groups = buildTimelineRunGroups(dataPoints)
|
||||
val root = CheckedTreeNode(TimelineTreeEntry("Session"))
|
||||
groups.forEach { group ->
|
||||
val toolCallSuffix = if (group.toolCallCount == 1) "tool call" else "tool calls"
|
||||
val runEntry = TimelineTreeEntry(
|
||||
title = "Run ${group.runNumber}",
|
||||
subtitle = "${group.toolCallCount} $toolCallSuffix",
|
||||
icon = AllIcons.Vcs.History,
|
||||
runNumber = group.runNumber
|
||||
)
|
||||
val runNode: DefaultMutableTreeNode = if (withCheckboxes) {
|
||||
CheckedTreeNode(runEntry).apply {
|
||||
isChecked = checkedNodeState[runNodeKey(group.runNumber)] ?: true
|
||||
}
|
||||
} else {
|
||||
DefaultMutableTreeNode(runEntry)
|
||||
}
|
||||
group.points.forEach { point ->
|
||||
val pointEntry = TimelineTreeEntry(
|
||||
title = point.title,
|
||||
subtitle = point.subtitle.trim(),
|
||||
icon = point.icon,
|
||||
point = point
|
||||
)
|
||||
val pointNode: DefaultMutableTreeNode = if (withCheckboxes) {
|
||||
val parentChecked = (runNode as? CheckedTreeNode)?.isChecked ?: true
|
||||
CheckedTreeNode(pointEntry).apply {
|
||||
isChecked = checkedNodeState[checkpointNodeKey(point)] ?: parentChecked
|
||||
}
|
||||
} else {
|
||||
DefaultMutableTreeNode(pointEntry)
|
||||
}
|
||||
runNode.add(pointNode)
|
||||
}
|
||||
root.add(runNode)
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
fun collectCheckedNonSystemMessageCounts(root: DefaultMutableTreeNode): Set<Int> {
|
||||
val checkedNonSystemMessageCounts = mutableSetOf<Int>()
|
||||
for (index in 0 until root.childCount) {
|
||||
val runNode = root.getChildAt(index) as? DefaultMutableTreeNode ?: continue
|
||||
for (childIndex in 0 until runNode.childCount) {
|
||||
val childNode =
|
||||
runNode.getChildAt(childIndex) as? DefaultMutableTreeNode ?: continue
|
||||
val childEntry = childNode.userObject as? TimelineTreeEntry ?: continue
|
||||
val point = childEntry.point ?: continue
|
||||
val nonSystemMessageCount = point.nonSystemMessageCount ?: continue
|
||||
if (childNode is CheckedTreeNode && childNode.isChecked) {
|
||||
checkedNonSystemMessageCounts.add(nonSystemMessageCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
return checkedNonSystemMessageCounts
|
||||
}
|
||||
|
||||
fun captureCheckedRunState(root: DefaultMutableTreeNode): Map<String, Boolean> {
|
||||
val state = mutableMapOf<String, Boolean>()
|
||||
for (index in 0 until root.childCount) {
|
||||
val runNode = root.getChildAt(index) as? DefaultMutableTreeNode ?: continue
|
||||
val runEntry = runNode.userObject as? TimelineTreeEntry
|
||||
val runNumber = runEntry?.runNumber
|
||||
if (runNode is CheckedTreeNode && runNumber != null) {
|
||||
state[runNodeKey(runNumber)] = runNode.isChecked
|
||||
}
|
||||
|
||||
for (childIndex in 0 until runNode.childCount) {
|
||||
val childNode =
|
||||
runNode.getChildAt(childIndex) as? DefaultMutableTreeNode ?: continue
|
||||
val childEntry = childNode.userObject as? TimelineTreeEntry ?: continue
|
||||
val point = childEntry.point ?: continue
|
||||
if (childNode is CheckedTreeNode) {
|
||||
state[checkpointNodeKey(point)] = childNode.isChecked
|
||||
}
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
fun ensureCheckedStateDefaults(
|
||||
dataPoints: List<RunTimelinePoint>,
|
||||
checkedNodeState: MutableMap<String, Boolean>
|
||||
) {
|
||||
if (!contextSelectionEnabled) return
|
||||
buildTimelineRunGroups(dataPoints).forEach { group ->
|
||||
checkedNodeState.putIfAbsent(runNodeKey(group.runNumber), true)
|
||||
}
|
||||
dataPoints.forEach { point ->
|
||||
checkedNodeState.putIfAbsent(checkpointNodeKey(point), true)
|
||||
}
|
||||
}
|
||||
|
||||
fun expandAllRows(tree: JTree) {
|
||||
var row = 0
|
||||
while (row < tree.rowCount) {
|
||||
tree.expandRow(row)
|
||||
row++
|
||||
}
|
||||
}
|
||||
|
||||
var latestPoints = points
|
||||
var editMode = false
|
||||
val groups = buildTimelineRunGroups(latestPoints)
|
||||
val checkedNodeState = if (contextSelectionEnabled) {
|
||||
buildMap {
|
||||
groups.forEach { group -> put(runNodeKey(group.runNumber), true) }
|
||||
latestPoints.forEach { point -> put(checkpointNodeKey(point), true) }
|
||||
}.toMutableMap()
|
||||
} else {
|
||||
mutableMapOf()
|
||||
}
|
||||
ensureCheckedStateDefaults(latestPoints, checkedNodeState)
|
||||
|
||||
fun headerText(): String {
|
||||
return when {
|
||||
!contextSelectionEnabled ->
|
||||
"Choose a point in this session timeline."
|
||||
|
||||
editMode ->
|
||||
"Right-click row for rollback/new/copy. Check runs or checkpoints to keep in context."
|
||||
|
||||
else ->
|
||||
"Right-click row for rollback/new/copy. Click Edit to modify current context."
|
||||
}
|
||||
}
|
||||
|
||||
val header = JBLabel(headerText()).apply {
|
||||
foreground = JBUI.CurrentTheme.Label.disabledForeground()
|
||||
font = JBUI.Fonts.smallFont()
|
||||
border = JBUI.Borders.emptyBottom(5)
|
||||
}
|
||||
|
||||
lateinit var tree: CheckboxTree
|
||||
|
||||
fun captureCurrentCheckedState() {
|
||||
if (!contextSelectionEnabled || !editMode) return
|
||||
val currentRoot = tree.model.root as? DefaultMutableTreeNode ?: return
|
||||
checkedNodeState.clear()
|
||||
checkedNodeState.putAll(captureCheckedRunState(currentRoot))
|
||||
ensureCheckedStateDefaults(latestPoints, checkedNodeState)
|
||||
}
|
||||
|
||||
fun rebuildTree(
|
||||
dataPoints: List<RunTimelinePoint>,
|
||||
captureCurrentSelectionState: Boolean
|
||||
) {
|
||||
if (captureCurrentSelectionState) {
|
||||
captureCurrentCheckedState()
|
||||
}
|
||||
latestPoints = dataPoints
|
||||
ensureCheckedStateDefaults(latestPoints, checkedNodeState)
|
||||
tree.model = DefaultTreeModel(
|
||||
buildTreeRoot(
|
||||
dataPoints = latestPoints,
|
||||
checkedNodeState = checkedNodeState,
|
||||
withCheckboxes = contextSelectionEnabled && editMode
|
||||
)
|
||||
)
|
||||
expandAllRows(tree)
|
||||
tree.revalidate()
|
||||
tree.repaint()
|
||||
if (contextSelectionEnabled && editMode) {
|
||||
onSelectionStateChanged?.invoke()
|
||||
}
|
||||
}
|
||||
|
||||
tree = object : CheckboxTree(
|
||||
object : CheckboxTree.CheckboxTreeCellRenderer() {
|
||||
override fun customizeRenderer(
|
||||
tree: JTree,
|
||||
value: Any?,
|
||||
selected: Boolean,
|
||||
expanded: Boolean,
|
||||
leaf: Boolean,
|
||||
row: Int,
|
||||
hasFocus: Boolean
|
||||
) {
|
||||
val node = value as? DefaultMutableTreeNode ?: return
|
||||
val entry = node.userObject as? TimelineTreeEntry ?: return
|
||||
textRenderer.icon = entry.icon
|
||||
if (entry.point == null) {
|
||||
textRenderer.append(
|
||||
entry.title,
|
||||
SimpleTextAttributes.REGULAR_BOLD_ATTRIBUTES
|
||||
)
|
||||
toolTipText = null
|
||||
if (entry.subtitle.isNotBlank()) {
|
||||
textRenderer.append(
|
||||
" ${entry.subtitle}",
|
||||
SimpleTextAttributes.GRAYED_ATTRIBUTES
|
||||
)
|
||||
}
|
||||
} else {
|
||||
textRenderer.append(entry.title, SimpleTextAttributes.REGULAR_ATTRIBUTES)
|
||||
if (entry.subtitle.isNotBlank()) {
|
||||
textRenderer.append(
|
||||
" ${entry.subtitle}",
|
||||
SimpleTextAttributes.GRAYED_ATTRIBUTES
|
||||
)
|
||||
toolTipText = entry.subtitle
|
||||
} else {
|
||||
toolTipText = null
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
buildTreeRoot(
|
||||
dataPoints = latestPoints,
|
||||
checkedNodeState = checkedNodeState,
|
||||
withCheckboxes = contextSelectionEnabled && editMode
|
||||
)
|
||||
) {}.apply {
|
||||
isRootVisible = false
|
||||
showsRootHandles = true
|
||||
rowHeight = 0
|
||||
border = JBUI.Borders.empty(2, 0)
|
||||
|
||||
expandAllRows(this)
|
||||
|
||||
fun selectCurrentTimelinePoint() {
|
||||
val selectedNode = lastSelectedPathComponent as? DefaultMutableTreeNode ?: return
|
||||
val selectedEntry = selectedNode.userObject as? TimelineTreeEntry ?: return
|
||||
val point = selectedEntry.point ?: return
|
||||
onSelect(point)
|
||||
}
|
||||
|
||||
addMouseListener(object : MouseAdapter() {
|
||||
override fun mouseClicked(e: MouseEvent) {
|
||||
if (e.button == MouseEvent.BUTTON1 && e.clickCount >= 2) {
|
||||
selectCurrentTimelinePoint()
|
||||
}
|
||||
}
|
||||
|
||||
override fun mousePressed(e: MouseEvent) {
|
||||
maybeShowTimelineContextMenu(e)
|
||||
}
|
||||
|
||||
override fun mouseReleased(e: MouseEvent) {
|
||||
maybeShowTimelineContextMenu(e)
|
||||
if (contextSelectionEnabled && editMode && !e.isPopupTrigger && e.button == MouseEvent.BUTTON1) {
|
||||
SwingUtilities.invokeLater {
|
||||
onSelectionStateChanged?.invoke()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeShowTimelineContextMenu(e: MouseEvent) {
|
||||
if (!e.isPopupTrigger && e.button != MouseEvent.BUTTON3) return
|
||||
|
||||
val path = getPathForLocation(e.x, e.y) ?: return
|
||||
selectionPath = path
|
||||
val selectedNode = path.lastPathComponent as? DefaultMutableTreeNode ?: return
|
||||
val selectedEntry = selectedNode.userObject as? TimelineTreeEntry ?: return
|
||||
val point = selectedEntry.point
|
||||
val selectedLabel =
|
||||
point?.let { timelinePointDisplayLabel(it) } ?: selectedEntry.title
|
||||
|
||||
val group = DefaultActionGroup(
|
||||
object : DumbAwareAction(
|
||||
"Rollback",
|
||||
"Rollback to checkpoint",
|
||||
AllIcons.Actions.Undo
|
||||
) {
|
||||
override fun actionPerformed(event: AnActionEvent) {
|
||||
val timelinePoint = point ?: return
|
||||
onNonCopyMenuAction?.invoke()
|
||||
rollbackToTimelinePoint(timelinePoint, selectedLabel)
|
||||
}
|
||||
|
||||
override fun update(event: AnActionEvent) {
|
||||
event.presentation.isEnabled =
|
||||
point != null && canRollbackTimelinePoint(point)
|
||||
}
|
||||
|
||||
override fun getActionUpdateThread(): ActionUpdateThread =
|
||||
ActionUpdateThread.EDT
|
||||
},
|
||||
object : DumbAwareAction(
|
||||
"Continue From New Session",
|
||||
"Start new session from given checkpoint",
|
||||
AllIcons.General.Add
|
||||
) {
|
||||
override fun actionPerformed(event: AnActionEvent) {
|
||||
val timelinePoint = point ?: return
|
||||
onNonCopyMenuAction?.invoke()
|
||||
onSelect(timelinePoint)
|
||||
}
|
||||
|
||||
override fun update(event: AnActionEvent) {
|
||||
event.presentation.isEnabled = point != null
|
||||
}
|
||||
|
||||
override fun getActionUpdateThread(): ActionUpdateThread =
|
||||
ActionUpdateThread.EDT
|
||||
},
|
||||
object : DumbAwareAction(
|
||||
"Copy Output",
|
||||
"Copies the Tool or Assistant's output",
|
||||
AllIcons.General.Copy
|
||||
) {
|
||||
override fun actionPerformed(event: AnActionEvent) {
|
||||
val timelinePoint = point ?: return
|
||||
copyTimelinePointOutput(timelinePoint, selectedLabel)
|
||||
}
|
||||
|
||||
override fun update(event: AnActionEvent) {
|
||||
event.presentation.isEnabled =
|
||||
point != null && canCopyTimelinePointOutput(point)
|
||||
}
|
||||
|
||||
override fun getActionUpdateThread(): ActionUpdateThread =
|
||||
ActionUpdateThread.EDT
|
||||
}
|
||||
)
|
||||
ActionManager.getInstance()
|
||||
.createActionPopupMenu("AgentTimeline.CheckpointMenu", group)
|
||||
.component
|
||||
.show(e.component, e.x, e.y)
|
||||
}
|
||||
})
|
||||
|
||||
addKeyListener(object : KeyAdapter() {
|
||||
override fun keyReleased(e: KeyEvent) {
|
||||
if (contextSelectionEnabled && editMode && e.keyCode == KeyEvent.VK_SPACE) {
|
||||
onSelectionStateChanged?.invoke()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
inputMap.put(KeyStroke.getKeyStroke("ENTER"), "timeline.select")
|
||||
actionMap.put("timeline.select", object : AbstractAction() {
|
||||
override fun actionPerformed(e: ActionEvent?) {
|
||||
selectCurrentTimelinePoint()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
val scrollPane = JBScrollPane(tree).apply {
|
||||
border = JBUI.Borders.empty()
|
||||
horizontalScrollBarPolicy = ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER
|
||||
verticalScrollBar.unitIncrement = 16
|
||||
}
|
||||
|
||||
SwingUtilities.invokeLater {
|
||||
if (tree.rowCount > 0) {
|
||||
tree.scrollRowToVisible(tree.rowCount - 1)
|
||||
scrollPane.verticalScrollBar.value = scrollPane.verticalScrollBar.maximum
|
||||
}
|
||||
}
|
||||
|
||||
val popupWidth = JBUI.scale(620)
|
||||
val minPopupWidth = JBUI.scale(560)
|
||||
val maxPopupWidth = JBUI.scale(760)
|
||||
val popupHeight =
|
||||
JBUI.scale(computeTimelineDialogHeight(groups.sumOf { it.points.size + 1 }))
|
||||
|
||||
val container = JPanel(BorderLayout()).apply {
|
||||
isOpaque = false
|
||||
border = JBUI.Borders.empty()
|
||||
preferredSize = Dimension(popupWidth, popupHeight)
|
||||
minimumSize = Dimension(minPopupWidth, popupHeight)
|
||||
maximumSize = Dimension(maxPopupWidth, popupHeight)
|
||||
add(header, BorderLayout.NORTH)
|
||||
add(scrollPane, BorderLayout.CENTER)
|
||||
}
|
||||
|
||||
return TimelineDialogUi(
|
||||
component = container,
|
||||
getSelectedPoint = {
|
||||
val node = tree.lastSelectedPathComponent as? DefaultMutableTreeNode
|
||||
val entry = node?.userObject as? TimelineTreeEntry
|
||||
entry?.point
|
||||
},
|
||||
getCheckedNonSystemMessageCounts = {
|
||||
val root = tree.model.root as? DefaultMutableTreeNode
|
||||
if (root == null) emptySet() else collectCheckedNonSystemMessageCounts(root)
|
||||
},
|
||||
refresh = {
|
||||
reloadPoints?.let { loader ->
|
||||
launchBackground {
|
||||
val refreshedPoints = loader()
|
||||
runInEdt {
|
||||
rebuildTree(refreshedPoints, captureCurrentSelectionState = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
setEditMode = { value ->
|
||||
if (contextSelectionEnabled && editMode != value) {
|
||||
if (editMode) captureCurrentCheckedState()
|
||||
editMode = value
|
||||
header.text = headerText()
|
||||
rebuildTree(latestPoints, captureCurrentSelectionState = false)
|
||||
}
|
||||
},
|
||||
isEditMode = { editMode }
|
||||
)
|
||||
}
|
||||
|
||||
private fun buildTimelineRunGroups(points: List<RunTimelinePoint>): List<TimelineRunGroup> {
|
||||
if (points.isEmpty()) return emptyList()
|
||||
|
||||
val grouped = linkedMapOf<Int, MutableList<RunTimelinePoint>>()
|
||||
points.forEach { point ->
|
||||
val key = if (point.runIndex <= 0) 1 else point.runIndex
|
||||
grouped.getOrPut(key) { mutableListOf() }.add(point)
|
||||
}
|
||||
|
||||
return grouped.values.mapIndexed { index, runPoints ->
|
||||
TimelineRunGroup(
|
||||
runNumber = index + 1,
|
||||
points = runPoints,
|
||||
toolCallCount = runPoints.count { it.kind == TimelinePointKind.TOOL_CALL }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun computeTimelineDialogHeight(estimatedRowCount: Int): Int {
|
||||
val visibleRows = estimatedRowCount.coerceIn(4, 16)
|
||||
val rowsHeight = visibleRows * 26
|
||||
return (rowsHeight + 44).coerceIn(160, 460)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import ai.koog.agents.snapshot.feature.AgentCheckpointData
|
||||
import ai.koog.prompt.message.Message as PromptMessage
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
import com.intellij.openapi.application.runReadAction
|
||||
import com.intellij.openapi.application.runWriteAction
|
||||
import com.intellij.openapi.fileEditor.FileDocumentManager
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import ee.carlrobert.codegpt.agent.ToolName
|
||||
import ee.carlrobert.codegpt.agent.ToolSpecs
|
||||
import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService
|
||||
import ee.carlrobert.codegpt.agent.tools.EditTool
|
||||
import ee.carlrobert.codegpt.agent.tools.ReadTool
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import kotlinx.serialization.json.booleanOrNull
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import java.io.File
|
||||
import java.nio.file.Paths
|
||||
import java.util.ArrayDeque
|
||||
|
||||
internal class AgentSessionTimelineHistoricalRollbackSupport(
|
||||
private val project: Project,
|
||||
private val historyService: AgentCheckpointHistoryService,
|
||||
private val replayJson: Json
|
||||
) {
|
||||
|
||||
suspend fun collectOperations(point: RunTimelinePoint): List<HistoricalRollbackOperation> {
|
||||
val checkpointRef = point.checkpointRef ?: return emptyList()
|
||||
val latestCheckpoint: AgentCheckpointData =
|
||||
historyService.listCheckpoints(checkpointRef.agentId).firstOrNull()
|
||||
?: historyService.loadLatestResumeCheckpoint(checkpointRef.agentId)
|
||||
?: historyService.loadCheckpoint(checkpointRef)
|
||||
?: return emptyList()
|
||||
val cutoff = point.nonSystemMessageCount ?: return emptyList()
|
||||
val history = latestCheckpoint.messageHistory.filterNot { it is PromptMessage.System }
|
||||
if (history.isEmpty() || cutoff >= history.size) return emptyList()
|
||||
|
||||
data class PendingCall(val index: Int, val call: PromptMessage.Tool.Call)
|
||||
|
||||
val pendingById = mutableMapOf<String, PendingCall>()
|
||||
val pendingWithoutId = ArrayDeque<PendingCall>()
|
||||
val latestKnownContentByFile = mutableMapOf<String, String>()
|
||||
val operations = mutableListOf<HistoricalRollbackOperation>()
|
||||
|
||||
history.forEachIndexed { index, message ->
|
||||
when (message) {
|
||||
is PromptMessage.Tool.Call -> {
|
||||
val callId = message.id?.takeIf { it.isNotBlank() }
|
||||
if (callId != null) {
|
||||
pendingById[callId] = PendingCall(index, message)
|
||||
} else {
|
||||
pendingWithoutId.addLast(PendingCall(index, message))
|
||||
}
|
||||
}
|
||||
|
||||
is PromptMessage.Tool.Result -> {
|
||||
val pendingCall = message.id
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { pendingById.remove(it) }
|
||||
?: pendingWithoutId.pollFirst()
|
||||
?: return@forEachIndexed
|
||||
val call = pendingCall.call
|
||||
|
||||
val tool = HistoricalRollbackCompatibility.resolveSupportedTool(call.tool)
|
||||
?: return@forEachIndexed
|
||||
if (!HistoricalRollbackCompatibility.isSuccessfulResult(tool, message.content, replayJson)) {
|
||||
return@forEachIndexed
|
||||
}
|
||||
|
||||
when (tool) {
|
||||
ToolName.READ -> {
|
||||
val args = decodeReadArgs(call.tool, call.content) ?: return@forEachIndexed
|
||||
val filePath = normalizeToolFilePath(args.filePath)
|
||||
?: return@forEachIndexed
|
||||
val content =
|
||||
decodeReadToolResultContent(message.content) ?: return@forEachIndexed
|
||||
latestKnownContentByFile[filePath] = content
|
||||
return@forEachIndexed
|
||||
}
|
||||
|
||||
ToolName.EDIT -> {
|
||||
val args = decodeEditArgs(call.tool, call.content) ?: return@forEachIndexed
|
||||
val filePath = normalizeToolFilePath(args.filePath)
|
||||
?: return@forEachIndexed
|
||||
val oldString = args.oldString
|
||||
val newString = args.newString
|
||||
val replaceAll = args.replaceAll
|
||||
if (oldString.isEmpty() || newString.isEmpty() || oldString == newString) return@forEachIndexed
|
||||
|
||||
if (pendingCall.index >= cutoff) {
|
||||
operations.add(
|
||||
HistoricalRollbackOperation(
|
||||
filePath = filePath,
|
||||
searchText = newString,
|
||||
replacementText = oldString,
|
||||
replaceAll = replaceAll,
|
||||
sourceTool = HistoricalRollbackSourceTool.EDIT
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
latestKnownContentByFile[filePath]?.let { known ->
|
||||
if (known.contains(oldString)) {
|
||||
latestKnownContentByFile[filePath] =
|
||||
if (replaceAll) known.replace(oldString, newString)
|
||||
else known.replaceFirst(oldString, newString)
|
||||
}
|
||||
}
|
||||
return@forEachIndexed
|
||||
}
|
||||
|
||||
ToolName.WRITE -> {
|
||||
val args = decodeWriteArgs(call.tool, call.content) ?: return@forEachIndexed
|
||||
val filePath = normalizeToolFilePath(args.filePath)
|
||||
?: return@forEachIndexed
|
||||
val newContent = args.content
|
||||
val previousContent = latestKnownContentByFile[filePath]
|
||||
if (pendingCall.index >= cutoff && previousContent != null && previousContent != newContent) {
|
||||
operations.add(
|
||||
HistoricalRollbackOperation(
|
||||
filePath = filePath,
|
||||
searchText = newContent,
|
||||
replacementText = previousContent,
|
||||
replaceAll = false,
|
||||
sourceTool = HistoricalRollbackSourceTool.WRITE
|
||||
)
|
||||
)
|
||||
}
|
||||
latestKnownContentByFile[filePath] = newContent
|
||||
}
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
return operations
|
||||
}
|
||||
|
||||
fun buildConfirmationText(
|
||||
selectedLabel: String,
|
||||
operations: List<HistoricalRollbackOperation>
|
||||
): String {
|
||||
val ordered = operations.asReversed()
|
||||
val previewLimit = 12
|
||||
val listed = ordered.take(previewLimit).joinToString(separator = "\n") { operation ->
|
||||
"${operation.sourceTool.symbol} ${toProjectRelativePath(operation.filePath)}"
|
||||
}
|
||||
val remaining = ordered.size - previewLimit
|
||||
val suffix = if (remaining > 0) "\n...and $remaining more operation(s)." else ""
|
||||
return """
|
||||
This rollback will return the session to:
|
||||
$selectedLabel
|
||||
|
||||
It will also replay ${operations.size} file operation(s) in reverse order:
|
||||
|
||||
$listed$suffix
|
||||
""".trimIndent()
|
||||
}
|
||||
|
||||
fun applyOperations(operations: List<HistoricalRollbackOperation>): List<String> {
|
||||
val errors = mutableListOf<String>()
|
||||
operations.asReversed().forEach { operation ->
|
||||
val filePath = operation.filePath
|
||||
val currentText = readFileText(filePath)
|
||||
if (currentText == null) {
|
||||
errors.add("File not found: ${toProjectRelativePath(filePath)}")
|
||||
return@forEach
|
||||
}
|
||||
if (!currentText.contains(operation.searchText)) {
|
||||
errors.add(
|
||||
"Expected content not found in ${toProjectRelativePath(filePath)} for ${operation.sourceTool.displayName} rollback."
|
||||
)
|
||||
return@forEach
|
||||
}
|
||||
|
||||
val updatedText = if (operation.replaceAll) {
|
||||
currentText.replace(operation.searchText, operation.replacementText)
|
||||
} else {
|
||||
currentText.replaceFirst(operation.searchText, operation.replacementText)
|
||||
}
|
||||
if (updatedText == currentText) {
|
||||
errors.add("No changes applied to ${toProjectRelativePath(filePath)}")
|
||||
return@forEach
|
||||
}
|
||||
|
||||
val writeOk = writeFileText(filePath, updatedText)
|
||||
if (!writeOk) {
|
||||
errors.add("Failed to write ${toProjectRelativePath(filePath)}")
|
||||
}
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
private fun parseToolArgs(rawArgs: String): Map<String, JsonElement>? {
|
||||
return runCatching { replayJson.parseToJsonElement(rawArgs).jsonObject }.getOrNull()
|
||||
}
|
||||
|
||||
private fun decodeReadArgs(rawToolName: String, rawArgs: String): ReadTool.Args? {
|
||||
val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? ReadTool.Args
|
||||
if (typed != null) return typed
|
||||
|
||||
val args = parseToolArgs(rawArgs) ?: return null
|
||||
val filePath = stringValue(args["file_path"])
|
||||
?: stringValue(args["path"])
|
||||
?: stringValue(args["pathInProject"])
|
||||
?: return null
|
||||
return ReadTool.Args(filePath = filePath)
|
||||
}
|
||||
|
||||
private fun decodeEditArgs(rawToolName: String, rawArgs: String): EditTool.Args? {
|
||||
val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? EditTool.Args
|
||||
if (typed != null) return typed
|
||||
|
||||
val args = parseToolArgs(rawArgs) ?: return null
|
||||
val filePath = stringValue(args["file_path"]) ?: return null
|
||||
val oldString = stringValue(args["old_string"]) ?: return null
|
||||
val newString = stringValue(args["new_string"]) ?: return null
|
||||
val shortDescription = stringValue(args["short_description"]) ?: "Recovered historical edit"
|
||||
val replaceAll = booleanValue(args["replace_all"]) ?: false
|
||||
|
||||
return EditTool.Args(
|
||||
filePath = filePath,
|
||||
oldString = oldString,
|
||||
newString = newString,
|
||||
shortDescription = shortDescription,
|
||||
replaceAll = replaceAll
|
||||
)
|
||||
}
|
||||
|
||||
private fun decodeWriteArgs(rawToolName: String, rawArgs: String): WriteTool.Args? {
|
||||
val typed = ToolSpecs.decodeArgsOrNull(replayJson, rawToolName, rawArgs) as? WriteTool.Args
|
||||
if (typed != null) return typed
|
||||
|
||||
val args = parseToolArgs(rawArgs) ?: return null
|
||||
val filePath = stringValue(args["file_path"]) ?: return null
|
||||
val content = stringValue(args["content"]) ?: return null
|
||||
return WriteTool.Args(filePath = filePath, content = content)
|
||||
}
|
||||
|
||||
private fun booleanValue(element: JsonElement?): Boolean? {
|
||||
val primitive = element as? JsonPrimitive ?: return null
|
||||
return if (primitive.isString) primitive.content.toBooleanStrictOrNull() else primitive.booleanOrNull
|
||||
}
|
||||
|
||||
private fun stringValue(element: JsonElement?): String? {
|
||||
if (element == null) return null
|
||||
val primitive = element as? JsonPrimitive
|
||||
return if (primitive != null && primitive.isString) primitive.content else element.toString()
|
||||
}
|
||||
|
||||
private fun normalizeToolFilePath(rawPath: String?): String? {
|
||||
val trimmed = rawPath?.trim()?.takeIf { it.isNotEmpty() } ?: return null
|
||||
val normalized = trimmed.replace("\\", "/")
|
||||
val file = File(normalized)
|
||||
if (file.isAbsolute) {
|
||||
return file.toPath().normalize().toString().replace("\\", "/")
|
||||
}
|
||||
|
||||
val basePath = project.basePath ?: return file.absolutePath.replace("\\", "/")
|
||||
return Paths.get(basePath).resolve(normalized).normalize().toString().replace("\\", "/")
|
||||
}
|
||||
|
||||
private fun decodeReadToolResultContent(content: String): String? {
|
||||
if (content.isBlank()) return ""
|
||||
|
||||
val numberedLines = content.lineSequence().mapNotNull { line ->
|
||||
val tabIndex = line.indexOf('\t')
|
||||
if (tabIndex <= 0) return@mapNotNull null
|
||||
val prefix = line.substring(0, tabIndex)
|
||||
if (!prefix.all { it.isDigit() }) return@mapNotNull null
|
||||
line.substring(tabIndex + 1)
|
||||
}.toList()
|
||||
if (numberedLines.isNotEmpty()) return numberedLines.joinToString(separator = "\n")
|
||||
|
||||
if (content.startsWith("Error reading file", ignoreCase = true)) return null
|
||||
return content
|
||||
}
|
||||
|
||||
private fun readFileText(path: String): String? {
|
||||
val virtualFile =
|
||||
LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return null
|
||||
val documentText = runReadAction {
|
||||
FileDocumentManager.getInstance().getDocument(virtualFile)?.text
|
||||
}
|
||||
if (documentText != null) return documentText
|
||||
return runCatching { String(virtualFile.contentsToByteArray(), Charsets.UTF_8) }.getOrNull()
|
||||
}
|
||||
|
||||
private fun writeFileText(path: String, content: String): Boolean {
|
||||
val virtualFile =
|
||||
LocalFileSystem.getInstance().refreshAndFindFileByPath(path) ?: return false
|
||||
val document = runReadAction { FileDocumentManager.getInstance().getDocument(virtualFile) }
|
||||
return runCatching {
|
||||
runInEdt {
|
||||
runWriteAction {
|
||||
if (document != null) {
|
||||
document.setText(content)
|
||||
FileDocumentManager.getInstance().saveDocument(document)
|
||||
} else {
|
||||
virtualFile.setBinaryContent(content.toByteArray(Charsets.UTF_8))
|
||||
}
|
||||
}
|
||||
}
|
||||
}.isSuccess
|
||||
}
|
||||
|
||||
private fun toProjectRelativePath(path: String): String {
|
||||
val basePath = project.basePath ?: return path
|
||||
return runCatching {
|
||||
val absolute = Paths.get(path).normalize()
|
||||
val base = Paths.get(basePath).normalize()
|
||||
if (absolute.startsWith(base)) {
|
||||
base.relativize(absolute).toString().replace("\\", "/")
|
||||
} else {
|
||||
path
|
||||
}
|
||||
}.getOrDefault(path)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
||||
import javax.swing.Icon
|
||||
import javax.swing.JComponent
|
||||
import ai.koog.prompt.message.Message as PromptMessage
|
||||
|
||||
internal data class RunTimelinePoint(
|
||||
val cacheKey: String,
|
||||
val checkpointRef: CheckpointRef?,
|
||||
val title: String,
|
||||
val subtitle: String,
|
||||
val runLabel: String? = null,
|
||||
val icon: Icon? = null,
|
||||
val nonSystemMessageCount: Int? = null,
|
||||
val outputText: String? = null,
|
||||
val toolCallId: String? = null,
|
||||
val kind: TimelinePointKind = TimelinePointKind.ASSISTANT,
|
||||
val runIndex: Int = 0
|
||||
)
|
||||
|
||||
internal enum class TimelinePointKind {
|
||||
USER,
|
||||
ASSISTANT,
|
||||
REASONING,
|
||||
TOOL_CALL
|
||||
}
|
||||
|
||||
internal data class TimelineRunGroup(
|
||||
val runNumber: Int,
|
||||
val points: List<RunTimelinePoint>,
|
||||
val toolCallCount: Int
|
||||
)
|
||||
|
||||
internal data class TimelineTreeEntry(
|
||||
val title: String,
|
||||
val subtitle: String = "",
|
||||
val icon: Icon? = null,
|
||||
val point: RunTimelinePoint? = null,
|
||||
val runNumber: Int? = null
|
||||
)
|
||||
|
||||
internal data class HistoricalRollbackOperation(
|
||||
val filePath: String,
|
||||
val searchText: String,
|
||||
val replacementText: String,
|
||||
val replaceAll: Boolean,
|
||||
val sourceTool: HistoricalRollbackSourceTool
|
||||
)
|
||||
|
||||
internal enum class HistoricalRollbackSourceTool(
|
||||
val displayName: String,
|
||||
val symbol: String
|
||||
) {
|
||||
EDIT(displayName = "Edit", symbol = "E"),
|
||||
WRITE(displayName = "Write", symbol = "W")
|
||||
}
|
||||
|
||||
internal data class SessionContextSnapshot(
|
||||
val baseHistory: List<PromptMessage>
|
||||
)
|
||||
|
||||
internal data class TimelineDialogUi(
|
||||
val component: JComponent,
|
||||
val getSelectedPoint: () -> RunTimelinePoint?,
|
||||
val getCheckedNonSystemMessageCounts: () -> Set<Int>,
|
||||
val refresh: () -> Unit,
|
||||
val setEditMode: (Boolean) -> Unit,
|
||||
val isEditMode: () -> Boolean
|
||||
)
|
||||
|
|
@ -2,21 +2,23 @@ package ee.carlrobert.codegpt.toolwindow.agent
|
|||
|
||||
import com.intellij.openapi.Disposable
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.application.EDT
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.Disposer
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import com.intellij.ui.AnimatedIcon
|
||||
import com.intellij.ui.JBColor
|
||||
import com.intellij.ui.components.JBLabel
|
||||
import com.intellij.util.ui.JBUI
|
||||
import com.intellij.util.ui.components.BorderLayoutPanel
|
||||
import ee.carlrobert.codegpt.CodeGPTBundle
|
||||
import ee.carlrobert.codegpt.agent.AgentService
|
||||
import ee.carlrobert.codegpt.agent.AgentToolOutputNotifier
|
||||
import ee.carlrobert.codegpt.agent.MessageWithContext
|
||||
import ee.carlrobert.codegpt.agent.ToolSpecs
|
||||
import ee.carlrobert.codegpt.agent.ToolRunContext
|
||||
import ee.carlrobert.codegpt.agent.*
|
||||
import ee.carlrobert.codegpt.agent.ProxyAIAgent.loadProjectInstructions
|
||||
import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService
|
||||
import ee.carlrobert.codegpt.agent.history.AgentCheckpointTurnSequencer
|
||||
import ee.carlrobert.codegpt.agent.history.CheckpointRef
|
||||
import ee.carlrobert.codegpt.agent.rollback.RollbackService
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
|
|
@ -42,9 +44,12 @@ import ee.carlrobert.codegpt.ui.queue.QueuedMessagePanel
|
|||
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
||||
import ee.carlrobert.codegpt.util.EditorUtil
|
||||
import ee.carlrobert.codegpt.util.StringUtil.stripThinkingBlocks
|
||||
import ee.carlrobert.codegpt.util.coroutines.CoroutineDispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.util.*
|
||||
import javax.swing.Box
|
||||
import javax.swing.BoxLayout
|
||||
import javax.swing.JComponent
|
||||
|
|
@ -54,6 +59,10 @@ class AgentToolWindowTabPanel(
|
|||
private val project: Project,
|
||||
private val agentSession: AgentSession
|
||||
) : BorderLayoutPanel(), Disposable {
|
||||
companion object {
|
||||
private const val RECOVERED_CONVERSATION_RENDER_BATCH_SIZE = 6
|
||||
}
|
||||
|
||||
private val replayJson = Json {
|
||||
ignoreUnknownKeys = true
|
||||
isLenient = true
|
||||
|
|
@ -63,6 +72,7 @@ class AgentToolWindowTabPanel(
|
|||
private val scrollablePanel = ChatToolWindowScrollablePanel()
|
||||
private val tagManager = TagManager()
|
||||
private val dispatchers = CoroutineDispatchers()
|
||||
private val backgroundScope = DisposableCoroutineScope(dispatchers.io())
|
||||
private val sessionId = agentSession.sessionId
|
||||
private val conversation = agentSession.conversation
|
||||
private val psiRepository = PsiStructureRepository(
|
||||
|
|
@ -107,18 +117,43 @@ class AgentToolWindowTabPanel(
|
|||
this,
|
||||
FeatureType.AGENT,
|
||||
tagManager,
|
||||
onSubmit = { text -> handleSubmit(text) },
|
||||
onStop = { handleCancel() },
|
||||
onSubmit = ::handleSubmit,
|
||||
onStop = ::handleCancel,
|
||||
withRemovableSelectedEditorTag = true,
|
||||
agentTokenCounterPanel = TokenUsageCounterPanel(project, sessionId),
|
||||
sessionIdProvider = { sessionId },
|
||||
conversationIdProvider = { conversation.id }
|
||||
conversationIdProvider = { conversation.id },
|
||||
onStartSessionTimeline = ::showSessionStartTimelineDialog
|
||||
)
|
||||
private var rollbackPanel: RollbackPanel
|
||||
private val todoListPanel = TodoListPanel()
|
||||
private val projectMessageBusConnection = project.messageBus.connect()
|
||||
private val appMessageBusConnection = ApplicationManager.getApplication().messageBus.connect()
|
||||
private val rollbackService = RollbackService.getInstance(project)
|
||||
private val historyService = project.service<AgentCheckpointHistoryService>()
|
||||
|
||||
private data class RunCardState(
|
||||
val runMessageId: UUID,
|
||||
val rollbackRunId: String,
|
||||
var responsePanel: ResponseMessagePanel,
|
||||
var sourceMessage: Message? = null,
|
||||
var completed: Boolean = false
|
||||
)
|
||||
|
||||
// Insertion order matters: timeline run numbering depends on the message order.
|
||||
private val runCardsByMessageId = linkedMapOf<UUID, RunCardState>()
|
||||
private var activeRunMessageId: UUID? = null
|
||||
private var recoveredConversationJob: Job? = null
|
||||
private var activeLandingPanel: AgentToolWindowLandingPanel? = null
|
||||
|
||||
private val timelineController = AgentSessionTimelineController(
|
||||
project = project,
|
||||
agentSession = agentSession,
|
||||
conversation = conversation,
|
||||
runStateForRunIndex = ::runStateForRunIndex,
|
||||
applySeededSessionState = ::applySeededSessionState,
|
||||
onAfterRollbackRefresh = ::refreshViewAfterRollback
|
||||
)
|
||||
|
||||
private val eventHandler = AgentEventHandler(
|
||||
project = project,
|
||||
|
|
@ -140,6 +175,12 @@ class AgentToolWindowTabPanel(
|
|||
repaint()
|
||||
rollbackPanel.refreshOperations()
|
||||
},
|
||||
onRunFinishedCallback = {
|
||||
markActiveRunCompleted()
|
||||
},
|
||||
onRunCheckpointUpdatedCallback = { runMessageId, ref ->
|
||||
updateRunCheckpoint(runMessageId, ref)
|
||||
},
|
||||
onQueuedMessagesResolved = { message ->
|
||||
runInEdt {
|
||||
clearQueuedMessagesAndCreateNewResponse(
|
||||
|
|
@ -163,12 +204,15 @@ class AgentToolWindowTabPanel(
|
|||
}
|
||||
|
||||
userInputPanel.setStopEnabled(false)
|
||||
Disposer.register(this, rollbackPanel)
|
||||
Disposer.register(this, eventHandler)
|
||||
Disposer.register(this, timelineController)
|
||||
Disposer.register(this, backgroundScope)
|
||||
}
|
||||
|
||||
private fun setupMessageBusSubscriptions() {
|
||||
project.service<AgentService>().queuedMessageProcessed.let { flow ->
|
||||
kotlinx.coroutines.CoroutineScope(dispatchers.io()).launch {
|
||||
backgroundScope.launch {
|
||||
flow.collect { processedMessage ->
|
||||
ApplicationManager.getApplication().invokeLater {
|
||||
removeQueuedMessage(processedMessage)
|
||||
|
|
@ -230,6 +274,7 @@ class AgentToolWindowTabPanel(
|
|||
|
||||
private fun handleSubmit(text: String) {
|
||||
if (text.isBlank()) return
|
||||
disposeLandingPanelIfPresent()
|
||||
scrollablePanel.clearLandingViewIfVisible()
|
||||
agentSession.serviceType =
|
||||
ModelSelectionService.getInstance().getServiceForFeature(FeatureType.AGENT)
|
||||
|
|
@ -255,7 +300,7 @@ class AgentToolWindowTabPanel(
|
|||
.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.RUNNING)
|
||||
}
|
||||
|
||||
rollbackService.startSession(sessionId)
|
||||
val rollbackRunId = rollbackService.startSession(sessionId)
|
||||
rollbackPanel.refreshOperations()
|
||||
|
||||
val message = MessageWithContext(text, userInputPanel.getSelectedTags())
|
||||
|
|
@ -282,8 +327,16 @@ class AgentToolWindowTabPanel(
|
|||
messagePanel.add(responsePanel)
|
||||
scrollablePanel.update()
|
||||
|
||||
registerRunCard(
|
||||
runMessageId = message.id,
|
||||
rollbackRunId = rollbackRunId,
|
||||
responsePanel = responsePanel,
|
||||
prompt = text
|
||||
)
|
||||
|
||||
eventHandler.resetForNewSubmission()
|
||||
eventHandler.setCurrentResponseBody(responseBody)
|
||||
eventHandler.setCurrentRollbackRunId(rollbackRunId)
|
||||
|
||||
loadingLabel.text = CodeGPTBundle.get("toolwindow.chat.loading")
|
||||
loadingLabel.isVisible = true
|
||||
|
|
@ -300,7 +353,15 @@ class AgentToolWindowTabPanel(
|
|||
agentService.cancelCurrentRun(sessionId)
|
||||
agentService.clearPendingMessages(sessionId)
|
||||
|
||||
rollbackService.finishSession(sessionId)
|
||||
val activeRun = activeRunMessageId?.let { runCardsByMessageId[it] }
|
||||
if (activeRun != null) {
|
||||
rollbackService.finishRun(activeRun.rollbackRunId)
|
||||
runCardsByMessageId.remove(activeRun.runMessageId)
|
||||
activeRunMessageId = null
|
||||
} else {
|
||||
rollbackService.finishSession(sessionId)
|
||||
}
|
||||
eventHandler.setCurrentRollbackRunId(null)
|
||||
rollbackPanel.refreshOperations()
|
||||
|
||||
approvalContainer.removeAll()
|
||||
|
|
@ -316,37 +377,182 @@ class AgentToolWindowTabPanel(
|
|||
}
|
||||
|
||||
private fun displayLandingView() {
|
||||
scrollablePanel.displayLandingView(createLandingView())
|
||||
disposeLandingPanelIfPresent()
|
||||
val landingPanel = createLandingView()
|
||||
activeLandingPanel = landingPanel
|
||||
scrollablePanel.displayLandingView(landingPanel)
|
||||
}
|
||||
|
||||
private fun displayRecoveredConversation() {
|
||||
disposeLandingPanelIfPresent()
|
||||
scrollablePanel.clearAll()
|
||||
conversation.messages.forEach { message ->
|
||||
val prompt = message.prompt.orEmpty()
|
||||
val wrapper = scrollablePanel.addMessage(message.id)
|
||||
val userPanel = UserMessagePanel(project, message, this)
|
||||
userPanel.addCopyAction { CopyAction.copyToClipboard(prompt) }
|
||||
wrapper.add(userPanel)
|
||||
recoveredConversationJob?.cancel()
|
||||
recoveredConversationJob = backgroundScope.launch {
|
||||
val recoveredTurns =
|
||||
runCatching { loadRecoveredTurnsFromResumeCheckpoint() }.getOrNull()
|
||||
renderRecoveredConversation(recoveredTurns)
|
||||
}
|
||||
}
|
||||
|
||||
val responseBody = ChatMessageResponseBody(
|
||||
project,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
this
|
||||
)
|
||||
addRecoveredToolCards(responseBody, message)
|
||||
responseBody.withResponse(message.response.orEmpty())
|
||||
val responsePanel = ResponseMessagePanel().apply {
|
||||
setResponseContent(responseBody)
|
||||
private suspend fun renderRecoveredConversation(
|
||||
recoveredTurns: List<AgentCheckpointTurnSequencer.Turn>?
|
||||
) {
|
||||
withContext(Dispatchers.EDT) {
|
||||
if (!isActive || project.isDisposed) return@withContext
|
||||
|
||||
val canRenderInOrder = recoveredTurns != null &&
|
||||
recoveredTurns.size == conversation.messages.size &&
|
||||
recoveredTurns.indices.all { index ->
|
||||
recoveredTurns[index].prompt == conversation.messages[index].prompt.orEmpty()
|
||||
.trim()
|
||||
}
|
||||
|
||||
val messages = conversation.messages.toList()
|
||||
var nextIndex = 0
|
||||
while (nextIndex < messages.size) {
|
||||
if (!isActive || project.isDisposed) return@withContext
|
||||
|
||||
val batchEnd = minOf(
|
||||
nextIndex + RECOVERED_CONVERSATION_RENDER_BATCH_SIZE,
|
||||
messages.size
|
||||
)
|
||||
|
||||
while (nextIndex < batchEnd) {
|
||||
if (!isActive || project.isDisposed) return@withContext
|
||||
|
||||
val index = nextIndex
|
||||
val message = messages[index]
|
||||
val prompt = message.prompt.orEmpty()
|
||||
val wrapper = scrollablePanel.addMessage(message.id)
|
||||
val userPanel = UserMessagePanel(project, message, this@AgentToolWindowTabPanel)
|
||||
userPanel.addCopyAction { CopyAction.copyToClipboard(prompt) }
|
||||
wrapper.add(userPanel)
|
||||
|
||||
val responseBody = ChatMessageResponseBody(
|
||||
project,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
this@AgentToolWindowTabPanel
|
||||
)
|
||||
|
||||
val renderedInOrder = if (canRenderInOrder) {
|
||||
renderRecoveredTurnInOrder(responseBody, recoveredTurns[index].events)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
if (!renderedInOrder) {
|
||||
addRecoveredToolCards(responseBody, message)
|
||||
responseBody.withResponse(message.response.orEmpty().stripThinkingBlocks())
|
||||
}
|
||||
|
||||
val responsePanel = ResponseMessagePanel().apply {
|
||||
setResponseContent(responseBody)
|
||||
}
|
||||
wrapper.add(responsePanel)
|
||||
registerRecoveredRunCard(message, responsePanel)
|
||||
nextIndex += 1
|
||||
}
|
||||
|
||||
scrollablePanel.update()
|
||||
if (nextIndex < messages.size) {
|
||||
yield()
|
||||
}
|
||||
}
|
||||
wrapper.add(responsePanel)
|
||||
|
||||
scrollablePanel.scrollToBottom()
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun loadRecoveredTurnsFromResumeCheckpoint(): List<AgentCheckpointTurnSequencer.Turn>? {
|
||||
val resumeRef = agentSession.resumeCheckpointRef ?: return null
|
||||
val checkpoint = historyService.loadCheckpoint(resumeRef)
|
||||
?: historyService.loadResumeCheckpoint(resumeRef)
|
||||
?: return null
|
||||
val projectInstructions = loadProjectInstructions(project.basePath)
|
||||
return AgentCheckpointTurnSequencer.toVisibleTurns(
|
||||
history = checkpoint.messageHistory,
|
||||
projectInstructions = projectInstructions,
|
||||
preserveSyntheticContinuation = true
|
||||
)
|
||||
}
|
||||
|
||||
private fun renderRecoveredTurnInOrder(
|
||||
responseBody: ChatMessageResponseBody,
|
||||
events: List<AgentCheckpointTurnSequencer.TurnEvent>
|
||||
): Boolean {
|
||||
if (events.isEmpty()) {
|
||||
return false
|
||||
}
|
||||
|
||||
scrollablePanel.update()
|
||||
scrollablePanel.scrollToBottom()
|
||||
val pendingById = mutableMapOf<String, ToolCallCard>()
|
||||
val pendingWithoutId = ArrayDeque<ToolCallCard>()
|
||||
var rendered = false
|
||||
|
||||
events.forEach { event ->
|
||||
when (event) {
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> {
|
||||
val text = event.content.stripThinkingBlocks()
|
||||
if (text.isNotBlank()) {
|
||||
responseBody.withResponse(text)
|
||||
rendered = true
|
||||
}
|
||||
}
|
||||
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> {
|
||||
val text = event.content.stripThinkingBlocks()
|
||||
if (text.isNotBlank()) {
|
||||
responseBody.withResponse(text)
|
||||
rendered = true
|
||||
}
|
||||
}
|
||||
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> {
|
||||
val toolName = event.tool.ifBlank { "Tool" }
|
||||
if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) {
|
||||
return@forEach
|
||||
}
|
||||
val rawArgs = event.content
|
||||
val args = parseRecoveredToolArgs(toolName, rawArgs)
|
||||
val card = createRecoveredToolCard(toolName, args, rawArgs)
|
||||
responseBody.addToolStatusPanel(card)
|
||||
val callId = event.id?.takeIf { it.isNotBlank() }
|
||||
if (callId != null) {
|
||||
pendingById[callId] = card
|
||||
} else {
|
||||
pendingWithoutId.addLast(card)
|
||||
}
|
||||
rendered = true
|
||||
}
|
||||
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> {
|
||||
val toolName = event.tool.ifBlank { "Tool" }
|
||||
if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) {
|
||||
return@forEach
|
||||
}
|
||||
val rawResult = event.content
|
||||
val parsedResult = parseRecoveredToolResult(toolName, rawResult)
|
||||
val success = inferRecoveredToolSuccess(parsedResult, rawResult)
|
||||
val card = event.id
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { pendingById.remove(it) }
|
||||
?: pendingWithoutId.pollFirst()
|
||||
?: run {
|
||||
val orphan = createRecoveredToolCard(toolName, "", "")
|
||||
responseBody.addToolStatusPanel(orphan)
|
||||
orphan
|
||||
}
|
||||
card.complete(success, parsedResult ?: rawResult)
|
||||
rendered = true
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return rendered
|
||||
}
|
||||
|
||||
private fun addRecoveredToolCards(responseBody: ChatMessageResponseBody, message: Message) {
|
||||
|
|
@ -355,7 +561,7 @@ class AgentToolWindowTabPanel(
|
|||
|
||||
toolCalls.forEach { toolCall ->
|
||||
val toolName = toolCall.function.name ?: return@forEach
|
||||
if (toolName == "TodoWrite") {
|
||||
if (AgentCheckpointTurnSequencer.isTodoWriteTool(toolName)) {
|
||||
return@forEach
|
||||
}
|
||||
|
||||
|
|
@ -419,6 +625,11 @@ class AgentToolWindowTabPanel(
|
|||
return AgentToolWindowLandingPanel(project)
|
||||
}
|
||||
|
||||
private fun disposeLandingPanelIfPresent() {
|
||||
activeLandingPanel?.let { Disposer.dispose(it) }
|
||||
activeLandingPanel = null
|
||||
}
|
||||
|
||||
fun getSessionId(): String = sessionId
|
||||
|
||||
fun getAgentSession(): AgentSession = agentSession
|
||||
|
|
@ -483,8 +694,19 @@ class AgentToolWindowTabPanel(
|
|||
responsePanel.setResponseContent(responseBody)
|
||||
messagePanel.add(responsePanel)
|
||||
|
||||
val rollbackRunId = activeRunMessageId
|
||||
?.let { runCardsByMessageId[it] }
|
||||
?.rollbackRunId
|
||||
|
||||
eventHandler.resetForNewSubmission()
|
||||
eventHandler.setCurrentResponseBody(responseBody)
|
||||
eventHandler.setCurrentRollbackRunId(rollbackRunId)
|
||||
|
||||
activeRunMessageId?.let { runMessageId ->
|
||||
runCardsByMessageId[runMessageId]?.let { state ->
|
||||
state.responsePanel = responsePanel
|
||||
}
|
||||
}
|
||||
|
||||
scrollablePanel.update()
|
||||
}
|
||||
|
|
@ -518,8 +740,118 @@ class AgentToolWindowTabPanel(
|
|||
}
|
||||
}
|
||||
|
||||
private fun registerRunCard(
|
||||
runMessageId: UUID,
|
||||
rollbackRunId: String,
|
||||
responsePanel: ResponseMessagePanel,
|
||||
prompt: String
|
||||
) {
|
||||
val state = RunCardState(
|
||||
runMessageId = runMessageId,
|
||||
rollbackRunId = rollbackRunId,
|
||||
responsePanel = responsePanel,
|
||||
sourceMessage = Message(prompt)
|
||||
)
|
||||
runCardsByMessageId[runMessageId] = state
|
||||
activeRunMessageId = runMessageId
|
||||
}
|
||||
|
||||
private fun registerRecoveredRunCard(message: Message, responsePanel: ResponseMessagePanel) {
|
||||
val copiedMessage = Message(
|
||||
message.prompt.orEmpty(),
|
||||
message.response.orEmpty().stripThinkingBlocks()
|
||||
).apply {
|
||||
message.toolCalls?.let { toolCalls = ArrayList(it) }
|
||||
message.toolCallResults?.let { toolCallResults = LinkedHashMap(it) }
|
||||
}
|
||||
val state = RunCardState(
|
||||
runMessageId = message.id,
|
||||
rollbackRunId = "",
|
||||
responsePanel = responsePanel,
|
||||
sourceMessage = copiedMessage,
|
||||
completed = true
|
||||
)
|
||||
runCardsByMessageId[message.id] = state
|
||||
timelineController.invalidateTimelineCache()
|
||||
}
|
||||
|
||||
private fun markActiveRunCompleted() {
|
||||
val runMessageId = activeRunMessageId ?: return
|
||||
val state = runCardsByMessageId[runMessageId] ?: return
|
||||
state.completed = true
|
||||
activeRunMessageId = null
|
||||
}
|
||||
|
||||
private fun updateRunCheckpoint(runMessageId: UUID, ref: CheckpointRef?) {
|
||||
val state = runCardsByMessageId[runMessageId] ?: return
|
||||
timelineController.invalidateTimelineCache()
|
||||
if (ref != null) {
|
||||
state.sourceMessage = null
|
||||
}
|
||||
}
|
||||
|
||||
private fun showSessionStartTimelineDialog() {
|
||||
timelineController.showSessionStartTimelineDialog()
|
||||
}
|
||||
|
||||
private fun runStateForRunIndex(runIndex: Int): AgentTimelineRunState? {
|
||||
if (runIndex <= 0) return null
|
||||
val state = runCardsByMessageId.values.elementAtOrNull(runIndex - 1) ?: return null
|
||||
return AgentTimelineRunState(
|
||||
rollbackRunId = state.rollbackRunId,
|
||||
sourceMessage = state.sourceMessage
|
||||
)
|
||||
}
|
||||
|
||||
private fun applySeededSessionState(seededConversation: Conversation, seedRef: CheckpointRef) {
|
||||
conversation.messages = seededConversation.messages
|
||||
runCardsByMessageId.clear()
|
||||
activeRunMessageId = null
|
||||
timelineController.invalidateTimelineCache()
|
||||
|
||||
agentSession.runtimeAgentId = seedRef.agentId
|
||||
agentSession.resumeCheckpointRef = seedRef
|
||||
|
||||
val contentManager = project.service<AgentToolWindowContentManager>()
|
||||
contentManager.setRuntimeAgentId(sessionId, seedRef.agentId)
|
||||
contentManager.setResumeCheckpointRef(sessionId, seedRef)
|
||||
|
||||
project.service<AgentService>().clearPendingMessages(sessionId)
|
||||
loadingLabel.isVisible = false
|
||||
clearQueuedMessages()
|
||||
approvalContainer.removeAll()
|
||||
approvalContainer.isVisible = false
|
||||
eventHandler.resetForNewSubmission()
|
||||
eventHandler.setCurrentRollbackRunId(null)
|
||||
userInputPanel.setStopEnabled(false)
|
||||
|
||||
if (conversation.messages.isEmpty()) {
|
||||
displayLandingView()
|
||||
} else {
|
||||
displayRecoveredConversation()
|
||||
}
|
||||
|
||||
refreshViewAfterRollback()
|
||||
runCatching {
|
||||
contentManager.setTabStatus(sessionId, AgentToolWindowTabbedPane.TabStatus.STOPPED)
|
||||
}
|
||||
}
|
||||
|
||||
private fun refreshViewAfterRollback() {
|
||||
timelineController.invalidateTimelineCache()
|
||||
runCatching { VirtualFileManager.getInstance().asyncRefresh(null) }
|
||||
rollbackPanel.refreshOperations()
|
||||
scrollablePanel.update()
|
||||
revalidate()
|
||||
repaint()
|
||||
}
|
||||
|
||||
override fun dispose() {
|
||||
recoveredConversationJob?.cancel()
|
||||
disposeLandingPanelIfPresent()
|
||||
ToolRunContext.cleanupSession(sessionId)
|
||||
runCardsByMessageId.clear()
|
||||
activeRunMessageId = null
|
||||
|
||||
projectMessageBusConnection.disconnect()
|
||||
appMessageBusConnection.disconnect()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import ee.carlrobert.codegpt.agent.ToolName
|
||||
import ee.carlrobert.codegpt.agent.ToolSpecs
|
||||
import ee.carlrobert.codegpt.agent.tools.EditTool
|
||||
import ee.carlrobert.codegpt.agent.tools.ReadTool
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import kotlinx.serialization.json.Json
|
||||
|
||||
internal object HistoricalRollbackCompatibility {
|
||||
|
||||
private val supportedTools = setOf(ToolName.READ, ToolName.EDIT, ToolName.WRITE)
|
||||
|
||||
fun resolveSupportedTool(rawToolName: String): ToolName? {
|
||||
val normalized = rawToolName.trim()
|
||||
if (normalized.isEmpty()) return null
|
||||
|
||||
val resolved = ToolSpecs.find(normalized)?.name ?: return null
|
||||
return resolved.takeIf { it in supportedTools }
|
||||
}
|
||||
|
||||
fun isSuccessfulResult(toolName: ToolName, content: String, replayJson: Json): Boolean {
|
||||
if (toolName !in supportedTools) return false
|
||||
|
||||
val normalized = content.trim()
|
||||
val decoded = ToolSpecs.decodeResultOrNull(replayJson, toolName.id, normalized)
|
||||
if (decoded != null) {
|
||||
return when (toolName) {
|
||||
ToolName.READ -> decoded is ReadTool.Result.Success
|
||||
ToolName.EDIT -> decoded is EditTool.Result.Success
|
||||
ToolName.WRITE -> decoded is WriteTool.Result.Success
|
||||
else -> false
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Is there a better way to determine if the result is successful?
|
||||
// This string comparison won't cut it
|
||||
return when (toolName) {
|
||||
ToolName.READ -> !normalized.startsWith(READ_ERROR_PREFIX, ignoreCase = true)
|
||||
ToolName.EDIT ->
|
||||
normalized.isNotBlank() &&
|
||||
!normalized.startsWith(EDIT_ERROR_PREFIX, ignoreCase = true) &&
|
||||
(normalized.contains(EDIT_SUCCESS_MARKER, ignoreCase = true) ||
|
||||
normalized.contains(LEGACY_EDIT_SUCCESS_MARKER, ignoreCase = true))
|
||||
ToolName.WRITE ->
|
||||
normalized.isNotBlank() &&
|
||||
normalized.contains(WRITE_SUCCESS_MARKER, ignoreCase = true) &&
|
||||
!normalized.startsWith(WRITE_ERROR_PREFIX, ignoreCase = true)
|
||||
else -> false
|
||||
}
|
||||
}
|
||||
|
||||
private const val READ_ERROR_PREFIX = "Error reading file"
|
||||
private const val EDIT_ERROR_PREFIX = "Error editing file"
|
||||
private const val WRITE_ERROR_PREFIX = "Error writing file"
|
||||
private const val EDIT_SUCCESS_MARKER = "Successfully edited file"
|
||||
private const val LEGACY_EDIT_SUCCESS_MARKER = "Successfully made"
|
||||
private const val WRITE_SUCCESS_MARKER = "successfully"
|
||||
}
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent.ui
|
||||
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.application.runInEdt
|
||||
import com.intellij.openapi.command.WriteCommandAction
|
||||
import com.intellij.openapi.components.service
|
||||
|
|
@ -21,27 +20,33 @@ import ee.carlrobert.codegpt.agent.history.AgentCheckpointConversationMapper
|
|||
import ee.carlrobert.codegpt.agent.history.AgentCheckpointHistoryService
|
||||
import ee.carlrobert.codegpt.agent.history.AgentHistoryThreadSummary
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISubagent
|
||||
import ee.carlrobert.codegpt.settings.agents.SubagentsConfigurable
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookConfig
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookConfiguration
|
||||
import ee.carlrobert.codegpt.settings.skills.SkillDescriptor
|
||||
import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService
|
||||
import ee.carlrobert.codegpt.tokens.TokenComputationService
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentToolWindowContentManager
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.history.AgentHistoryListPanel
|
||||
import ee.carlrobert.codegpt.toolwindow.ui.ResponseMessagePanel
|
||||
import ee.carlrobert.codegpt.ui.UIUtil.createTextPane
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
|
||||
import com.intellij.openapi.Disposable
|
||||
import java.awt.BorderLayout
|
||||
import java.awt.Color
|
||||
import java.awt.Desktop
|
||||
import java.net.URI
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Path
|
||||
import java.time.Instant
|
||||
import java.time.ZoneId
|
||||
import java.time.format.DateTimeFormatter
|
||||
import javax.swing.Box
|
||||
import javax.swing.BoxLayout
|
||||
import javax.swing.JPanel
|
||||
|
||||
class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessagePanel() {
|
||||
class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessagePanel(), Disposable {
|
||||
|
||||
companion object {
|
||||
private val logger = thisLogger()
|
||||
|
|
@ -49,7 +54,10 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
|
|||
|
||||
private val historyService = project.service<AgentCheckpointHistoryService>()
|
||||
private val historyListPanel = AgentHistoryListPanel(defaultLimit = 5)
|
||||
private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO)
|
||||
private var refreshHistory = true
|
||||
@Volatile
|
||||
private var disposed = false
|
||||
|
||||
private fun createLabel(text: String) = JBLabel(text)
|
||||
private fun createLink(text: String, onClick: () -> Unit) = ActionLink(text) { onClick() }
|
||||
|
|
@ -166,77 +174,39 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
|
|||
tokensLabel.foreground = healthColor(tokens)
|
||||
topRow.add(tokensLabel)
|
||||
|
||||
val bottomRow = JPanel()
|
||||
bottomRow.layout = BoxLayout(bottomRow, BoxLayout.X_AXIS)
|
||||
bottomRow.alignmentX = LEFT_ALIGNMENT
|
||||
val settingsService = project.service<ProxyAISettingsService>()
|
||||
val skills = project.service<SkillDiscoveryService>().listSkills()
|
||||
val subagents = settingsService.getSubagents()
|
||||
val hookEntries = collectHookEntries(settingsService.getHooks())
|
||||
val enabledHooksCount = hookEntries.count { it.hook.enabled }
|
||||
|
||||
val ts = vf.timeStamp
|
||||
val now = Instant.now()
|
||||
val fileTime = Instant.ofEpochMilli(ts)
|
||||
val zonedNow = now.atZone(ZoneId.systemDefault())
|
||||
val zonedFileTime = fileTime.atZone(ZoneId.systemDefault())
|
||||
|
||||
val timeLabel = when {
|
||||
zonedFileTime.toLocalDate() == zonedNow.toLocalDate() -> {
|
||||
val timeFormat =
|
||||
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
|
||||
JBLabel("Modified today at ${timeFormat.format(fileTime)}")
|
||||
}
|
||||
|
||||
zonedFileTime.toLocalDate() == zonedNow.minusDays(1).toLocalDate() -> {
|
||||
val timeFormat =
|
||||
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
|
||||
JBLabel("Modified yesterday at ${timeFormat.format(fileTime)}")
|
||||
}
|
||||
|
||||
fileTime.isAfter(now.minusSeconds(7 * 24 * 60 * 60)) -> {
|
||||
val dayFormat =
|
||||
DateTimeFormatter.ofPattern("EEEE").withZone(ZoneId.systemDefault())
|
||||
val timeFormat =
|
||||
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
|
||||
JBLabel(
|
||||
"Modified on ${dayFormat.format(fileTime)} at ${
|
||||
timeFormat.format(
|
||||
fileTime
|
||||
)
|
||||
}"
|
||||
)
|
||||
}
|
||||
|
||||
zonedFileTime.toLocalDate().year == zonedNow.toLocalDate().year -> {
|
||||
val dateFormat =
|
||||
DateTimeFormatter.ofPattern("MMM d").withZone(ZoneId.systemDefault())
|
||||
val timeFormat =
|
||||
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
|
||||
JBLabel(
|
||||
"Modified on ${dateFormat.format(fileTime)} at ${
|
||||
timeFormat.format(
|
||||
fileTime
|
||||
)
|
||||
}"
|
||||
)
|
||||
}
|
||||
|
||||
else -> {
|
||||
val dateFormat =
|
||||
DateTimeFormatter.ofPattern("MMM d, yyyy").withZone(ZoneId.systemDefault())
|
||||
val timeFormat =
|
||||
DateTimeFormatter.ofPattern("h:mm a").withZone(ZoneId.systemDefault())
|
||||
JBLabel(
|
||||
"Modified on ${dateFormat.format(fileTime)} at ${
|
||||
timeFormat.format(
|
||||
fileTime
|
||||
)
|
||||
}"
|
||||
)
|
||||
}
|
||||
}
|
||||
timeLabel.foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor
|
||||
bottomRow.add(timeLabel)
|
||||
val detailsRow = JPanel()
|
||||
detailsRow.layout = BoxLayout(detailsRow, BoxLayout.X_AXIS)
|
||||
detailsRow.alignmentX = LEFT_ALIGNMENT
|
||||
detailsRow.add(
|
||||
createHoverDetailsLabel(
|
||||
text = "Skills ${skills.size}",
|
||||
tooltip = buildSkillsTooltip(skills)
|
||||
)
|
||||
)
|
||||
detailsRow.add(createDetailsSeparator())
|
||||
detailsRow.add(
|
||||
createHoverDetailsLabel(
|
||||
text = "Hooks $enabledHooksCount",
|
||||
tooltip = buildHooksTooltip(hookEntries)
|
||||
)
|
||||
)
|
||||
detailsRow.add(createDetailsSeparator())
|
||||
detailsRow.add(
|
||||
createHoverDetailsLabel(
|
||||
text = "Subagents ${subagents.size}",
|
||||
tooltip = buildSubagentsTooltip(subagents)
|
||||
)
|
||||
)
|
||||
|
||||
container.add(topRow)
|
||||
container.add(Box.createVerticalStrut(2))
|
||||
container.add(bottomRow)
|
||||
container.add(detailsRow)
|
||||
}
|
||||
|
||||
panel.add(container)
|
||||
|
|
@ -299,22 +269,23 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
|
|||
limit: Int,
|
||||
onResult: (List<AgentHistoryThreadSummary>, Boolean, Int) -> Unit
|
||||
) {
|
||||
if (disposed || project.isDisposed) return
|
||||
val shouldRefresh = offset == 0 && refreshHistory
|
||||
ApplicationManager.getApplication().executeOnPooledThread {
|
||||
backgroundScope.launch {
|
||||
val page = runCatching {
|
||||
runBlocking {
|
||||
historyService.listThreadsPage(
|
||||
query = query,
|
||||
offset = offset,
|
||||
limit = limit,
|
||||
refresh = shouldRefresh
|
||||
)
|
||||
}
|
||||
historyService.listThreadsPage(
|
||||
query = query,
|
||||
offset = offset,
|
||||
limit = limit,
|
||||
refresh = shouldRefresh
|
||||
)
|
||||
}
|
||||
.onFailure { logger.warn("Failed to load checkpoint history", it) }
|
||||
.getOrNull()
|
||||
|
||||
if (disposed || project.isDisposed) return@launch
|
||||
runInEdt {
|
||||
if (disposed || project.isDisposed) return@runInEdt
|
||||
if (page != null) {
|
||||
refreshHistory = false
|
||||
}
|
||||
|
|
@ -324,25 +295,37 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
|
|||
}
|
||||
|
||||
private fun openCheckpointThread(thread: AgentHistoryThreadSummary) {
|
||||
ApplicationManager.getApplication().executeOnPooledThread {
|
||||
if (disposed || project.isDisposed) return
|
||||
backgroundScope.launch {
|
||||
val checkpoint = runCatching {
|
||||
historyService.loadCheckpoint(thread.latest)
|
||||
}.onFailure {
|
||||
logger.warn("Failed to open checkpoint thread ${thread.agentId}", it)
|
||||
}.getOrNull() ?: return@launch
|
||||
|
||||
val conversation = runCatching {
|
||||
val checkpoint = runBlocking { historyService.loadCheckpoint(thread.latest) }
|
||||
?: return@executeOnPooledThread
|
||||
AgentCheckpointConversationMapper.toConversation(
|
||||
checkpoint = checkpoint,
|
||||
projectInstructions = loadProjectInstructions(project.basePath)
|
||||
)
|
||||
}.onFailure {
|
||||
logger.warn("Failed to open checkpoint thread ${thread.agentId}", it)
|
||||
}.getOrNull() ?: return@executeOnPooledThread
|
||||
}.getOrNull() ?: return@launch
|
||||
|
||||
ApplicationManager.getApplication().invokeLater {
|
||||
if (disposed || project.isDisposed) return@launch
|
||||
runInEdt {
|
||||
if (disposed || project.isDisposed) return@runInEdt
|
||||
project.service<AgentToolWindowContentManager>()
|
||||
.openCheckpointConversation(thread, conversation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun dispose() {
|
||||
disposed = true
|
||||
backgroundScope.dispose()
|
||||
}
|
||||
|
||||
private fun welcomeMessage(): String {
|
||||
val name = GeneralSettings.getCurrentState().displayName
|
||||
return """
|
||||
|
|
@ -385,4 +368,93 @@ class AgentToolWindowLandingPanel(private val project: Project) : ResponseMessag
|
|||
val bl = (a.blue + (b.blue - a.blue) * t).toInt()
|
||||
return Color(r, g, bl)
|
||||
}
|
||||
|
||||
private fun createHoverDetailsLabel(text: String, tooltip: String): JBLabel {
|
||||
return JBLabel(text).apply {
|
||||
foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor
|
||||
toolTipText = tooltip
|
||||
}
|
||||
}
|
||||
|
||||
private fun createDetailsSeparator(): JBLabel {
|
||||
return JBLabel(" • ").apply {
|
||||
foreground = SimpleTextAttributes.GRAYED_ATTRIBUTES.fgColor
|
||||
}
|
||||
}
|
||||
|
||||
private fun collectHookEntries(configuration: HookConfiguration): List<HookEntry> = buildList {
|
||||
addAll(configuration.beforeToolUse.map { HookEntry("Before tool", it) })
|
||||
addAll(configuration.afterToolUse.map { HookEntry("After tool", it) })
|
||||
addAll(configuration.subagentStart.map { HookEntry("Subagent start", it) })
|
||||
addAll(configuration.subagentStop.map { HookEntry("Subagent stop", it) })
|
||||
addAll(configuration.beforeShellExecution.map { HookEntry("Before shell", it) })
|
||||
addAll(configuration.afterShellExecution.map { HookEntry("After shell", it) })
|
||||
addAll(configuration.beforeReadFile.map { HookEntry("Before read", it) })
|
||||
addAll(configuration.afterFileEdit.map { HookEntry("After edit", it) })
|
||||
addAll(configuration.stop.map { HookEntry("Stop", it) })
|
||||
}
|
||||
|
||||
private fun buildSkillsTooltip(skills: List<SkillDescriptor>): String {
|
||||
if (skills.isEmpty()) {
|
||||
return tooltipHtml("Skills", listOf("No skills discovered in .proxyai/skills"))
|
||||
}
|
||||
|
||||
val lines = skills.take(5).map { "• ${it.name} — ${it.title}" }.toMutableList()
|
||||
val remaining = skills.size - lines.size
|
||||
if (remaining > 0) {
|
||||
lines.add("+$remaining more")
|
||||
}
|
||||
|
||||
return tooltipHtml("Skills (${skills.size})", lines)
|
||||
}
|
||||
|
||||
private fun buildHooksTooltip(hookEntries: List<HookEntry>): String {
|
||||
val enabled = hookEntries.filter { it.hook.enabled }
|
||||
if (hookEntries.isEmpty()) {
|
||||
return tooltipHtml("Hooks", listOf("No hooks configured"))
|
||||
}
|
||||
|
||||
val lines = mutableListOf<String>()
|
||||
enabled.take(4).forEach { entry ->
|
||||
lines.add("• ${entry.event}: ${entry.hook.command}")
|
||||
}
|
||||
val remainingEnabled = enabled.size - 4
|
||||
if (remainingEnabled > 0) {
|
||||
lines.add("+$remainingEnabled more enabled")
|
||||
}
|
||||
|
||||
return tooltipHtml("Hooks", lines)
|
||||
}
|
||||
|
||||
private fun buildSubagentsTooltip(subagents: List<ProxyAISubagent>): String {
|
||||
if (subagents.isEmpty()) {
|
||||
return tooltipHtml("Subagents", listOf("No subagents configured"))
|
||||
}
|
||||
|
||||
val lines = subagents.take(5).map { "• ${it.title}" }.toMutableList()
|
||||
val remaining = subagents.size - lines.size
|
||||
if (remaining > 0) {
|
||||
lines.add("+$remaining more")
|
||||
}
|
||||
|
||||
return tooltipHtml("Subagents (${subagents.size})", lines)
|
||||
}
|
||||
|
||||
private fun tooltipHtml(title: String, lines: List<String>): String {
|
||||
val escapedTitle = escapeHtml(title)
|
||||
val escapedLines = lines.joinToString("<br>") { escapeHtml(it) }
|
||||
return "<html><b>$escapedTitle</b><br>$escapedLines</html>"
|
||||
}
|
||||
|
||||
private fun escapeHtml(value: String): String {
|
||||
return value
|
||||
.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
}
|
||||
|
||||
private data class HookEntry(
|
||||
val event: String,
|
||||
val hook: HookConfig
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import com.intellij.diff.DiffManager
|
|||
import com.intellij.diff.requests.SimpleDiffRequest
|
||||
import com.intellij.icons.AllIcons
|
||||
import com.intellij.ide.actions.OpenFileAction
|
||||
import com.intellij.openapi.Disposable
|
||||
import com.intellij.openapi.actionSystem.AnAction
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
import com.intellij.openapi.application.EDT
|
||||
|
|
@ -23,6 +24,7 @@ import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.drawCenteredText
|
|||
import ee.carlrobert.codegpt.toolwindow.agent.ui.renderer.lineDiffStats
|
||||
import ee.carlrobert.codegpt.ui.IconActionButton
|
||||
import ee.carlrobert.codegpt.ui.components.LeftEllipsisLabel
|
||||
import ee.carlrobert.codegpt.util.coroutines.DisposableCoroutineScope
|
||||
import kotlinx.coroutines.*
|
||||
import java.awt.BorderLayout
|
||||
import java.awt.Dimension
|
||||
|
|
@ -40,15 +42,11 @@ import java.time.format.DateTimeFormatter
|
|||
import java.util.concurrent.ConcurrentHashMap
|
||||
import javax.swing.*
|
||||
|
||||
/**
|
||||
* Panel showing file operations performed by the agent with rollback controls.
|
||||
*
|
||||
*/
|
||||
class RollbackPanel(
|
||||
private val project: Project,
|
||||
private val sessionId: String,
|
||||
private val onRollbackComplete: () -> Unit
|
||||
) : BorderLayoutPanel() {
|
||||
) : BorderLayoutPanel(), Disposable {
|
||||
private val rollbackService = RollbackService.getInstance(project)
|
||||
private val titleLabel = JBLabel()
|
||||
private val timeLabel = JBLabel()
|
||||
|
|
@ -58,7 +56,8 @@ class RollbackPanel(
|
|||
private val keepAllLink = createKeepAllLink()
|
||||
private val diffStatsCache = ConcurrentHashMap<String, Triple<Int, Int, Int>>()
|
||||
private val diffDataCache = ConcurrentHashMap<String, RollbackDiffData>()
|
||||
private val backgroundScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
private var lastSnapshotRunId: String? = null
|
||||
private val backgroundScope = DisposableCoroutineScope(Dispatchers.IO)
|
||||
|
||||
init {
|
||||
setupUI()
|
||||
|
|
@ -130,12 +129,26 @@ class RollbackPanel(
|
|||
.sortedBy { it.path }
|
||||
|
||||
withContext(Dispatchers.EDT) {
|
||||
refreshOperationsUI(changes, snapshot?.completedAt)
|
||||
refreshOperationsUI(
|
||||
changes = changes,
|
||||
completedAt = snapshot?.completedAt,
|
||||
snapshotRunId = snapshot?.runId
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun refreshOperationsUI(changes: List<FileChange>, completedAt: Instant?) {
|
||||
private fun refreshOperationsUI(
|
||||
changes: List<FileChange>,
|
||||
completedAt: Instant?,
|
||||
snapshotRunId: String?
|
||||
) {
|
||||
if (snapshotRunId != lastSnapshotRunId) {
|
||||
diffStatsCache.clear()
|
||||
diffDataCache.clear()
|
||||
lastSnapshotRunId = snapshotRunId
|
||||
}
|
||||
|
||||
isVisible = changes.isNotEmpty()
|
||||
if (changes.isEmpty()) {
|
||||
titleLabel.text = "Changes"
|
||||
|
|
@ -168,10 +181,10 @@ class RollbackPanel(
|
|||
}
|
||||
|
||||
private fun handleRollback() {
|
||||
CoroutineScope(Dispatchers.Main).launch {
|
||||
backgroundScope.launch {
|
||||
val result = rollbackService.rollbackSession(sessionId)
|
||||
|
||||
withContext(Dispatchers.Main) {
|
||||
withContext(Dispatchers.EDT) {
|
||||
when (result) {
|
||||
is RollbackResult.Success -> {
|
||||
refreshOperations()
|
||||
|
|
@ -424,10 +437,10 @@ class RollbackPanel(
|
|||
}
|
||||
|
||||
private fun rollbackFile(path: String) {
|
||||
CoroutineScope(Dispatchers.Main).launch {
|
||||
backgroundScope.launch {
|
||||
val result = rollbackService.rollbackFile(sessionId, path)
|
||||
|
||||
withContext(Dispatchers.Main) {
|
||||
withContext(Dispatchers.EDT) {
|
||||
when (result) {
|
||||
is RollbackResult.Success -> {
|
||||
refreshOperations()
|
||||
|
|
@ -466,7 +479,7 @@ class RollbackPanel(
|
|||
backgroundScope.launch {
|
||||
val diffData = rollbackService.getDiffData(sessionId, path) ?: return@launch
|
||||
diffDataCache[path] = diffData
|
||||
withContext(Dispatchers.Main) {
|
||||
withContext(Dispatchers.EDT) {
|
||||
openDiff(diffData)
|
||||
}
|
||||
}
|
||||
|
|
@ -480,4 +493,8 @@ class RollbackPanel(
|
|||
val request = SimpleDiffRequest(title, before, after, "Before", "After")
|
||||
DiffManager.getInstance().showDiff(project, request)
|
||||
}
|
||||
}
|
||||
|
||||
override fun dispose() {
|
||||
backgroundScope.dispose()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,6 +70,10 @@ abstract class BaseMessagePanel : BorderLayoutPanel() {
|
|||
)
|
||||
}
|
||||
|
||||
fun addHeaderAction(action: AnAction, actionCode: String) {
|
||||
addIconActionButton(IconActionButton(action, actionCode))
|
||||
}
|
||||
|
||||
fun addContent(content: JComponent) {
|
||||
body.addContent(content)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class UserInputPanel @JvmOverloads constructor(
|
|||
private val agentTokenCounterPanel: JComponent? = null,
|
||||
private val sessionIdProvider: (() -> String?)? = null,
|
||||
private val conversationIdProvider: (() -> UUID?)? = null,
|
||||
private val onStartSessionTimeline: (() -> Unit)? = null,
|
||||
) : BorderLayoutPanel() {
|
||||
|
||||
constructor(
|
||||
|
|
@ -98,6 +99,7 @@ class UserInputPanel @JvmOverloads constructor(
|
|||
null,
|
||||
withRemovableSelectedEditorTag,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
)
|
||||
|
||||
|
|
@ -220,7 +222,26 @@ class UserInputPanel @JvmOverloads constructor(
|
|||
} else {
|
||||
null
|
||||
}
|
||||
private val promptEnhancerSeparator = if (featureType == FeatureType.AGENT) {
|
||||
private val sessionTimelineButton = if (featureType == FeatureType.AGENT) {
|
||||
IconActionButton(
|
||||
object : AnAction(
|
||||
"Timeline",
|
||||
"Choose a timeline point from this session",
|
||||
AllIcons.Vcs.History
|
||||
) {
|
||||
override fun actionPerformed(e: AnActionEvent) {
|
||||
onStartSessionTimeline?.invoke()
|
||||
}
|
||||
},
|
||||
"SESSION_TIMELINE"
|
||||
).apply {
|
||||
isVisible = onStartSessionTimeline != null
|
||||
cursor = Cursor.getPredefinedCursor(Cursor.HAND_CURSOR)
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
private val sessionTimelineSeparator = if (featureType == FeatureType.AGENT) {
|
||||
createActionSeparator()
|
||||
} else {
|
||||
null
|
||||
|
|
@ -600,9 +621,12 @@ class UserInputPanel @JvmOverloads constructor(
|
|||
panel {
|
||||
row {
|
||||
if (applyChip != null) cell(applyChip).gap(RightGap.SMALL)
|
||||
if (promptEnhancerButton != null && promptEnhancerSeparator != null) {
|
||||
if (promptEnhancerButton != null) {
|
||||
cell(promptEnhancerButton).gap(RightGap.SMALL)
|
||||
cell(promptEnhancerSeparator).gap(RightGap.SMALL)
|
||||
}
|
||||
if (sessionTimelineButton != null && sessionTimelineSeparator != null) {
|
||||
cell(sessionTimelineButton).gap(RightGap.SMALL)
|
||||
cell(sessionTimelineSeparator).gap(RightGap.SMALL)
|
||||
}
|
||||
cell(submitButton).gap(RightGap.SMALL)
|
||||
cell(stopButton)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.FolderTagDetails
|
|||
class FolderActionItem(
|
||||
private val project: Project,
|
||||
private val folder: VirtualFile
|
||||
) : AbstractLookupActionItem() {
|
||||
) : AbstractLookupActionItem(), InsertsDisplayNameLookupItem {
|
||||
|
||||
override val displayName = folder.name
|
||||
override val icon = AllIcons.Nodes.Folder
|
||||
|
|
@ -33,4 +33,4 @@ class FolderActionItem(
|
|||
override fun execute(project: Project, userInputPanel: UserInputPanel) {
|
||||
userInputPanel.addTag(FolderTagDetails(folder))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ import com.intellij.openapi.vfs.VirtualFile
|
|||
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.AbstractLookupActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.InsertsDisplayNameLookupItem
|
||||
|
||||
class FileActionItem(private val project: Project, val file: VirtualFile) :
|
||||
AbstractLookupActionItem() {
|
||||
AbstractLookupActionItem(), InsertsDisplayNameLookupItem {
|
||||
|
||||
override val displayName = file.name
|
||||
override val icon = file.fileType.icon ?: AllIcons.FileTypes.Any_type
|
||||
|
|
@ -32,4 +33,4 @@ class FileActionItem(private val project: Project, val file: VirtualFile) :
|
|||
override fun execute(project: Project, userInputPanel: UserInputPanel) {
|
||||
userInputPanel.addTag(EditorTagDetails(file))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,13 @@ import ai.grazie.nlp.utils.takeWhitespaces
|
|||
|
||||
object StringUtil {
|
||||
|
||||
private val thinkBlockRegex = Regex("(?s)<think>.*?</think>\\s*")
|
||||
|
||||
fun String.stripThinkingBlocks(): String {
|
||||
if (this.isBlank()) return ""
|
||||
return thinkBlockRegex.replace(this, "").trim()
|
||||
}
|
||||
|
||||
fun adjustWhitespace(
|
||||
completionLine: String,
|
||||
editorLine: String
|
||||
|
|
@ -33,12 +40,4 @@ object StringUtil {
|
|||
val intersection = bigrams1.intersect(bigrams2).size
|
||||
return (2.0 * intersection) / (bigrams1.size + bigrams2.size)
|
||||
}
|
||||
|
||||
fun String.extractUntilNewline(): String {
|
||||
val index = this.indexOf('\n')
|
||||
if (index == -1) {
|
||||
return this
|
||||
}
|
||||
return this.substring(0, index + 1)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,26 +69,26 @@ class AgentServiceIntegrationTest : IntegrationTest() {
|
|||
val secondMessage = MessageWithContext("Second request")
|
||||
|
||||
submitAndAwait(sessionId, firstMessage)
|
||||
val firstSnapshot = snapshot(sessionId)
|
||||
val firstActualSnapshot = snapshot(sessionId)
|
||||
submitAndAwait(sessionId, secondMessage)
|
||||
val secondSnapshot = snapshot(sessionId)
|
||||
val secondActualSnapshot = snapshot(sessionId)
|
||||
|
||||
val runtimeAgentId = firstSnapshot.runtimeAgentId
|
||||
val managedService = factory.createdServices.single()
|
||||
assertThat(runtimeAgentId).isNotBlank()
|
||||
val actualRuntimeAgentId = firstActualSnapshot.runtimeAgentId
|
||||
val actualManagedService = factory.createdServices.single()
|
||||
assertThat(actualRuntimeAgentId).isNotBlank()
|
||||
assertThat(factory.createdServices).hasSize(1)
|
||||
assertThat(listOf(firstSnapshot, secondSnapshot))
|
||||
assertThat(listOf(firstActualSnapshot, secondActualSnapshot))
|
||||
.extracting("runtimeAgentId", "resumeCheckpointRef.agentId")
|
||||
.containsExactly(
|
||||
tuple(runtimeAgentId, runtimeAgentId),
|
||||
tuple(runtimeAgentId, runtimeAgentId)
|
||||
tuple(actualRuntimeAgentId, actualRuntimeAgentId),
|
||||
tuple(actualRuntimeAgentId, actualRuntimeAgentId)
|
||||
)
|
||||
assertThat(firstSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
|
||||
assertThat(secondSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
|
||||
assertThat(secondSnapshot.resumeCheckpointRef?.checkpointId)
|
||||
.isNotEqualTo(firstSnapshot.resumeCheckpointRef?.checkpointId)
|
||||
assertThat(managedService.createdAgentIds).hasSize(2)
|
||||
assertThat(managedService.createdAgentIds.last()).isEqualTo(runtimeAgentId)
|
||||
assertThat(firstActualSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
|
||||
assertThat(secondActualSnapshot.resumeCheckpointRef?.checkpointId).isNotBlank()
|
||||
assertThat(secondActualSnapshot.resumeCheckpointRef?.checkpointId)
|
||||
.isNotEqualTo(firstActualSnapshot.resumeCheckpointRef?.checkpointId)
|
||||
assertThat(actualManagedService.createdAgentIds).hasSize(2)
|
||||
assertThat(actualManagedService.createdAgentIds.last()).isEqualTo(actualRuntimeAgentId)
|
||||
}
|
||||
|
||||
fun testSubmitMessageRebuildsRuntimeWhenMcpSelectionChanges() {
|
||||
|
|
@ -102,12 +102,38 @@ class AgentServiceIntegrationTest : IntegrationTest() {
|
|||
submitAndAwait(sessionId, withoutMcp)
|
||||
submitAndAwait(sessionId, withMcp)
|
||||
|
||||
assertThat(factory.createdServices).hasSize(2)
|
||||
assertThat(factory.createdServices)
|
||||
val actualCreatedServices = factory.createdServices
|
||||
assertThat(actualCreatedServices).hasSize(2)
|
||||
assertThat(actualCreatedServices)
|
||||
.extracting("closeAllCalls")
|
||||
.containsExactly(1, 0)
|
||||
}
|
||||
|
||||
fun testSubmitMessageEmitsRunCheckpointCallback() {
|
||||
val sessionId = createSession("agent-runtime-callback")
|
||||
val factory = RecordingRuntimeFactory(agentModel())
|
||||
agentService.runtimeFactory = factory
|
||||
val message = MessageWithContext("Callback request")
|
||||
var callbackMessageId: UUID? = null
|
||||
var callbackRef: CheckpointRef? = null
|
||||
val events = object : AgentEvents {
|
||||
override fun onQueuedMessagesResolved() = Unit
|
||||
|
||||
override fun onRunCheckpointUpdated(runMessageId: UUID, ref: CheckpointRef?) {
|
||||
callbackMessageId = runMessageId
|
||||
callbackRef = ref
|
||||
}
|
||||
}
|
||||
|
||||
agentService.submitMessage(message, events, sessionId)
|
||||
awaitSessionToFinish(sessionId)
|
||||
|
||||
val actualSession = contentManager.getSession(sessionId)
|
||||
assertThat(callbackMessageId).isEqualTo(message.id)
|
||||
assertThat(callbackRef).isNotNull
|
||||
assertThat(callbackRef?.agentId).isEqualTo(actualSession?.runtimeAgentId)
|
||||
}
|
||||
|
||||
fun testGetCheckpointFallsBackToRuntimeAgentIdWhenResumeRefIsStale() {
|
||||
val sessionId = createSession("agent-checkpoint-fallback")
|
||||
val runtimeAgentId = "runtime-agent-${UUID.randomUUID()}"
|
||||
|
|
@ -133,11 +159,11 @@ class AgentServiceIntegrationTest : IntegrationTest() {
|
|||
contentManager.setRuntimeAgentId(sessionId, runtimeAgentId)
|
||||
contentManager.setResumeCheckpointRef(sessionId, staleRef)
|
||||
|
||||
val checkpoint = runBlocking { agentService.getCheckpoint(sessionId) }
|
||||
|
||||
assertThat(checkpoint).isNotNull
|
||||
assertThat(checkpoint!!.checkpointId).isEqualTo(checkpointId)
|
||||
assertThat(contentManager.getSession(sessionId)!!.resumeCheckpointRef)
|
||||
val actualCheckpoint = runBlocking { agentService.getCheckpoint(sessionId) }
|
||||
val actualSession = contentManager.getSession(sessionId)
|
||||
assertThat(actualCheckpoint).isNotNull
|
||||
assertThat(actualCheckpoint!!.checkpointId).isEqualTo(checkpointId)
|
||||
assertThat(actualSession!!.resumeCheckpointRef)
|
||||
.isEqualTo(CheckpointRef(runtimeAgentId, checkpointId))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ package ee.carlrobert.codegpt.agent
|
|||
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import ee.carlrobert.codegpt.agent.rollback.ChangeKind
|
||||
import ee.carlrobert.codegpt.agent.rollback.RollbackResult
|
||||
import ee.carlrobert.codegpt.agent.rollback.RollbackService
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
import java.io.File
|
||||
|
|
@ -31,10 +32,10 @@ class RollbackServiceTrackingTest : IntegrationTest() {
|
|||
filePath = filePath,
|
||||
originalContent = "before"
|
||||
)
|
||||
val snapshot = rollbackService.finishSession(sessionId)
|
||||
val actualSnapshot = rollbackService.finishSession(sessionId)
|
||||
|
||||
assertThat(snapshot!!.changes).hasSize(1)
|
||||
assertThat(snapshot.changes.single())
|
||||
assertThat(actualSnapshot!!.changes).hasSize(1)
|
||||
assertThat(actualSnapshot.changes.single())
|
||||
.extracting("path", "kind")
|
||||
.containsExactly(filePath, ChangeKind.MODIFIED)
|
||||
}
|
||||
|
|
@ -46,10 +47,10 @@ class RollbackServiceTrackingTest : IntegrationTest() {
|
|||
val sessionId = "s2"
|
||||
rollbackService.startSession(sessionId)
|
||||
|
||||
rollbackService.trackWrite(sessionId, filePath, WriteTool.Args(filePath, "hello"))
|
||||
val snapshot = rollbackService.finishSession(sessionId)
|
||||
rollbackService.trackWrite(sessionId, filePath)
|
||||
val actualSnapshot = rollbackService.finishSession(sessionId)
|
||||
|
||||
assertThat(snapshot).isNotNull
|
||||
assertThat(actualSnapshot).isNotNull
|
||||
.extracting { it!!.changes.single().kind }
|
||||
.isEqualTo(ChangeKind.ADDED)
|
||||
}
|
||||
|
|
@ -62,14 +63,81 @@ class RollbackServiceTrackingTest : IntegrationTest() {
|
|||
val sessionId = "s3"
|
||||
rollbackService.startSession(sessionId)
|
||||
|
||||
rollbackService.trackWrite(sessionId, filePath, WriteTool.Args(filePath, "after"))
|
||||
rollbackService.trackWrite(sessionId, filePath)
|
||||
file.writeText("after")
|
||||
LocalFileSystem.getInstance().refreshAndFindFileByPath(filePath)
|
||||
val snapshot = rollbackService.finishSession(sessionId)
|
||||
val actualSnapshot = rollbackService.finishSession(sessionId)
|
||||
val actualDiff = rollbackService.getDiffData(sessionId, filePath)
|
||||
|
||||
assertThat(snapshot!!.changes.single().kind).isEqualTo(ChangeKind.MODIFIED)
|
||||
assertThat(rollbackService.getDiffData(sessionId, filePath))
|
||||
assertThat(actualSnapshot!!.changes.single().kind).isEqualTo(ChangeKind.MODIFIED)
|
||||
assertThat(actualDiff)
|
||||
.extracting("beforeText", "afterText")
|
||||
.containsExactly("before", "after")
|
||||
}
|
||||
}
|
||||
|
||||
fun testSnapshotsRemainAvailablePerRunWithinSameSession() {
|
||||
val rollbackService = RollbackService(project)
|
||||
val sessionId = "run-scoped-session"
|
||||
|
||||
val firstPath = getTestFilePath("run_scoped_a.txt")
|
||||
File(firstPath).delete()
|
||||
val firstRunId = rollbackService.startSession(sessionId)
|
||||
rollbackService.trackWrite(sessionId, firstPath)
|
||||
File(firstPath).writeText("run-a")
|
||||
rollbackService.finishSession(sessionId)
|
||||
|
||||
val secondPath = getTestFilePath("run_scoped_b.txt")
|
||||
File(secondPath).delete()
|
||||
val secondRunId = rollbackService.startSession(sessionId)
|
||||
rollbackService.trackWrite(sessionId, secondPath)
|
||||
File(secondPath).writeText("run-b")
|
||||
rollbackService.finishSession(sessionId)
|
||||
|
||||
val firstRunSnapshot = rollbackService.getRunSnapshot(firstRunId)
|
||||
val secondRunSnapshot = rollbackService.getRunSnapshot(secondRunId)
|
||||
val latestSessionSnapshot = rollbackService.getSnapshot(sessionId)
|
||||
|
||||
assertThat(firstRunSnapshot)
|
||||
.extracting("runId")
|
||||
.isEqualTo(firstRunId)
|
||||
assertThat(secondRunSnapshot)
|
||||
.extracting("runId")
|
||||
.isEqualTo(secondRunId)
|
||||
assertThat(latestSessionSnapshot)
|
||||
.extracting("runId")
|
||||
.isEqualTo(secondRunId)
|
||||
}
|
||||
|
||||
fun testRollbackRunDoesNotClearOtherRunSnapshots() {
|
||||
val rollbackService = RollbackService(project)
|
||||
val sessionId = "run-rollback-session"
|
||||
|
||||
val firstPath = getTestFilePath("run_rollback_a.txt")
|
||||
File(firstPath).delete()
|
||||
val firstRunId = rollbackService.startSession(sessionId)
|
||||
rollbackService.trackWrite(sessionId, firstPath)
|
||||
File(firstPath).writeText("run-a")
|
||||
rollbackService.finishSession(sessionId)
|
||||
|
||||
val secondPath = getTestFilePath("run_rollback_b.txt")
|
||||
File(secondPath).delete()
|
||||
val secondRunId = rollbackService.startSession(sessionId)
|
||||
rollbackService.trackWrite(sessionId, secondPath)
|
||||
File(secondPath).writeText("run-b")
|
||||
rollbackService.finishSession(sessionId)
|
||||
|
||||
val actualRollbackResult = runBlocking {
|
||||
rollbackService.rollbackRun(firstRunId)
|
||||
}
|
||||
val firstRunSnapshotAfterRollback = rollbackService.getRunSnapshot(firstRunId)
|
||||
val secondRunSnapshotAfterRollback = rollbackService.getRunSnapshot(secondRunId)
|
||||
val latestSessionSnapshot = rollbackService.getSnapshot(sessionId)
|
||||
|
||||
assertThat(actualRollbackResult).isInstanceOf(RollbackResult.Success::class.java)
|
||||
assertThat(firstRunSnapshotAfterRollback).isNull()
|
||||
assertThat(secondRunSnapshotAfterRollback).isNotNull()
|
||||
assertThat(latestSessionSnapshot)
|
||||
.extracting("runId")
|
||||
.isEqualTo(secondRunId)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData
|
|||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.message.RequestMetaInfo
|
||||
import ai.koog.prompt.message.ResponseMetaInfo
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import kotlinx.datetime.Clock
|
||||
import kotlinx.datetime.Instant
|
||||
import kotlinx.serialization.json.JsonNull
|
||||
|
|
@ -15,7 +14,6 @@ import org.assertj.core.groups.Tuple.tuple
|
|||
import kotlin.test.Test
|
||||
|
||||
class AgentCheckpointConversationMapperTest {
|
||||
private val actualUserTurn = "Actual user prompt" to "Actual response"
|
||||
|
||||
@Test
|
||||
fun `maps user and assistant turns into conversation messages`() {
|
||||
|
|
@ -29,9 +27,9 @@ class AgentCheckpointConversationMapperTest {
|
|||
)
|
||||
)
|
||||
|
||||
val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
|
||||
assertThat(conversation.messages)
|
||||
assertThat(actualConversation.messages)
|
||||
.extracting("prompt", "response")
|
||||
.containsExactly(
|
||||
tuple("First prompt", "First response"),
|
||||
|
|
@ -60,14 +58,14 @@ class AgentCheckpointConversationMapperTest {
|
|||
)
|
||||
)
|
||||
|
||||
val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
val mapped = conversation.messages.single()
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
val actualMessage = actualConversation.messages.single()
|
||||
|
||||
assertThat(mapped.response).isEqualTo("Working on it")
|
||||
assertThat(mapped.toolCalls.orEmpty())
|
||||
assertThat(actualMessage.response).isEqualTo("Working on it")
|
||||
assertThat(actualMessage.toolCalls.orEmpty())
|
||||
.extracting("function.name", "function.arguments")
|
||||
.containsExactly(tuple("Bash", """{"command":"ls"}"""))
|
||||
assertThat(mapped.toolCallResults).containsEntry("tool-1", "file1.txt")
|
||||
assertThat(actualMessage.toolCallResults).containsEntry("tool-1", "file1.txt")
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -81,12 +79,14 @@ class AgentCheckpointConversationMapperTest {
|
|||
)
|
||||
)
|
||||
|
||||
val conversation = AgentCheckpointConversationMapper.toConversation(
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(
|
||||
checkpoint = checkpoint,
|
||||
projectInstructions = projectInstructions
|
||||
)
|
||||
|
||||
assertSingleTurn(conversation, actualUserTurn)
|
||||
assertThat(actualConversation.messages)
|
||||
.extracting("prompt", "response")
|
||||
.containsExactly(tuple("Actual user prompt", "Actual response"))
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -102,12 +102,43 @@ class AgentCheckpointConversationMapperTest {
|
|||
)
|
||||
)
|
||||
|
||||
val conversation = AgentCheckpointConversationMapper.toConversation(
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(
|
||||
checkpoint = checkpoint,
|
||||
projectInstructions = null
|
||||
)
|
||||
|
||||
assertSingleTurn(conversation, actualUserTurn)
|
||||
assertThat(actualConversation.messages)
|
||||
.extracting("prompt", "response")
|
||||
.containsExactly(tuple("Actual user prompt", "Actual response"))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `does not merge hidden user turn output into previous visible turn`() {
|
||||
val checkpoint = checkpoint(
|
||||
listOf(
|
||||
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
|
||||
Message.User(
|
||||
content = "Hidden cacheable instruction",
|
||||
metaInfo = cacheableMetaInfo()
|
||||
),
|
||||
Message.Assistant("Hidden response", ResponseMetaInfo.Empty),
|
||||
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Visible response 2", ResponseMetaInfo.Empty)
|
||||
)
|
||||
)
|
||||
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(
|
||||
checkpoint = checkpoint,
|
||||
projectInstructions = null
|
||||
)
|
||||
|
||||
assertThat(actualConversation.messages)
|
||||
.extracting("prompt", "response")
|
||||
.containsExactly(
|
||||
tuple("Visible prompt 1", "Visible response 1"),
|
||||
tuple("Visible prompt 2", "Visible response 2")
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -136,12 +167,80 @@ class AgentCheckpointConversationMapperTest {
|
|||
)
|
||||
)
|
||||
|
||||
val conversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
|
||||
assertThat(conversation.messages.single().toolCallResults)
|
||||
assertThat(actualConversation.messages.single().toolCallResults)
|
||||
.containsEntry("tool-1", "file1.txt\n\nfile2.txt")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `filters synthetic timeline user turn and its output`() {
|
||||
val checkpoint = checkpoint(
|
||||
listOf(
|
||||
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
|
||||
Message.User("I haven't created a todo list yet.", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Synthetic response", ResponseMetaInfo.Empty),
|
||||
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Visible response 2", ResponseMetaInfo.Empty)
|
||||
)
|
||||
)
|
||||
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(
|
||||
checkpoint = checkpoint,
|
||||
projectInstructions = null
|
||||
)
|
||||
|
||||
assertThat(actualConversation.messages)
|
||||
.extracting("prompt", "response")
|
||||
.containsExactly(
|
||||
tuple("Visible prompt 1", "Visible response 1"),
|
||||
tuple("Visible prompt 2", "Visible response 2")
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `filters internal timeline tools from mapped message`() {
|
||||
val checkpoint = checkpoint(
|
||||
listOf(
|
||||
Message.User("Visible prompt", RequestMetaInfo.Empty),
|
||||
Message.Tool.Call(
|
||||
id = "todo-1",
|
||||
tool = "TodoWrite",
|
||||
content = """{"todos":[{"content":"x","status":"pending"}]}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "todo-1",
|
||||
tool = "TodoWrite",
|
||||
content = "ok",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Call(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = """{"command":"ls"}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = "file1.txt",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(checkpoint, null)
|
||||
|
||||
val actualMessage = actualConversation.messages.single()
|
||||
assertThat(actualMessage.toolCalls.orEmpty())
|
||||
.extracting("function.name")
|
||||
.containsExactly("Bash")
|
||||
assertThat(actualMessage.toolCallResults)
|
||||
.containsOnlyKeys("tool-1")
|
||||
}
|
||||
|
||||
private fun cacheableMetaInfo(): RequestMetaInfo {
|
||||
return RequestMetaInfo(
|
||||
timestamp = Clock.System.now(),
|
||||
|
|
@ -149,15 +248,6 @@ class AgentCheckpointConversationMapperTest {
|
|||
)
|
||||
}
|
||||
|
||||
private fun assertSingleTurn(
|
||||
conversation: Conversation,
|
||||
expectedTurn: Pair<String, String>
|
||||
) {
|
||||
assertThat(conversation.messages)
|
||||
.extracting("prompt", "response")
|
||||
.containsExactly(tuple(expectedTurn.first, expectedTurn.second))
|
||||
}
|
||||
|
||||
private fun checkpoint(history: List<Message>): AgentCheckpointData {
|
||||
return AgentCheckpointData(
|
||||
checkpointId = "checkpoint-1",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,248 @@
|
|||
package ee.carlrobert.codegpt.agent.history
|
||||
|
||||
import ai.koog.agents.snapshot.feature.AgentCheckpointData
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.message.RequestMetaInfo
|
||||
import ai.koog.prompt.message.ResponseMetaInfo
|
||||
import kotlinx.datetime.Clock
|
||||
import kotlinx.datetime.Instant
|
||||
import kotlinx.serialization.json.JsonNull
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import kotlin.test.Test
|
||||
|
||||
class AgentCheckpointTurnSequencerTest {
|
||||
|
||||
@Test
|
||||
fun `preserves assistant and tool event order within a turn`() {
|
||||
val history = listOf(
|
||||
Message.User("Prompt 1", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Assistant 1", ResponseMetaInfo.Empty),
|
||||
Message.Tool.Call(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = """{"command":"ls"}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Assistant("Assistant 2", ResponseMetaInfo.Empty),
|
||||
Message.Tool.Result(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = "file.txt",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Reasoning(content = "Reasoning 1", metaInfo = ResponseMetaInfo.Empty),
|
||||
Message.User("Prompt 2", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Assistant 3", ResponseMetaInfo.Empty)
|
||||
)
|
||||
|
||||
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(
|
||||
history = history,
|
||||
projectInstructions = null
|
||||
)
|
||||
|
||||
assertThat(actualTurns).hasSize(2)
|
||||
assertThat(actualTurns[0].prompt).isEqualTo("Prompt 1")
|
||||
assertThat(actualTurns[0].userNonSystemMessageCount).isEqualTo(1)
|
||||
assertThat(actualTurns[0].events.map(::eventSignature))
|
||||
.containsExactly(
|
||||
"assistant:Assistant 1",
|
||||
"tool-call:Bash:tool-1",
|
||||
"assistant:Assistant 2",
|
||||
"tool-result:Bash:tool-1",
|
||||
"reasoning:Reasoning 1"
|
||||
)
|
||||
assertThat(actualTurns[0].events.map { it.nonSystemMessageCount })
|
||||
.containsExactly(2, 3, 4, 5, 6)
|
||||
assertThat(actualTurns[1].prompt).isEqualTo("Prompt 2")
|
||||
assertThat(actualTurns[1].userNonSystemMessageCount).isEqualTo(7)
|
||||
assertThat(actualTurns[1].events.map(::eventSignature))
|
||||
.containsExactly("assistant:Assistant 3")
|
||||
assertThat(actualTurns[1].events.map { it.nonSystemMessageCount })
|
||||
.containsExactly(8)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `filters hidden and synthetic user turns and internal tools`() {
|
||||
val history = listOf(
|
||||
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
|
||||
Message.User("Hidden cacheable prompt", cacheableMetaInfo()),
|
||||
Message.Assistant("Hidden response", ResponseMetaInfo.Empty),
|
||||
Message.User("I haven't created a todo list yet.", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Synthetic response", ResponseMetaInfo.Empty),
|
||||
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
|
||||
Message.Tool.Call(
|
||||
id = "todo-1",
|
||||
tool = "TodoWrite",
|
||||
content = """{"todos":[{"content":"x","status":"pending"}]}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "todo-1",
|
||||
tool = "TodoWrite",
|
||||
content = "ok",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Call(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = """{"command":"ls"}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = "file.txt",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Assistant("Visible response 2", ResponseMetaInfo.Empty)
|
||||
)
|
||||
|
||||
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(
|
||||
history = history,
|
||||
projectInstructions = null
|
||||
)
|
||||
|
||||
assertThat(actualTurns).hasSize(2)
|
||||
assertThat(actualTurns.map { it.prompt })
|
||||
.containsExactly("Visible prompt 1", "Visible prompt 2")
|
||||
assertThat(actualTurns.map { it.userNonSystemMessageCount })
|
||||
.containsExactly(1, 7)
|
||||
assertThat(actualTurns[0].events.map(::eventSignature))
|
||||
.containsExactly("assistant:Visible response 1")
|
||||
assertThat(actualTurns[0].events.map { it.nonSystemMessageCount })
|
||||
.containsExactly(2)
|
||||
assertThat(actualTurns[1].events.map(::eventSignature))
|
||||
.containsExactly(
|
||||
"tool-call:Bash:tool-1",
|
||||
"tool-result:Bash:tool-1",
|
||||
"assistant:Visible response 2"
|
||||
)
|
||||
assertThat(actualTurns[1].events.map { it.nonSystemMessageCount })
|
||||
.containsExactly(10, 11, 12)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `mapper and sequencer keep same visible turn prompts`() {
|
||||
val history = listOf(
|
||||
Message.User("Visible prompt 1", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Visible response 1", ResponseMetaInfo.Empty),
|
||||
Message.User("Hidden cacheable prompt", cacheableMetaInfo()),
|
||||
Message.Assistant("Hidden response", ResponseMetaInfo.Empty),
|
||||
Message.User("Visible prompt 2", RequestMetaInfo.Empty),
|
||||
Message.Tool.Call(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = """{"command":"ls"}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "tool-1",
|
||||
tool = "Bash",
|
||||
content = "file.txt",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
)
|
||||
)
|
||||
|
||||
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(history, null)
|
||||
val actualConversation = AgentCheckpointConversationMapper.toConversation(
|
||||
checkpoint = AgentCheckpointData(
|
||||
checkpointId = "checkpoint-1",
|
||||
createdAt = Instant.parse("2026-02-04T00:00:00Z"),
|
||||
nodePath = "agent/single_run/nodeExecuteTool",
|
||||
lastInput = JsonNull,
|
||||
messageHistory = history,
|
||||
version = 0
|
||||
),
|
||||
projectInstructions = null
|
||||
)
|
||||
|
||||
assertThat(actualConversation.messages.map { it.prompt })
|
||||
.containsExactlyElementsOf(actualTurns.map { it.prompt })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `keeps events after synthetic todo user message within the active visible turn`() {
|
||||
val history = listOf(
|
||||
Message.User("Implement feature X", RequestMetaInfo.Empty),
|
||||
Message.Assistant("Starting implementation", ResponseMetaInfo.Empty),
|
||||
Message.Tool.Call(
|
||||
id = "bash-1",
|
||||
tool = "Bash",
|
||||
content = """{"command":"ls"}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "bash-1",
|
||||
tool = "Bash",
|
||||
content = "file.txt",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.User(
|
||||
"It seems that you haven't created a todo list yet. If the task on hand requires multiple steps then create a todo list to track your changes.",
|
||||
RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Call(
|
||||
id = "todo-1",
|
||||
tool = "TodoWrite",
|
||||
content = """{"todos":[{"content":"x","status":"pending"}]}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "todo-1",
|
||||
tool = "TodoWrite",
|
||||
content = "ok",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Call(
|
||||
id = "read-1",
|
||||
tool = "Read",
|
||||
content = """{"path":"src/Main.kt"}""",
|
||||
metaInfo = ResponseMetaInfo.Empty
|
||||
),
|
||||
Message.Tool.Result(
|
||||
id = "read-1",
|
||||
tool = "Read",
|
||||
content = "content",
|
||||
metaInfo = RequestMetaInfo.Empty
|
||||
),
|
||||
Message.Assistant("Done", ResponseMetaInfo.Empty)
|
||||
)
|
||||
|
||||
val actualTurns = AgentCheckpointTurnSequencer.toVisibleTurns(
|
||||
history = history,
|
||||
projectInstructions = null,
|
||||
preserveSyntheticContinuation = true
|
||||
)
|
||||
|
||||
assertThat(actualTurns).hasSize(1)
|
||||
assertThat(actualTurns.single().prompt).isEqualTo("Implement feature X")
|
||||
assertThat(actualTurns.single().events.map(::eventSignature))
|
||||
.containsExactly(
|
||||
"assistant:Starting implementation",
|
||||
"tool-call:Bash:bash-1",
|
||||
"tool-result:Bash:bash-1",
|
||||
"tool-call:Read:read-1",
|
||||
"tool-result:Read:read-1",
|
||||
"assistant:Done"
|
||||
)
|
||||
}
|
||||
|
||||
private fun eventSignature(event: AgentCheckpointTurnSequencer.TurnEvent): String {
|
||||
return when (event) {
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.Assistant -> "assistant:${event.content}"
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.Reasoning -> "reasoning:${event.content}"
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.ToolCall -> "tool-call:${event.tool}:${event.id}"
|
||||
is AgentCheckpointTurnSequencer.TurnEvent.ToolResult -> "tool-result:${event.tool}:${event.id}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun cacheableMetaInfo(): RequestMetaInfo {
|
||||
return RequestMetaInfo(
|
||||
timestamp = Clock.System.now(),
|
||||
metadata = JsonObject(mapOf("cacheable" to JsonPrimitive(true)))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -17,7 +17,6 @@ class ChatToolWindowTabbedPaneTest : BasePlatformTestCase() {
|
|||
assertThat(tabbedPane.activeTabMapping).isEmpty()
|
||||
}
|
||||
|
||||
|
||||
fun testAddingNewTabs() {
|
||||
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue