add gemini support

This commit is contained in:
LUIS NOVO 2024-10-26 05:41:06 -03:00
parent 7648caca7b
commit aaa7831ab1
4 changed files with 28 additions and 2 deletions

View file

@ -1,5 +1,6 @@
from open_notebook.llms import (
AnthropicLanguageModel,
GeminiLanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
OpenAILanguageModel,
@ -17,6 +18,7 @@ PROVIDER_CLASS_MAP = {
"vertexai": VertexAILanguageModel,
"anthropic": AnthropicLanguageModel,
"openai": OpenAILanguageModel,
"gemini": GeminiLanguageModel,
}

View file

@ -10,6 +10,7 @@ from typing import Any, Dict, Optional
from langchain_anthropic import ChatAnthropic
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
from langchain_ollama.chat_models import ChatOllama
@ -62,7 +63,7 @@ class OllamaLanguageModel(LanguageModel):
base_url=self.base_url,
# keep_alive="10m",
num_predict=self.max_tokens,
temperature=self.temperature,
temperature=self.temperature or 0.5,
verbose=True,
top_p=self.top_p,
)
@ -90,6 +91,7 @@ class VertexAnthropicLanguageModel(LanguageModel):
streaming=False,
kwargs=self.kwargs,
top_p=self.top_p,
temperature=self.temperature or 0.5,
)
@ -136,6 +138,26 @@ class VertexAILanguageModel(LanguageModel):
location=self.location,
project=self.project,
safety_settings=None,
temperature=self.temperature or 0.5,
)
@dataclass
class GeminiLanguageModel(LanguageModel):
"""
Language model that uses the Gemini Family of chat models.
"""
model_name: str
def to_langchain(self) -> ChatGoogleGenerativeAI:
"""
Convert the language model to a LangChain chat model.
"""
return ChatGoogleGenerativeAI(
model=self.model_name,
max_tokens=self.max_tokens,
temperature=self.temperature or 0.5,
)
@ -188,6 +210,7 @@ class AnthropicLanguageModel(LanguageModel):
streaming=False,
timeout=30,
top_p=self.top_p,
temperature=self.temperature or 0.5,
)