mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 16:28:46 +00:00
refactor: session checkpoint centralization
This commit is contained in:
parent
e6e92a3616
commit
8552f4e634
8 changed files with 63 additions and 173 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import ai.koog.agents.core.agent.AIAgent
|
||||
import ai.koog.agents.snapshot.feature.AgentCheckpointData
|
||||
import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider
|
||||
import ai.koog.prompt.executor.clients.LLMClientException
|
||||
import com.intellij.openapi.components.Service
|
||||
|
|
@ -10,7 +11,6 @@ import com.intellij.openapi.project.Project
|
|||
import ee.carlrobert.codegpt.conversations.message.TokenUsageTracker
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentSessionState
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.asSharedFlow
|
||||
|
|
@ -40,6 +40,22 @@ class AgentService(private val project: Project) {
|
|||
pendingMessages.getOrPut(sessionId) { ArrayDeque() }.add(message)
|
||||
}
|
||||
|
||||
suspend fun getCheckpoint(sessionId: String): AgentCheckpointData? {
|
||||
val prevAgentId = sessionAgents[sessionId]?.id ?: return null
|
||||
return runCatching {
|
||||
checkpointStorage.getCheckpoints(prevAgentId)
|
||||
.filter { it.nodePath != "tombstone" }
|
||||
.maxByOrNull { it.createdAt }
|
||||
}.onFailure { ex ->
|
||||
val sessionInfo = sessionId.let { " session=$it" }
|
||||
logger.error(
|
||||
"Agent checkpoints: failed to load for$sessionInfo agentId=$prevAgentId " +
|
||||
"error=${ex.message}",
|
||||
ex
|
||||
)
|
||||
}.getOrNull()
|
||||
}
|
||||
|
||||
fun submitMessage(message: MessageWithContext, events: AgentEvents, sessionId: String) {
|
||||
if (isSessionRunning(sessionId)) {
|
||||
addToQueue(message, sessionId)
|
||||
|
|
@ -47,27 +63,7 @@ class AgentService(private val project: Project) {
|
|||
}
|
||||
|
||||
val provider = service<ModelSelectionService>().getServiceForFeature(FeatureType.AGENT)
|
||||
val previousCheckpoint = runBlocking {
|
||||
val prevAgentId =
|
||||
project.service<AgentSessionState>().state.sessions
|
||||
.firstOrNull { it.sessionId == sessionId }?.lastAgentId
|
||||
if (prevAgentId == null) {
|
||||
return@runBlocking null
|
||||
}
|
||||
|
||||
runCatching {
|
||||
val checkpoints = checkpointStorage.getCheckpoints(prevAgentId)
|
||||
val latestCheckpoint = checkpoints
|
||||
.filter { it.nodePath != "tombstone" }
|
||||
.maxByOrNull { it.createdAt }
|
||||
latestCheckpoint
|
||||
}.onFailure { ex ->
|
||||
logger.error(
|
||||
"Agent checkpoints: failed to load for session=$sessionId " +
|
||||
"agentId=$prevAgentId error=${ex.message}", ex
|
||||
)
|
||||
}.getOrNull()
|
||||
}
|
||||
val previousCheckpoint = runBlocking { getCheckpoint(sessionId) }
|
||||
|
||||
val agent = ProxyAIAgent.create(
|
||||
project,
|
||||
|
|
@ -79,7 +75,6 @@ class AgentService(private val project: Project) {
|
|||
pendingMessages
|
||||
)
|
||||
sessionAgents[sessionId] = agent
|
||||
project.service<AgentSessionState>().updateSession(sessionId, lastAgentId = agent.id)
|
||||
sessionJobs[sessionId] = CoroutineScope(Dispatchers.IO).launch {
|
||||
try {
|
||||
agent.run(message)
|
||||
|
|
@ -100,6 +95,13 @@ class AgentService(private val project: Project) {
|
|||
sessionJobs.remove(sessionId)
|
||||
}
|
||||
|
||||
fun removeSession(sessionId: String) {
|
||||
cancelCurrentRun(sessionId)
|
||||
pendingMessages.remove(sessionId)
|
||||
sessionAgents.remove(sessionId)
|
||||
sessionTokenTrackers.remove(sessionId)
|
||||
}
|
||||
|
||||
fun isSessionRunning(sessionId: String): Boolean {
|
||||
return sessionJobs[sessionId]?.isActive == true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,16 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentSessionState
|
||||
import ee.carlrobert.codegpt.ui.textarea.TagProcessorFactory
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails
|
||||
import ee.carlrobert.codegpt.util.GitUtil
|
||||
import ee.carlrobert.codegpt.util.ThinkingOutputParser
|
||||
import ee.carlrobert.codegpt.util.file.FileUtil
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlin.io.path.Path
|
||||
import ai.koog.prompt.message.Message as KoogMessage
|
||||
import ai.koog.prompt.message.Message as PromptMessage
|
||||
import ee.carlrobert.codegpt.conversations.message.Message as ChatMessage
|
||||
|
|
@ -157,14 +154,8 @@ class PromptEnhancer(private val project: Project) {
|
|||
|
||||
private suspend fun buildHistoryContext(sessionId: String?): String {
|
||||
if (sessionId.isNullOrBlank()) return ""
|
||||
val sessionState = project.service<AgentSessionState>()
|
||||
val agentId = sessionState.getLastAgentId(sessionId) ?: return ""
|
||||
val storage = JVMFilePersistenceStorageProvider(Path(project.basePath ?: "", ".proxyai"))
|
||||
val checkpoint = storage.getCheckpoints(agentId)
|
||||
.filter { it.nodePath != "tombstone" }
|
||||
.maxByOrNull { it.createdAt }
|
||||
?: return ""
|
||||
return formatHistory(checkpoint.messageHistory)
|
||||
val checkpoint = project.service<AgentService>().getCheckpoint(sessionId)
|
||||
return formatHistory(checkpoint?.messageHistory ?: emptyList())
|
||||
}
|
||||
|
||||
private fun formatHistory(messages: List<PromptMessage>): String {
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
|||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentSessionState
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalType
|
||||
|
|
@ -159,11 +158,7 @@ object ProxyAIAgent {
|
|||
checkpointId = ctx.context.runId,
|
||||
version = 0L
|
||||
)
|
||||
project.service<AgentSessionState>().updateSession(
|
||||
sessionId = sessionId,
|
||||
lastAgentId = ctx.context.agentId,
|
||||
checkpointId = checkpoint?.checkpointId
|
||||
)
|
||||
checkpoint?.checkpointId ?: return@onNodeExecutionCompleted
|
||||
|
||||
if (stream) return@onNodeExecutionCompleted
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import ai.koog.agents.core.feature.handler.agent.AgentCompletedContext
|
||||
import ai.koog.http.client.KoogHttpClientException
|
||||
import ai.koog.prompt.executor.clients.LLMClientException
|
||||
import com.intellij.openapi.Disposable
|
||||
|
|
@ -163,20 +162,6 @@ class AgentEventHandler(
|
|||
project.service<AgentToolWindowContentManager>().getTabbedPane()
|
||||
.onAgentCompleted(sessionId)
|
||||
}
|
||||
|
||||
val tabTitle = runCatching {
|
||||
project.service<AgentToolWindowContentManager>()
|
||||
.getTabbedPane()
|
||||
.tryFindTabTitle(sessionId)
|
||||
.orElse(null)
|
||||
}.getOrNull()
|
||||
val resolvedAgentId = project.service<AgentService>().getAgentForSession(sessionId)?.id
|
||||
?: agentId
|
||||
project.service<AgentSessionState>().updateSession(
|
||||
sessionId,
|
||||
lastAgentId = resolvedAgentId,
|
||||
displayName = tabTitle
|
||||
)
|
||||
handleDone()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,85 +0,0 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.agent
|
||||
|
||||
import com.intellij.openapi.components.*
|
||||
import com.intellij.util.xmlb.XmlSerializerUtil
|
||||
|
||||
@State(
|
||||
name = "ProxyAI_AgentSessionState",
|
||||
storages = [Storage(StoragePathMacros.WORKSPACE_FILE)]
|
||||
)
|
||||
@Service(Service.Level.PROJECT)
|
||||
class AgentSessionState : PersistentStateComponent<AgentSessionState.State> {
|
||||
|
||||
data class SessionEntry(
|
||||
var sessionId: String = "",
|
||||
var lastAgentId: String? = null,
|
||||
var checkpointId: String? = null,
|
||||
var displayName: String = "",
|
||||
)
|
||||
|
||||
data class State(
|
||||
var sessions: MutableList<SessionEntry> = mutableListOf(),
|
||||
var lastActiveSessionId: String? = null
|
||||
)
|
||||
|
||||
private var state = State()
|
||||
|
||||
override fun getState(): State = state
|
||||
|
||||
override fun loadState(state: State) {
|
||||
XmlSerializerUtil.copyBean(state, this.state)
|
||||
}
|
||||
|
||||
fun getSessionIds(): List<String> {
|
||||
return state.sessions.mapNotNull { it.sessionId.takeIf { id -> id.isNotBlank() } }
|
||||
}
|
||||
|
||||
fun getLastActiveSessionId(): String? = state.lastActiveSessionId
|
||||
|
||||
fun setLastActiveSessionId(sessionId: String) {
|
||||
state.lastActiveSessionId = sessionId
|
||||
}
|
||||
|
||||
fun getLastAgentId(sessionId: String): String? {
|
||||
return state.sessions.firstOrNull { it.sessionId == sessionId }?.lastAgentId
|
||||
}
|
||||
|
||||
fun getDisplayName(sessionId: String): String? {
|
||||
return state.sessions.firstOrNull { it.sessionId == sessionId }?.displayName
|
||||
}
|
||||
|
||||
fun ensureSession(sessionId: String): SessionEntry {
|
||||
val existing = state.sessions.firstOrNull { it.sessionId == sessionId }
|
||||
if (existing != null) return existing
|
||||
|
||||
val entry = SessionEntry(sessionId = sessionId)
|
||||
state.sessions.add(entry)
|
||||
return entry
|
||||
}
|
||||
|
||||
fun updateSession(sessionId: String, lastAgentId: String? = null, checkpointId: String? = null, displayName: String? = null) {
|
||||
val entry = ensureSession(sessionId)
|
||||
if (lastAgentId != null) {
|
||||
entry.lastAgentId = lastAgentId
|
||||
}
|
||||
if (displayName != null) {
|
||||
entry.displayName = displayName
|
||||
}
|
||||
if (checkpointId != null) {
|
||||
entry.checkpointId = checkpointId
|
||||
}
|
||||
}
|
||||
|
||||
fun removeSession(sessionId: String) {
|
||||
state.sessions.removeIf { it.sessionId == sessionId }
|
||||
if (state.lastActiveSessionId == sessionId) {
|
||||
state.lastActiveSessionId = null
|
||||
}
|
||||
}
|
||||
|
||||
fun replaceSession(oldSessionId: String, newSessionId: String) {
|
||||
removeSession(oldSessionId)
|
||||
ensureSession(newSessionId)
|
||||
state.lastActiveSessionId = newSessionId
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,9 @@ import com.intellij.openapi.components.Service
|
|||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.Disposer
|
||||
import ee.carlrobert.codegpt.agent.AgentService
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import java.util.UUID
|
||||
|
||||
@Service(Service.Level.PROJECT)
|
||||
class AgentToolWindowContentManager(private val project: Project) : Disposable {
|
||||
|
|
@ -12,7 +15,6 @@ class AgentToolWindowContentManager(private val project: Project) : Disposable {
|
|||
private val activeSessions = mutableMapOf<String, AgentSession>()
|
||||
private val tabPanels = mutableMapOf<String, AgentToolWindowTabPanel>()
|
||||
private val autoApprovedSessions = mutableSetOf<String>()
|
||||
private val sessionState = project.service<AgentSessionState>()
|
||||
private val tabbedPane = AgentToolWindowTabbedPane(project)
|
||||
|
||||
fun initializeTabbedPane(): AgentToolWindowTabbedPane {
|
||||
|
|
@ -28,22 +30,12 @@ class AgentToolWindowContentManager(private val project: Project) : Disposable {
|
|||
return tabbedPane
|
||||
}
|
||||
|
||||
fun createNewAgentTab(
|
||||
sessionId: String? = null,
|
||||
select: Boolean = true
|
||||
): AgentToolWindowTabPanel {
|
||||
val tabPanel = if (sessionId != null) {
|
||||
AgentToolWindowTabPanel(project, sessionId)
|
||||
} else {
|
||||
AgentToolWindowTabPanel(project)
|
||||
}
|
||||
val resolvedSessionId = tabPanel.getSessionId()
|
||||
val session = AgentSession(resolvedSessionId, tabPanel.getConversation())
|
||||
val sessionEntry = sessionState.ensureSession(resolvedSessionId)
|
||||
if (sessionEntry.displayName.isNotBlank()) {
|
||||
tabPanel.getAgentSession().displayName = sessionEntry.displayName
|
||||
session.displayName = sessionEntry.displayName
|
||||
}
|
||||
fun createNewAgentTab(select: Boolean = true): AgentToolWindowTabPanel {
|
||||
return createNewAgentTab(AgentSession(UUID.randomUUID().toString(), Conversation()), select)
|
||||
}
|
||||
|
||||
fun createNewAgentTab(session: AgentSession, select: Boolean = true): AgentToolWindowTabPanel {
|
||||
val tabPanel = AgentToolWindowTabPanel(project, session)
|
||||
activeSessions[session.sessionId] = session
|
||||
tabPanels[session.sessionId] = tabPanel
|
||||
tabbedPane.addNewTab(tabPanel, select)
|
||||
|
|
@ -82,6 +74,15 @@ class AgentToolWindowContentManager(private val project: Project) : Disposable {
|
|||
return autoApprovedSessions.contains(sessionId)
|
||||
}
|
||||
|
||||
fun removeSession(sessionId: String) {
|
||||
activeSessions.remove(sessionId)
|
||||
tabPanels.remove(sessionId)
|
||||
autoApprovedSessions.remove(sessionId)
|
||||
project.service<AgentService>().removeSession(sessionId)
|
||||
}
|
||||
|
||||
fun getSession(sessionId: String): AgentSession? = activeSessions[sessionId]
|
||||
|
||||
companion object {
|
||||
fun getInstance(project: Project): AgentToolWindowContentManager {
|
||||
return project.service()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
|||
import ee.carlrobert.codegpt.util.EditorUtil
|
||||
import ee.carlrobert.codegpt.util.coroutines.CoroutineDispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import java.util.*
|
||||
import javax.swing.Box
|
||||
import javax.swing.BoxLayout
|
||||
import javax.swing.JComponent
|
||||
|
|
@ -50,15 +49,14 @@ import javax.swing.JPanel
|
|||
|
||||
class AgentToolWindowTabPanel(
|
||||
private val project: Project,
|
||||
private val sessionId: String = UUID.randomUUID().toString()
|
||||
private val agentSession: AgentSession
|
||||
) : BorderLayoutPanel(), Disposable {
|
||||
|
||||
private val scrollablePanel = ChatToolWindowScrollablePanel()
|
||||
private val tagManager = TagManager()
|
||||
private val dispatchers = CoroutineDispatchers()
|
||||
|
||||
private val conversation = Conversation()
|
||||
private val agentSession = AgentSession(sessionId, conversation)
|
||||
private val sessionId = agentSession.sessionId
|
||||
private val conversation = agentSession.conversation
|
||||
private val psiRepository = PsiStructureRepository(
|
||||
this,
|
||||
project,
|
||||
|
|
@ -221,7 +219,8 @@ class AgentToolWindowTabPanel(
|
|||
|
||||
private fun handleSubmit(text: String) {
|
||||
if (text.isBlank()) return
|
||||
agentSession.serviceType = ModelSelectionService.getInstance().getServiceForFeature(FeatureType.AGENT)
|
||||
agentSession.serviceType =
|
||||
ModelSelectionService.getInstance().getServiceForFeature(FeatureType.AGENT)
|
||||
|
||||
val agentService = project.service<AgentService>()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import com.intellij.ui.components.JBLabel
|
|||
import com.intellij.ui.components.JBTabbedPane
|
||||
import com.intellij.util.ui.JBUI
|
||||
import ee.carlrobert.codegpt.agent.AgentService
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import java.awt.*
|
||||
import java.awt.event.ActionEvent
|
||||
import java.awt.event.ActionListener
|
||||
|
|
@ -146,7 +147,6 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
val title = getTitle(toolWindowPanel, nextIndex)
|
||||
val sessionId = toolWindowPanel.getSessionId()
|
||||
toolWindowPanel.getAgentSession().displayName = title
|
||||
project.service<AgentSessionState>().updateSession(sessionId, displayName = title)
|
||||
|
||||
super.insertTab(title, null, toolWindowPanel, null, nextIndex)
|
||||
activeTabMapping[title] = toolWindowPanel
|
||||
|
|
@ -219,7 +219,6 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
activeTabMapping[uniqueName] = panel
|
||||
|
||||
panel.getAgentSession().displayName = uniqueName
|
||||
project.service<AgentSessionState>().updateSession(sessionId, displayName = uniqueName)
|
||||
|
||||
applyIconForSession(sessionId)
|
||||
}
|
||||
|
|
@ -252,7 +251,6 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
selectedState.unseen = false
|
||||
applyIconForSession(selectedSessionId)
|
||||
}
|
||||
project.service<AgentSessionState>().setLastActiveSessionId(selectedSessionId)
|
||||
}
|
||||
|
||||
for (i in 0 until tabCount) {
|
||||
|
|
@ -312,15 +310,19 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
fun resetCurrentlyActiveTabPanel() {
|
||||
tryFindActiveTabPanel().ifPresent { tabPanel ->
|
||||
val oldSessionId = tabPanel.getSessionId()
|
||||
val oldDisplayName = tabPanel.getAgentSession().displayName
|
||||
Disposer.dispose(tabPanel)
|
||||
activeTabMapping.remove(getTitleAt(selectedIndex))
|
||||
removeTabAt(selectedIndex)
|
||||
sessionStates.remove(oldSessionId)
|
||||
|
||||
val newTabPanel = AgentToolWindowTabPanel(project)
|
||||
project.service<AgentSessionState>()
|
||||
.replaceSession(oldSessionId, newTabPanel.getSessionId())
|
||||
addNewTab(newTabPanel)
|
||||
project.service<AgentToolWindowContentManager>().removeSession(oldSessionId)
|
||||
val newSession = AgentSession(
|
||||
UUID.randomUUID().toString(),
|
||||
Conversation(),
|
||||
displayName = oldDisplayName
|
||||
)
|
||||
project.service<AgentToolWindowContentManager>().createNewAgentTab(newSession)
|
||||
repaint()
|
||||
revalidate()
|
||||
}
|
||||
|
|
@ -375,7 +377,7 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
if (tabIndex >= 0) {
|
||||
activeTabMapping[title]?.let { panel ->
|
||||
sessionStates.remove(panel.getSessionId())
|
||||
project.service<AgentSessionState>().removeSession(panel.getSessionId())
|
||||
project.service<AgentToolWindowContentManager>().removeSession(panel.getSessionId())
|
||||
Disposer.dispose(panel)
|
||||
}
|
||||
removeTabAt(tabIndex)
|
||||
|
|
@ -399,7 +401,7 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
val title = getTitleAt(selectedPopupTabIndex)
|
||||
activeTabMapping[title]?.let { panel ->
|
||||
sessionStates.remove(panel.getSessionId())
|
||||
project.service<AgentSessionState>().removeSession(panel.getSessionId())
|
||||
project.service<AgentToolWindowContentManager>().removeSession(panel.getSessionId())
|
||||
Disposer.dispose(panel)
|
||||
}
|
||||
removeTabAt(selectedPopupTabIndex)
|
||||
|
|
@ -416,7 +418,7 @@ class AgentToolWindowTabbedPane(private val project: Project) : JBTabbedPane(),
|
|||
activeTabMapping.values
|
||||
.map { it.getSessionId() }
|
||||
.filter { it != keepSessionId }
|
||||
.forEach { project.service<AgentSessionState>().removeSession(it) }
|
||||
.forEach { project.service<AgentToolWindowContentManager>().removeSession(it) }
|
||||
|
||||
clearAll()
|
||||
tabPanel?.let { addNewTab(it) }
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue