diff --git a/surfsense_backend/app/connectors/google_gmail_connector.py b/surfsense_backend/app/connectors/google_gmail_connector.py index 0b0bf26..f51f4c1 100644 --- a/surfsense_backend/app/connectors/google_gmail_connector.py +++ b/surfsense_backend/app/connectors/google_gmail_connector.py @@ -5,24 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials. """ import base64 +import json import re from typing import Any -from fastapi import Depends from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm.attributes import flag_modified from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, - get_async_session, ) -from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase -from app.users import current_active_user class GoogleGmailConnector: @@ -31,6 +28,8 @@ class GoogleGmailConnector: def __init__( self, credentials: Credentials, + session: AsyncSession, + user_id: str, ): """ Initialize the GoogleGmailConnector class. @@ -38,12 +37,12 @@ class GoogleGmailConnector: credentials: Google OAuth Credentials object """ self._credentials = credentials + self._session = session + self._user_id = user_id self.service = None async def _get_credentials( self, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), ) -> Credentials: """ Get valid Google OAuth credentials. @@ -75,30 +74,30 @@ class GoogleGmailConnector: client_id=self._credentials.client_id, client_secret=self._credentials.client_secret, scopes=self._credentials.scopes, + expiry=self._credentials.expiry, ) # Refresh the token if needed - if self._credentials.expired or not self._credentials.valid: + if self._credentials.expired: try: self._credentials.refresh(Request()) # Update the connector config in DB - - connector = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.user_id == user.id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.GMAIL, + if self._session: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.user_id == self._user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + ) ) - ) - - connector = connector.scalars().first() - - connector.config = GoogleAuthCredentialsBase( - self._credentials - ).to_json() - session.add(connector) - await session.commit() - + connector = result.scalars().first() + if connector is None: + raise RuntimeError( + "GMAIL connector not found for current user; cannot persist refreshed token." + ) + connector.config = json.loads(self._credentials.to_json()) + flag_modified(connector, "config") + await self._session.commit() except Exception as e: raise Exception( f"Failed to refresh Google OAuth credentials: {e!s}" diff --git a/surfsense_backend/app/schemas/google_auth_credentials.py b/surfsense_backend/app/schemas/google_auth_credentials.py index 16e112e..f114aed 100644 --- a/surfsense_backend/app/schemas/google_auth_credentials.py +++ b/surfsense_backend/app/schemas/google_auth_credentials.py @@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel): client_secret: str @property - def is_expired(self) -> bool: + def expired(self) -> bool: """Check if the credentials have expired.""" return self.expiry <= datetime.now(UTC) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index 2c00fc4..531cfca 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -95,6 +95,7 @@ async def index_google_gmail_messages( # Create credentials from connector config config_data = connector.config + exp = config_data.get("expiry").replace("Z", "") credentials = Credentials( token=config_data.get("token"), refresh_token=config_data.get("refresh_token"), @@ -102,6 +103,7 @@ async def index_google_gmail_messages( client_id=config_data.get("client_id"), client_secret=config_data.get("client_secret"), scopes=config_data.get("scopes", []), + expiry=datetime.fromisoformat(exp), ) if ( @@ -125,7 +127,7 @@ async def index_google_gmail_messages( ) # Initialize Google gmail connector - gmail_connector = GoogleGmailConnector(credentials) + gmail_connector = GoogleGmailConnector(credentials, session, user_id) # Fetch recent Google gmail messages logger.info(f"Fetching recent emails for connector {connector_id}")