mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-04 22:30:36 +00:00
refactor objectmodel
This commit is contained in:
parent
f140a5e228
commit
c297dcb809
8 changed files with 186 additions and 68 deletions
|
|
@ -28,15 +28,14 @@ class Model(ObjectModel):
|
|||
|
||||
class DefaultModels(RecordModel):
|
||||
record_id: ClassVar[str] = "open_notebook:default_models"
|
||||
|
||||
default_chat_model: Optional[str] = None
|
||||
default_transformation_model: Optional[str] = None
|
||||
large_context_model: Optional[str] = None
|
||||
default_text_to_speech_model: Optional[str] = None
|
||||
default_speech_to_text_model: Optional[str] = None
|
||||
# default_vision_model: Optional[str] = None
|
||||
default_embedding_model: Optional[str] = None
|
||||
default_tools_model: Optional[str] = None
|
||||
default_chat_model: Optional[str]
|
||||
default_transformation_model: Optional[str]
|
||||
large_context_model: Optional[str]
|
||||
default_text_to_speech_model: Optional[str]
|
||||
default_speech_to_text_model: Optional[str]
|
||||
# default_vision_model: Optional[str]
|
||||
default_embedding_model: Optional[str]
|
||||
default_tools_model: Optional[str]
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
|
@ -54,7 +53,10 @@ class ModelManager:
|
|||
self._default_models = None
|
||||
self.refresh_defaults()
|
||||
|
||||
def get_model(self, model_id: str, **kwargs) -> ModelType:
|
||||
def get_model(self, model_id: str, **kwargs) -> Optional[ModelType]:
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
cache_key = f"{model_id}:{str(kwargs)}"
|
||||
|
||||
if cache_key in self._model_cache:
|
||||
|
|
@ -68,9 +70,6 @@ class ModelManager:
|
|||
)
|
||||
return cached_model
|
||||
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
model: Model = Model.get(model_id)
|
||||
|
||||
if not model:
|
||||
|
|
@ -111,7 +110,10 @@ class ModelManager:
|
|||
@property
|
||||
def speech_to_text(self, **kwargs) -> Optional[SpeechToTextModel]:
|
||||
"""Get the default speech-to-text model"""
|
||||
model = self.get_default_model("speech_to_text", **kwargs)
|
||||
model_id = self.defaults.default_speech_to_text_model
|
||||
if not model_id:
|
||||
return None
|
||||
model = self.get_model(model_id, **kwargs)
|
||||
assert model is None or isinstance(
|
||||
model, SpeechToTextModel
|
||||
), f"Expected SpeechToTextModel but got {type(model)}"
|
||||
|
|
@ -120,7 +122,10 @@ class ModelManager:
|
|||
@property
|
||||
def text_to_speech(self, **kwargs) -> Optional[TextToSpeechModel]:
|
||||
"""Get the default text-to-speech model"""
|
||||
model = self.get_default_model("text_to_speech", **kwargs)
|
||||
model_id = self.defaults.default_text_to_speech_model
|
||||
if not model_id:
|
||||
return None
|
||||
model = self.get_model(model_id, **kwargs)
|
||||
assert model is None or isinstance(
|
||||
model, TextToSpeechModel
|
||||
), f"Expected TextToSpeechModel but got {type(model)}"
|
||||
|
|
@ -129,13 +134,16 @@ class ModelManager:
|
|||
@property
|
||||
def embedding_model(self, **kwargs) -> Optional[EmbeddingModel]:
|
||||
"""Get the default embedding model"""
|
||||
model = self.get_default_model("embedding", **kwargs)
|
||||
model_id = self.defaults.default_embedding_model
|
||||
if not model_id:
|
||||
return None
|
||||
model = self.get_model(model_id, **kwargs)
|
||||
assert model is None or isinstance(
|
||||
model, EmbeddingModel
|
||||
), f"Expected EmbeddingModel but got {type(model)}"
|
||||
return model
|
||||
|
||||
def get_default_model(self, model_type: str, **kwargs) -> ModelType:
|
||||
def get_default_model(self, model_type: str, **kwargs) -> Optional[ModelType]:
|
||||
"""
|
||||
Get the default model for a specific type.
|
||||
|
||||
|
|
@ -165,6 +173,9 @@ class ModelManager:
|
|||
elif model_type == "large_context":
|
||||
model_id = self.defaults.large_context_model
|
||||
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
return self.get_model(model_id, **kwargs)
|
||||
|
||||
def clear_cache(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue