refactor objectmodel

This commit is contained in:
LUIS NOVO 2024-11-19 19:03:32 -03:00
parent f140a5e228
commit c297dcb809
8 changed files with 186 additions and 68 deletions

View file

@ -28,15 +28,14 @@ class Model(ObjectModel):
class DefaultModels(RecordModel):
record_id: ClassVar[str] = "open_notebook:default_models"
default_chat_model: Optional[str] = None
default_transformation_model: Optional[str] = None
large_context_model: Optional[str] = None
default_text_to_speech_model: Optional[str] = None
default_speech_to_text_model: Optional[str] = None
# default_vision_model: Optional[str] = None
default_embedding_model: Optional[str] = None
default_tools_model: Optional[str] = None
default_chat_model: Optional[str]
default_transformation_model: Optional[str]
large_context_model: Optional[str]
default_text_to_speech_model: Optional[str]
default_speech_to_text_model: Optional[str]
# default_vision_model: Optional[str]
default_embedding_model: Optional[str]
default_tools_model: Optional[str]
class ModelManager:
@ -54,7 +53,10 @@ class ModelManager:
self._default_models = None
self.refresh_defaults()
def get_model(self, model_id: str, **kwargs) -> ModelType:
def get_model(self, model_id: str, **kwargs) -> Optional[ModelType]:
if not model_id:
return None
cache_key = f"{model_id}:{str(kwargs)}"
if cache_key in self._model_cache:
@ -68,9 +70,6 @@ class ModelManager:
)
return cached_model
if not model_id:
return None
model: Model = Model.get(model_id)
if not model:
@ -111,7 +110,10 @@ class ModelManager:
@property
def speech_to_text(self, **kwargs) -> Optional[SpeechToTextModel]:
"""Get the default speech-to-text model"""
model = self.get_default_model("speech_to_text", **kwargs)
model_id = self.defaults.default_speech_to_text_model
if not model_id:
return None
model = self.get_model(model_id, **kwargs)
assert model is None or isinstance(
model, SpeechToTextModel
), f"Expected SpeechToTextModel but got {type(model)}"
@ -120,7 +122,10 @@ class ModelManager:
@property
def text_to_speech(self, **kwargs) -> Optional[TextToSpeechModel]:
"""Get the default text-to-speech model"""
model = self.get_default_model("text_to_speech", **kwargs)
model_id = self.defaults.default_text_to_speech_model
if not model_id:
return None
model = self.get_model(model_id, **kwargs)
assert model is None or isinstance(
model, TextToSpeechModel
), f"Expected TextToSpeechModel but got {type(model)}"
@ -129,13 +134,16 @@ class ModelManager:
@property
def embedding_model(self, **kwargs) -> Optional[EmbeddingModel]:
"""Get the default embedding model"""
model = self.get_default_model("embedding", **kwargs)
model_id = self.defaults.default_embedding_model
if not model_id:
return None
model = self.get_model(model_id, **kwargs)
assert model is None or isinstance(
model, EmbeddingModel
), f"Expected EmbeddingModel but got {type(model)}"
return model
def get_default_model(self, model_type: str, **kwargs) -> ModelType:
def get_default_model(self, model_type: str, **kwargs) -> Optional[ModelType]:
"""
Get the default model for a specific type.
@ -165,6 +173,9 @@ class ModelManager:
elif model_type == "large_context":
model_id = self.defaults.large_context_model
if not model_id:
return None
return self.get_model(model_id, **kwargs)
def clear_cache(self):