refresh the token when expired

This commit is contained in:
CREDO23 2025-08-21 01:09:13 +02:00
parent b0b6df0971
commit 9711af2b72
3 changed files with 27 additions and 26 deletions

View file

@ -5,24 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials.
"""
import base64
import json
import re
from typing import Any
from fastapi import Depends
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
from app.users import current_active_user
class GoogleGmailConnector:
@ -31,6 +28,8 @@ class GoogleGmailConnector:
def __init__(
self,
credentials: Credentials,
session: AsyncSession,
user_id: str,
):
"""
Initialize the GoogleGmailConnector class.
@ -38,12 +37,12 @@ class GoogleGmailConnector:
credentials: Google OAuth Credentials object
"""
self._credentials = credentials
self._session = session
self._user_id = user_id
self.service = None
async def _get_credentials(
self,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> Credentials:
"""
Get valid Google OAuth credentials.
@ -75,30 +74,30 @@ class GoogleGmailConnector:
client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret,
scopes=self._credentials.scopes,
expiry=self._credentials.expiry,
)
# Refresh the token if needed
if self._credentials.expired or not self._credentials.valid:
if self._credentials.expired:
try:
self._credentials.refresh(Request())
# Update the connector config in DB
connector = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GMAIL,
if self._session:
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == self._user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
)
)
)
connector = connector.scalars().first()
connector.config = GoogleAuthCredentialsBase(
self._credentials
).to_json()
session.add(connector)
await session.commit()
connector = result.scalars().first()
if connector is None:
raise RuntimeError(
"GMAIL connector not found for current user; cannot persist refreshed token."
)
connector.config = json.loads(self._credentials.to_json())
flag_modified(connector, "config")
await self._session.commit()
except Exception as e:
raise Exception(
f"Failed to refresh Google OAuth credentials: {e!s}"

View file

@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel):
client_secret: str
@property
def is_expired(self) -> bool:
def expired(self) -> bool:
"""Check if the credentials have expired."""
return self.expiry <= datetime.now(UTC)

View file

@ -95,6 +95,7 @@ async def index_google_gmail_messages(
# Create credentials from connector config
config_data = connector.config
exp = config_data.get("expiry").replace("Z", "")
credentials = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
@ -102,6 +103,7 @@ async def index_google_gmail_messages(
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp),
)
if (
@ -125,7 +127,7 @@ async def index_google_gmail_messages(
)
# Initialize Google gmail connector
gmail_connector = GoogleGmailConnector(credentials)
gmail_connector = GoogleGmailConnector(credentials, session, user_id)
# Fetch recent Google gmail messages
logger.info(f"Fetching recent emails for connector {connector_id}")