From b0b6df09718c1415c2b047ae005a31a81390c8b5 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 20 Aug 2025 20:27:00 +0200 Subject: [PATCH] updated the connector config after refreshing the token --- .../app/connectors/google_gmail_connector.py | 62 +++++++++++++++---- .../app/tasks/connector_indexers/base.py | 21 +++++++ .../google_gmail_indexer.py | 2 +- 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/surfsense_backend/app/connectors/google_gmail_connector.py b/surfsense_backend/app/connectors/google_gmail_connector.py index 0e75080..0b0bf26 100644 --- a/surfsense_backend/app/connectors/google_gmail_connector.py +++ b/surfsense_backend/app/connectors/google_gmail_connector.py @@ -8,9 +8,21 @@ import base64 import re from typing import Any +from fastapi import Depends from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase +from app.users import current_active_user class GoogleGmailConnector: @@ -28,7 +40,11 @@ class GoogleGmailConnector: self._credentials = credentials self.service = None - def _get_credentials(self) -> Credentials: + async def _get_credentials( + self, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), + ) -> Credentials: """ Get valid Google OAuth credentials. Returns: @@ -65,6 +81,24 @@ class GoogleGmailConnector: if self._credentials.expired or not self._credentials.valid: try: self._credentials.refresh(Request()) + # Update the connector config in DB + + connector = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.GMAIL, + ) + ) + + connector = connector.scalars().first() + + connector.config = GoogleAuthCredentialsBase( + self._credentials + ).to_json() + session.add(connector) + await session.commit() + except Exception as e: raise Exception( f"Failed to refresh Google OAuth credentials: {e!s}" @@ -72,7 +106,7 @@ class GoogleGmailConnector: return self._credentials - def _get_service(self): + async def _get_service(self): """ Get the Gmail service instance using Google OAuth credentials. Returns: @@ -85,20 +119,20 @@ class GoogleGmailConnector: return self.service try: - credentials = self._get_credentials() + credentials = await self._get_credentials() self.service = build("gmail", "v1", credentials=credentials) return self.service except Exception as e: raise Exception(f"Failed to create Gmail service: {e!s}") from e - def get_user_profile(self) -> tuple[dict[str, Any], str | None]: + async def get_user_profile(self) -> tuple[dict[str, Any], str | None]: """ Fetch user's Gmail profile information. Returns: Tuple containing (profile dict, error message or None) """ try: - service = self._get_service() + service = await self._get_service() profile = service.users().getProfile(userId="me").execute() return { @@ -111,7 +145,7 @@ class GoogleGmailConnector: except Exception as e: return {}, f"Error fetching user profile: {e!s}" - def get_messages_list( + async def get_messages_list( self, max_results: int = 100, query: str = "", @@ -129,7 +163,7 @@ class GoogleGmailConnector: Tuple containing (messages list, error message or None) """ try: - service = self._get_service() + service = await self._get_service() # Build request parameters request_params = { @@ -152,7 +186,9 @@ class GoogleGmailConnector: except Exception as e: return [], f"Error fetching messages list: {e!s}" - def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]: + async def get_message_details( + self, message_id: str + ) -> tuple[dict[str, Any], str | None]: """ Fetch detailed information for a specific message. Args: @@ -161,7 +197,7 @@ class GoogleGmailConnector: Tuple containing (message details dict, error message or None) """ try: - service = self._get_service() + service = await self._get_service() # Get full message details message = ( @@ -176,7 +212,7 @@ class GoogleGmailConnector: except Exception as e: return {}, f"Error fetching message details: {e!s}" - def get_recent_messages( + async def get_recent_messages( self, max_results: int = 50, days_back: int = 30, @@ -198,7 +234,7 @@ class GoogleGmailConnector: query = f"after:{date_query}" # Get messages list - messages_list, error = self.get_messages_list( + messages_list, error = await self.get_messages_list( max_results=max_results, query=query ) @@ -208,7 +244,9 @@ class GoogleGmailConnector: # Get detailed information for each message detailed_messages = [] for msg in messages_list: - message_details, detail_error = self.get_message_details(msg["id"]) + message_details, detail_error = await self.get_message_details( + msg["id"] + ) if detail_error: continue # Skip messages that can't be fetched detailed_messages.append(message_details) diff --git a/surfsense_backend/app/tasks/connector_indexers/base.py b/surfsense_backend/app/tasks/connector_indexers/base.py index 28cd206..fb639f9 100644 --- a/surfsense_backend/app/tasks/connector_indexers/base.py +++ b/surfsense_backend/app/tasks/connector_indexers/base.py @@ -60,6 +60,27 @@ async def get_connector_by_id( return result.scalars().first() +async def get_connector_by_type( + session: AsyncSession, connector_type: SearchSourceConnectorType +) -> SearchSourceConnector | None: + """ + Get a connector by type from the database. + + Args: + session: Database session + connector_type: Type of the connector + + Returns: + Connector object if found, None otherwise + """ + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type == connector_type, + ) + ) + return result.scalars().first() + + def calculate_date_range( connector: SearchSourceConnector, start_date: str | None = None, diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index ac85afd..2c00fc4 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -129,7 +129,7 @@ async def index_google_gmail_messages( # Fetch recent Google gmail messages logger.info(f"Fetching recent emails for connector {connector_id}") - messages, error = gmail_connector.get_recent_messages( + messages, error = await gmail_connector.get_recent_messages( max_results=max_messages, days_back=days_back )