mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-04 11:39:19 +00:00
updated the connector config after refreshing the token
This commit is contained in:
parent
3d93fe8186
commit
b0b6df0971
3 changed files with 72 additions and 13 deletions
|
@ -8,9 +8,21 @@ import base64
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
from google.auth.transport.requests import Request
|
from google.auth.transport.requests import Request
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
SearchSourceConnector,
|
||||||
|
SearchSourceConnectorType,
|
||||||
|
User,
|
||||||
|
get_async_session,
|
||||||
|
)
|
||||||
|
from app.schemas.google_auth_credentials import GoogleAuthCredentialsBase
|
||||||
|
from app.users import current_active_user
|
||||||
|
|
||||||
|
|
||||||
class GoogleGmailConnector:
|
class GoogleGmailConnector:
|
||||||
|
@ -28,7 +40,11 @@ class GoogleGmailConnector:
|
||||||
self._credentials = credentials
|
self._credentials = credentials
|
||||||
self.service = None
|
self.service = None
|
||||||
|
|
||||||
def _get_credentials(self) -> Credentials:
|
async def _get_credentials(
|
||||||
|
self,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
) -> Credentials:
|
||||||
"""
|
"""
|
||||||
Get valid Google OAuth credentials.
|
Get valid Google OAuth credentials.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -65,6 +81,24 @@ class GoogleGmailConnector:
|
||||||
if self._credentials.expired or not self._credentials.valid:
|
if self._credentials.expired or not self._credentials.valid:
|
||||||
try:
|
try:
|
||||||
self._credentials.refresh(Request())
|
self._credentials.refresh(Request())
|
||||||
|
# Update the connector config in DB
|
||||||
|
|
||||||
|
connector = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.user_id == user.id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.GMAIL,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
connector = connector.scalars().first()
|
||||||
|
|
||||||
|
connector.config = GoogleAuthCredentialsBase(
|
||||||
|
self._credentials
|
||||||
|
).to_json()
|
||||||
|
session.add(connector)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to refresh Google OAuth credentials: {e!s}"
|
f"Failed to refresh Google OAuth credentials: {e!s}"
|
||||||
|
@ -72,7 +106,7 @@ class GoogleGmailConnector:
|
||||||
|
|
||||||
return self._credentials
|
return self._credentials
|
||||||
|
|
||||||
def _get_service(self):
|
async def _get_service(self):
|
||||||
"""
|
"""
|
||||||
Get the Gmail service instance using Google OAuth credentials.
|
Get the Gmail service instance using Google OAuth credentials.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -85,20 +119,20 @@ class GoogleGmailConnector:
|
||||||
return self.service
|
return self.service
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = self._get_credentials()
|
credentials = await self._get_credentials()
|
||||||
self.service = build("gmail", "v1", credentials=credentials)
|
self.service = build("gmail", "v1", credentials=credentials)
|
||||||
return self.service
|
return self.service
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to create Gmail service: {e!s}") from e
|
raise Exception(f"Failed to create Gmail service: {e!s}") from e
|
||||||
|
|
||||||
def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
|
async def get_user_profile(self) -> tuple[dict[str, Any], str | None]:
|
||||||
"""
|
"""
|
||||||
Fetch user's Gmail profile information.
|
Fetch user's Gmail profile information.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing (profile dict, error message or None)
|
Tuple containing (profile dict, error message or None)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
service = self._get_service()
|
service = await self._get_service()
|
||||||
profile = service.users().getProfile(userId="me").execute()
|
profile = service.users().getProfile(userId="me").execute()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -111,7 +145,7 @@ class GoogleGmailConnector:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {}, f"Error fetching user profile: {e!s}"
|
return {}, f"Error fetching user profile: {e!s}"
|
||||||
|
|
||||||
def get_messages_list(
|
async def get_messages_list(
|
||||||
self,
|
self,
|
||||||
max_results: int = 100,
|
max_results: int = 100,
|
||||||
query: str = "",
|
query: str = "",
|
||||||
|
@ -129,7 +163,7 @@ class GoogleGmailConnector:
|
||||||
Tuple containing (messages list, error message or None)
|
Tuple containing (messages list, error message or None)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
service = self._get_service()
|
service = await self._get_service()
|
||||||
|
|
||||||
# Build request parameters
|
# Build request parameters
|
||||||
request_params = {
|
request_params = {
|
||||||
|
@ -152,7 +186,9 @@ class GoogleGmailConnector:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return [], f"Error fetching messages list: {e!s}"
|
return [], f"Error fetching messages list: {e!s}"
|
||||||
|
|
||||||
def get_message_details(self, message_id: str) -> tuple[dict[str, Any], str | None]:
|
async def get_message_details(
|
||||||
|
self, message_id: str
|
||||||
|
) -> tuple[dict[str, Any], str | None]:
|
||||||
"""
|
"""
|
||||||
Fetch detailed information for a specific message.
|
Fetch detailed information for a specific message.
|
||||||
Args:
|
Args:
|
||||||
|
@ -161,7 +197,7 @@ class GoogleGmailConnector:
|
||||||
Tuple containing (message details dict, error message or None)
|
Tuple containing (message details dict, error message or None)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
service = self._get_service()
|
service = await self._get_service()
|
||||||
|
|
||||||
# Get full message details
|
# Get full message details
|
||||||
message = (
|
message = (
|
||||||
|
@ -176,7 +212,7 @@ class GoogleGmailConnector:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {}, f"Error fetching message details: {e!s}"
|
return {}, f"Error fetching message details: {e!s}"
|
||||||
|
|
||||||
def get_recent_messages(
|
async def get_recent_messages(
|
||||||
self,
|
self,
|
||||||
max_results: int = 50,
|
max_results: int = 50,
|
||||||
days_back: int = 30,
|
days_back: int = 30,
|
||||||
|
@ -198,7 +234,7 @@ class GoogleGmailConnector:
|
||||||
query = f"after:{date_query}"
|
query = f"after:{date_query}"
|
||||||
|
|
||||||
# Get messages list
|
# Get messages list
|
||||||
messages_list, error = self.get_messages_list(
|
messages_list, error = await self.get_messages_list(
|
||||||
max_results=max_results, query=query
|
max_results=max_results, query=query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -208,7 +244,9 @@ class GoogleGmailConnector:
|
||||||
# Get detailed information for each message
|
# Get detailed information for each message
|
||||||
detailed_messages = []
|
detailed_messages = []
|
||||||
for msg in messages_list:
|
for msg in messages_list:
|
||||||
message_details, detail_error = self.get_message_details(msg["id"])
|
message_details, detail_error = await self.get_message_details(
|
||||||
|
msg["id"]
|
||||||
|
)
|
||||||
if detail_error:
|
if detail_error:
|
||||||
continue # Skip messages that can't be fetched
|
continue # Skip messages that can't be fetched
|
||||||
detailed_messages.append(message_details)
|
detailed_messages.append(message_details)
|
||||||
|
|
|
@ -60,6 +60,27 @@ async def get_connector_by_id(
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_connector_by_type(
|
||||||
|
session: AsyncSession, connector_type: SearchSourceConnectorType
|
||||||
|
) -> SearchSourceConnector | None:
|
||||||
|
"""
|
||||||
|
Get a connector by type from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
connector_type: Type of the connector
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Connector object if found, None otherwise
|
||||||
|
"""
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.connector_type == connector_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
def calculate_date_range(
|
def calculate_date_range(
|
||||||
connector: SearchSourceConnector,
|
connector: SearchSourceConnector,
|
||||||
start_date: str | None = None,
|
start_date: str | None = None,
|
||||||
|
|
|
@ -129,7 +129,7 @@ async def index_google_gmail_messages(
|
||||||
|
|
||||||
# Fetch recent Google gmail messages
|
# Fetch recent Google gmail messages
|
||||||
logger.info(f"Fetching recent emails for connector {connector_id}")
|
logger.info(f"Fetching recent emails for connector {connector_id}")
|
||||||
messages, error = gmail_connector.get_recent_messages(
|
messages, error = await gmail_connector.get_recent_messages(
|
||||||
max_results=max_messages, days_back=days_back
|
max_results=max_messages, days_back=days_back
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue