diff --git a/models.py b/models.py index 4b517f1e2..09d0705a6 100644 --- a/models.py +++ b/models.py @@ -20,7 +20,7 @@ from langchain_huggingface import ( HuggingFaceEndpoint, ) from langchain_google_genai import ( - GoogleGenerativeAI, + ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory, embeddings as google_embeddings, @@ -267,7 +267,7 @@ def get_google_chat( ): if not api_key: api_key = get_api_key("google") - return GoogleGenerativeAI(model=model_name, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE}, **kwargs) # type: ignore + return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE}, **kwargs) # type: ignore def get_google_embedding( @@ -277,7 +277,7 @@ def get_google_embedding( ): if not api_key: api_key = get_api_key("google") - return google_embeddings.GoogleGenerativeAIEmbeddings(model=model_name, api_key=api_key, **kwargs) # type: ignore + return google_embeddings.GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=api_key, **kwargs) # type: ignore # Mistral models