Merge pull request #298 from MODSetter/dev
Some checks are pending
pre-commit / pre-commit (push) Waiting to run

fix: auto refresh token for google based connectors
This commit is contained in:
Rohan Verma 2025-08-26 22:22:08 -07:00 committed by GitHub
commit a5bd1ebe4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 100 additions and 24 deletions

View file

@ -4,6 +4,7 @@ A module for retrieving calendar events from Google Calendar using Google OAuth
Allows fetching events from specified calendars within date ranges using Google OAuth credentials. Allows fetching events from specified calendars within date ranges using Google OAuth credentials.
""" """
import json
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@ -12,6 +13,14 @@ from dateutil.parser import isoparse
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
class GoogleCalendarConnector: class GoogleCalendarConnector:
@ -20,6 +29,8 @@ class GoogleCalendarConnector:
def __init__( def __init__(
self, self,
credentials: Credentials, credentials: Credentials,
session: AsyncSession,
user_id: str,
): ):
""" """
Initialize the GoogleCalendarConnector class. Initialize the GoogleCalendarConnector class.
@ -27,9 +38,13 @@ class GoogleCalendarConnector:
credentials: Google OAuth Credentials object credentials: Google OAuth Credentials object
""" """
self._credentials = credentials self._credentials = credentials
self._session = session
self._user_id = user_id
self.service = None self.service = None
def _get_credentials(self) -> Credentials: async def _get_credentials(
self,
) -> Credentials:
""" """
Get valid Google OAuth credentials. Get valid Google OAuth credentials.
Returns: Returns:
@ -60,12 +75,30 @@ class GoogleCalendarConnector:
client_id=self._credentials.client_id, client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret, client_secret=self._credentials.client_secret,
scopes=self._credentials.scopes, scopes=self._credentials.scopes,
expiry=self._credentials.expiry,
) )
# Refresh the token if needed # Refresh the token if needed
if self._credentials.expired or not self._credentials.valid: if self._credentials.expired or not self._credentials.valid:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
# Update the connector config in DB
if self._session:
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == self._user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
)
)
connector = result.scalars().first()
if connector is None:
raise RuntimeError(
"GOOGLE_CALENDAR_CONNECTOR connector not found for current user; cannot persist refreshed token."
)
connector.config = json.loads(self._credentials.to_json())
flag_modified(connector, "config")
await self._session.commit()
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Failed to refresh Google OAuth credentials: {e!s}" f"Failed to refresh Google OAuth credentials: {e!s}"
@ -73,7 +106,7 @@ class GoogleCalendarConnector:
return self._credentials return self._credentials
def _get_service(self): async def _get_service(self):
""" """
Get the Google Calendar service instance using Google OAuth credentials. Get the Google Calendar service instance using Google OAuth credentials.
Returns: Returns:
@ -86,20 +119,20 @@ class GoogleCalendarConnector:
return self.service return self.service
try: try:
credentials = self._get_credentials() credentials = await self._get_credentials()
self.service = build("calendar", "v3", credentials=credentials) self.service = build("calendar", "v3", credentials=credentials)
return self.service return self.service
except Exception as e: except Exception as e:
raise Exception(f"Failed to create Google Calendar service: {e!s}") from e raise Exception(f"Failed to create Google Calendar service: {e!s}") from e
def get_calendars(self) -> tuple[list[dict[str, Any]], str | None]: async def get_calendars(self) -> tuple[list[dict[str, Any]], str | None]:
""" """
Fetch list of user's calendars using Google OAuth credentials. Fetch list of user's calendars using Google OAuth credentials.
Returns: Returns:
Tuple containing (calendars list, error message or None) Tuple containing (calendars list, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
calendars_result = service.calendarList().list().execute() calendars_result = service.calendarList().list().execute()
calendars = calendars_result.get("items", []) calendars = calendars_result.get("items", [])
@ -122,7 +155,7 @@ class GoogleCalendarConnector:
except Exception as e: except Exception as e:
return [], f"Error fetching calendars: {e!s}" return [], f"Error fetching calendars: {e!s}"
def get_all_primary_calendar_events( async def get_all_primary_calendar_events(
self, self,
start_date: str, start_date: str,
end_date: str, end_date: str,
@ -136,7 +169,7 @@ class GoogleCalendarConnector:
Tuple containing (events list, error message or None) Tuple containing (events list, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
# Parse both dates # Parse both dates
dt_start = isoparse(start_date) dt_start = isoparse(start_date)

View file

@ -5,12 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials.
""" """
import base64 import base64
import json
import re import re
from typing import Any from typing import Any
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
)
class GoogleGmailConnector: class GoogleGmailConnector:
@ -19,6 +28,8 @@ class GoogleGmailConnector:
def __init__( def __init__(
self, self,
credentials: Credentials, credentials: Credentials,
session: AsyncSession,
user_id: str,
): ):
""" """
Initialize the GoogleGmailConnector class. Initialize the GoogleGmailConnector class.
@ -26,9 +37,13 @@ class GoogleGmailConnector:
credentials: Google OAuth Credentials object credentials: Google OAuth Credentials object
""" """
self._credentials = credentials self._credentials = credentials
self._session = session
self._user_id = user_id
self.service = None self.service = None
def _get_credentials(self) -> Credentials: async def _get_credentials(
self,
) -> Credentials:
""" """
Get valid Google OAuth credentials. Get valid Google OAuth credentials.
Returns: Returns:
@ -59,12 +74,30 @@ class GoogleGmailConnector:
client_id=self._credentials.client_id, client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret, client_secret=self._credentials.client_secret,
scopes=self._credentials.scopes, scopes=self._credentials.scopes,
expiry=self._credentials.expiry,
) )
# Refresh the token if needed # Refresh the token if needed
if self._credentials.expired or not self._credentials.valid: if self._credentials.expired or not self._credentials.valid:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
# Update the connector config in DB
if self._session:
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == self._user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
)
)
connector = result.scalars().first()
if connector is None:
raise RuntimeError(
"GMAIL connector not found for current user; cannot persist refreshed token."
)
connector.config = json.loads(self._credentials.to_json())
flag_modified(connector, "config")
await self._session.commit()
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Failed to refresh Google OAuth credentials: {e!s}" f"Failed to refresh Google OAuth credentials: {e!s}"
@ -72,7 +105,7 @@ class GoogleGmailConnector:
return self._credentials return self._credentials
def _get_service(self): async def _get_service(self):
""" """
Get the Gmail service instance using Google OAuth credentials. Get the Gmail service instance using Google OAuth credentials.
Returns: Returns:
@ -85,20 +118,20 @@ class GoogleGmailConnector:
return self.service return self.service
try: try:
credentials = self._get_credentials() credentials = await self._get_credentials()
self.service = build("gmail", "v1", credentials=credentials) self.service = build("gmail", "v1", credentials=credentials)
return self.service return self.service
except Exception as e: except Exception as e:
raise Exception(f"Failed to create Gmail service: {e!s}") from e raise Exception(f"Failed to create Gmail service: {e!s}") from e
def get_user_profile(self) -> tuple[dict[str, Any], str | None]: async def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
""" """
Fetch user's Gmail profile information. Fetch user's Gmail profile information.
Returns: Returns:
Tuple containing (profile dict, error message or None) Tuple containing (profile dict, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
profile = service.users().getProfile(userId="me").execute() profile = service.users().getProfile(userId="me").execute()
return { return {
@ -111,7 +144,7 @@ class GoogleGmailConnector:
except Exception as e: except Exception as e:
return {}, f"Error fetching user profile: {e!s}" return {}, f"Error fetching user profile: {e!s}"
def get_messages_list( async def get_messages_list(
self, self,
max_results: int = 100, max_results: int = 100,
query: str = "", query: str = "",
@ -129,7 +162,7 @@ class GoogleGmailConnector:
Tuple containing (messages list, error message or None) Tuple containing (messages list, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
# Build request parameters # Build request parameters
request_params = { request_params = {
@ -152,7 +185,9 @@ class GoogleGmailConnector:
except Exception as e: except Exception as e:
return [], f"Error fetching messages list: {e!s}" return [], f"Error fetching messages list: {e!s}"
def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]: async def get_message_details(
self, message_id: str
) -> tuple[dict[str, Any], str | None]:
""" """
Fetch detailed information for a specific message. Fetch detailed information for a specific message.
Args: Args:
@ -161,7 +196,7 @@ class GoogleGmailConnector:
Tuple containing (message details dict, error message or None) Tuple containing (message details dict, error message or None)
""" """
try: try:
service = self._get_service() service = await self._get_service()
# Get full message details # Get full message details
message = ( message = (
@ -176,7 +211,7 @@ class GoogleGmailConnector:
except Exception as e: except Exception as e:
return {}, f"Error fetching message details: {e!s}" return {}, f"Error fetching message details: {e!s}"
def get_recent_messages( async def get_recent_messages(
self, self,
max_results: int = 50, max_results: int = 50,
days_back: int = 30, days_back: int = 30,
@ -198,7 +233,7 @@ class GoogleGmailConnector:
query = f"after:{date_query}" query = f"after:{date_query}"
# Get messages list # Get messages list
messages_list, error = self.get_messages_list( messages_list, error = await self.get_messages_list(
max_results=max_results, query=query max_results=max_results, query=query
) )
@ -208,7 +243,9 @@ class GoogleGmailConnector:
# Get detailed information for each message # Get detailed information for each message
detailed_messages = [] detailed_messages = []
for msg in messages_list: for msg in messages_list:
message_details, detail_error = self.get_message_details(msg["id"]) message_details, detail_error = await self.get_message_details(
msg["id"]
)
if detail_error: if detail_error:
continue # Skip messages that can't be fetched continue # Skip messages that can't be fetched
detailed_messages.append(message_details) detailed_messages.append(message_details)

View file

@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel):
client_secret: str client_secret: str
@property @property
def is_expired(self) -> bool: def expired(self) -> bool:
"""Check if the credentials have expired.""" """Check if the credentials have expired."""
return self.expiry <= datetime.now(UTC) return self.expiry <= datetime.now(UTC)

View file

@ -81,6 +81,7 @@ async def index_google_calendar_events(
return 0, f"Connector with ID {connector_id} not found" return 0, f"Connector with ID {connector_id} not found"
# Get the Google Calendar credentials from the connector config # Get the Google Calendar credentials from the connector config
exp = connector.config.get("expiry").replace("Z", "")
credentials = Credentials( credentials = Credentials(
token=connector.config.get("token"), token=connector.config.get("token"),
refresh_token=connector.config.get("refresh_token"), refresh_token=connector.config.get("refresh_token"),
@ -88,6 +89,7 @@ async def index_google_calendar_events(
client_id=connector.config.get("client_id"), client_id=connector.config.get("client_id"),
client_secret=connector.config.get("client_secret"), client_secret=connector.config.get("client_secret"),
scopes=connector.config.get("scopes"), scopes=connector.config.get("scopes"),
expiry=datetime.fromisoformat(exp),
) )
if ( if (
@ -110,7 +112,9 @@ async def index_google_calendar_events(
{"stage": "client_initialization"}, {"stage": "client_initialization"},
) )
calendar_client = GoogleCalendarConnector(credentials=credentials) calendar_client = GoogleCalendarConnector(
credentials=credentials, session=session, user_id=user_id
)
# Calculate date range # Calculate date range
if start_date is None or end_date is None: if start_date is None or end_date is None:
@ -169,7 +173,7 @@ async def index_google_calendar_events(
# Get events within date range from primary calendar # Get events within date range from primary calendar
try: try:
events, error = calendar_client.get_all_primary_calendar_events( events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str start_date=start_date_str, end_date=end_date_str
) )

View file

@ -95,6 +95,7 @@ async def index_google_gmail_messages(
# Create credentials from connector config # Create credentials from connector config
config_data = connector.config config_data = connector.config
exp = config_data.get("expiry").replace("Z", "")
credentials = Credentials( credentials = Credentials(
token=config_data.get("token"), token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"), refresh_token=config_data.get("refresh_token"),
@ -102,6 +103,7 @@ async def index_google_gmail_messages(
client_id=config_data.get("client_id"), client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"), client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []), scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp),
) )
if ( if (
@ -125,11 +127,11 @@ async def index_google_gmail_messages(
) )
# Initialize Google gmail connector # Initialize Google gmail connector
gmail_connector = GoogleGmailConnector(credentials) gmail_connector = GoogleGmailConnector(credentials, session, user_id)
# Fetch recent Google gmail messages # Fetch recent Google gmail messages
logger.info(f"Fetching recent emails for connector {connector_id}") logger.info(f"Fetching recent emails for connector {connector_id}")
messages, error = gmail_connector.get_recent_messages( messages, error = await gmail_connector.get_recent_messages(
max_results=max_messages, days_back=days_back max_results=max_messages, days_back=days_back
) )