refresh the token when expired

This commit is contained in:
CREDO23 2025-08-21 01:09:13 +02:00
parent b0b6df0971
commit 9711af2b72
3 changed files with 27 additions and 26 deletions

View file

@ -5,24 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials.
""" """
import base64 import base64
import json
import re import re
from typing import Any from typing import Any
from fastapi import Depends
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import ( from app.db import (
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
User,
get_async_session,
) )
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
from app.users import current_active_user
class GoogleGmailConnector: class GoogleGmailConnector:
@ -31,6 +28,8 @@ class GoogleGmailConnector:
def __init__( def __init__(
self, self,
credentials: Credentials, credentials: Credentials,
session: AsyncSession,
user_id: str,
): ):
""" """
Initialize the GoogleGmailConnector class. Initialize the GoogleGmailConnector class.
@ -38,12 +37,12 @@ class GoogleGmailConnector:
credentials: Google OAuth Credentials object credentials: Google OAuth Credentials object
""" """
self._credentials = credentials self._credentials = credentials
self._session = session
self._user_id = user_id
self.service = None self.service = None
async def _get_credentials( async def _get_credentials(
self, self,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> Credentials: ) -> Credentials:
""" """
Get valid Google OAuth credentials. Get valid Google OAuth credentials.
@ -75,30 +74,30 @@ class GoogleGmailConnector:
client_id=self._credentials.client_id, client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret, client_secret=self._credentials.client_secret,
scopes=self._credentials.scopes, scopes=self._credentials.scopes,
expiry=self._credentials.expiry,
) )
# Refresh the token if needed # Refresh the token if needed
if self._credentials.expired or not self._credentials.valid: if self._credentials.expired:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
# Update the connector config in DB # Update the connector config in DB
if self._session:
connector = await session.execute( result = await self._session.execute(
select(SearchSourceConnector).filter( select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user.id, SearchSourceConnector.user_id == self._user_id,
SearchSourceConnector.connector_type SearchSourceConnector.connector_type
== SearchSourceConnectorType.GMAIL, == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
)
) )
) connector = result.scalars().first()
if connector is None:
connector = connector.scalars().first() raise RuntimeError(
"GMAIL connector not found for current user; cannot persist refreshed token."
connector.config = GoogleAuthCredentialsBase( )
self._credentials connector.config = json.loads(self._credentials.to_json())
).to_json() flag_modified(connector, "config")
session.add(connector) await self._session.commit()
await session.commit()
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Failed to refresh Google OAuth credentials: {e!s}" f"Failed to refresh Google OAuth credentials: {e!s}"

View file

@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel):
client_secret: str client_secret: str
@property @property
def is_expired(self) -> bool: def expired(self) -> bool:
"""Check if the credentials have expired.""" """Check if the credentials have expired."""
return self.expiry <= datetime.now(UTC) return self.expiry <= datetime.now(UTC)

View file

@ -95,6 +95,7 @@ async def index_google_gmail_messages(
# Create credentials from connector config # Create credentials from connector config
config_data = connector.config config_data = connector.config
exp = config_data.get("expiry").replace("Z", "")
credentials = Credentials( credentials = Credentials(
token=config_data.get("token"), token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"), refresh_token=config_data.get("refresh_token"),
@ -102,6 +103,7 @@ async def index_google_gmail_messages(
client_id=config_data.get("client_id"), client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"), client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []), scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp),
) )
if ( if (
@ -125,7 +127,7 @@ async def index_google_gmail_messages(
) )
# Initialize Google gmail connector # Initialize Google gmail connector
gmail_connector = GoogleGmailConnector(credentials) gmail_connector = GoogleGmailConnector(credentials, session, user_id)
# Fetch recent Google gmail messages # Fetch recent Google gmail messages
logger.info(f"Fetching recent emails for connector {connector_id}") logger.info(f"Fetching recent emails for connector {connector_id}")