mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-04 06:10:35 +00:00
feat: implement the new model management based on esperanto framework
This commit is contained in:
parent
10049342cb
commit
bea43f3ce7
4 changed files with 58 additions and 42 deletions
|
|
@ -1,16 +1,18 @@
|
|||
from typing import ClassVar, Dict, Optional
|
||||
from typing import ClassVar, Dict, Optional, Union
|
||||
|
||||
from open_notebook.database.repository import repo_query
|
||||
from open_notebook.domain.base import ObjectModel, RecordModel
|
||||
from open_notebook.models import (
|
||||
MODEL_CLASS_MAP,
|
||||
from esperanto import (
|
||||
AIFactory,
|
||||
EmbeddingModel,
|
||||
LanguageModel,
|
||||
ModelType,
|
||||
SpeechToTextModel,
|
||||
TextToSpeechModel,
|
||||
)
|
||||
|
||||
from open_notebook.database.repository import repo_query
|
||||
from open_notebook.domain.base import ObjectModel, RecordModel
|
||||
|
||||
ModelType = Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]
|
||||
|
||||
|
||||
class Model(ObjectModel):
|
||||
table_name: ClassVar[str] = "model"
|
||||
|
|
@ -75,21 +77,53 @@ class ModelManager:
|
|||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
|
||||
if not model.type or model.type not in MODEL_CLASS_MAP:
|
||||
if not model.type or model.type not in [
|
||||
"language",
|
||||
"embedding",
|
||||
"speech_to_text",
|
||||
"text_to_speech",
|
||||
]:
|
||||
raise ValueError(f"Invalid model type: {model.type}")
|
||||
|
||||
provider_map = MODEL_CLASS_MAP[model.type]
|
||||
if model.provider not in provider_map:
|
||||
# todo: change to providers in the future
|
||||
if model.provider not in [
|
||||
"ollama",
|
||||
"openrouter",
|
||||
"vertexai-anthropic",
|
||||
"litellm",
|
||||
"vertexai",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"xai",
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with {model.type} models"
|
||||
)
|
||||
|
||||
model_class = provider_map[model.provider]
|
||||
model_instance = model_class(model_name=model.name, **kwargs)
|
||||
|
||||
# Special handling for language models that need langchain conversion
|
||||
if model.type == "language":
|
||||
model_instance = model_instance
|
||||
model_instance: LanguageModel = AIFactory.create_language(
|
||||
model_name=model.name,
|
||||
provider=model.provider,
|
||||
config=kwargs,
|
||||
)
|
||||
elif model.type == "embedding":
|
||||
model_instance: EmbeddingModel = AIFactory.create_embedding(
|
||||
model_name=model.name,
|
||||
provider=model.provider,
|
||||
config=kwargs,
|
||||
)
|
||||
elif model.type == "speech_to_text":
|
||||
model_instance: SpeechToTextModel = AIFactory.create_speech_to_text(
|
||||
model_name=model.name,
|
||||
provider=model.provider,
|
||||
config=kwargs,
|
||||
)
|
||||
elif model.type == "text_to_speech":
|
||||
model_instance: TextToSpeechModel = AIFactory.create_text_to_speech(
|
||||
model_name=model.name,
|
||||
provider=model.provider,
|
||||
config=kwargs,
|
||||
)
|
||||
|
||||
self._model_cache[cache_key] = model_instance
|
||||
return model_instance
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue