mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
Merge pull request #298 from MODSetter/dev
Some checks are pending
pre-commit / pre-commit (push) Waiting to run
Some checks are pending
pre-commit / pre-commit (push) Waiting to run
fix: auto refresh token for google based connectors
This commit is contained in:
commit
a5bd1ebe4f
5 changed files with 100 additions and 24 deletions
|
@ -4,6 +4,7 @@ A module for retrieving calendar events from Google Calendar using Google OAuth
|
|||
Allows fetching events from specified calendars within date ranges using Google OAuth credentials.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
@ -12,6 +13,14 @@ from dateutil.parser import isoparse
|
|||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
|
||||
class GoogleCalendarConnector:
|
||||
|
@ -20,6 +29,8 @@ class GoogleCalendarConnector:
|
|||
def __init__(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleCalendarConnector class.
|
||||
|
@ -27,9 +38,13 @@ class GoogleCalendarConnector:
|
|||
credentials: Google OAuth Credentials object
|
||||
"""
|
||||
self._credentials = credentials
|
||||
self._session = session
|
||||
self._user_id = user_id
|
||||
self.service = None
|
||||
|
||||
def _get_credentials(self) -> Credentials:
|
||||
async def _get_credentials(
|
||||
self,
|
||||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
|
@ -60,12 +75,30 @@ class GoogleCalendarConnector:
|
|||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == self._user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector is None:
|
||||
raise RuntimeError(
|
||||
"GOOGLE_CALENDAR_CONNECTOR connector not found for current user; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to refresh Google OAuth credentials: {e!s}"
|
||||
|
@ -73,7 +106,7 @@ class GoogleCalendarConnector:
|
|||
|
||||
return self._credentials
|
||||
|
||||
def _get_service(self):
|
||||
async def _get_service(self):
|
||||
"""
|
||||
Get the Google Calendar service instance using Google OAuth credentials.
|
||||
Returns:
|
||||
|
@ -86,20 +119,20 @@ class GoogleCalendarConnector:
|
|||
return self.service
|
||||
|
||||
try:
|
||||
credentials = self._get_credentials()
|
||||
credentials = await self._get_credentials()
|
||||
self.service = build("calendar", "v3", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create Google Calendar service: {e!s}") from e
|
||||
|
||||
def get_calendars(self) -> tuple[list[dict[str, Any]], str | None]:
|
||||
async def get_calendars(self) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch list of user's calendars using Google OAuth credentials.
|
||||
Returns:
|
||||
Tuple containing (calendars list, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
calendars_result = service.calendarList().list().execute()
|
||||
calendars = calendars_result.get("items", [])
|
||||
|
||||
|
@ -122,7 +155,7 @@ class GoogleCalendarConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching calendars: {e!s}"
|
||||
|
||||
def get_all_primary_calendar_events(
|
||||
async def get_all_primary_calendar_events(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
|
@ -136,7 +169,7 @@ class GoogleCalendarConnector:
|
|||
Tuple containing (events list, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
|
||||
# Parse both dates
|
||||
dt_start = isoparse(start_date)
|
||||
|
|
|
@ -5,12 +5,21 @@ Allows fetching emails from Gmail mailbox using Google OAuth credentials.
|
|||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGmailConnector:
|
||||
|
@ -19,6 +28,8 @@ class GoogleGmailConnector:
|
|||
def __init__(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleGmailConnector class.
|
||||
|
@ -26,9 +37,13 @@ class GoogleGmailConnector:
|
|||
credentials: Google OAuth Credentials object
|
||||
"""
|
||||
self._credentials = credentials
|
||||
self._session = session
|
||||
self._user_id = user_id
|
||||
self.service = None
|
||||
|
||||
def _get_credentials(self) -> Credentials:
|
||||
async def _get_credentials(
|
||||
self,
|
||||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
|
@ -59,12 +74,30 @@ class GoogleGmailConnector:
|
|||
client_id=self._credentials.client_id,
|
||||
client_secret=self._credentials.client_secret,
|
||||
scopes=self._credentials.scopes,
|
||||
expiry=self._credentials.expiry,
|
||||
)
|
||||
|
||||
# Refresh the token if needed
|
||||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
if self._session:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == self._user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector is None:
|
||||
raise RuntimeError(
|
||||
"GMAIL connector not found for current user; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to refresh Google OAuth credentials: {e!s}"
|
||||
|
@ -72,7 +105,7 @@ class GoogleGmailConnector:
|
|||
|
||||
return self._credentials
|
||||
|
||||
def _get_service(self):
|
||||
async def _get_service(self):
|
||||
"""
|
||||
Get the Gmail service instance using Google OAuth credentials.
|
||||
Returns:
|
||||
|
@ -85,20 +118,20 @@ class GoogleGmailConnector:
|
|||
return self.service
|
||||
|
||||
try:
|
||||
credentials = self._get_credentials()
|
||||
credentials = await self._get_credentials()
|
||||
self.service = build("gmail", "v1", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create Gmail service: {e!s}") from e
|
||||
|
||||
def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
|
||||
async def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
|
||||
"""
|
||||
Fetch user's Gmail profile information.
|
||||
Returns:
|
||||
Tuple containing (profile dict, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
profile = service.users().getProfile(userId="me").execute()
|
||||
|
||||
return {
|
||||
|
@ -111,7 +144,7 @@ class GoogleGmailConnector:
|
|||
except Exception as e:
|
||||
return {}, f"Error fetching user profile: {e!s}"
|
||||
|
||||
def get_messages_list(
|
||||
async def get_messages_list(
|
||||
self,
|
||||
max_results: int = 100,
|
||||
query: str = "",
|
||||
|
@ -129,7 +162,7 @@ class GoogleGmailConnector:
|
|||
Tuple containing (messages list, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
|
||||
# Build request parameters
|
||||
request_params = {
|
||||
|
@ -152,7 +185,9 @@ class GoogleGmailConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching messages list: {e!s}"
|
||||
|
||||
def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]:
|
||||
async def get_message_details(
|
||||
self, message_id: str
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
"""
|
||||
Fetch detailed information for a specific message.
|
||||
Args:
|
||||
|
@ -161,7 +196,7 @@ class GoogleGmailConnector:
|
|||
Tuple containing (message details dict, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
|
||||
# Get full message details
|
||||
message = (
|
||||
|
@ -176,7 +211,7 @@ class GoogleGmailConnector:
|
|||
except Exception as e:
|
||||
return {}, f"Error fetching message details: {e!s}"
|
||||
|
||||
def get_recent_messages(
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
max_results: int = 50,
|
||||
days_back: int = 30,
|
||||
|
@ -198,7 +233,7 @@ class GoogleGmailConnector:
|
|||
query = f"after:{date_query}"
|
||||
|
||||
# Get messages list
|
||||
messages_list, error = self.get_messages_list(
|
||||
messages_list, error = await self.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
|
||||
|
@ -208,7 +243,9 @@ class GoogleGmailConnector:
|
|||
# Get detailed information for each message
|
||||
detailed_messages = []
|
||||
for msg in messages_list:
|
||||
message_details, detail_error = self.get_message_details(msg["id"])
|
||||
message_details, detail_error = await self.get_message_details(
|
||||
msg["id"]
|
||||
)
|
||||
if detail_error:
|
||||
continue # Skip messages that can't be fetched
|
||||
detailed_messages.append(message_details)
|
||||
|
|
|
@ -13,6 +13,6 @@ class GoogleAuthCredentialsBase(BaseModel):
|
|||
client_secret: str
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
def expired(self) -> bool:
|
||||
"""Check if the credentials have expired."""
|
||||
return self.expiry <= datetime.now(UTC)
|
||||
|
|
|
@ -81,6 +81,7 @@ async def index_google_calendar_events(
|
|||
return 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Get the Google Calendar credentials from the connector config
|
||||
exp = connector.config.get("expiry").replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=connector.config.get("token"),
|
||||
refresh_token=connector.config.get("refresh_token"),
|
||||
|
@ -88,6 +89,7 @@ async def index_google_calendar_events(
|
|||
client_id=connector.config.get("client_id"),
|
||||
client_secret=connector.config.get("client_secret"),
|
||||
scopes=connector.config.get("scopes"),
|
||||
expiry=datetime.fromisoformat(exp),
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -110,7 +112,9 @@ async def index_google_calendar_events(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
calendar_client = GoogleCalendarConnector(credentials=credentials)
|
||||
calendar_client = GoogleCalendarConnector(
|
||||
credentials=credentials, session=session, user_id=user_id
|
||||
)
|
||||
|
||||
# Calculate date range
|
||||
if start_date is None or end_date is None:
|
||||
|
@ -169,7 +173,7 @@ async def index_google_calendar_events(
|
|||
|
||||
# Get events within date range from primary calendar
|
||||
try:
|
||||
events, error = calendar_client.get_all_primary_calendar_events(
|
||||
events, error = await calendar_client.get_all_primary_calendar_events(
|
||||
start_date=start_date_str, end_date=end_date_str
|
||||
)
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ async def index_google_gmail_messages(
|
|||
|
||||
# Create credentials from connector config
|
||||
config_data = connector.config
|
||||
exp = config_data.get("expiry").replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
|
@ -102,6 +103,7 @@ async def index_google_gmail_messages(
|
|||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp),
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -125,11 +127,11 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
|
||||
# Initialize Google gmail connector
|
||||
gmail_connector = GoogleGmailConnector(credentials)
|
||||
gmail_connector = GoogleGmailConnector(credentials, session, user_id)
|
||||
|
||||
# Fetch recent Google gmail messages
|
||||
logger.info(f"Fetching recent emails for connector {connector_id}")
|
||||
messages, error = gmail_connector.get_recent_messages(
|
||||
messages, error = await gmail_connector.get_recent_messages(
|
||||
max_results=max_messages, days_back=days_back
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue