open-notebook/open_notebook/models/llms.py
2024-11-13 15:52:44 -03:00

305 lines
8.3 KiB
Python

"""
Classes for supporting different language models
"""
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from langchain_anthropic import ChatAnthropic
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
from langchain_groq.chat_models import ChatGroq
from langchain_ollama.chat_models import ChatOllama
from langchain_openai.chat_models import ChatOpenAI
from pydantic import SecretStr
# future: is there a value on returning langchain specific models?
@dataclass
class LanguageModel(ABC):
"""
Abstract base class for language models.
"""
model_name: Optional[str] = None
max_tokens: Optional[int] = 850
temperature: Optional[float] = 1.0
streaming: bool = True
top_p: Optional[float] = 0.9
kwargs: Dict[str, Any] = field(default_factory=dict)
json: bool = False
@abstractmethod
def to_langchain(self) -> BaseChatModel:
"""
Convert the language model to a LangChain chat model.
"""
raise NotImplementedError
@dataclass
class OllamaLanguageModel(LanguageModel):
"""
Language model that uses the Ollama chat model.
"""
model_name: str
base_url: str = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
max_tokens: Optional[int] = 650
json: bool = False
def to_langchain(self) -> ChatOllama:
"""
Convert the language model to a LangChain chat model.
"""
return ChatOllama(
# api_key="ollama",
model=self.model_name,
base_url=self.base_url,
# keep_alive="10m",
num_predict=self.max_tokens,
temperature=self.temperature or 0.5,
verbose=True,
top_p=self.top_p,
)
@dataclass
class VertexAnthropicLanguageModel(LanguageModel):
"""
Language model that uses the Vertex Anthropic chat model.
"""
model_name: str
project: Optional[str] = os.environ.get("VERTEX_PROJECT", "no-project")
location: Optional[str] = os.environ.get("VERTEX_LOCATION", "us-central1")
def to_langchain(self) -> ChatAnthropicVertex:
"""
Convert the language model to a LangChain chat model.
"""
return ChatAnthropicVertex(
model=self.model_name,
project=self.project,
location=self.location,
max_tokens=self.max_tokens,
streaming=False,
kwargs=self.kwargs,
top_p=self.top_p,
temperature=self.temperature or 0.5,
)
@dataclass
class LiteLLMLanguageModel(LanguageModel):
"""
Language model that uses the LiteLLM chat model.
"""
model_name: str
def to_langchain(self) -> ChatLiteLLM:
"""
Convert the language model to a LangChain chat model.
"""
return ChatLiteLLM(
model=self.model_name,
temperature=self.temperature or 0.5,
max_tokens=self.max_tokens,
streaming=self.streaming,
top_p=self.top_p,
)
@dataclass
class VertexAILanguageModel(LanguageModel):
"""
Language model that uses the Vertex AI chat model.
"""
model_name: str
project: Optional[str] = os.environ.get("VERTEX_PROJECT", "no-project")
location: Optional[str] = os.environ.get("VERTEX_LOCATION", "us-central1")
def to_langchain(self) -> ChatVertexAI:
"""
Convert the language model to a LangChain chat model.
"""
return ChatVertexAI(
model=self.model_name,
streaming=self.streaming,
max_tokens=self.max_tokens,
top_p=self.top_p,
location=self.location,
project=self.project,
safety_settings=None,
temperature=self.temperature or 0.5,
)
@dataclass
class GeminiLanguageModel(LanguageModel):
"""
Language model that uses the Gemini Family of chat models.
"""
model_name: str
def to_langchain(self) -> ChatGoogleGenerativeAI:
"""
Convert the language model to a LangChain chat model.
"""
return ChatGoogleGenerativeAI(
model=self.model_name,
max_tokens=self.max_tokens,
temperature=self.temperature or 0.5,
)
@dataclass
class OpenRouterLanguageModel(LanguageModel):
"""
Language model that uses the OpenAI chat model.
"""
model_name: str
def to_langchain(self) -> ChatOpenAI:
"""
Convert the language model to a LangChain chat model for Open Router.
"""
kwargs = self.kwargs
if self.json:
kwargs["response_format"] = {"type": "json_object"}
return ChatOpenAI(
model=self.model_name,
temperature=self.temperature or 0.5,
base_url=os.environ.get(
"OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"
),
max_tokens=self.max_tokens,
model_kwargs=kwargs,
streaming=self.streaming,
api_key=SecretStr(os.environ.get("OPENROUTER_API_KEY", "openrouter")),
top_p=self.top_p,
)
@dataclass
class GroqLanguageModel(LanguageModel):
"""
Language model that uses the Groq chat model.
"""
model_name: str
def to_langchain(self) -> ChatGroq:
"""
Convert the language model to a LangChain chat model for Groq.
"""
kwargs = self.kwargs
kwargs["top_p"] = self.top_p
return ChatGroq(
model=self.model_name,
temperature=self.temperature or 0.5,
max_tokens=self.max_tokens,
model_kwargs=kwargs,
stop_sequences=None,
)
@dataclass
class XAILanguageModel(LanguageModel):
"""
Language model that uses the OpenAI chat model for X.AI.
"""
model_name: str
def to_langchain(self) -> ChatOpenAI:
"""
Convert the language model to a LangChain chat model.
"""
kwargs = self.kwargs
if self.json:
kwargs["response_format"] = {"type": "json_object"}
return ChatOpenAI(
model=self.model_name,
temperature=self.temperature or 0.5,
base_url=os.environ.get("XAI_BASE_URL", "https://api.x.ai/v1"),
max_tokens=self.max_tokens,
model_kwargs=kwargs,
streaming=self.streaming,
api_key=SecretStr(os.environ.get("XAI_API_KEY", "xai")),
top_p=self.top_p,
)
@dataclass
class AnthropicLanguageModel(LanguageModel):
"""
Language model that uses the Anthropic chat model.
"""
model_name: str
def to_langchain(self) -> ChatAnthropic:
"""
Convert the language model to a LangChain chat model.
"""
return ChatAnthropic( # type: ignore[call-arg]
model_name=self.model_name,
max_tokens_to_sample=self.max_tokens or 850,
model_kwargs=self.kwargs,
streaming=False,
timeout=30,
top_p=self.top_p,
temperature=self.temperature or 0.5,
)
@dataclass
class OpenAILanguageModel(LanguageModel):
"""
Language model that uses the OpenAI chat model.
"""
model_name: str
def to_langchain(self) -> ChatOpenAI:
"""
Convert the language model to a LangChain chat model.
"""
kwargs = self.kwargs.copy() # Make a copy to avoid modifying the original
if self.json:
kwargs["response_format"] = {"type": "json_object"}
# Set the token limit in kwargs with the appropriate key
if self.model_name in ["o1-mini", "o1-preview"]:
kwargs["max_completion_tokens"] = self.max_tokens
top_p = 1
streaming = False
max_tokens = None
else:
max_tokens = self.max_tokens
top_p = self.top_p
streaming = self.streaming
return ChatOpenAI(
model=self.model_name,
temperature=self.temperature,
streaming=streaming,
max_tokens=max_tokens,
top_p=top_p,
model_kwargs=kwargs,
)