diff --git a/open_notebook/config.py b/open_notebook/config.py index 4543a55..38a6972 100644 --- a/open_notebook/config.py +++ b/open_notebook/config.py @@ -3,6 +3,10 @@ import os import yaml from loguru import logger +from open_notebook.domain.models import DefaultModels +from open_notebook.models.embedding_models import get_embedding_model +from open_notebook.models.speech_to_text_models import get_speech_to_text_model + # todo: enable config file overwrite with env vars current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) @@ -32,3 +36,12 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True) # PODCASTS FOLDER PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts" os.makedirs(PODCASTS_FOLDER, exist_ok=True) + + +DEFAULT_MODELS = DefaultModels.load() + +EMBEDDING_MODEL = get_embedding_model(DEFAULT_MODELS.default_embedding_model) + +SPEECH_TO_TEXT_MODEL = get_speech_to_text_model( + DEFAULT_MODELS.default_speech_to_text_model +) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py new file mode 100644 index 0000000..12074ff --- /dev/null +++ b/open_notebook/domain/models.py @@ -0,0 +1,38 @@ +from typing import ClassVar, Optional + +from loguru import logger +from pydantic import BaseModel + +from open_notebook.domain.base import ObjectModel +from open_notebook.repository import ( + repo_query, + repo_update, +) + + +class Model(ObjectModel): + table_name: ClassVar[str] = "model" + name: str + provider: str + type: str + + +class DefaultModels(BaseModel): + default_chat_model: Optional[str] = None + default_transformation_model: Optional[str] = None + large_context_model: Optional[str] = None + default_text_to_speech_model: Optional[str] = None + default_speech_to_text_model: Optional[str] = None + # default_vision_model: Optional[str] = None + default_embedding_model: Optional[str] = None + + @classmethod + def load(self): + result = repo_query("SELECT * FROM open_notebook:default_models;") + if result: + logger.debug(result) + return DefaultModels(**result[0]) + + @classmethod + def update(self, data): + repo_update("open_notebook:default_models", data) diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 37c6f04..84a30be 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -9,8 +9,8 @@ from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from typing_extensions import TypedDict -from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE -from open_notebook.domain import Notebook +from open_notebook.config import DEFAULT_MODELS, LANGGRAPH_CHECKPOINT_FILE +from open_notebook.domain.notebook import Notebook from open_notebook.graphs.utils import run_pattern @@ -22,7 +22,9 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get("model_name", None) + model_name = config.get("configurable", {}).get( + "model_name", DEFAULT_MODELS.default_chat_model + ) ai_message = run_pattern( "chat", model_name, diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index 5afafb7..4921403 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -4,9 +4,9 @@ from math import ceil from loguru import logger from pydub import AudioSegment +from open_notebook.config import SPEECH_TO_TEXT_MODEL from open_notebook.graphs.content_processing.state import SourceState -# todo: add a speechtotext model to the config # future: parallelize the transcription process @@ -73,9 +73,6 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): input_audio_path = data.get("file_path") - from openai import OpenAI - - client = OpenAI() audio_files = [] try: @@ -83,11 +80,7 @@ def extract_audio(data: SourceState): transcriptions = [] for audio_file in audio_files: - with open(audio_file, "rb") as audio: - transcription = client.audio.transcriptions.create( - model="whisper-1", file=audio - ) - transcriptions.append(transcription.text) + transcriptions.append(SPEECH_TO_TEXT_MODEL.transcribe(audio_file)) return {"content": " ".join(transcriptions)} diff --git a/open_notebook/graphs/doc_query.py b/open_notebook/graphs/doc_query.py index 2c673db..9a2ffb9 100644 --- a/open_notebook/graphs/doc_query.py +++ b/open_notebook/graphs/doc_query.py @@ -6,7 +6,7 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import TypedDict -from open_notebook.domain import Note, Notebook, Source +from open_notebook.domain.notebook import Note, Notebook, Source from open_notebook.graphs.utils import run_pattern diff --git a/open_notebook/graphs/multipattern.py b/open_notebook/graphs/multipattern.py index 8f2b6d7..e74d7a3 100644 --- a/open_notebook/graphs/multipattern.py +++ b/open_notebook/graphs/multipattern.py @@ -1,5 +1,4 @@ import operator -import os from typing import List, Literal, Sequence from langchain_core.runnables import ( @@ -8,6 +7,7 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import Annotated, TypedDict +from open_notebook.config import DEFAULT_MODELS from open_notebook.graphs.utils import run_pattern @@ -19,7 +19,7 @@ class PatternChainState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: model_name = config.get("configurable", {}).get( - "model_name", os.environ.get("DEFAULT_MODEL") + "model_name", DEFAULT_MODELS.default_transformation_model ) transformations = state["transformations"] current_transformation = transformations.pop(0) diff --git a/open_notebook/graphs/pattern.py b/open_notebook/graphs/pattern.py index b7a9bd0..c47cc14 100644 --- a/open_notebook/graphs/pattern.py +++ b/open_notebook/graphs/pattern.py @@ -1,11 +1,10 @@ -import os - from langchain_core.runnables import ( RunnableConfig, ) from langgraph.graph import END, START, StateGraph from typing_extensions import TypedDict +from open_notebook.config import DEFAULT_MODELS from open_notebook.graphs.utils import run_pattern @@ -17,7 +16,7 @@ class PatternState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: model_name = config.get("configurable", {}).get( - "model_name", os.environ.get("DEFAULT_MODEL") + "model_name", DEFAULT_MODELS.default_transformation_model ) return { "output": run_pattern( diff --git a/open_notebook/graphs/recursive_toc.py b/open_notebook/graphs/recursive_toc.py index 9cffc5d..a9cb795 100644 --- a/open_notebook/graphs/recursive_toc.py +++ b/open_notebook/graphs/recursive_toc.py @@ -7,6 +7,7 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import TypedDict +from open_notebook.config import DEFAULT_MODELS from open_notebook.graphs.utils import run_pattern from open_notebook.utils import split_text @@ -49,7 +50,7 @@ def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: igno def call_model(state: TocState, config: RunnableConfig) -> dict: model_name = config.get("configurable", {}).get( - "model_name", os.environ.get("SUMMARIZATION_MODEL") + "model_name", DEFAULT_MODELS.default_transformation_model ) return { "toc": run_pattern( diff --git a/open_notebook/graphs/summary.py b/open_notebook/graphs/summary.py index a262c5e..df54ff5 100644 --- a/open_notebook/graphs/summary.py +++ b/open_notebook/graphs/summary.py @@ -9,6 +9,7 @@ from langgraph.graph import END, START, StateGraph from pydantic import BaseModel from typing_extensions import TypedDict +from open_notebook.config import DEFAULT_MODELS from open_notebook.graphs.utils import run_pattern from open_notebook.utils import split_text @@ -57,9 +58,9 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type: return END -def call_model(state: SummaryState, config: RunnableConfig) -> dict: +def call_model(state: dict, config: RunnableConfig) -> dict: model_name = config.get("configurable", {}).get( - "model_name", os.environ.get("SUMMARIZATION_MODEL") + "model_name", DEFAULT_MODELS.default_transformation_model ) parser = PydanticOutputParser(pydantic_object=SummaryResponse) return { diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 67f6866..9e95fcd 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,8 +1,7 @@ -import os - from langchain.output_parsers import OutputFixingParser -from open_notebook.llm_router import get_langchain_model +from open_notebook.config import DEFAULT_MODELS +from open_notebook.models.llms import get_langchain_model from open_notebook.prompter import Prompter @@ -15,7 +14,7 @@ def run_pattern( output_fixing_model_name=None, ) -> dict: if not model_name: - model_name = os.environ["DEFAULT_MODEL"] + model_name = DEFAULT_MODELS.default_transformation_model chain = get_langchain_model(model_name) if parser: diff --git a/open_notebook/llm_router.py b/open_notebook/llm_router.py deleted file mode 100644 index 9e1096e..0000000 --- a/open_notebook/llm_router.py +++ /dev/null @@ -1,35 +0,0 @@ -from open_notebook.llms import ( - AnthropicLanguageModel, - GeminiLanguageModel, - LiteLLMLanguageModel, - OllamaLanguageModel, - OpenAILanguageModel, - OpenRouterLanguageModel, - VertexAILanguageModel, - VertexAnthropicLanguageModel, -) - -# Map provider names to classes -PROVIDER_CLASS_MAP = { - "ollama": OllamaLanguageModel, - "openrouter": OpenRouterLanguageModel, - "vertexai-anthropic": VertexAnthropicLanguageModel, - "litellm": LiteLLMLanguageModel, - "vertexai": VertexAILanguageModel, - "anthropic": AnthropicLanguageModel, - "openai": OpenAILanguageModel, - "gemini": GeminiLanguageModel, -} - - -def get_langchain_model(model_name, json=False): - parts = model_name.split("/") - provider = parts[0] - model_name_wihout_provider = "/".join(parts[1:]) - if provider not in PROVIDER_CLASS_MAP.keys(): - raise ValueError( - f"Provider {provider} not found in config. Make sure you use the correct format for model names, example: openai/gpt-4o-mini" - ) - return PROVIDER_CLASS_MAP[provider]( - model_name=model_name_wihout_provider, json=json - ).to_langchain() diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/open_notebook/models/embedding_models.py b/open_notebook/models/embedding_models.py new file mode 100644 index 0000000..aaaec65 --- /dev/null +++ b/open_notebook/models/embedding_models.py @@ -0,0 +1,62 @@ +""" +Classes for supporting different embedding models +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +from openai import OpenAI + +from open_notebook.domain.models import Model + + +@dataclass +class EmbeddingModel(ABC): + """ + Abstract base class for language models. + """ + + model_name: Optional[str] = None + + @abstractmethod + def embed(self, text: str) -> List[float]: + """ + Generates an embedding + """ + raise NotImplementedError + + +@dataclass +class OpenAIEmbeddingModel(EmbeddingModel): + model_name: str + + def embed(self, text: str) -> List[float]: + """ + Embeds the content using Open AI embedding + """ + # todo: make this Singleton + client = OpenAI() + text = text.replace("\n", " ") + return ( + client.embeddings.create(input=[text], model=self.model_name) + .data[0] + .embedding + ) + + +EMBEDDING_CLASS_MAP = { + "openai": OpenAIEmbeddingModel, +} + + +def get_embedding_model(model_id): + assert model_id, "Model ID cannot be empty" + model = Model.get(model_id) + if not model: + raise ValueError(f"Model with ID {model_id} not found") + if model.provider not in EMBEDDING_CLASS_MAP.keys(): + raise ValueError( + f"Provider {model.provider} not compatible with Embedding Models" + ) + return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name) diff --git a/open_notebook/llms.py b/open_notebook/models/llms.py similarity index 86% rename from open_notebook/llms.py rename to open_notebook/models/llms.py index efcc431..64c0213 100644 --- a/open_notebook/llms.py +++ b/open_notebook/models/llms.py @@ -1,5 +1,5 @@ """ -Classes for supporting different language and vector models +Classes for supporting different language models """ import os @@ -15,9 +15,9 @@ from langchain_google_vertexai import ChatVertexAI from langchain_google_vertexai.model_garden import ChatAnthropicVertex from langchain_ollama.chat_models import ChatOllama from langchain_openai.chat_models import ChatOpenAI +from pydantic import SecretStr -# from redisvl.utils.vectorize import BaseVectorizer -# from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer +from open_notebook.domain.models import Model @dataclass @@ -186,7 +186,7 @@ class OpenRouterLanguageModel(LanguageModel): max_tokens=self.max_tokens, model_kwargs=kwargs, streaming=self.streaming, - api_key=os.environ.get("OPENROUTER_API_KEY", "openrouter"), + api_key=SecretStr(os.environ.get("OPENROUTER_API_KEY", "openrouter")), top_p=self.top_p, ) @@ -240,26 +240,26 @@ class OpenAILanguageModel(LanguageModel): ) -# @dataclass -# class EmbeddingModel(ABC): -# model_name: str -# dimensions: int - -# def to_redis_vectorizer(self) -> BaseVectorizer: -# raise NotImplementedError +# Map provider names to classes +PROVIDER_CLASS_MAP = { + "ollama": OllamaLanguageModel, + "openrouter": OpenRouterLanguageModel, + "vertexai-anthropic": VertexAnthropicLanguageModel, + "litellm": LiteLLMLanguageModel, + "vertexai": VertexAILanguageModel, + "anthropic": AnthropicLanguageModel, + "openai": OpenAILanguageModel, + "gemini": GeminiLanguageModel, +} -# @dataclass -# class OpenAIEmbeddingModel(EmbeddingModel): -# """ -# Embedding model that uses the OpenAI text embedding model. -# """ - -# model_name: str -# dimensions: int - -# def to_redis_vectorizer(self) -> OpenAITextVectorizer: -# """ -# Convert the embedding model to a Redis vectorizer. -# """ -# return OpenAITextVectorizer(model=self.model_name) +# todo: make the provider check type specific +def get_langchain_model(model_id, json=False): + model = Model.get(model_id) + if not model: + raise ValueError(f"Model {model_id} not found") + if model.provider not in PROVIDER_CLASS_MAP.keys(): + raise ValueError(f"Provider {model.provider} not found") + return PROVIDER_CLASS_MAP[model.provider]( + model_name=model.name, json=json + ).to_langchain() diff --git a/open_notebook/models/speech_to_text_models.py b/open_notebook/models/speech_to_text_models.py new file mode 100644 index 0000000..4191812 --- /dev/null +++ b/open_notebook/models/speech_to_text_models.py @@ -0,0 +1,62 @@ +""" +Classes for supporting different transcription models +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +from openai import OpenAI + +from open_notebook.domain.models import Model + + +@dataclass +class SpeechToTextModel(ABC): + """ + Abstract base class for speech to text models. + """ + + model_name: Optional[str] = None + + @abstractmethod + def transcribe(self, audio_file_path: str) -> str: + """ + Generates a text transcription from audio + """ + raise NotImplementedError + + +@dataclass +class OpenAISpeechToTextModel(SpeechToTextModel): + model_name: str + + def transcribe(self, audio_file_path: str) -> str: + """ + Transcribes an audio file into text + """ + # todo: make this Singleton + client = OpenAI() + with open(audio_file_path, "rb") as audio: + transcription = client.audio.transcriptions.create( + model=self.model_name, file=audio + ) + return transcription.text + + +SPEECH_TO_TEXT_CLASS_MAP = { + "openai": OpenAISpeechToTextModel, +} + + +# todo: acho que dá pra juntar todos os get models em uma coisa só +def get_speech_to_text_model(model_id): + assert model_id, "Model ID cannot be empty" + model = Model.get(model_id) + if not model: + raise ValueError(f"Model with ID {model_id} not found") + if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys(): + raise ValueError( + f"Provider {model.provider} not compatible with Embedding Models" + ) + return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name) diff --git a/open_notebook/plugins/podcasts.py b/open_notebook/plugins/podcasts.py index 327a969..2645474 100644 --- a/open_notebook/plugins/podcasts.py +++ b/open_notebook/plugins/podcasts.py @@ -4,7 +4,7 @@ from loguru import logger from podcastfy.client import generate_podcast from pydantic import Field, field_validator -from open_notebook.domain import ObjectModel +from open_notebook.domain.notebook import ObjectModel class PodcastEpisode(ObjectModel): diff --git a/open_notebook/utils.py b/open_notebook/utils.py index ab43172..340762e 100644 --- a/open_notebook/utils.py +++ b/open_notebook/utils.py @@ -6,11 +6,8 @@ from urllib.parse import urlparse import requests import tomli from langchain_text_splitters import CharacterTextSplitter -from openai import OpenAI from packaging.version import parse as parse_version -client = OpenAI() - def split_text(txt: str, chunk=1000, overlap=0, separator=" "): """ @@ -63,21 +60,6 @@ def token_cost(token_count, cost_per_million=0.150): return cost_per_million * (token_count / 1_000_000) -def get_embedding(text, model="text-embedding-3-small"): - """ - Get the embedding for the input text using the specified model. - - Args: - text (str): The input text to get the embedding for. - model (str): The name of the embedding model to use. Default is "text-embedding-3-small". - - Returns: - list: The embedding vector for the input text. - """ - text = text.replace("\n", " ") - return client.embeddings.create(input=[text], model=model).data[0].embedding - - def remove_non_ascii(text): return re.sub(r"[^\x00-\x7F]+", "", text) diff --git a/pages/9_⚙️_Settings.py b/pages/9_⚙️_Settings.py new file mode 100644 index 0000000..110bf7c --- /dev/null +++ b/pages/9_⚙️_Settings.py @@ -0,0 +1,212 @@ +import os + +import streamlit as st + +from open_notebook.domain.models import DefaultModels, Model +from stream_app.utils import version_sidebar + +st.set_page_config( + layout="wide", page_title="⚙️ Settings", initial_sidebar_state="expanded" +) +version_sidebar() + + +st.title("Settings") + +model_tab, model_defaults_tab = st.tabs(["Models", "Model Defaults"]) + +provider_status = {} + +model_types = [ + # "vision", + "text generation", + "embedding", + "text to speech", + "speech to text", +] + +provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None +provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None +provider_status["vertexai"] = ( + os.environ.get("VERTEX_PROJECT") is not None + and os.environ.get("VERTEX_LOCATION") is not None + and os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None +) +provider_status["vertexai-anthropic"] = ( + os.environ.get("VERTEX_PROJECT") is not None + and os.environ.get("VERTEX_LOCATION") is not None + and os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None +) +provider_status["gemini"] = os.environ.get("GEMINI_API_KEY") is not None +provider_status["openrouter"] = ( + os.environ.get("OPENROUTER_API_KEY") is not None + and os.environ.get("OPENAI_API_KEY") is not None + and os.environ.get("OPENROUTER_BASE_URL") is not None +) +provider_status["anthropic"] = os.environ.get("ANTHROPIC_API_KEY") is not None +provider_status["eleven_labs"] = os.environ.get("ELEVENLABS_API_KEY") is not None +provider_status["litellm"] = ( + provider_status["ollama"] + or provider_status["vertexai"] + or provider_status["vertexai-anthropic"] + or provider_status["anthropic"] + or provider_status["openai"] + or provider_status["gemini"] +) + +available_providers = [k for k, v in provider_status.items() if v] +unavailable_providers = [k for k, v in provider_status.items() if not v] + +with model_tab: + st.subheader("Add Model") + provider = st.selectbox("Provider", available_providers) + if len(unavailable_providers) > 0: + st.caption( + f"Unavailable Providers: {', '.join(unavailable_providers)}. Please check docs page if you wish to enable them." + ) + model_name = st.text_input("Model Name", "") + model_type = st.selectbox("Model Type", model_types) + if st.button("Save"): + model = Model(name=model_name, provider=provider, type=model_type) + model.save() + st.success("Saved") + st.divider() + all_models = Model.get_all() + st.subheader("Configured Models") + model_types_available = { + # "vision": False, + "text generation": False, + "embedding": False, + "text to speech": False, + "speech to text": False, + } + for model in all_models: + model_types_available[model.type] = True + with st.container(border=True): + st.markdown(f"{model.name} ({model.provider}, {model.type})") + if st.button("Delete", key=f"delete_{model.id}"): + model.delete() + st.rerun() + + for model_type, available in model_types_available.items(): + if not available: + st.warning(f"No models available for {model_type}") + + +# todo: check for each type of model +def get_selected_index(models, model_id, default=0): + """Returns the index of the selected model in the list of models""" + if not model_id or not models: + return default + for i, model in enumerate(models): + if model.id == model_id: + return i + return default + + +with model_defaults_tab: + default_models = DefaultModels.load().model_dump() + all_models = Model.get_all() + text_generation_models = [ + model for model in all_models if model.type == "text generation" + ] + + text_to_speech_models = [ + model for model in all_models if model.type == "text to speech" + ] + + speech_to_text_models = [ + model for model in all_models if model.type == "speech to text" + ] + vision_models = [model for model in all_models if model.type == "vision"] + embedding_models = [model for model in all_models if model.type == "embedding"] + st.write( + "In this section, you can select the default models to be used on the various content operations done by Open Notebook. Some of these can be overriden in the different modules." + ) + defs = {} + defs["default_chat_model"] = st.selectbox( + "Default Chat Model", + text_generation_models, + format_func=lambda x: x.name, + help="This model will be used for chat.", + index=get_selected_index( + text_generation_models, default_models.get("default_chat_model") + ), + ) + st.divider() + defs["default_transformation_model"] = st.selectbox( + "Default Transformation Model", + text_generation_models, + format_func=lambda x: x.name, + help="This model will be used for text transformations such as summaries, insights, etc.", + index=get_selected_index( + text_generation_models, default_models.get("default_transformation_model") + ), + ) + st.caption("You can override this model on individual transformations") + st.divider() + defs["large_context_model"] = st.selectbox( + "Large Context Model", + text_generation_models, + format_func=lambda x: x.name, + help="This model will be used for larger context generation -- recommended: Gemini", + index=get_selected_index( + text_generation_models, default_models.get("large_context_model") + ), + ) + st.caption("Recommended to use Gemini models for larger context processing") + st.divider() + defs["default_text_to_speech_model"] = st.selectbox( + "Default Text to Speech Model", + text_to_speech_models, + format_func=lambda x: x.name, + help="This is the default model for converting text to speech (podcasts, etc)", + index=get_selected_index( + text_to_speech_models, default_models.get("default_text_to_speech_model") + ), + ) + st.caption("You can override this model on different podcasts") + st.divider() + defs["default_speech_to_text_model"] = st.selectbox( + "Default Speech to Text Model", + speech_to_text_models, + format_func=lambda x: x.name, + help="This is the default model for converting speech to text (audio transcriptions, etc)", + index=get_selected_index( + speech_to_text_models, default_models.get("default_speech_to_text_model") + ), + ) + st.divider() + # defs["default_vision_model"] = st.selectbox( + # "Default Vision Model", + # vision_models, + # format_func=lambda x: x.name, + # help="This is the default model for vision tasks (image recognition, PDF recognition, etc)", + # index=get_selected_index( + # vision_models, default_models.get("default_vision_model") + # ), + # ) + # st.divider() + + defs["default_embedding_model"] = st.selectbox( + "Default Embedding Model", + embedding_models, + format_func=lambda x: x.name, + help="This is the default model for embeddings (semantic search, etc)", + index=get_selected_index( + embedding_models, default_models.get("default_embedding_model") + ), + ) + st.caption( + "Caution: you cannot change the embedding model once there is embeddings or they will need to be regenerated" + ) + + for k, v in defs.items(): + defs[k] = v.id + + if st.button("Save Defaults", key="save_defaults"): + DefaultModels.update(defs) + st.rerun() + +# todo: return an error if a selected model is no longer supported +# todo: do this check on the app homepage as well diff --git a/stream_app/source.py b/stream_app/source.py index d8f316a..96109bb 100644 --- a/stream_app/source.py +++ b/stream_app/source.py @@ -8,7 +8,7 @@ from humanize import naturaltime from loguru import logger from open_notebook.config import UPLOADS_FOLDER -from open_notebook.domain import Asset, Source +from open_notebook.domain.notebook import Asset, Source from open_notebook.exceptions import UnsupportedTypeException from open_notebook.graphs.content_processing import graph from open_notebook.graphs.multipattern import graph as transform_graph