mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 05:51:32 +00:00
113 lines
3.3 KiB
Python
113 lines
3.3 KiB
Python
from uuid import UUID
|
|
|
|
import httpx
|
|
import pytest
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import config
|
|
from app.db import (
|
|
SearchSourceConnector,
|
|
SearchSourceConnectorType,
|
|
SearchSpace,
|
|
User,
|
|
)
|
|
from app.utils.oauth_security import OAuthStateManager
|
|
|
|
pytestmark = pytest.mark.integration
|
|
|
|
|
|
def _state_for(space_id: int, user_id: UUID, toolkit_id: str = "googledrive") -> str:
|
|
return OAuthStateManager(config.SECRET_KEY).generate_secure_state(
|
|
space_id=space_id,
|
|
user_id=user_id,
|
|
toolkit_id=toolkit_id,
|
|
)
|
|
|
|
|
|
async def _drive_connectors(
|
|
session: AsyncSession,
|
|
*,
|
|
user_id: UUID,
|
|
search_space_id: int,
|
|
) -> list[SearchSourceConnector]:
|
|
result = await session.execute(
|
|
select(SearchSourceConnector).where(
|
|
SearchSourceConnector.user_id == user_id,
|
|
SearchSourceConnector.search_space_id == search_space_id,
|
|
SearchSourceConnector.connector_type
|
|
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
|
)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def test_callback_with_error_param_redirects_to_denied_page(
|
|
client: httpx.AsyncClient,
|
|
db_session: AsyncSession,
|
|
db_user: User,
|
|
db_search_space: SearchSpace,
|
|
):
|
|
state = _state_for(db_search_space.id, db_user.id)
|
|
|
|
response = await client.get(
|
|
f"/api/v1/auth/composio/connector/callback?state={state}&error=access_denied"
|
|
)
|
|
|
|
assert response.status_code in {302, 303, 307}
|
|
location = response.headers["location"]
|
|
assert (
|
|
f"/dashboard/{db_search_space.id}/connectors/callback?"
|
|
"error=composio_oauth_denied"
|
|
) in location
|
|
|
|
connectors = await _drive_connectors(
|
|
db_session,
|
|
user_id=db_user.id,
|
|
search_space_id=db_search_space.id,
|
|
)
|
|
assert connectors == []
|
|
|
|
|
|
async def test_second_oauth_for_same_toolkit_takes_reconnection_branch(
|
|
client: httpx.AsyncClient,
|
|
db_session: AsyncSession,
|
|
db_user: User,
|
|
db_search_space: SearchSpace,
|
|
):
|
|
first_state = _state_for(db_search_space.id, db_user.id)
|
|
|
|
first_response = await client.get(
|
|
"/api/v1/auth/composio/connector/callback"
|
|
f"?state={first_state}&connectedAccountId=fake-acct-googledrive-first"
|
|
)
|
|
|
|
assert first_response.status_code in {302, 303, 307}
|
|
first_connectors = await _drive_connectors(
|
|
db_session,
|
|
user_id=db_user.id,
|
|
search_space_id=db_search_space.id,
|
|
)
|
|
assert len(first_connectors) == 1
|
|
first_connector = first_connectors[0]
|
|
assert first_connector.config["composio_connected_account_id"] == (
|
|
"fake-acct-googledrive-first"
|
|
)
|
|
|
|
second_state = _state_for(db_search_space.id, db_user.id)
|
|
second_response = await client.get(
|
|
"/api/v1/auth/composio/connector/callback"
|
|
f"?state={second_state}&connectedAccountId=fake-acct-googledrive-second"
|
|
)
|
|
|
|
assert second_response.status_code in {302, 303, 307}
|
|
second_connectors = await _drive_connectors(
|
|
db_session,
|
|
user_id=db_user.id,
|
|
search_space_id=db_search_space.id,
|
|
)
|
|
assert len(second_connectors) == 1
|
|
assert second_connectors[0].id == first_connector.id
|
|
assert second_connectors[0].config["composio_connected_account_id"] == (
|
|
"fake-acct-googledrive-second"
|
|
)
|