diff --git a/open_notebook/llm_router.py b/open_notebook/llm_router.py index 9fdb85d..9e1096e 100644 --- a/open_notebook/llm_router.py +++ b/open_notebook/llm_router.py @@ -1,5 +1,6 @@ from open_notebook.llms import ( AnthropicLanguageModel, + GeminiLanguageModel, LiteLLMLanguageModel, OllamaLanguageModel, OpenAILanguageModel, @@ -17,6 +18,7 @@ PROVIDER_CLASS_MAP = { "vertexai": VertexAILanguageModel, "anthropic": AnthropicLanguageModel, "openai": OpenAILanguageModel, + "gemini": GeminiLanguageModel, } diff --git a/open_notebook/llms.py b/open_notebook/llms.py index c6357d1..efcc431 100644 --- a/open_notebook/llms.py +++ b/open_notebook/llms.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Optional from langchain_anthropic import ChatAnthropic from langchain_community.chat_models import ChatLiteLLM from langchain_core.language_models.chat_models import BaseChatModel +from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_vertexai import ChatVertexAI from langchain_google_vertexai.model_garden import ChatAnthropicVertex from langchain_ollama.chat_models import ChatOllama @@ -62,7 +63,7 @@ class OllamaLanguageModel(LanguageModel): base_url=self.base_url, # keep_alive="10m", num_predict=self.max_tokens, - temperature=self.temperature, + temperature=self.temperature or 0.5, verbose=True, top_p=self.top_p, ) @@ -90,6 +91,7 @@ class VertexAnthropicLanguageModel(LanguageModel): streaming=False, kwargs=self.kwargs, top_p=self.top_p, + temperature=self.temperature or 0.5, ) @@ -136,6 +138,26 @@ class VertexAILanguageModel(LanguageModel): location=self.location, project=self.project, safety_settings=None, + temperature=self.temperature or 0.5, + ) + + +@dataclass +class GeminiLanguageModel(LanguageModel): + """ + Language model that uses the Gemini Family of chat models. + """ + + model_name: str + + def to_langchain(self) -> ChatGoogleGenerativeAI: + """ + Convert the language model to a LangChain chat model. + """ + return ChatGoogleGenerativeAI( + model=self.model_name, + max_tokens=self.max_tokens, + temperature=self.temperature or 0.5, ) @@ -188,6 +210,7 @@ class AnthropicLanguageModel(LanguageModel): streaming=False, timeout=30, top_p=self.top_p, + temperature=self.temperature or 0.5, ) diff --git a/poetry.lock b/poetry.lock index bf9b147..590313e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6063,4 +6063,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b92bbd2ce61e78ccc2e182627cf0ba5d98ccf849898e5e941d5d17e74a7827ab" +content-hash = "5f7bdea405c6c6433fa805b3321ac1550e13deee0d3a3c04e38136cd6992f5b1" diff --git a/pyproject.toml b/pyproject.toml index 3d30ed6..507708d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ langchain-anthropic = "^0.2.3" langchain-ollama = "^0.2.0" langchain-google-vertexai = "^2.0.5" sdblpy = "^0.3.0" +langchain-google-genai = "^2.0.1" podcastfy = "^0.2.8" [tool.poetry.group.dev.dependencies]