add x.ai support

This commit is contained in:
LUIS NOVO 2024-11-11 16:49:50 -03:00
parent 2e2a4947b3
commit ac2ea9e554
3 changed files with 32 additions and 1 deletions

View file

@ -17,6 +17,7 @@ from open_notebook.models.llms import (
OpenRouterLanguageModel,
VertexAILanguageModel,
VertexAnthropicLanguageModel,
XAILanguageModel,
)
from open_notebook.models.speech_to_text_models import (
OpenAISpeechToTextModel,
@ -44,6 +45,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = {
"anthropic": AnthropicLanguageModel,
"openai": OpenAILanguageModel,
"gemini": GeminiLanguageModel,
"xai": XAILanguageModel,
},
"embedding": {
"openai": OpenAIEmbeddingModel,

View file

@ -171,7 +171,7 @@ class OpenRouterLanguageModel(LanguageModel):
def to_langchain(self) -> ChatOpenAI:
"""
Convert the language model to a LangChain chat model.
Convert the language model to a LangChain chat model for Open Router.
"""
kwargs = self.kwargs
if self.json:
@ -191,6 +191,34 @@ class OpenRouterLanguageModel(LanguageModel):
)
@dataclass
class XAILanguageModel(LanguageModel):
"""
Language model that uses the OpenAI chat model for X.AI.
"""
model_name: str
def to_langchain(self) -> ChatOpenAI:
"""
Convert the language model to a LangChain chat model.
"""
kwargs = self.kwargs
if self.json:
kwargs["response_format"] = {"type": "json_object"}
return ChatOpenAI(
model=self.model_name,
temperature=self.temperature or 0.5,
base_url=os.environ.get("XAI_BASE_URL", "https://api.x.ai/v1"),
max_tokens=self.max_tokens,
model_kwargs=kwargs,
streaming=self.streaming,
api_key=SecretStr(os.environ.get("XAI_API_KEY", "xai")),
top_p=self.top_p,
)
@dataclass
class AnthropicLanguageModel(LanguageModel):
"""