feat: implement the new model management based on esperanto framework

This commit is contained in:
LUIS NOVO 2025-06-08 19:38:43 -03:00
parent 10049342cb
commit bea43f3ce7
4 changed files with 58 additions and 42 deletions

View file

@ -1,16 +1,18 @@
from typing import ClassVar, Dict, Optional
from typing import ClassVar, Dict, Optional, Union
from open_notebook.database.repository import repo_query
from open_notebook.domain.base import ObjectModel, RecordModel
from open_notebook.models import (
MODEL_CLASS_MAP,
from esperanto import (
AIFactory,
EmbeddingModel,
LanguageModel,
ModelType,
SpeechToTextModel,
TextToSpeechModel,
)
from open_notebook.database.repository import repo_query
from open_notebook.domain.base import ObjectModel, RecordModel
ModelType = Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]
class Model(ObjectModel):
table_name: ClassVar[str] = "model"
@ -75,21 +77,53 @@ class ModelManager:
if not model:
raise ValueError(f"Model with ID {model_id} not found")
if not model.type or model.type not in MODEL_CLASS_MAP:
if not model.type or model.type not in [
"language",
"embedding",
"speech_to_text",
"text_to_speech",
]:
raise ValueError(f"Invalid model type: {model.type}")
provider_map = MODEL_CLASS_MAP[model.type]
if model.provider not in provider_map:
# todo: change to providers in the future
if model.provider not in [
"ollama",
"openrouter",
"vertexai-anthropic",
"litellm",
"vertexai",
"anthropic",
"openai",
"xai",
]:
raise ValueError(
f"Provider {model.provider} not compatible with {model.type} models"
)
model_class = provider_map[model.provider]
model_instance = model_class(model_name=model.name, **kwargs)
# Special handling for language models that need langchain conversion
if model.type == "language":
model_instance = model_instance
model_instance: LanguageModel = AIFactory.create_language(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
elif model.type == "embedding":
model_instance: EmbeddingModel = AIFactory.create_embedding(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
elif model.type == "speech_to_text":
model_instance: SpeechToTextModel = AIFactory.create_speech_to_text(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
elif model.type == "text_to_speech":
model_instance: TextToSpeechModel = AIFactory.create_text_to_speech(
model_name=model.name,
provider=model.provider,
config=kwargs,
)
self._model_cache[cache_key] = model_instance
return model_instance