From 0f2216207be2bb5bd8eeeec0ad4e259228304cf1 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 15:08:14 -0300 Subject: [PATCH] enable multiple chat sessions --- migrations/3.surrealql | 14 +++- migrations/3_down.surrealql | 5 ++ open_notebook/domain/notebook.py | 121 +++++++++++++++++-------------- pages/2_📒_Notebooks.py | 83 ++++++++++++--------- pages/stream_app/chat.py | 89 ++++++++++++++++++----- pages/stream_app/note.py | 26 +++---- pages/stream_app/source.py | 9 ++- pages/stream_app/utils.py | 62 ++++++++++++++-- 8 files changed, 276 insertions(+), 133 deletions(-) diff --git a/migrations/3.surrealql b/migrations/3.surrealql index 73b79a7..f2a067f 100644 --- a/migrations/3.surrealql +++ b/migrations/3.surrealql @@ -1,4 +1,11 @@ -REMOVE FUNCTION fn::vector_search; + +DEFINE TABLE IF NOT EXISTS chat_session SCHEMALESS; + +DEFINE TABLE IF NOT EXISTS refers_to +TYPE RELATION +FROM chat_session TO notebook; + +REMOVE FUNCTION IF EXISTS fn::vector_search; DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) { let $source_embedding_search = @@ -16,7 +23,6 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_cou )} ELSE { [] }; - -- Busca em source_insight com threshold let $source_insight_search = IF $sources {( SELECT @@ -67,10 +73,10 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_cou }; -REMOVE FUNCTION fn::text_search; +REMOVE FUNCTION IF EXISTS fn::text_search; - DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) { +DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) { let $source_title_search = IF $sources {( diff --git a/migrations/3_down.surrealql b/migrations/3_down.surrealql index aaab4d9..b8438e0 100644 --- a/migrations/3_down.surrealql +++ b/migrations/3_down.surrealql @@ -1,3 +1,8 @@ +REMOVE TABLE IF EXISTS chat_session; + +REMOVE TABLE IF EXISTS refers_to; + + REMOVE FUNCTION fn::vector_search; diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index ee2035e..86944f8 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -1,11 +1,9 @@ -import os from typing import Any, ClassVar, Dict, List, Literal, Optional from loguru import logger from pydantic import BaseModel, Field, field_validator from open_notebook.database.repository import ( - repo_create, repo_query, ) from open_notebook.domain.base import ObjectModel @@ -68,6 +66,27 @@ class Notebook(ObjectModel): logger.exception(e) raise DatabaseOperationError(e) + @property + def chat_sessions(self) -> List["ChatSession"]: + try: + srcs = repo_query(f""" + select * from ( + select + <- chat_session as chat_session + from refers_to + where out={self.id} + fetch chat_session + ) + order by chat_session.updated desc + """) + return ( + [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else [] + ) + except Exception as e: + logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}") + logger.exception(e) + raise DatabaseOperationError(e) + class Asset(BaseModel): file_path: Optional[str] = None @@ -99,6 +118,22 @@ class Source(ObjectModel): else: return dict(id=self.id, title=self.title, insights=self.insights) + @property + def embedded_chunks(self) -> int: + try: + result = repo_query( + f""" + select count() as chunks from source_embedding where source={self.id} GROUP ALL + """ + ) + if len(result) == 0: + return 0 + return result[0]["chunks"] + except Exception as e: + logger.error(f"Error fetching insights for source {self.id}: {str(e)}") + logger.exception(e) + raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}") + @property def insights(self) -> List[SourceInsight]: try: @@ -118,24 +153,6 @@ class Source(ObjectModel): raise InvalidInputError("Notebook ID must be provided") return self.relate("reference", notebook_id) - def save_chunks(self, text: str) -> None: - if not text: - raise InvalidInputError("Text cannot be empty") - try: - chunks = split_text(text, chunk=500000, overlap=1000) - logger.debug(f"Split into {len(chunks)} chunks") - for i, chunk in enumerate(chunks): - logger.debug(f"Saving chunk {i}") - data = {"source": self.id, "order": i, "content": surreal_clean(chunk)} - repo_create( - "source_chunk", - data, - ) - except Exception as e: - logger.exception(e) - logger.error(f"Error saving chunks for source {self.id}: {str(e)}") - raise DatabaseOperationError(e) - def vectorize(self) -> None: EMBEDDING_MODEL = model_manager.embedding_model @@ -144,8 +161,6 @@ class Source(ObjectModel): return chunks = split_text( self.full_text, - chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)), - overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)), ) logger.debug(f"Split into {len(chunks)} chunks") @@ -166,26 +181,26 @@ class Source(ObjectModel): logger.exception(e) raise DatabaseOperationError(e) - @classmethod - def search(cls, query: str) -> List[Dict[str, Any]]: - if not query: - raise InvalidInputError("Search query cannot be empty") - try: - result = repo_query( - """ - SELECT * omit full_text - FROM source - WHERE string::lowercase(title) CONTAINS $query or title @@ $query - OR string::lowercase(summary) CONTAINS $query or summary @@ $query - OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query - """, - {"query": query}, - ) - return result - except Exception as e: - logger.error(f"Error searching sources: {str(e)}") - logger.exception(e) - raise DatabaseOperationError("Failed to search sources") + # @classmethod + # def search(cls, query: str) -> List[Dict[str, Any]]: + # if not query: + # raise InvalidInputError("Search query cannot be empty") + # try: + # result = repo_query( + # """ + # SELECT * omit full_text + # FROM source + # WHERE string::lowercase(title) CONTAINS $query or title @@ $query + # OR string::lowercase(summary) CONTAINS $query or summary @@ $query + # OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query + # """, + # {"query": query}, + # ) + # return result + # except Exception as e: + # logger.error(f"Error searching sources: {str(e)}") + # logger.exception(e) + # raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: EMBEDDING_MODEL = model_manager.embedding_model @@ -246,6 +261,16 @@ class Note(ObjectModel): return self.content +class ChatSession(ObjectModel): + table_name: ClassVar[str] = "chat_session" + title: Optional[str] = None + + def relate_to_notebook(self, notebook_id: str) -> Any: + if not notebook_id: + raise InvalidInputError("Notebook ID must be provided") + return self.relate("refers_to", notebook_id) + + def text_search(keyword: str, results: int, source: bool = True, note: bool = True): if not keyword: raise InvalidInputError("Search keyword cannot be empty") @@ -263,18 +288,6 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr raise DatabaseOperationError(e) -# def hybrid_search( -# keyword_search: List[str], -# embed_search: List[str], -# results: int = 50, -# source: bool = True, -# note: bool = True, -# ): -# EMBEDDING_MODEL = model_manager.embedding_model -# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None - - -# todo: mover o embedding pra ca def vector_search(keyword: str, results: int, source: bool = True, note: bool = True): if not keyword: raise InvalidInputError("Search keyword cannot be empty") diff --git a/pages/2_📒_Notebooks.py b/pages/2_📒_Notebooks.py index f9eb191..f1e3354 100644 --- a/pages/2_📒_Notebooks.py +++ b/pages/2_📒_Notebooks.py @@ -10,11 +10,14 @@ from pages.stream_app.utils import setup_page, setup_stream_state setup_page("📒 Open Notebook") -def notebook_header(current_notebook): +def notebook_header(current_notebook: Notebook): + """ + Defines the header of the notebook page, including the ability to edit the notebook's name and description. + """ c1, c2, c3 = st.columns([8, 2, 2]) c1.header(current_notebook.name) if c2.button("Back to the list", icon="🔙"): - st.session_state["current_notebook"] = None + st.session_state["current_notebook_id"] = None st.rerun() if c3.button("Refresh", icon="🔄"): @@ -49,20 +52,20 @@ def notebook_header(current_notebook): st.toast("Notebook unarchived", icon="🗃️") if c3.button("Delete forever", type="primary", icon="☠️"): current_notebook.delete() - st.session_state["current_notebook"] = None + st.session_state["current_notebook_id"] = None st.rerun() -def notebook_page(current_notebook_id): - current_notebook: Notebook = Notebook.get(current_notebook_id) - if not current_notebook: - st.error("Notebook not found") - return - if current_notebook_id not in st.session_state.keys(): - st.session_state[current_notebook_id] = current_notebook +def notebook_page(current_notebook: Notebook): + # Guarantees that we have an entry for this notebook in the session state + if current_notebook.id not in st.session_state: + st.session_state[current_notebook.id] = {"notebook": current_notebook} + + # sets up the active session + current_session = setup_stream_state( + current_notebook=current_notebook, + ) - session_id = st.session_state["active_session"] - st.session_state[session_id]["notebook"] = current_notebook sources = current_notebook.sources notes = current_notebook.notes @@ -74,18 +77,18 @@ def notebook_page(current_notebook_id): with sources_tab: with st.container(border=True): if st.button("Add Source", icon="➕"): - add_source(session_id) + add_source(current_notebook.id) for source in sources: - source_card(session_id=session_id, source=source) + source_card(source=source, notebook_id=current_notebook.id) with notes_tab: with st.container(border=True): if st.button("Write a Note", icon="📝"): - add_note(session_id) + add_note(current_notebook.id) for note in notes: - note_card(session_id=session_id, note=note) + note_card(note=note, notebook_id=current_notebook.id) with chat_tab: - chat_sidebar(session_id=session_id) + chat_sidebar(current_notebook=current_notebook, current_session=current_session) def notebook_list_item(notebook): @@ -96,40 +99,50 @@ def notebook_list_item(notebook): ) st.write(notebook.description) if st.button("Open", key=f"open_notebook_{notebook.id}"): - setup_stream_state(notebook.id) - st.session_state["current_notebook"] = notebook.id + st.session_state["current_notebook_id"] = notebook.id st.rerun() -if "current_notebook" not in st.session_state: - st.session_state["current_notebook"] = None +if "current_notebook_id" not in st.session_state: + st.session_state["current_notebook_id"] = None -if st.session_state["current_notebook"]: - notebook_page(st.session_state["current_notebook"]) +# todo: get the notebook, check if it exists and if it's archived +if st.session_state["current_notebook_id"]: + current_notebook: Notebook = Notebook.get(st.session_state["current_notebook_id"]) + if not current_notebook: + st.error("Notebook not found") + st.stop() + notebook_page(current_notebook) st.stop() st.title("📒 My Notebooks") -st.caption("Here are all your notebooks") +st.caption( + "Notebooks are a great way to organize your thoughts, ideas, and sources. You can create notebooks for different research topics and projects, to create new articles, etc. " +) + +with st.expander("➕ **New Notebook**"): + new_notebook_title = st.text_input("New Notebook Name") + new_notebook_description = st.text_area( + "Description", + placeholder="Explain the purpose of this notebook. The more details the better.", + ) + if st.button("Create a new Notebook", icon="➕"): + notebook = Notebook( + name=new_notebook_title, description=new_notebook_description + ) + notebook.save() + st.toast("Notebook created successfully", icon="📒") notebooks = Notebook.get_all(order_by="updated desc") +archived_notebooks = [nb for nb in notebooks if nb.archived] for notebook in notebooks: if notebook.archived: continue notebook_list_item(notebook) -with st.expander("➕ **New Notebook**"): - new_notebook_title = st.text_input("New Notebook Name") - new_notebook_description = st.text_area("Description") - if st.button("Create a new Notebook", icon="➕"): - notebook = Notebook( - name=new_notebook_title, description=new_notebook_description - ) - notebook.save() - st.rerun() - -archived_notebooks = [nb for nb in notebooks if nb.archived] if len(archived_notebooks) > 0: with st.expander(f"**🗃️ {len(archived_notebooks)} archived Notebooks**"): + st.write("ℹ Archived Notebooks can still be accessed and used in search.") for notebook in archived_notebooks: notebook_list_item(notebook) diff --git a/pages/stream_app/chat.py b/pages/stream_app/chat.py index c3c2426..cb33852 100644 --- a/pages/stream_app/chat.py +++ b/pages/stream_app/chat.py @@ -1,19 +1,21 @@ +import humanize import streamlit as st from langchain_core.runnables import RunnableConfig -from open_notebook.domain.notebook import Note, Source +from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source from open_notebook.graphs.chat import graph as chat_graph from open_notebook.plugins.podcasts import PodcastConfig from open_notebook.utils import token_count +from pages.stream_app.utils import create_session_for_notebook from .note import make_note_from_chat # todo: build a smarter, more robust context manager function -def build_context(session_id): - st.session_state[session_id]["context"] = dict(note=[], source=[]) +def build_context(notebook_id): + st.session_state[notebook_id]["context"] = dict(note=[], source=[]) - for id, status in st.session_state[session_id]["context_config"].items(): + for id, status in st.session_state[notebook_id]["context_config"].items(): if not id: continue @@ -24,6 +26,7 @@ def build_context(session_id): if "not in" in status: continue + # todo: there is problably a better way to handle this if item_type == "note": item: Note = Note.get(id) elif item_type == "source": @@ -34,30 +37,33 @@ def build_context(session_id): if not item: continue if "summary" in status: - st.session_state[session_id]["context"][item_type] += [ + st.session_state[notebook_id]["context"][item_type] += [ item.get_context(context_size="short") ] elif "full content" in status: - st.session_state[session_id]["context"][item_type] += [ + st.session_state[notebook_id]["context"][item_type] += [ item.get_context(context_size="long") ] - return st.session_state[session_id]["context"] + return st.session_state[notebook_id]["context"] -def execute_chat(txt_input, session_id): - current_state = st.session_state[session_id] +def execute_chat(txt_input, current_session): + current_state = st.session_state[current_session.id] current_state["messages"] += [txt_input] result = chat_graph.invoke( input=current_state, - config=RunnableConfig(configurable={"thread_id": session_id}), + config=RunnableConfig(configurable={"thread_id": current_session.id}), ) + current_session.save() return result -def chat_sidebar(session_id): - context = build_context(session_id=session_id) - tokens = token_count(str(context) + str(st.session_state[session_id]["messages"])) +def chat_sidebar(current_notebook: Notebook, current_session: ChatSession): + context = build_context(notebook_id=current_notebook.id) + tokens = token_count( + str(context) + str(st.session_state[current_session.id]["messages"]) + ) chat_tab, podcast_tab = st.tabs(["Chat", "Podcast"]) with st.expander(f"Context ({tokens} tokens), {len(str(context))} chars"): st.json(context) @@ -91,15 +97,64 @@ def chat_sidebar(session_id): st.success("Episode generated successfully") st.page_link("pages/5_🎙️_Podcasts.py", label="🎙️ Go to Podcasts") with chat_tab: + with st.expander( + f"**Session:** {current_session.title} - {humanize.naturaltime(current_session.updated)}" + ): + new_session_name = st.text_input( + "Current Session", + key="new_session_name", + value=current_session.title, + ) + c1, c2 = st.columns(2) + if c1.button("Rename", key="rename_session"): + current_session.title = new_session_name + current_session.save() + st.rerun() + if c2.button("Delete", key="delete_session_1"): + current_session.delete() + st.session_state[current_notebook.id]["active_session"] = None + st.rerun() + st.divider() + new_session_name = st.text_input( + "New Session Name", + key="new_session_name_f", + placeholder="Enter a name for the new session...", + ) + st.caption("If no name provided, we'll use the current date.") + if st.button("Create New Session", key="create_new_session"): + new_session = create_session_for_notebook( + notebook_id=current_notebook.id, session_name=new_session_name + ) + st.session_state[current_notebook.id]["active_session"] = new_session.id + st.rerun() + st.divider() + sessions = current_notebook.chat_sessions + if len(sessions) > 1: + st.markdown("**Other Sessions:**") + for session in sessions: + if session.id == current_session.id: + continue + + st.markdown( + f"{session.title} - {humanize.naturaltime(session.updated)}" + ) + if st.button(label="Load", key=f"load_session_{session.id}"): + st.session_state[current_notebook.id]["active_session"] = ( + session.id + ) + st.rerun() with st.container(border=True): request = st.chat_input("Enter your question") # removing for now since it's not multi-model capable right now st.caption(f"Total tokens: {tokens}") if request: - response = execute_chat(txt_input=request, session_id=session_id) - st.session_state[session_id]["messages"] = response["messages"] + response = execute_chat( + txt_input=request, + current_session=current_session, + ) + st.session_state[current_session.id]["messages"] = response["messages"] - for msg in st.session_state[session_id]["messages"][::-1]: + for msg in st.session_state[current_session.id]["messages"][::-1]: if msg.type not in ["human", "ai"]: continue if not msg.content: @@ -111,6 +166,6 @@ def chat_sidebar(session_id): if st.button("💾 New Note", key=f"render_save_{msg.id}"): make_note_from_chat( content=msg.content, - notebook_id=st.session_state[session_id]["notebook"].id, + notebook_id=current_notebook.id, ) st.rerun() diff --git a/pages/stream_app/note.py b/pages/stream_app/note.py index f22e29c..0516a1c 100644 --- a/pages/stream_app/note.py +++ b/pages/stream_app/note.py @@ -1,3 +1,5 @@ +from typing import Optional + import streamlit as st from humanize import naturaltime from loguru import logger @@ -11,22 +13,20 @@ from .consts import context_icons @st.dialog("Write a Note", width="large") -def add_note(session_id): +def add_note(notebook_id): note_title = st.text_input("Title") note_content = st.text_area("Content") if st.button("Save", key="add_note"): logger.debug("Adding note") note = Note(title=note_title, content=note_content, note_type="human") note.save() - note.add_to_notebook(st.session_state[session_id]["notebook"].id) + note.add_to_notebook(notebook_id) st.rerun() @st.dialog("Add a Source", width="large") -def note_panel(session_id=None, note_id=None): - if note_id: - note: Note = Note.get(note_id) - else: +def note_panel(notebook_id=None, note: Optional[Note] = None): + if not note: note: Note = Note(note_type="human") t_preview, t_edit = st.tabs(["Preview", "Edit"]) @@ -38,13 +38,13 @@ def note_panel(session_id=None, note_id=None): note.content = st_monaco( value=note.content, height="600px", language="markdown" ) - if st.button("Save", key=f"pn_edit_note_{note_id}"): + if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"): logger.debug("Editing note") note.save() if not note.id: - note.add_to_notebook(st.session_state[session_id]["notebook"].id) + note.add_to_notebook(notebook_id) st.rerun() - if st.button("Delete", type="primary", key=f"delete_note_{note_id}"): + if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"): logger.debug("Deleting note") note.delete() st.rerun() @@ -70,7 +70,7 @@ def make_note_from_chat(content, notebook_id=None): st.rerun() -def note_card(session_id, note): +def note_card(note, notebook_id): if note.note_type == "human": icon = "🤵" else: @@ -88,9 +88,9 @@ def note_card(session_id, note): st.caption(f"Updated: {naturaltime(note.updated)}") if st.button("Expand", icon="📝", key=f"edit_note_{note.id}"): - note_panel(session_id, note.id) + note_panel(notebook_id=notebook_id, note=note) - st.session_state[session_id]["context_config"][note.id] = context_state + st.session_state[notebook_id]["context_config"][note.id] = context_state def note_list_item(note_id, score=None): @@ -105,4 +105,4 @@ def note_list_item(note_id, score=None): ): st.write(note.content) if st.button("Edit Note", icon="📝", key=f"x_edit_note_{note.id}"): - note_panel(note_id=note.id) + note_panel(note=note) diff --git a/pages/stream_app/source.py b/pages/stream_app/source.py index 25f880b..c249b36 100644 --- a/pages/stream_app/source.py +++ b/pages/stream_app/source.py @@ -95,6 +95,7 @@ def source_panel(source_id): if st.button( "Embed vectors", icon="🦾", + disabled=source.embedded_chunks > 0, help="This will generate your embedding vectors on the database for powerful search capabilities", ): source.vectorize() @@ -119,7 +120,7 @@ def source_panel(source_id): @st.dialog("Add a Source", width="large") -def add_source(session_id): +def add_source(notebook_id): source_link = None source_file = None source_text = None @@ -167,7 +168,7 @@ def add_source(session_id): title=result.get("title"), ) source.save() - source.add_to_notebook(st.session_state[session_id]["notebook"].id) + source.add_to_notebook(notebook_id) st.write("Summarizing...") generate_toc_and_title(source) except UnsupportedTypeException as e: @@ -188,7 +189,7 @@ def add_source(session_id): st.rerun() -def source_card(session_id, source): +def source_card(source, notebook_id): # todo: more descriptive icons icon = "🔗" @@ -208,7 +209,7 @@ def source_card(session_id, source): if st.button("Expand", icon="📝", key=source.id): source_panel(source.id) - st.session_state[session_id]["context_config"][source.id] = context_state + st.session_state[notebook_id]["context_config"][source.id] = context_state def source_list_item(source_id, score=None): diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py index 88f8849..b790db0 100644 --- a/pages/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -1,8 +1,12 @@ +from datetime import datetime +from typing import List, Union + import streamlit as st from loguru import logger from open_notebook.database.migrate import MigrationManager from open_notebook.domain.models import model_manager +from open_notebook.domain.notebook import ChatSession, Notebook from open_notebook.graphs.chat import ThreadState, graph from open_notebook.utils import ( compare_versions, @@ -33,19 +37,65 @@ def version_sidebar(): ) -def setup_stream_state(session_id) -> None: +def create_session_for_notebook(notebook_id: str, session_name: str = None): + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + title = f"Chat Session {current_time}" if not session_name else session_name + chat_session = ChatSession(title=title) + chat_session.save() + chat_session.relate_to_notebook(notebook_id) + return chat_session + + +def setup_stream_state(current_notebook: Notebook) -> ChatSession: """ Sets the value of the current session_id for langgraph thread state. If there is no existing thread state for this session_id, it creates a new one. + Finally, it acquires the existing state for the session from Langgraph state and sets it in the streamlit session state. """ - existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values - if len(existing_state.keys()) == 0: - st.session_state[session_id] = ThreadState( + assert ( + current_notebook is not None and current_notebook.id + ), "Current Notebook not selected properly" + + if "context_config" not in st.session_state[current_notebook.id]: + st.session_state[current_notebook.id]["context_config"] = {} + + current_session_id = st.session_state[current_notebook.id].get("active_session") + + # gets the chat session if provided + chat_session: Union[ChatSession, None] = ( + ChatSession.get(current_session_id) if current_session_id else None + ) + + # if there is no chat session, create one or get the first one + if not chat_session: + sessions: List[ChatSession] = current_notebook.chat_sessions + if not sessions or len(sessions) == 0: + logger.debug("Creating new chat session") + chat_session = create_session_for_notebook(current_notebook.id) + else: + logger.debug("Getting last updated session") + chat_session = sessions[0] + + logger.debug(f"Chat session: {chat_session}") + + if not chat_session or chat_session.id is None: + raise ValueError("Problem acquiring chat session") + # sets the active session for the notebook + st.session_state[current_notebook.id]["active_session"] = chat_session.id + + # gets the existing state for the session from Langgraph state + existing_state = graph.get_state( + {"configurable": {"thread_id": chat_session.id}} + ).values + if not existing_state or len(existing_state.keys()) == 0: + st.session_state[chat_session.id] = ThreadState( messages=[], context=None, notebook=None, context_config={} ) else: - st.session_state[session_id] = existing_state - st.session_state["active_session"] = session_id + st.session_state[chat_session.id] = existing_state + + st.session_state[current_notebook.id]["active_session"] = chat_session.id + return chat_session def check_migration():