mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 23:42:21 +00:00
feat: implement and test prepare_for_indexing
This commit is contained in:
parent
a0134a5830
commit
579a9e2cb5
4 changed files with 243 additions and 3 deletions
|
|
@ -0,0 +1,56 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
documents = []
|
||||
|
||||
for connector_doc in connector_docs:
|
||||
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
content_hash = compute_content_hash(connector_doc)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing is not None:
|
||||
if existing.content_hash == content_hash:
|
||||
if existing.title != connector_doc.title:
|
||||
existing.title = connector_doc.title
|
||||
continue
|
||||
|
||||
existing.title = connector_doc.title
|
||||
existing.content_hash = content_hash
|
||||
existing.source_markdown = connector_doc.source_markdown
|
||||
existing.status = DocumentStatus.pending()
|
||||
documents.append(existing)
|
||||
continue
|
||||
|
||||
document = Document(
|
||||
title=connector_doc.title,
|
||||
document_type=connector_doc.document_type,
|
||||
content="Pending...",
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=connector_doc.source_markdown,
|
||||
document_metadata=connector_doc.metadata,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
connector_id=connector_doc.connector_id,
|
||||
status=DocumentStatus.pending(),
|
||||
)
|
||||
self.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
await self.session.commit()
|
||||
return documents
|
||||
|
|
@ -78,6 +78,7 @@ dev = [
|
|||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
"unit: pure logic tests, no DB or external services",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,7 +9,8 @@ from sqlalchemy import text
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.db import Base
|
||||
from app.db import Base, SearchSpace
|
||||
from app.db import User
|
||||
|
||||
_EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+
|
||||
|
||||
|
|
@ -18,7 +20,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool, echo=False)
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
# Required for asyncpg + savepoints: disables prepared statement cache
|
||||
# to prevent "another operation is in progress" errors during savepoint rollbacks.
|
||||
connect_args={"prepared_statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
|
|
@ -26,8 +35,11 @@ async def async_engine():
|
|||
|
||||
yield engine
|
||||
|
||||
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
|
||||
# DROP SCHEMA CASCADE handles this without needing topological sort.
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.execute(text("DROP SCHEMA public CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
|
@ -50,6 +62,32 @@ async def db_session(async_engine) -> AsyncSession:
|
|||
await transaction.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Test Space",
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> AsyncMock:
|
||||
llm = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_new_document_is_persisted_with_pending_status(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert len(results) == 1
|
||||
document_id = results[0].id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded is not None
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_unchanged_document_is_skipped(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
await service.prepare_for_indexing([doc])
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
async def test_title_only_change_updates_title_in_db(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title")
|
||||
results = await service.prepare_for_indexing([renamed])
|
||||
|
||||
assert results == []
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
|
||||
|
||||
async def test_changed_content_is_returned_for_reprocessing(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_all_documents_in_batch_are_persisted(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
docs = [
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"),
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"),
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"),
|
||||
]
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing(docs)
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
async def test_duplicate_in_batch_is_persisted_once(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc, doc])
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
async def test_title_and_content_change_updates_both_and_returns_document(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Original Title",
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Updated Title",
|
||||
source_markdown="## v2",
|
||||
)
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
Loading…
Add table
Add a link
Reference in a new issue