diff --git a/open_notebook/database/repository.py b/open_notebook/database/repository.py index f514c74..d90ac9d 100644 --- a/open_notebook/database/repository.py +++ b/open_notebook/database/repository.py @@ -40,6 +40,11 @@ def repo_create(table: str, data: Dict[str, Any]): return repo_query(query) +def repo_upsert(table: str, data: Dict[str, Any]): + query = f"UPSERT {table} CONTENT {data};" + return repo_query(query) + + def repo_update(id: str, data: Dict[str, Any]): query = "UPDATE $id CONTENT $data;" vars = {"id": id, "data": data} diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index ce0e496..d7f739d 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -1,8 +1,22 @@ from datetime import datetime -from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, cast +from typing import ( + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + cast, +) from loguru import logger -from pydantic import BaseModel, ValidationError, field_validator +from pydantic import ( + BaseModel, + ValidationError, + field_validator, + model_validator, +) from open_notebook.database.repository import ( repo_create, @@ -10,6 +24,7 @@ from open_notebook.database.repository import ( repo_query, repo_relate, repo_update, + repo_upsert, ) from open_notebook.exceptions import ( DatabaseOperationError, @@ -204,24 +219,92 @@ class ObjectModel(BaseModel): class RecordModel(BaseModel): record_id: ClassVar[str] + auto_save: ClassVar[bool] = ( + False # Default to False, can be overridden in subclasses + ) + _instances: ClassVar[Dict[str, "RecordModel"]] = {} # Store instances by record_id + + class Config: + validate_assignment = True + arbitrary_types_allowed = True + extra = "allow" + from_attributes = True + defer_build = True + + def __new__(cls, **kwargs): + # If an instance already exists for this record_id, return it + if cls.record_id in cls._instances: + instance = cls._instances[cls.record_id] + # Update instance with any new kwargs if provided + if kwargs: + for key, value in kwargs.items(): + setattr(instance, key, value) + return instance + + # If no instance exists, create a new one + instance = super().__new__(cls) + cls._instances[cls.record_id] = instance + return instance def __init__(self, **kwargs): - super().__init__(**kwargs) - self.load() + # Only initialize if this is a new instance + if not hasattr(self, "_initialized"): + object.__setattr__(self, "__dict__", {}) + # Load data from DB first + result = repo_query(f"SELECT * FROM {self.record_id};") + if result: + db_data = result[0] + else: + # Initialize empty object with None for Optional fields + db_data = { + field_name: None + for field_name, field_info in self.model_fields.items() + if not str(field_info.annotation).startswith("typing.ClassVar") + } + + # Initialize with DB data and any overrides + super().__init__(**{**db_data, **kwargs}) + object.__setattr__(self, "_initialized", True) + + @classmethod + def get_instance(cls) -> "RecordModel": + """Get or create the singleton instance""" + return cls() + + @model_validator(mode="after") + def auto_save_validator(self): + if self.__class__.auto_save: + self.update() + return self + + def update(self): + # Get all non-ClassVar fields and their values + data = { + field_name: getattr(self, field_name) + for field_name, field_info in self.model_fields.items() + if not str(field_info.annotation).startswith("typing.ClassVar") + } + + repo_upsert(self.record_id, data) - def load(self): result = repo_query(f"SELECT * FROM {self.record_id};") if result: - result = result[0] - else: - repo_create(self.record_id, {}) - result = {} - for key, value in result.items(): - if hasattr(self, key): - setattr(self, key, value) + for key, value in result[0].items(): + if hasattr(self, key): + object.__setattr__( + self, key, value + ) # Use object.__setattr__ to avoid triggering validation again return self - def update(self, data): - repo_update(self.record_id, data) - return self.load() + @classmethod + def clear_instance(cls): + """Clear the singleton instance (useful for testing)""" + if cls.record_id in cls._instances: + del cls._instances[cls.record_id] + + def patch(self, model_dict: dict): + """Update model attributes from dictionary and save""" + for key, value in model_dict.items(): + setattr(self, key, value) + self.update() diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index 22b3028..918b5e7 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -28,15 +28,14 @@ class Model(ObjectModel): class DefaultModels(RecordModel): record_id: ClassVar[str] = "open_notebook:default_models" - - default_chat_model: Optional[str] = None - default_transformation_model: Optional[str] = None - large_context_model: Optional[str] = None - default_text_to_speech_model: Optional[str] = None - default_speech_to_text_model: Optional[str] = None - # default_vision_model: Optional[str] = None - default_embedding_model: Optional[str] = None - default_tools_model: Optional[str] = None + default_chat_model: Optional[str] + default_transformation_model: Optional[str] + large_context_model: Optional[str] + default_text_to_speech_model: Optional[str] + default_speech_to_text_model: Optional[str] + # default_vision_model: Optional[str] + default_embedding_model: Optional[str] + default_tools_model: Optional[str] class ModelManager: @@ -54,7 +53,10 @@ class ModelManager: self._default_models = None self.refresh_defaults() - def get_model(self, model_id: str, **kwargs) -> ModelType: + def get_model(self, model_id: str, **kwargs) -> Optional[ModelType]: + if not model_id: + return None + cache_key = f"{model_id}:{str(kwargs)}" if cache_key in self._model_cache: @@ -68,9 +70,6 @@ class ModelManager: ) return cached_model - if not model_id: - return None - model: Model = Model.get(model_id) if not model: @@ -111,7 +110,10 @@ class ModelManager: @property def speech_to_text(self, **kwargs) -> Optional[SpeechToTextModel]: """Get the default speech-to-text model""" - model = self.get_default_model("speech_to_text", **kwargs) + model_id = self.defaults.default_speech_to_text_model + if not model_id: + return None + model = self.get_model(model_id, **kwargs) assert model is None or isinstance( model, SpeechToTextModel ), f"Expected SpeechToTextModel but got {type(model)}" @@ -120,7 +122,10 @@ class ModelManager: @property def text_to_speech(self, **kwargs) -> Optional[TextToSpeechModel]: """Get the default text-to-speech model""" - model = self.get_default_model("text_to_speech", **kwargs) + model_id = self.defaults.default_text_to_speech_model + if not model_id: + return None + model = self.get_model(model_id, **kwargs) assert model is None or isinstance( model, TextToSpeechModel ), f"Expected TextToSpeechModel but got {type(model)}" @@ -129,13 +134,16 @@ class ModelManager: @property def embedding_model(self, **kwargs) -> Optional[EmbeddingModel]: """Get the default embedding model""" - model = self.get_default_model("embedding", **kwargs) + model_id = self.defaults.default_embedding_model + if not model_id: + return None + model = self.get_model(model_id, **kwargs) assert model is None or isinstance( model, EmbeddingModel ), f"Expected EmbeddingModel but got {type(model)}" return model - def get_default_model(self, model_type: str, **kwargs) -> ModelType: + def get_default_model(self, model_type: str, **kwargs) -> Optional[ModelType]: """ Get the default model for a specific type. @@ -165,6 +173,9 @@ class ModelManager: elif model_type == "large_context": model_id = self.defaults.large_context_model + if not model_id: + return None + return self.get_model(model_id, **kwargs) def clear_cache(self): diff --git a/open_notebook/graphs/transformation.py b/open_notebook/graphs/transformation.py index f125c20..ab5c799 100644 --- a/open_notebook/graphs/transformation.py +++ b/open_notebook/graphs/transformation.py @@ -1,4 +1,3 @@ -from executing import Source from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import ( RunnableConfig, @@ -6,6 +5,7 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import TypedDict +from open_notebook.domain.notebook import Source from open_notebook.domain.transformation import DefaultPrompts, Transformation from open_notebook.graphs.utils import provision_langchain_model from open_notebook.prompter import Prompter @@ -26,7 +26,7 @@ def run_transformation(state: dict, config: RunnableConfig) -> dict: if not content: content = source.full_text transformation_prompt_text = transformation.prompt - default_prompts: DefaultPrompts = DefaultPrompts().load() + default_prompts: DefaultPrompts = DefaultPrompts() if default_prompts.transformation_instructions: transformation_prompt_text = f"{default_prompts.transformation_instructions}\n\n{transformation_prompt_text}" diff --git a/pages/3_🔍_Ask_and_Search.py b/pages/3_🔍_Ask_and_Search.py index 1eddbf2..452c9bf 100644 --- a/pages/3_🔍_Ask_and_Search.py +++ b/pages/3_🔍_Ask_and_Search.py @@ -54,7 +54,7 @@ with ask_tab: "The LLM will answer your query based on the documents in your knowledge base. " ) question = st.text_input("Question", "") - default_model = DefaultModels().load().default_chat_model + default_model = DefaultModels().default_chat_model strategy_model = model_selector( "Query Strategy Model", "strategy_model", diff --git a/pages/7_🤖_Models.py b/pages/7_🤖_Models.py index 28bb222..10abc85 100644 --- a/pages/7_🤖_Models.py +++ b/pages/7_🤖_Models.py @@ -89,7 +89,7 @@ def generate_new_models(models, suggested_models): return new_models -default_models = DefaultModels().model_dump() +default_models = DefaultModels() all_models = Model.get_all() with model_tab: @@ -176,82 +176,101 @@ with model_defaults_tab: "In this section, you can select the default models to be used on the various content operations done by Open Notebook. Some of these can be overriden in the different modules." ) defs = {} - defs["default_chat_model"] = model_selector( + # Handle chat model selection + selected_model = model_selector( "Default Chat Model", "default_chat_model", - selected_id=default_models.get("default_chat_model"), + selected_id=default_models.default_chat_model, help="This model will be used for chat.", model_type="language", ) + if selected_model: + default_models.default_chat_model = selected_model.id st.divider() - defs["default_transformation_model"] = model_selector( + # Handle transformation model selection + selected_model = model_selector( "Default Transformation Model", "default_transformation_model", - selected_id=default_models.get("default_transformation_model"), + selected_id=default_models.default_transformation_model, help="This model will be used for text transformations such as summaries, insights, etc.", model_type="language", ) + if selected_model: + default_models.default_transformation_model = selected_model.id st.caption("You can use a cheap model here like gpt-4o-mini, llama3, etc.") st.divider() - defs["default_tools_model"] = model_selector( + + # Handle tools model selection + selected_model = model_selector( "Default Tools Model", "default_tools_model", - selected_id=default_models.get("default_tools_model"), + selected_id=default_models.default_tools_model, help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.", model_type="language", ) + if selected_model: + default_models.default_tools_model = selected_model.id st.caption("Recommended to use a capable model here, like gpt-4o, claude, etc.") st.divider() - defs["large_context_model"] = model_selector( + + # Handle large context model selection + selected_model = model_selector( "Large Context Model", "large_context_model", - selected_id=default_models.get("large_context_model"), + selected_id=default_models.large_context_model, help="This model will be used for larger context generation -- recommended: Gemini", model_type="language", ) + if selected_model: + default_models.large_context_model = selected_model.id st.caption("Recommended to use Gemini models for larger context processing") st.divider() - defs["default_text_to_speech_model"] = model_selector( + + # Handle text-to-speech model selection + selected_model = model_selector( "Default Text to Speech Model", "default_text_to_speech_model", - selected_id=default_models.get("default_text_to_speech_model"), + selected_id=default_models.default_text_to_speech_model, help="This is the default model for converting text to speech (podcasts, etc)", model_type="text_to_speech", ) st.caption("You can override this model on different podcasts") + if selected_model: + default_models.default_text_to_speech_model = selected_model.id st.divider() - defs["default_speech_to_text_model"] = model_selector( + + # Handle speech-to-text model selection + selected_model = model_selector( "Default Speech to Text Model", - "default_speech_to_text_model", - selected_id=default_models.get("default_speech_to_text_model"), + selected_id=default_models.default_speech_to_text_model, help="This is the default model for converting speech to text (audio transcriptions, etc)", model_type="speech_to_text", + key="default_speech_to_text_model", ) - st.divider() - # defs["default_vision_model"] = ( - # model_selector( - # "Default Speech to Text Model", - # "default_vision_model", - # selected_id=default_models.get("default_vision_model"), - # help="This is the default model for vision tasks", - # model_type="vision", - # ), - # ) + if selected_model: + default_models.default_speech_to_text_model = selected_model.id - defs["default_embedding_model"] = model_selector( + st.divider() + # Handle embedding model selection + selected_model = model_selector( "Default Speech to Text Model", "default_embedding_model", - selected_id=default_models.get("default_embedding_model"), + selected_id=default_models.default_embedding_model, help="This is the default model for embeddings (semantic search, etc)", model_type="embedding", ) - st.caption( + if selected_model: + default_models.default_embedding_model = selected_model.id + st.warning( "Caution: you cannot change the embedding model once there is embeddings or they will need to be regenerated" ) for k, v in defs.items(): if v: defs[k] = v.id - DefaultModels().update(defs) - model_manager.refresh_defaults() + + if st.button("Save Defaults"): + default_models.patch(defs) + model_manager.refresh_defaults() + st.success("Saved") diff --git a/pages/8_💱_Transformations.py b/pages/8_💱_Transformations.py index 686f52b..940869c 100644 --- a/pages/8_💱_Transformations.py +++ b/pages/8_💱_Transformations.py @@ -25,7 +25,7 @@ with transformations_tab: st.markdown( "Transformations are prompts that will be used by the LLM to process a source and extract insights, summaries, etc. " ) - default_prompts: DefaultPrompts = DefaultPrompts().load() + default_prompts: DefaultPrompts = DefaultPrompts() with st.expander("**⚙️ Default Transformation Prompt**"): default_prompts.transformation_instructions = st.text_area( "Default Transformation Prompt", @@ -34,7 +34,7 @@ with transformations_tab: ) st.caption("This will be added to all your transformation prompts.") if st.button("Save", key="save_default_prompt"): - default_prompts.update(default_prompts.model_dump()) + default_prompts.update() st.toast("Default prompt saved successfully!") if st.button("Create new Transformation", icon="➕", key="new_transformation"): new_transformation = Transformation( diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py index d56d52a..758d6cf 100644 --- a/pages/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -6,7 +6,7 @@ import streamlit as st from loguru import logger from open_notebook.database.migrate import MigrationManager -from open_notebook.domain.models import model_manager +from open_notebook.domain.models import DefaultModels from open_notebook.domain.notebook import ChatSession, Notebook from open_notebook.graphs.chat import ThreadState, graph from open_notebook.utils import ( @@ -109,7 +109,7 @@ def check_migration(): def check_models(only_mandatory=True, stop_on_error=True): - default_models = model_manager.defaults + default_models = DefaultModels() mandatory_models = [ default_models.default_chat_model, default_models.default_transformation_model,