updated the connector config after refreshing the token

This commit is contained in:
CREDO23 2025-08-20 20:27:00 +02:00
parent 3d93fe8186
commit b0b6df0971
3 changed files with 72 additions and 13 deletions

View file

@ -8,9 +8,21 @@ import base64
import re import re
from typing import Any from typing import Any
from fastapi import Depends
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
from app.users import current_active_user
class GoogleGmailConnector: class GoogleGmailConnector:
@ -28,7 +40,11 @@ class GoogleGmailConnector:
self._credentials = credentials self._credentials = credentials
self.service = None self.service = None
def _get_credentials(self) -> Credentials: async def _get_credentials(
self,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> Credentials:
""" """
Get valid Google OAuth credentials. Get valid Google OAuth credentials.
Returns: Returns:
@ -65,6 +81,24 @@ class GoogleGmailConnector:
if self._credentials.expired or not self._credentials.valid: if self._credentials.expired or not self._credentials.valid:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
# Update the connector config in DB
connector = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GMAIL,
)
)
connector = connector.scalars().first()
connector.config = GoogleAuthCredentialsBase(
self._credentials
).to_json()
session.add(connector)
await session.commit()
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Failed to refresh Google OAuth credentials: {e!s}" f"Failed to refresh Google OAuth credentials: {e!s}"
@ -72,7 +106,7 @@ class GoogleGmailConnector:
return self._credentials return self._credentials
def _get_service(self): async def _get_service(self):
""" """
Get the Gmail service instance using Google OAuth credentials. Get the Gmail service instance using Google OAuth credentials.
Returns: Returns:
@ -85,20 +119,20 @@ class GoogleGmailConnector:
return self.service return self.service
try: try:
credentials = self._get_credentials() credentials = await self._get_credentials()
self.service = build("gmail", "v1", credentials=credentials) self.service = build("gmail", "v1", credentials=credentials)
return self.service return self.service
except Exception as e: except Exception as e:
raise Exception(f"Failed to create Gmail service: {e!s}") from e raise Exception(f"Failed to create Gmail service: {e!s}") from e
def get_user_profile(self) -> tuple[dict[str, Any], str | None]: async def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
""" """
Fetch user's Gmail profile information. Fetch user's Gmail profile information.
Returns: Returns:
Tuple containing (profile dict, error message or None) Tuple containing (profile dict, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
profile = service.users().getProfile(userId="me").execute() profile = service.users().getProfile(userId="me").execute()
return { return {
@ -111,7 +145,7 @@ class GoogleGmailConnector:
except Exception as e: except Exception as e:
return {}, f"Error fetching user profile: {e!s}" return {}, f"Error fetching user profile: {e!s}"
def get_messages_list( async def get_messages_list(
self, self,
max_results: int = 100, max_results: int = 100,
query: str = "", query: str = "",
@ -129,7 +163,7 @@ class GoogleGmailConnector:
Tuple containing (messages list, error message or None) Tuple containing (messages list, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
# Build request parameters # Build request parameters
request_params = { request_params = {
@ -152,7 +186,9 @@ class GoogleGmailConnector:
except Exception as e: except Exception as e:
return [], f"Error fetching messages list: {e!s}" return [], f"Error fetching messages list: {e!s}"
def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]: async def get_message_details(
self, message_id: str
) -> tuple[dict[str, Any], str | None]:
""" """
Fetch detailed information for a specific message. Fetch detailed information for a specific message.
Args: Args:
@ -161,7 +197,7 @@ class GoogleGmailConnector:
Tuple containing (message details dict, error message or None) Tuple containing (message details dict, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
# Get full message details # Get full message details
message = ( message = (
@ -176,7 +212,7 @@ class GoogleGmailConnector:
except Exception as e: except Exception as e:
return {}, f"Error fetching message details: {e!s}" return {}, f"Error fetching message details: {e!s}"
def get_recent_messages( async def get_recent_messages(
self, self,
max_results: int = 50, max_results: int = 50,
days_back: int = 30, days_back: int = 30,
@ -198,7 +234,7 @@ class GoogleGmailConnector:
query = f"after:{date_query}" query = f"after:{date_query}"
# Get messages list # Get messages list
messages_list, error = self.get_messages_list( messages_list, error = await self.get_messages_list(
max_results=max_results, query=query max_results=max_results, query=query
) )
@ -208,7 +244,9 @@ class GoogleGmailConnector:
# Get detailed information for each message # Get detailed information for each message
detailed_messages = [] detailed_messages = []
for msg in messages_list: for msg in messages_list:
message_details, detail_error = self.get_message_details(msg["id"]) message_details, detail_error = await self.get_message_details(
msg["id"]
)
if detail_error: if detail_error:
continue # Skip messages that can't be fetched continue # Skip messages that can't be fetched
detailed_messages.append(message_details) detailed_messages.append(message_details)

View file

@ -60,6 +60,27 @@ async def get_connector_by_id(
return result.scalars().first() return result.scalars().first()
async def get_connector_by_type(
session: AsyncSession, connector_type: SearchSourceConnectorType
) -> SearchSourceConnector | None:
"""
Get a connector by type from the database.
Args:
session: Database session
connector_type: Type of the connector
Returns:
Connector object if found, None otherwise
"""
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type == connector_type,
)
)
return result.scalars().first()
def calculate_date_range( def calculate_date_range(
connector: SearchSourceConnector, connector: SearchSourceConnector,
start_date: str | None = None, start_date: str | None = None,

View file

@ -129,7 +129,7 @@ async def index_google_gmail_messages(
# Fetch recent Google gmail messages # Fetch recent Google gmail messages
logger.info(f"Fetching recent emails for connector {connector_id}") logger.info(f"Fetching recent emails for connector {connector_id}")
messages, error = gmail_connector.get_recent_messages( messages, error = await gmail_connector.get_recent_messages(
max_results=max_messages, days_back=days_back max_results=max_messages, days_back=days_back
) )