diff --git a/surfsense_backend/app/connectors/google_calendar_connector.py b/surfsense_backend/app/connectors/google_calendar_connector.py index 3d7ca2e..2f8846c 100644 --- a/surfsense_backend/app/connectors/google_calendar_connector.py +++ b/surfsense_backend/app/connectors/google_calendar_connector.py @@ -4,6 +4,7 @@ A module for retrieving calendar events from Google Calendar using Google OAuth Allows fetching events from specified calendars within date ranges using Google OAuth credentials. """ +import json from datetime import datetime from typing import Any @@ -12,6 +13,14 @@ from dateutil.parser import isoparse from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm.attributes import flag_modified + +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, +) class GoogleCalendarConnector: @@ -20,6 +29,8 @@ class GoogleCalendarConnector: def __init__( self, credentials: Credentials, + session: AsyncSession, + user_id: str, ): """ Initialize the GoogleCalendarConnector class. @@ -27,9 +38,13 @@ class GoogleCalendarConnector: credentials: Google OAuth Credentials object """ self._credentials = credentials + self._session = session + self._user_id = user_id self.service = None - def _get_credentials(self) -> Credentials: + async def _get_credentials( + self, + ) -> Credentials: """ Get valid Google OAuth credentials. Returns: @@ -60,12 +75,30 @@ class GoogleCalendarConnector: client_id=self._credentials.client_id, client_secret=self._credentials.client_secret, scopes=self._credentials.scopes, + expiry=self._credentials.expiry, ) # Refresh the token if needed if self._credentials.expired or not self._credentials.valid: try: self._credentials.refresh(Request()) + # Update the connector config in DB + if self._session: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.user_id == self._user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + ) + ) + connector = result.scalars().first() + if connector is None: + raise RuntimeError( + "GOOGLE_CALENDAR_CONNECTOR connector not found for current user; cannot persist refreshed token." + ) + connector.config = json.loads(self._credentials.to_json()) + flag_modified(connector, "config") + await self._session.commit() except Exception as e: raise Exception( f"Failed to refresh Google OAuth credentials: {e!s}" @@ -73,7 +106,7 @@ class GoogleCalendarConnector: return self._credentials - def _get_service(self): + async def _get_service(self): """ Get the Google Calendar service instance using Google OAuth credentials. Returns: @@ -86,20 +119,20 @@ class GoogleCalendarConnector: return self.service try: - credentials = self._get_credentials() + credentials = await self._get_credentials() self.service = build("calendar", "v3", credentials=credentials) return self.service except Exception as e: raise Exception(f"Failed to create Google Calendar service: {e!s}") from e - def get_calendars(self) -> tuple[list[dict[str, Any]], str | None]: + async def get_calendars(self) -> tuple[list[dict[str, Any]], str | None]: """ Fetch list of user's calendars using Google OAuth credentials. Returns: Tuple containing (calendars list, error message or None) """ try: - service = self._get_service() + service = await self._get_service() calendars_result = service.calendarList().list().execute() calendars = calendars_result.get("items", []) @@ -122,7 +155,7 @@ class GoogleCalendarConnector: except Exception as e: return [], f"Error fetching calendars: {e!s}" - def get_all_primary_calendar_events( + async def get_all_primary_calendar_events( self, start_date: str, end_date: str, @@ -136,7 +169,7 @@ class GoogleCalendarConnector: Tuple containing (events list, error message or None) """ try: - service = self._get_service() + service = await self._get_service() # Parse both dates dt_start = isoparse(start_date) diff --git a/surfsense_backend/app/connectors/google_gmail_connector.py b/surfsense_backend/app/connectors/google_gmail_connector.py index 0e75080..d012ade 100644 --- a/surfsense_backend/app/connectors/google_gmail_connector.py +++ b/surfsense_backend/app/connectors/google_gmail_connector.py @@ -5,12 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials. """ import base64 +import json import re from typing import Any from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm.attributes import flag_modified + +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, +) class GoogleGmailConnector: @@ -19,6 +28,8 @@ class GoogleGmailConnector: def __init__( self, credentials: Credentials, + session: AsyncSession, + user_id: str, ): """ Initialize the GoogleGmailConnector class. @@ -26,9 +37,13 @@ class GoogleGmailConnector: credentials: Google OAuth Credentials object """ self._credentials = credentials + self._session = session + self._user_id = user_id self.service = None - def _get_credentials(self) -> Credentials: + async def _get_credentials( + self, + ) -> Credentials: """ Get valid Google OAuth credentials. Returns: @@ -59,12 +74,30 @@ class GoogleGmailConnector: client_id=self._credentials.client_id, client_secret=self._credentials.client_secret, scopes=self._credentials.scopes, + expiry=self._credentials.expiry, ) # Refresh the token if needed if self._credentials.expired or not self._credentials.valid: try: self._credentials.refresh(Request()) + # Update the connector config in DB + if self._session: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.user_id == self._user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + ) + ) + connector = result.scalars().first() + if connector is None: + raise RuntimeError( + "GMAIL connector not found for current user; cannot persist refreshed token." + ) + connector.config = json.loads(self._credentials.to_json()) + flag_modified(connector, "config") + await self._session.commit() except Exception as e: raise Exception( f"Failed to refresh Google OAuth credentials: {e!s}" @@ -72,7 +105,7 @@ class GoogleGmailConnector: return self._credentials - def _get_service(self): + async def _get_service(self): """ Get the Gmail service instance using Google OAuth credentials. Returns: @@ -85,20 +118,20 @@ class GoogleGmailConnector: return self.service try: - credentials = self._get_credentials() + credentials = await self._get_credentials() self.service = build("gmail", "v1", credentials=credentials) return self.service except Exception as e: raise Exception(f"Failed to create Gmail service: {e!s}") from e - def get_user_profile(self) -> tuple[dict[str, Any], str | None]: + async def get_user_profile(self) -> tuple[dict[str, Any], str | None]: """ Fetch user's Gmail profile information. Returns: Tuple containing (profile dict, error message or None) """ try: - service = self._get_service() + service = await self._get_service() profile = service.users().getProfile(userId="me").execute() return { @@ -111,7 +144,7 @@ class GoogleGmailConnector: except Exception as e: return {}, f"Error fetching user profile: {e!s}" - def get_messages_list( + async def get_messages_list( self, max_results: int = 100, query: str = "", @@ -129,7 +162,7 @@ class GoogleGmailConnector: Tuple containing (messages list, error message or None) """ try: - service = self._get_service() + service = await self._get_service() # Build request parameters request_params = { @@ -152,7 +185,9 @@ class GoogleGmailConnector: except Exception as e: return [], f"Error fetching messages list: {e!s}" - def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]: + async def get_message_details( + self, message_id: str + ) -> tuple[dict[str, Any], str | None]: """ Fetch detailed information for a specific message. Args: @@ -161,7 +196,7 @@ class GoogleGmailConnector: Tuple containing (message details dict, error message or None) """ try: - service = self._get_service() + service = await self._get_service() # Get full message details message = ( @@ -176,7 +211,7 @@ class GoogleGmailConnector: except Exception as e: return {}, f"Error fetching message details: {e!s}" - def get_recent_messages( + async def get_recent_messages( self, max_results: int = 50, days_back: int = 30, @@ -198,7 +233,7 @@ class GoogleGmailConnector: query = f"after:{date_query}" # Get messages list - messages_list, error = self.get_messages_list( + messages_list, error = await self.get_messages_list( max_results=max_results, query=query ) @@ -208,7 +243,9 @@ class GoogleGmailConnector: # Get detailed information for each message detailed_messages = [] for msg in messages_list: - message_details, detail_error = self.get_message_details(msg["id"]) + message_details, detail_error = await self.get_message_details( + msg["id"] + ) if detail_error: continue # Skip messages that can't be fetched detailed_messages.append(message_details) diff --git a/surfsense_backend/app/schemas/google_auth_credentials.py b/surfsense_backend/app/schemas/google_auth_credentials.py index 16e112e..f114aed 100644 --- a/surfsense_backend/app/schemas/google_auth_credentials.py +++ b/surfsense_backend/app/schemas/google_auth_credentials.py @@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel): client_secret: str @property - def is_expired(self) -> bool: + def expired(self) -> bool: """Check if the credentials have expired.""" return self.expiry <= datetime.now(UTC) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index d3ac800..abc4925 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -81,6 +81,7 @@ async def index_google_calendar_events( return 0, f"Connector with ID {connector_id} not found" # Get the Google Calendar credentials from the connector config + exp = connector.config.get("expiry").replace("Z", "") credentials = Credentials( token=connector.config.get("token"), refresh_token=connector.config.get("refresh_token"), @@ -88,6 +89,7 @@ async def index_google_calendar_events( client_id=connector.config.get("client_id"), client_secret=connector.config.get("client_secret"), scopes=connector.config.get("scopes"), + expiry=datetime.fromisoformat(exp), ) if ( @@ -110,7 +112,9 @@ async def index_google_calendar_events( {"stage": "client_initialization"}, ) - calendar_client = GoogleCalendarConnector(credentials=credentials) + calendar_client = GoogleCalendarConnector( + credentials=credentials, session=session, user_id=user_id + ) # Calculate date range if start_date is None or end_date is None: @@ -169,7 +173,7 @@ async def index_google_calendar_events( # Get events within date range from primary calendar try: - events, error = calendar_client.get_all_primary_calendar_events( + events, error = await calendar_client.get_all_primary_calendar_events( start_date=start_date_str, end_date=end_date_str ) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index ac85afd..531cfca 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -95,6 +95,7 @@ async def index_google_gmail_messages( # Create credentials from connector config config_data = connector.config + exp = config_data.get("expiry").replace("Z", "") credentials = Credentials( token=config_data.get("token"), refresh_token=config_data.get("refresh_token"), @@ -102,6 +103,7 @@ async def index_google_gmail_messages( client_id=config_data.get("client_id"), client_secret=config_data.get("client_secret"), scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp), ) if ( @@ -125,11 +127,11 @@ async def index_google_gmail_messages( ) # Initialize Google gmail connector - gmail_connector = GoogleGmailConnector(credentials) + gmail_connector = GoogleGmailConnector(credentials, session, user_id) # Fetch recent Google gmail messages logger.info(f"Fetching recent emails for connector {connector_id}") - messages, error = gmail_connector.get_recent_messages( + messages, error = await gmail_connector.get_recent_messages( max_results=max_messages, days_back=days_back )