mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 01:59:06 +00:00
refresh the token when expired
This commit is contained in:
parent
b0b6df0971
commit
9711af2b72
3 changed files with 27 additions and 26 deletions
|
@ -5,24 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials.
|
|||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
|
||||
|
||||
class GoogleGmailConnector:
|
||||
|
@ -31,6 +28,8 @@ class GoogleGmailConnector:
|
|||
def __init__(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleGmailConnector class.
|
||||
|
@ -38,12 +37,12 @@ class GoogleGmailConnector:
|
|||
credentials: Google OAuth Credentials object
|
||||
"""
|
||||
self._credentials = credentials
|
||||
self._session = session
|
||||
self._user_id = user_id
|
||||
self.service = None
|
||||
|
||||
async def _get_credentials(
|
||||
self,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
|
@ -75,30 +74,30 @@ class GoogleGmailConnector:
|
|||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
if self._credentials.expired:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
|
||||
connector = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GMAIL,
|
||||
if self._session:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == self._user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
connector = connector.scalars().first()
|
||||
|
||||
connector.config = GoogleAuthCredentialsBase(
|
||||
self._credentials
|
||||
).to_json()
|
||||
session.add(connector)
|
||||
await session.commit()
|
||||
|
||||
connector = result.scalars().first()
|
||||
if connector is None:
|
||||
raise RuntimeError(
|
||||
"GMAIL connector not found for current user; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to refresh Google OAuth credentials: {e!s}"
|
||||
|
|
|
@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel):
|
|||
client_secret: str
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
def expired(self) -> bool:
|
||||
"""Check if the credentials have expired."""
|
||||
return self.expiry <= datetime.now(UTC)
|
||||
|
|
|
@ -95,6 +95,7 @@ async def index_google_gmail_messages(
|
|||
|
||||
# Create credentials from connector config
|
||||
config_data = connector.config
|
||||
exp = config_data.get("expiry").replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
|
@ -102,6 +103,7 @@ async def index_google_gmail_messages(
|
|||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp),
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -125,7 +127,7 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
|
||||
# Initialize Google gmail connector
|
||||
gmail_connector = GoogleGmailConnector(credentials)
|
||||
gmail_connector = GoogleGmailConnector(credentials, session, user_id)
|
||||
|
||||
# Fetch recent Google gmail messages
|
||||
logger.info(f"Fetching recent emails for connector {connector_id}")
|
||||
|
|
Loading…
Add table
Reference in a new issue