simplify the model selector

This commit is contained in:
LUIS NOVO 2024-10-30 14:30:29 -03:00
parent 8bb5db158f
commit 859b7f6e7e
6 changed files with 151 additions and 75 deletions

View file

@ -8,8 +8,6 @@ from typing import List, Optional
from openai import OpenAI
from open_notebook.domain.models import Model
@dataclass
class EmbeddingModel(ABC):
@ -43,20 +41,3 @@ class OpenAIEmbeddingModel(EmbeddingModel):
.data[0]
.embedding
)
EMBEDDING_CLASS_MAP = {
"openai": OpenAIEmbeddingModel,
}
def get_embedding_model(model_id):
assert model_id, "Model ID cannot be empty"
model = Model.get(model_id)
if not model:
raise ValueError(f"Model with ID {model_id} not found")
if model.provider not in EMBEDDING_CLASS_MAP.keys():
raise ValueError(
f"Provider {model.provider} not compatible with Embedding Models"
)
return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name)