mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
refresh the token when expired
This commit is contained in:
parent
b0b6df0971
commit
9711af2b72
3 changed files with 27 additions and 26 deletions
|
@ -5,24 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import Depends
|
|
||||||
from google.auth.transport.requests import Request
|
from google.auth.transport.requests import Request
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.db import (
|
from app.db import (
|
||||||
SearchSourceConnector,
|
SearchSourceConnector,
|
||||||
SearchSourceConnectorType,
|
SearchSourceConnectorType,
|
||||||
User,
|
|
||||||
get_async_session,
|
|
||||||
)
|
)
|
||||||
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
|
|
||||||
from app.users import current_active_user
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleGmailConnector:
|
class GoogleGmailConnector:
|
||||||
|
@ -31,6 +28,8 @@ class GoogleGmailConnector:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
credentials: Credentials,
|
credentials: Credentials,
|
||||||
|
session: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the GoogleGmailConnector class.
|
Initialize the GoogleGmailConnector class.
|
||||||
|
@ -38,12 +37,12 @@ class GoogleGmailConnector:
|
||||||
credentials: Google OAuth Credentials object
|
credentials: Google OAuth Credentials object
|
||||||
"""
|
"""
|
||||||
self._credentials = credentials
|
self._credentials = credentials
|
||||||
|
self._session = session
|
||||||
|
self._user_id = user_id
|
||||||
self.service = None
|
self.service = None
|
||||||
|
|
||||||
async def _get_credentials(
|
async def _get_credentials(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""
|
"""
|
||||||
Get valid Google OAuth credentials.
|
Get valid Google OAuth credentials.
|
||||||
|
@ -75,30 +74,30 @@ class GoogleGmailConnector:
|
||||||
client_id=self._credentials.client_id,
|
client_id=self._credentials.client_id,
|
||||||
client_secret=self._credentials.client_secret,
|
client_secret=self._credentials.client_secret,
|
||||||
scopes=self._credentials.scopes,
|
scopes=self._credentials.scopes,
|
||||||
|
expiry=self._credentials.expiry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Refresh the token if needed
|
# Refresh the token if needed
|
||||||
if self._credentials.expired or not self._credentials.valid:
|
if self._credentials.expired:
|
||||||
try:
|
try:
|
||||||
self._credentials.refresh(Request())
|
self._credentials.refresh(Request())
|
||||||
# Update the connector config in DB
|
# Update the connector config in DB
|
||||||
|
if self._session:
|
||||||
connector = await session.execute(
|
result = await self._session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.user_id == user.id,
|
SearchSourceConnector.user_id == self._user_id,
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
== SearchSourceConnectorType.GMAIL,
|
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
connector = result.scalars().first()
|
||||||
|
if connector is None:
|
||||||
connector = connector.scalars().first()
|
raise RuntimeError(
|
||||||
|
"GMAIL connector not found for current user; cannot persist refreshed token."
|
||||||
connector.config = GoogleAuthCredentialsBase(
|
)
|
||||||
self._credentials
|
connector.config = json.loads(self._credentials.to_json())
|
||||||
).to_json()
|
flag_modified(connector, "config")
|
||||||
session.add(connector)
|
await self._session.commit()
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to refresh Google OAuth credentials: {e!s}"
|
f"Failed to refresh Google OAuth credentials: {e!s}"
|
||||||
|
|
|
@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel):
|
||||||
client_secret: str
|
client_secret: str
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_expired(self) -> bool:
|
def expired(self) -> bool:
|
||||||
"""Check if the credentials have expired."""
|
"""Check if the credentials have expired."""
|
||||||
return self.expiry <= datetime.now(UTC)
|
return self.expiry <= datetime.now(UTC)
|
||||||
|
|
|
@ -95,6 +95,7 @@ async def index_google_gmail_messages(
|
||||||
|
|
||||||
# Create credentials from connector config
|
# Create credentials from connector config
|
||||||
config_data = connector.config
|
config_data = connector.config
|
||||||
|
exp = config_data.get("expiry").replace("Z", "")
|
||||||
credentials = Credentials(
|
credentials = Credentials(
|
||||||
token=config_data.get("token"),
|
token=config_data.get("token"),
|
||||||
refresh_token=config_data.get("refresh_token"),
|
refresh_token=config_data.get("refresh_token"),
|
||||||
|
@ -102,6 +103,7 @@ async def index_google_gmail_messages(
|
||||||
client_id=config_data.get("client_id"),
|
client_id=config_data.get("client_id"),
|
||||||
client_secret=config_data.get("client_secret"),
|
client_secret=config_data.get("client_secret"),
|
||||||
scopes=config_data.get("scopes", []),
|
scopes=config_data.get("scopes", []),
|
||||||
|
expiry=datetime.fromisoformat(exp),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -125,7 +127,7 @@ async def index_google_gmail_messages(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Google gmail connector
|
# Initialize Google gmail connector
|
||||||
gmail_connector = GoogleGmailConnector(credentials)
|
gmail_connector = GoogleGmailConnector(credentials, session, user_id)
|
||||||
|
|
||||||
# Fetch recent Google gmail messages
|
# Fetch recent Google gmail messages
|
||||||
logger.info(f"Fetching recent emails for connector {connector_id}")
|
logger.info(f"Fetching recent emails for connector {connector_id}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue