mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 23:42:21 +00:00
feat: implement and test index happy path
This commit is contained in:
parent
579a9e2cb5
commit
497ed681d5
3 changed files with 100 additions and 5 deletions
|
|
@ -1,9 +1,24 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
|
||||
from app.utils.document_converters import create_document_chunks, generate_document_summary
|
||||
|
||||
|
||||
def _safe_set_chunks(document: Document, chunks: list) -> None:
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
|
|
@ -54,3 +69,33 @@ class IndexingPipelineService:
|
|||
|
||||
await self.session.commit()
|
||||
return documents
|
||||
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> None:
|
||||
try:
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
if connector_doc.should_summarize:
|
||||
content, embedding = await generate_document_summary(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
|
||||
chunks = await create_document_chunks(connector_doc.source_markdown)
|
||||
|
||||
document.source_markdown = connector_doc.source_markdown
|
||||
document.content = content
|
||||
document.embedding = embedding
|
||||
_safe_set_chunks(document, chunks)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
await self.session.refresh(document)
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
await self.session.commit()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from sqlalchemy.pool import NullPool
|
|||
from app.db import Base, SearchSpace
|
||||
from app.db import User
|
||||
|
||||
_EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+
|
||||
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
|
||||
|
||||
_DEFAULT_TEST_DB = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
|
||||
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
||||
|
|
@ -96,9 +96,33 @@ def mock_llm() -> AsyncMock:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_model() -> MagicMock:
|
||||
model = MagicMock()
|
||||
model.embed = MagicMock(
|
||||
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
|
||||
def patched_generate_summary(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(return_value=("Mocked summary.", [0.1] * _EMBEDDING_DIM))
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.generate_document_summary",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_create_chunks(monkeypatch) -> MagicMock:
|
||||
from app.db import Chunk
|
||||
|
||||
chunk = Chunk(content="Test chunk content.", embedding=[0.1] * _EMBEDDING_DIM)
|
||||
mock = AsyncMock(return_value=[chunk])
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.create_document_chunks",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_embedding_model(monkeypatch) -> MagicMock:
|
||||
from app.config import config
|
||||
|
||||
model = MagicMock()
|
||||
model.embed = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
|
||||
monkeypatch.setattr(config, "embedding_model_instance", model)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_sets_status_ready(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
mock_llm, patched_generate_summary, patched_create_chunks,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, mock_llm)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
Loading…
Add table
Add a link
Reference in a new issue