mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
updated the connector config after refreshing the token
This commit is contained in:
parent
3d93fe8186
commit
b0b6df0971
3 changed files with 72 additions and 13 deletions
|
@ -8,9 +8,21 @@ import base64
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
|
||||
|
||||
class GoogleGmailConnector:
|
||||
|
@ -28,7 +40,11 @@ class GoogleGmailConnector:
|
|||
self._credentials = credentials
|
||||
self.service = None
|
||||
|
||||
def _get_credentials(self) -> Credentials:
|
||||
async def _get_credentials(
|
||||
self,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> Credentials:
|
||||
"""
|
||||
Get valid Google OAuth credentials.
|
||||
Returns:
|
||||
|
@ -65,6 +81,24 @@ class GoogleGmailConnector:
|
|||
if self._credentials.expired or not self._credentials.valid:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
# Update the connector config in DB
|
||||
|
||||
connector = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GMAIL,
|
||||
)
|
||||
)
|
||||
|
||||
connector = connector.scalars().first()
|
||||
|
||||
connector.config = GoogleAuthCredentialsBase(
|
||||
self._credentials
|
||||
).to_json()
|
||||
session.add(connector)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to refresh Google OAuth credentials: {e!s}"
|
||||
|
@ -72,7 +106,7 @@ class GoogleGmailConnector:
|
|||
|
||||
return self._credentials
|
||||
|
||||
def _get_service(self):
|
||||
async def _get_service(self):
|
||||
"""
|
||||
Get the Gmail service instance using Google OAuth credentials.
|
||||
Returns:
|
||||
|
@ -85,20 +119,20 @@ class GoogleGmailConnector:
|
|||
return self.service
|
||||
|
||||
try:
|
||||
credentials = self._get_credentials()
|
||||
credentials = await self._get_credentials()
|
||||
self.service = build("gmail", "v1", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create Gmail service: {e!s}") from e
|
||||
|
||||
def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
|
||||
async def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
|
||||
"""
|
||||
Fetch user's Gmail profile information.
|
||||
Returns:
|
||||
Tuple containing (profile dict, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
profile = service.users().getProfile(userId="me").execute()
|
||||
|
||||
return {
|
||||
|
@ -111,7 +145,7 @@ class GoogleGmailConnector:
|
|||
except Exception as e:
|
||||
return {}, f"Error fetching user profile: {e!s}"
|
||||
|
||||
def get_messages_list(
|
||||
async def get_messages_list(
|
||||
self,
|
||||
max_results: int = 100,
|
||||
query: str = "",
|
||||
|
@ -129,7 +163,7 @@ class GoogleGmailConnector:
|
|||
Tuple containing (messages list, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
|
||||
# Build request parameters
|
||||
request_params = {
|
||||
|
@ -152,7 +186,9 @@ class GoogleGmailConnector:
|
|||
except Exception as e:
|
||||
return [], f"Error fetching messages list: {e!s}"
|
||||
|
||||
def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]:
|
||||
async def get_message_details(
|
||||
self, message_id: str
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
"""
|
||||
Fetch detailed information for a specific message.
|
||||
Args:
|
||||
|
@ -161,7 +197,7 @@ class GoogleGmailConnector:
|
|||
Tuple containing (message details dict, error message or None)
|
||||
"""
|
||||
try:
|
||||
service = self._get_service()
|
||||
service = await self._get_service()
|
||||
|
||||
# Get full message details
|
||||
message = (
|
||||
|
@ -176,7 +212,7 @@ class GoogleGmailConnector:
|
|||
except Exception as e:
|
||||
return {}, f"Error fetching message details: {e!s}"
|
||||
|
||||
def get_recent_messages(
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
max_results: int = 50,
|
||||
days_back: int = 30,
|
||||
|
@ -198,7 +234,7 @@ class GoogleGmailConnector:
|
|||
query = f"after:{date_query}"
|
||||
|
||||
# Get messages list
|
||||
messages_list, error = self.get_messages_list(
|
||||
messages_list, error = await self.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
|
||||
|
@ -208,7 +244,9 @@ class GoogleGmailConnector:
|
|||
# Get detailed information for each message
|
||||
detailed_messages = []
|
||||
for msg in messages_list:
|
||||
message_details, detail_error = self.get_message_details(msg["id"])
|
||||
message_details, detail_error = await self.get_message_details(
|
||||
msg["id"]
|
||||
)
|
||||
if detail_error:
|
||||
continue # Skip messages that can't be fetched
|
||||
detailed_messages.append(message_details)
|
||||
|
|
|
@ -60,6 +60,27 @@ async def get_connector_by_id(
|
|||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_connector_by_type(
|
||||
session: AsyncSession, connector_type: SearchSourceConnectorType
|
||||
) -> SearchSourceConnector | None:
|
||||
"""
|
||||
Get a connector by type from the database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_type: Type of the connector
|
||||
|
||||
Returns:
|
||||
Connector object if found, None otherwise
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type == connector_type,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def calculate_date_range(
|
||||
connector: SearchSourceConnector,
|
||||
start_date: str | None = None,
|
||||
|
|
|
@ -129,7 +129,7 @@ async def index_google_gmail_messages(
|
|||
|
||||
# Fetch recent Google gmail messages
|
||||
logger.info(f"Fetching recent emails for connector {connector_id}")
|
||||
messages, error = gmail_connector.get_recent_messages(
|
||||
messages, error = await gmail_connector.get_recent_messages(
|
||||
max_results=max_messages, days_back=days_back
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue