mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-29 20:10:07 +00:00
model manager
This commit is contained in:
parent
3b7dd5f25f
commit
a9ac4a6dc8
6 changed files with 141 additions and 43 deletions
|
|
@ -1,4 +1,6 @@
|
|||
from open_notebook.domain.models import Model
|
||||
from typing import Dict, Optional
|
||||
|
||||
from open_notebook.domain.models import DefaultModels, Model
|
||||
from open_notebook.models.embedding_models import (
|
||||
GeminiEmbeddingModel,
|
||||
OllamaEmbeddingModel,
|
||||
|
|
@ -49,34 +51,137 @@ MODEL_CLASS_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def get_model(model_id, **kwargs):
|
||||
"""
|
||||
Get a model instance based on model_id and type.
|
||||
# def get_model(model_id, **kwargs):
|
||||
# """
|
||||
# Get a model instance based on model_id and type.
|
||||
|
||||
Args:
|
||||
model_id: The ID of the model to retrieve
|
||||
**kwargs: Additional arguments to pass to the model constructor
|
||||
"""
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model: Model = Model.get(model_id)
|
||||
# Args:
|
||||
# model_id: The ID of the model to retrieve
|
||||
# **kwargs: Additional arguments to pass to the model constructor
|
||||
# """
|
||||
# assert model_id, "Model ID cannot be empty"
|
||||
# model: Model = Model.get(model_id)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
# if not model:
|
||||
# raise ValueError(f"Model with ID {model_id} not found")
|
||||
|
||||
if not model.type or model.type not in MODEL_CLASS_MAP:
|
||||
raise ValueError(f"Invalid model type: {model.type}")
|
||||
# if not model.type or model.type not in MODEL_CLASS_MAP:
|
||||
# raise ValueError(f"Invalid model type: {model.type}")
|
||||
|
||||
provider_map = MODEL_CLASS_MAP[model.type]
|
||||
if model.provider not in provider_map:
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with {model.type} models"
|
||||
)
|
||||
# provider_map = MODEL_CLASS_MAP[model.type]
|
||||
# if model.provider not in provider_map:
|
||||
# raise ValueError(
|
||||
# f"Provider {model.provider} not compatible with {model.type} models"
|
||||
# )
|
||||
|
||||
model_class = provider_map[model.provider]
|
||||
model_instance = model_class(model_name=model.name, **kwargs)
|
||||
# model_class = provider_map[model.provider]
|
||||
# model_instance = model_class(model_name=model.name, **kwargs)
|
||||
|
||||
# Special handling for language models that need langchain conversion
|
||||
if model.type == "language":
|
||||
return model_instance.to_langchain()
|
||||
# # Special handling for language models that need langchain conversion
|
||||
# if model.type == "language":
|
||||
# return model_instance.to_langchain()
|
||||
|
||||
return model_instance
|
||||
# return model_instance
|
||||
|
||||
|
||||
class ModelManager:
|
||||
_instance = None
|
||||
_model_cache: Dict[str, object] = {}
|
||||
_default_models: Optional[DefaultModels] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._initialized = True
|
||||
self.refresh_defaults()
|
||||
|
||||
def refresh_defaults(self):
|
||||
"""Refresh the default models from the database"""
|
||||
self._default_models = DefaultModels.load()
|
||||
|
||||
@property
|
||||
def defaults(self) -> DefaultModels:
|
||||
"""Get the default models configuration"""
|
||||
if not self._default_models:
|
||||
self.refresh_defaults()
|
||||
return self._default_models
|
||||
|
||||
def get_model(self, model_id: str, **kwargs) -> object:
|
||||
"""
|
||||
Get a model instance based on model_id. Uses caching to avoid recreating instances.
|
||||
|
||||
Args:
|
||||
model_id: The ID of the model to retrieve
|
||||
**kwargs: Additional arguments to pass to the model constructor
|
||||
"""
|
||||
cache_key = f"{model_id}:{str(kwargs)}"
|
||||
|
||||
if cache_key in self._model_cache:
|
||||
return self._model_cache[cache_key]
|
||||
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model: Model = Model.get(model_id)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
|
||||
if not model.type or model.type not in MODEL_CLASS_MAP:
|
||||
raise ValueError(f"Invalid model type: {model.type}")
|
||||
|
||||
provider_map = MODEL_CLASS_MAP[model.type]
|
||||
if model.provider not in provider_map:
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with {model.type} models"
|
||||
)
|
||||
|
||||
model_class = provider_map[model.provider]
|
||||
model_instance = model_class(model_name=model.name, **kwargs)
|
||||
|
||||
# Special handling for language models that need langchain conversion
|
||||
if model.type == "language":
|
||||
model_instance = model_instance.to_langchain()
|
||||
|
||||
self._model_cache[cache_key] = model_instance
|
||||
return model_instance
|
||||
|
||||
def get_default_model(self, model_type: str, **kwargs) -> object:
|
||||
"""
|
||||
Get the default model for a specific type.
|
||||
|
||||
Args:
|
||||
model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.)
|
||||
**kwargs: Additional arguments to pass to the model constructor
|
||||
"""
|
||||
model_id = None
|
||||
|
||||
if model_type == "chat":
|
||||
model_id = self.defaults.default_chat_model
|
||||
elif model_type == "transformation":
|
||||
model_id = (
|
||||
self.defaults.default_transformation_model
|
||||
or self.defaults.default_chat_model
|
||||
)
|
||||
elif model_type == "embedding":
|
||||
model_id = self.defaults.default_embedding_model
|
||||
elif model_type == "text_to_speech":
|
||||
model_id = self.defaults.default_text_to_speech_model
|
||||
elif model_type == "speech_to_text":
|
||||
model_id = self.defaults.default_speech_to_text_model
|
||||
elif model_type == "large_context":
|
||||
model_id = self.defaults.large_context_model
|
||||
|
||||
if not model_id:
|
||||
raise ValueError(f"No default model configured for type: {model_type}")
|
||||
|
||||
return self.get_model(model_id, **kwargs)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the model cache"""
|
||||
self._model_cache.clear()
|
||||
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue