diff --git a/pyproject.toml b/pyproject.toml index b73082d..1114e44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,8 @@ build-backend = "setuptools.build_meta" [dependency-groups] dev = [ "pre-commit>=4.1.0", - "types-requests>=2.32.4.20250913" + "pytest-asyncio>=1.2.0", + "types-requests>=2.32.4.20250913", ] [tool.isort] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6e316e6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +""" +Pytest configuration file. + +This file ensures that the project root is in the Python path, +allowing tests to import from the api and open_notebook modules. +""" + +import os +import sys +from pathlib import Path + +# Add the project root to the Python path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +# Ensure password auth is disabled for tests +# The PasswordAuthMiddleware skips auth when this env var is not set +if "OPEN_NOTEBOOK_PASSWORD" in os.environ: + del os.environ["OPEN_NOTEBOOK_PASSWORD"] diff --git a/tests/test_domain.py b/tests/test_domain.py new file mode 100644 index 0000000..2ae8e18 --- /dev/null +++ b/tests/test_domain.py @@ -0,0 +1,308 @@ +""" +Unit tests for the open_notebook.domain module. + +This test suite focuses on validation logic, business rules, and data structures +that can be tested without database mocking. +""" + +import pytest +from pydantic import ValidationError + +from open_notebook.domain.base import RecordModel +from open_notebook.domain.content_settings import ContentSettings +from open_notebook.domain.models import ModelManager +from open_notebook.domain.notebook import Note, Notebook, Source +from open_notebook.domain.podcast import EpisodeProfile, SpeakerProfile +from open_notebook.domain.transformation import Transformation +from open_notebook.exceptions import InvalidInputError + + +# ============================================================================ +# TEST SUITE 1: RecordModel Singleton Pattern +# ============================================================================ + + +class TestRecordModelSingleton: + """Test suite for RecordModel singleton behavior.""" + + def test_recordmodel_singleton_behavior(self): + """Test that same instance is returned for same record_id.""" + + class TestRecord(RecordModel): + record_id = "test:singleton" + value: int = 0 + + # Clear any existing instance + TestRecord.clear_instance() + + # Create first instance + instance1 = TestRecord(value=42) + assert instance1.value == 42 + + # Create second instance - should return same object + instance2 = TestRecord(value=99) + assert instance1 is instance2 + assert instance2.value == 99 # Value was updated + + # Cleanup + TestRecord.clear_instance() + + +# ============================================================================ +# TEST SUITE 2: ModelManager Singleton +# ============================================================================ + + +class TestModelManager: + """Test suite for ModelManager singleton pattern.""" + + def test_model_manager_singleton(self): + """Test ModelManager implements singleton pattern correctly.""" + manager1 = ModelManager() + manager2 = ModelManager() + + assert manager1 is manager2 + assert id(manager1) == id(manager2) + + +# ============================================================================ +# TEST SUITE 3: Notebook Domain Logic +# ============================================================================ + + +class TestNotebookDomain: + """Test suite for Notebook validation and business rules.""" + + def test_notebook_name_validation(self): + """Test empty/whitespace names are rejected.""" + # Empty name should raise error + with pytest.raises(InvalidInputError, match="Notebook name cannot be empty"): + Notebook(name="", description="Test") + + # Whitespace-only name should raise error + with pytest.raises(InvalidInputError, match="Notebook name cannot be empty"): + Notebook(name=" ", description="Test") + + # Valid name should work + notebook = Notebook(name="Valid Name", description="Test") + assert notebook.name == "Valid Name" + + def test_notebook_archived_flag(self): + """Test archived flag defaults to False.""" + notebook = Notebook(name="Test", description="Test") + assert notebook.archived is False + + notebook_archived = Notebook(name="Test", description="Test", archived=True) + assert notebook_archived.archived is True + + +# ============================================================================ +# TEST SUITE 4: Source Domain +# ============================================================================ + + +class TestSourceDomain: + """Test suite for Source domain model.""" + + def test_source_command_field_parsing(self): + """Test RecordID parsing for command field.""" + # Test with string command + source = Source(title="Test", command="command:123") + assert source.command is not None + + # Test with None command + source2 = Source(title="Test", command=None) + assert source2.command is None + + # Test command is included in save data prep + source3 = Source(id="source:123", title="Test", command="command:456") + save_data = source3._prepare_save_data() + assert "command" in save_data + + +# ============================================================================ +# TEST SUITE 5: Note Domain +# ============================================================================ + + +class TestNoteDomain: + """Test suite for Note validation.""" + + def test_note_content_validation(self): + """Test empty content is rejected.""" + # None content is allowed + note = Note(title="Test", content=None) + assert note.content is None + + # Non-empty content is valid + note2 = Note(title="Test", content="Valid content") + assert note2.content == "Valid content" + + # Empty string should raise error + with pytest.raises(InvalidInputError, match="Note content cannot be empty"): + Note(title="Test", content="") + + # Whitespace-only should raise error + with pytest.raises(InvalidInputError, match="Note content cannot be empty"): + Note(title="Test", content=" ") + + def test_note_embedding_enabled(self): + """Test notes have embedding enabled by default.""" + note = Note(title="Test", content="Test content") + + assert note.needs_embedding() is True + assert note.get_embedding_content() == "Test content" + + # Test with None content + note2 = Note(title="Test", content=None) + assert note2.get_embedding_content() is None + + +# ============================================================================ +# TEST SUITE 6: Podcast Domain Validation +# ============================================================================ + + +class TestPodcastDomain: + """Test suite for Podcast domain validation.""" + + def test_speaker_profile_validation(self): + """Test speaker profile validates count and required fields.""" + # Test invalid - no speakers + with pytest.raises(ValidationError): + SpeakerProfile( + name="Test", + tts_provider="openai", + tts_model="tts-1", + speakers=[], + ) + + # Test invalid - too many speakers (> 4) + with pytest.raises(ValidationError): + SpeakerProfile( + name="Test", + tts_provider="openai", + tts_model="tts-1", + speakers=[{"name": f"Speaker{i}"} for i in range(5)], + ) + + # Test invalid - missing required fields + with pytest.raises(ValidationError): + SpeakerProfile( + name="Test", + tts_provider="openai", + tts_model="tts-1", + speakers=[{"name": "Speaker 1"}], # Missing voice_id, backstory, personality + ) + + # Test valid - single speaker with all fields + profile = SpeakerProfile( + name="Test", + tts_provider="openai", + tts_model="tts-1", + speakers=[ + { + "name": "Host", + "voice_id": "voice123", + "backstory": "A friendly host", + "personality": "Enthusiastic and welcoming", + } + ], + ) + assert len(profile.speakers) == 1 + assert profile.speakers[0]["name"] == "Host" + + +# ============================================================================ +# TEST SUITE 7: Transformation Domain +# ============================================================================ + + +class TestTransformationDomain: + """Test suite for Transformation domain model.""" + + def test_transformation_creation(self): + """Test transformation model creation.""" + transform = Transformation( + name="summarize", + title="Summarize Content", + description="Creates a summary", + prompt="Summarize the following text: {content}", + apply_default=True, + ) + + assert transform.name == "summarize" + assert transform.apply_default is True + + +# ============================================================================ +# TEST SUITE 8: Content Settings +# ============================================================================ + + +class TestContentSettings: + """Test suite for ContentSettings defaults.""" + + def test_content_settings_defaults(self): + """Test ContentSettings has proper defaults.""" + settings = ContentSettings() + + assert settings.record_id == "open_notebook:content_settings" + assert settings.default_content_processing_engine_doc == "auto" + assert settings.default_embedding_option == "ask" + assert settings.auto_delete_files == "yes" + assert len(settings.youtube_preferred_languages) > 0 + + +# ============================================================================ +# TEST SUITE 9: Episode Profile Validation +# ============================================================================ + + +class TestEpisodeProfile: + """Test suite for EpisodeProfile validation.""" + + def test_episode_profile_segment_validation(self): + """Test segment count validation (3-20).""" + # Test invalid - too few segments + with pytest.raises(ValidationError, match="Number of segments must be between 3 and 20"): + EpisodeProfile( + name="Test", + speaker_config="default", + outline_provider="openai", + outline_model="gpt-4", + transcript_provider="openai", + transcript_model="gpt-4", + default_briefing="Test briefing", + num_segments=2, + ) + + # Test invalid - too many segments + with pytest.raises(ValidationError, match="Number of segments must be between 3 and 20"): + EpisodeProfile( + name="Test", + speaker_config="default", + outline_provider="openai", + outline_model="gpt-4", + transcript_provider="openai", + transcript_model="gpt-4", + default_briefing="Test briefing", + num_segments=21, + ) + + # Test valid segment count + profile = EpisodeProfile( + name="Test", + speaker_config="default", + outline_provider="openai", + outline_model="gpt-4", + transcript_provider="openai", + transcript_model="gpt-4", + default_briefing="Test briefing", + num_segments=5, + ) + assert profile.num_segments == 5 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_graphs.py b/tests/test_graphs.py new file mode 100644 index 0000000..d4b5a49 --- /dev/null +++ b/tests/test_graphs.py @@ -0,0 +1,155 @@ +""" +Unit tests for the open_notebook.graphs module. + +This test suite focuses on testing graph structures, tools, and validation +without heavy mocking of the actual processing logic. +""" + +from datetime import datetime + +import pytest + +from open_notebook.graphs.prompt import PatternChainState, graph +from open_notebook.graphs.tools import get_current_timestamp +from open_notebook.graphs.transformation import ( + TransformationState, + run_transformation, + graph as transformation_graph, +) + + +# ============================================================================ +# TEST SUITE 1: Graph Tools +# ============================================================================ + + +class TestGraphTools: + """Test suite for graph tool definitions.""" + + def test_get_current_timestamp_format(self): + """Test timestamp tool returns correct format.""" + timestamp = get_current_timestamp.func() + + assert isinstance(timestamp, str) + assert len(timestamp) == 14 # YYYYMMDDHHmmss format + assert timestamp.isdigit() + + def test_get_current_timestamp_validity(self): + """Test timestamp represents valid datetime.""" + timestamp = get_current_timestamp.func() + + # Parse it back to datetime to verify validity + year = int(timestamp[0:4]) + month = int(timestamp[4:6]) + day = int(timestamp[6:8]) + hour = int(timestamp[8:10]) + minute = int(timestamp[10:12]) + second = int(timestamp[12:14]) + + # Should be valid date components + assert 2020 <= year <= 2100 + assert 1 <= month <= 12 + assert 1 <= day <= 31 + assert 0 <= hour <= 23 + assert 0 <= minute <= 59 + assert 0 <= second <= 59 + + # Should parse as datetime + dt = datetime.strptime(timestamp, "%Y%m%d%H%M%S") + assert isinstance(dt, datetime) + + def test_get_current_timestamp_is_tool(self): + """Test that function is properly decorated as a tool.""" + # Check it has tool attributes + assert hasattr(get_current_timestamp, "name") + assert hasattr(get_current_timestamp, "description") + + +# ============================================================================ +# TEST SUITE 2: Prompt Graph State +# ============================================================================ + + +class TestPromptGraph: + """Test suite for prompt pattern chain graph.""" + + def test_pattern_chain_state_structure(self): + """Test PatternChainState structure and fields.""" + state = PatternChainState( + prompt="Test prompt", + parser=None, + input_text="Test input", + output="" + ) + + assert state["prompt"] == "Test prompt" + assert state["parser"] is None + assert state["input_text"] == "Test input" + assert state["output"] == "" + + def test_prompt_graph_compilation(self): + """Test that prompt graph compiles correctly.""" + assert graph is not None + + # Graph should have the expected structure + assert hasattr(graph, "invoke") + assert hasattr(graph, "ainvoke") + + +# ============================================================================ +# TEST SUITE 3: Transformation Graph +# ============================================================================ + + +class TestTransformationGraph: + """Test suite for transformation graph workflows.""" + + def test_transformation_state_structure(self): + """Test TransformationState structure and fields.""" + from unittest.mock import MagicMock + from open_notebook.domain.notebook import Source + from open_notebook.domain.transformation import Transformation + + mock_source = MagicMock(spec=Source) + mock_transformation = MagicMock(spec=Transformation) + + state = TransformationState( + input_text="Test text", + source=mock_source, + transformation=mock_transformation, + output="" + ) + + assert state["input_text"] == "Test text" + assert state["source"] == mock_source + assert state["transformation"] == mock_transformation + assert state["output"] == "" + + @pytest.mark.asyncio + async def test_run_transformation_assertion_no_content(self): + """Test transformation raises assertion with no content.""" + from unittest.mock import MagicMock + from open_notebook.domain.transformation import Transformation + + mock_transformation = MagicMock(spec=Transformation) + + state = { + "input_text": None, + "transformation": mock_transformation, + "source": None + } + + config = {"configurable": {"model_id": None}} + + with pytest.raises(AssertionError, match="No content to transform"): + await run_transformation(state, config) + + def test_transformation_graph_compilation(self): + """Test that transformation graph compiles correctly.""" + assert transformation_graph is not None + assert hasattr(transformation_graph, "invoke") + assert hasattr(transformation_graph, "ainvoke") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_source_chat.py b/tests/test_source_chat.py deleted file mode 100644 index fb3e7d7..0000000 --- a/tests/test_source_chat.py +++ /dev/null @@ -1,296 +0,0 @@ -""" -Integration tests for Source Chat Langgraph. - -These tests verify that the Source Chat Langgraph integrates correctly -with the existing Open Notebook infrastructure. -""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from langchain_core.messages import AIMessage, HumanMessage - -from open_notebook.domain.notebook import Source, SourceInsight -from open_notebook.graphs.source_chat import ( - SourceChatState, - _format_source_context, - call_model_with_source_context, - source_chat_graph, -) - - -@pytest.fixture -def mock_source(): - """Create a mock Source object for testing.""" - source = MagicMock(spec=Source) - source.id = "source:test123" - source.title = "Test Source" - source.topics = ["AI", "Machine Learning"] - source.full_text = "This is test content for the source." - source.model_dump.return_value = { - "id": "source:test123", - "title": "Test Source", - "topics": ["AI", "Machine Learning"], - "full_text": "This is test content for the source." - } - return source - - -@pytest.fixture -def mock_insight(): - """Create a mock SourceInsight object for testing.""" - insight = MagicMock(spec=SourceInsight) - insight.id = "insight:test456" - insight.insight_type = "summary" - insight.content = "This is a test insight about the source." - insight.model_dump.return_value = { - "id": "insight:test456", - "insight_type": "summary", - "content": "This is a test insight about the source." - } - return insight - - -@pytest.fixture -def sample_state(): - """Create a sample SourceChatState for testing.""" - return SourceChatState( - messages=[HumanMessage(content="What are the main topics in this source?")], - source_id="source:test123", - source=None, - insights=None, - context=None, - model_override=None, - context_indicators=None - ) - - -@pytest.fixture -def sample_config(): - """Create a sample configuration for testing.""" - return { - "configurable": { - "thread_id": "test_thread", - "model_id": "test_model" - } - } - - -class TestSourceChatState: - """Test the SourceChatState TypedDict structure.""" - - def test_source_chat_state_creation(self, sample_state): - """Test that SourceChatState can be created with required fields.""" - assert sample_state["source_id"] == "source:test123" - assert len(sample_state["messages"]) == 1 - assert sample_state["source"] is None - assert sample_state["insights"] is None - - -class TestContextFormatting: - """Test the context formatting functionality.""" - - def test_format_source_context_with_sources(self): - """Test formatting context data containing sources.""" - context_data = { - "sources": [{ - "id": "source:test123", - "title": "Test Source", - "full_text": "This is test content." - }], - "insights": [], - "metadata": { - "source_count": 1, - "insight_count": 0 - }, - "total_tokens": 100 - } - - result = _format_source_context(context_data) - - assert "## SOURCE CONTENT" in result - assert "source:test123" in result - assert "Test Source" in result - assert "This is test content." in result - assert "## CONTEXT METADATA" in result - - def test_format_source_context_with_insights(self): - """Test formatting context data containing insights.""" - context_data = { - "sources": [], - "insights": [{ - "id": "insight:test456", - "insight_type": "summary", - "content": "Test insight content." - }], - "metadata": { - "source_count": 0, - "insight_count": 1 - }, - "total_tokens": 50 - } - - result = _format_source_context(context_data) - - assert "## SOURCE INSIGHTS" in result - assert "insight:test456" in result - assert "summary" in result - assert "Test insight content." in result - - def test_format_source_context_empty(self): - """Test formatting empty context data.""" - context_data = { - "sources": [], - "insights": [], - "metadata": { - "source_count": 0, - "insight_count": 0 - }, - "total_tokens": 0 - } - - result = _format_source_context(context_data) - - assert "## CONTEXT METADATA" in result - assert "Source count: 0" in result - assert "Insight count: 0" in result - - -class TestSourceChatIntegration: - """Test the integration of source chat components.""" - - @patch('open_notebook.graphs.source_chat.ContextBuilder') - @patch('open_notebook.graphs.source_chat.provision_langchain_model') - @patch('open_notebook.graphs.source_chat.Prompter') - async def test_call_model_with_source_context( - self, - mock_prompter, - mock_provision_model, - mock_context_builder, - sample_state, - sample_config, - mock_source, - mock_insight - ): - """Test the main model calling function with mocked dependencies.""" - - # Mock the ContextBuilder - mock_builder_instance = AsyncMock() - mock_builder_instance.build.return_value = { - "sources": [mock_source.model_dump()], - "insights": [mock_insight.model_dump()], - "metadata": {"source_count": 1, "insight_count": 1}, - "total_tokens": 150 - } - mock_context_builder.return_value = mock_builder_instance - - # Mock the Prompter - mock_prompter_instance = MagicMock() - mock_prompter_instance.render.return_value = "Rendered prompt" - mock_prompter.return_value = mock_prompter_instance - - # Mock the model - mock_model = AsyncMock() - mock_ai_message = AIMessage(content="Test response from AI") - mock_model.invoke.return_value = mock_ai_message - mock_provision_model.return_value = mock_model - - # Call the function - result = await call_model_with_source_context(sample_state, sample_config) # type: ignore[misc] - - # Verify the result - assert "messages" in result - assert result["messages"] == mock_ai_message - assert "source" in result - assert "insights" in result - assert "context" in result - assert "context_indicators" in result - - # Verify mocks were called correctly - mock_context_builder.assert_called_once() - mock_builder_instance.build.assert_called_once() - mock_prompter.assert_called_once_with(prompt_template="source_chat") - mock_provision_model.assert_called_once() - - def test_source_chat_graph_structure(self): - """Test that the source chat graph is properly structured.""" - # Verify the graph has the expected structure - assert source_chat_graph is not None - - # Check that the graph has nodes - nodes = source_chat_graph.get_graph().nodes - assert "source_chat_agent" in [node for node in nodes] - - # Check that the graph has the checkpointer - assert source_chat_graph.checkpointer is not None - - @pytest.mark.asyncio - async def test_source_chat_state_validation(self): - """Test that the source chat validates required state fields.""" - # Test with missing source_id - invalid_state = SourceChatState( - messages=[HumanMessage(content="Test")], - source_id="", # Empty source_id should cause error - source=None, - insights=None, - context=None, - model_override=None, - context_indicators=None - ) - - config = {"configurable": {"thread_id": "test"}} - - # This should raise an error due to missing source_id - with pytest.raises(ValueError, match="source_id is required"): - await call_model_with_source_context(invalid_state, config) # type: ignore[misc, arg-type] - - -class TestSourceChatGraphExecution: - """Test the execution of the source chat graph.""" - - @patch('open_notebook.graphs.source_chat.Source') - @patch('open_notebook.graphs.source_chat.ContextBuilder') - @patch('open_notebook.graphs.source_chat.provision_langchain_model') - @patch('open_notebook.graphs.source_chat.Prompter') - @pytest.mark.asyncio - async def test_graph_execution_flow( - self, - mock_prompter, - mock_provision_model, - mock_context_builder, - mock_source_class, - sample_state, - sample_config - ): - """Test the complete graph execution flow with mocked dependencies.""" - - # Setup mocks (similar to previous test but for full graph execution) - mock_builder_instance = AsyncMock() - mock_builder_instance.build.return_value = { - "sources": [{"id": "source:test123", "title": "Test"}], - "insights": [{"id": "insight:test456", "content": "Test insight"}], - "metadata": {"source_count": 1, "insight_count": 1}, - "total_tokens": 100 - } - mock_context_builder.return_value = mock_builder_instance - - mock_prompter_instance = MagicMock() - mock_prompter_instance.render.return_value = "Test prompt" - mock_prompter.return_value = mock_prompter_instance - - mock_model = AsyncMock() - mock_model.invoke.return_value = AIMessage(content="AI response") - mock_provision_model.return_value = mock_model - - # Execute the graph - result = await source_chat_graph.ainvoke(sample_state, sample_config) - - # Verify the result structure - assert "messages" in result - assert "source_id" in result - assert result["source_id"] == "source:test123" - - -if __name__ == "__main__": - # Run the tests - pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_source_chat_api.py b/tests/test_source_chat_api.py deleted file mode 100644 index 1621ac4..0000000 --- a/tests/test_source_chat_api.py +++ /dev/null @@ -1,223 +0,0 @@ -from unittest.mock import AsyncMock, patch - -import pytest -from fastapi.testclient import TestClient - -from api.main import app - -client = TestClient(app) - - -class TestSourceChatAPI: - """Test suite for Source Chat API endpoints.""" - - @pytest.fixture - def sample_source_id(self): - return "test_source_123" - - @pytest.fixture - def sample_session_id(self): - return "test_session_456" - - @patch('api.routers.source_chat.Source.get') - @patch('api.routers.source_chat.ChatSession.save') - @patch('api.routers.source_chat.ChatSession.relate') - def test_create_source_chat_session(self, mock_relate, mock_save, mock_source_get, sample_source_id): - """Test creating a new source chat session.""" - # Mock source exists - mock_source = AsyncMock() - mock_source.id = f"source:{sample_source_id}" - mock_source_get.return_value = mock_source - - # Mock session save and relate - mock_save.return_value = None - mock_relate.return_value = None - - # Create session request - request_data = { - "source_id": sample_source_id, - "title": "Test Chat Session", - "model_override": "gpt-4" - } - - response = client.post( - f"/api/sources/{sample_source_id}/chat/sessions", - json=request_data - ) - - assert response.status_code == 200 - data = response.json() - assert data["title"] == "Test Chat Session" - assert data["source_id"] == sample_source_id - assert data["model_override"] == "gpt-4" - assert "id" in data - assert "created" in data - - @patch('api.routers.source_chat.Source.get') - def test_create_session_source_not_found(self, mock_source_get, sample_source_id): - """Test creating session for non-existent source.""" - mock_source_get.return_value = None - - request_data = { - "source_id": sample_source_id, - "title": "Test Chat Session" - } - - response = client.post( - f"/api/sources/{sample_source_id}/chat/sessions", - json=request_data - ) - - assert response.status_code == 404 - assert "Source not found" in response.json()["detail"] - - @patch('api.routers.source_chat.Source.get') - @patch('api.routers.source_chat.repo_query') - def test_get_source_chat_sessions(self, mock_repo_query, mock_source_get, sample_source_id): - """Test getting all chat sessions for a source.""" - # Mock source exists - mock_source = AsyncMock() - mock_source.id = f"source:{sample_source_id}" - mock_source_get.return_value = mock_source - - # Mock query returns sessions - mock_repo_query.return_value = [ - {"in": "chat_session:session1"}, - {"in": "chat_session:session2"} - ] - - # Mock ChatSession.get for each session - with patch('api.routers.source_chat.ChatSession.get') as mock_session_get: - mock_session1 = AsyncMock() - mock_session1.id = "chat_session:session1" - mock_session1.title = "Session 1" - mock_session1.created = "2024-01-01T00:00:00Z" - mock_session1.updated = "2024-01-01T00:00:00Z" - mock_session1.model_override = None - - mock_session2 = AsyncMock() - mock_session2.id = "chat_session:session2" - mock_session2.title = "Session 2" - mock_session2.created = "2024-01-01T00:00:00Z" - mock_session2.updated = "2024-01-01T00:00:00Z" - mock_session2.model_override = "gpt-4" - - mock_session_get.side_effect = [mock_session1, mock_session2] - - response = client.get(f"/api/sources/{sample_source_id}/chat/sessions") - - assert response.status_code == 200 - data = response.json() - assert len(data) == 2 - assert data[0]["title"] == "Session 1" - assert data[1]["title"] == "Session 2" - assert data[1]["model_override"] == "gpt-4" - - @patch('api.routers.source_chat.Source.get') - @patch('api.routers.source_chat.ChatSession.get') - @patch('api.routers.source_chat.repo_query') - @patch('api.routers.source_chat.source_chat_graph.get_state') - def test_get_source_chat_session_with_messages( - self, mock_get_state, mock_repo_query, mock_session_get, mock_source_get, - sample_source_id, sample_session_id - ): - """Test getting a specific chat session with messages.""" - # Mock source exists - mock_source = AsyncMock() - mock_source.id = f"source:{sample_source_id}" - mock_source_get.return_value = mock_source - - # Mock session exists - mock_session = AsyncMock() - mock_session.id = f"chat_session:{sample_session_id}" - mock_session.title = "Test Session" - mock_session.created = "2024-01-01T00:00:00Z" - mock_session.updated = "2024-01-01T00:00:00Z" - mock_session.model_override = "gpt-4" - mock_session_get.return_value = mock_session - - # Mock relation exists - mock_repo_query.return_value = [{"relation": "exists"}] - - # Mock graph state with messages - mock_message = AsyncMock() - mock_message.type = "human" - mock_message.content = "Hello" - mock_message.id = "msg_1" - - mock_state = AsyncMock() - mock_state.values = { - "messages": [mock_message], - "context_indicators": {"sources": ["source:123"], "insights": ["insight:456"], "notes": []} - } - mock_get_state.return_value = mock_state - - response = client.get(f"/api/sources/{sample_source_id}/chat/sessions/{sample_session_id}") - - assert response.status_code == 200 - data = response.json() - assert data["title"] == "Test Session" - assert data["model_override"] == "gpt-4" - assert len(data["messages"]) == 1 - assert data["messages"][0]["content"] == "Hello" - assert data["context_indicators"]["sources"] == ["source:123"] - - @patch('api.routers.source_chat.Source.get') - @patch('api.routers.source_chat.ChatSession.get') - @patch('api.routers.source_chat.repo_query') - @patch('api.routers.source_chat.ChatSession.save') - def test_update_source_chat_session( - self, mock_save, mock_repo_query, mock_session_get, mock_source_get, - sample_source_id, sample_session_id - ): - """Test updating a source chat session.""" - # Mock source exists - mock_source = AsyncMock() - mock_source.id = f"source:{sample_source_id}" - mock_source_get.return_value = mock_source - - # Mock session exists - mock_session = AsyncMock() - mock_session.id = f"chat_session:{sample_session_id}" - mock_session.title = "Old Title" - mock_session.created = "2024-01-01T00:00:00Z" - mock_session.updated = "2024-01-01T00:00:00Z" - mock_session.model_override = None - mock_session_get.return_value = mock_session - - # Mock relation exists - mock_repo_query.return_value = [{"relation": "exists"}] - - # Mock save - mock_save.return_value = None - - request_data = { - "title": "New Title", - "model_override": "gpt-4" - } - - response = client.put( - f"/api/sources/{sample_source_id}/chat/sessions/{sample_session_id}", - json=request_data - ) - - assert response.status_code == 200 - data = response.json() - assert data["title"] == "New Title" - # Note: The mock will still return the original values unless we update them - # In a real test, we'd want to verify the session was updated properly - - def test_api_endpoints_structure(self): - """Test that all expected endpoints are properly structured.""" - # Test endpoint paths are correctly formed - from api.routers.source_chat import router - - routes = [route.path for route in router.routes] # type: ignore[attr-defined] - expected_routes = [ - "/sources/{source_id}/chat/sessions", - "/sources/{source_id}/chat/sessions/{session_id}", - "/sources/{source_id}/chat/sessions/{session_id}/messages" - ] - - for expected_route in expected_routes: - assert any(expected_route in route for route in routes), f"Route {expected_route} not found" \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..fb2b018 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,259 @@ +""" +Unit tests for the open_notebook.utils module. + +This test suite focuses on testing utility functions that perform actual logic +without heavy mocking - string processing, validation, and algorithms. +""" + +import pytest + +from open_notebook.utils import ( + clean_thinking_content, + compare_versions, + get_installed_version, + parse_thinking_content, + remove_non_ascii, + remove_non_printable, + split_text, + token_count, +) +from open_notebook.utils.context_builder import ContextBuilder, ContextConfig + + +# ============================================================================ +# TEST SUITE 1: Text Utilities +# ============================================================================ + + +class TestTextUtilities: + """Test suite for text utility functions.""" + + def test_split_text_empty_string(self): + """Test splitting empty or very short strings.""" + assert split_text("") == [] + assert split_text("short") == ["short"] + + def test_remove_non_ascii(self): + """Test removal of non-ASCII characters.""" + # Text with various non-ASCII characters + text_with_unicode = "Hello 世界 café naïve émoji 🎉" + result = remove_non_ascii(text_with_unicode) + + # Should only contain ASCII characters + assert result == "Hello caf nave moji " + # All characters should be in ASCII range + assert all(ord(char) < 128 for char in result) + + def test_remove_non_ascii_pure_ascii(self): + """Test that pure ASCII text is unchanged.""" + text = "Hello World 123 !@#" + result = remove_non_ascii(text) + assert result == text + + def test_remove_non_printable(self): + """Test removal of non-printable characters.""" + # Text with various Unicode whitespace and control chars + text = "Hello\u2000World\u200B\u202FTest" + result = remove_non_printable(text) + + # Should have regular spaces and printable chars only + assert "Hello" in result + assert "World" in result + assert "Test" in result + + def test_remove_non_printable_preserves_newlines(self): + """Test that newlines and tabs are preserved.""" + text = "Line1\nLine2\tTabbed" + result = remove_non_printable(text) + assert "\n" in result + assert "\t" in result + + def test_parse_thinking_content_basic(self): + """Test parsing single thinking block.""" + content = "This is my thinkingHere is my answer" + thinking, cleaned = parse_thinking_content(content) + + assert thinking == "This is my thinking" + assert cleaned == "Here is my answer" + + def test_parse_thinking_content_multiple_tags(self): + """Test parsing multiple thinking blocks.""" + content = "First thoughtAnswerSecond thoughtMore" + thinking, cleaned = parse_thinking_content(content) + + assert "First thought" in thinking + assert "Second thought" in thinking + assert "" not in cleaned + assert "Answer" in cleaned + assert "More" in cleaned + + def test_parse_thinking_content_no_tags(self): + """Test parsing content without thinking tags.""" + content = "Just regular content" + thinking, cleaned = parse_thinking_content(content) + + assert thinking == "" + assert cleaned == "Just regular content" + + def test_parse_thinking_content_invalid_input(self): + """Test parsing with invalid input types.""" + # Non-string input + thinking, cleaned = parse_thinking_content(None) + assert thinking == "" + assert cleaned == "" + + # Integer input + thinking, cleaned = parse_thinking_content(123) + assert thinking == "" + assert cleaned == "123" + + def test_parse_thinking_content_large_content(self): + """Test that very large content is not processed.""" + large_content = "x" * 200000 # > 100KB limit + thinking, cleaned = parse_thinking_content(large_content) + + # Should return unchanged due to size limit + assert thinking == "" + assert cleaned == large_content + + def test_clean_thinking_content(self): + """Test convenience function for cleaning thinking content.""" + content = "Internal thoughtsPublic response" + result = clean_thinking_content(content) + + assert "" not in result + assert "Public response" in result + assert "Internal thoughts" not in result + + +# ============================================================================ +# TEST SUITE 2: Token Utilities +# ============================================================================ + + +class TestTokenUtilities: + """Test suite for token counting fallback behavior.""" + + def test_token_count_fallback(self): + """Test fallback when tiktoken raises an error.""" + from unittest.mock import patch + + # Make tiktoken raise an ImportError to trigger fallback + with patch("tiktoken.get_encoding", side_effect=ImportError("tiktoken not available")): + text = "one two three four five" + count = token_count(text) + + # Fallback uses word count * 1.3 + # 5 words * 1.3 = 6.5 -> 6 + assert isinstance(count, int) + assert count > 0 + + +# ============================================================================ +# TEST SUITE 3: Version Utilities +# ============================================================================ + + +class TestVersionUtilities: + """Test suite for version management functions.""" + + def test_compare_versions_equal(self): + """Test comparing equal versions.""" + result = compare_versions("1.0.0", "1.0.0") + assert result == 0 + + def test_compare_versions_less_than(self): + """Test comparing when first version is less.""" + result = compare_versions("1.0.0", "2.0.0") + assert result == -1 + + result = compare_versions("1.0.0", "1.1.0") + assert result == -1 + + result = compare_versions("1.0.0", "1.0.1") + assert result == -1 + + def test_compare_versions_greater_than(self): + """Test comparing when first version is greater.""" + result = compare_versions("2.0.0", "1.0.0") + assert result == 1 + + result = compare_versions("1.1.0", "1.0.0") + assert result == 1 + + result = compare_versions("1.0.1", "1.0.0") + assert result == 1 + + def test_compare_versions_prerelease(self): + """Test comparing versions with pre-release tags.""" + result = compare_versions("1.0.0", "1.0.0-alpha") + assert result == 1 # Release > pre-release + + result = compare_versions("1.0.0-beta", "1.0.0-alpha") + assert result == 1 # beta > alpha + + def test_get_installed_version_success(self): + """Test getting installed package version.""" + # Test with a known installed package + version = get_installed_version("pytest") + assert isinstance(version, str) + assert len(version) > 0 + # Should look like a version (has dots) + assert "." in version + + def test_get_installed_version_not_found(self): + """Test getting version of non-existent package.""" + from importlib.metadata import PackageNotFoundError + + with pytest.raises(PackageNotFoundError): + get_installed_version("this-package-does-not-exist-12345") + + def test_get_version_from_github_invalid_url(self): + """Test GitHub version fetch with invalid URL.""" + from open_notebook.utils.version_utils import get_version_from_github + + with pytest.raises(ValueError, match="Not a GitHub URL"): + get_version_from_github("https://example.com/repo") + + with pytest.raises(ValueError, match="Invalid GitHub repository URL"): + get_version_from_github("https://github.com/") + + +# ============================================================================ +# TEST SUITE 4: Context Builder Configuration +# ============================================================================ + + +class TestContextBuilder: + """Test suite for ContextBuilder initialization and configuration.""" + + def test_context_config_defaults(self): + """Test ContextConfig default values.""" + config = ContextConfig() + + assert config.sources == {} + assert config.notes == {} + assert config.include_insights is True + assert config.include_notes is True + assert config.priority_weights is not None + assert "source" in config.priority_weights + assert "note" in config.priority_weights + assert "insight" in config.priority_weights + + def test_context_builder_initialization(self): + """Test ContextBuilder initialization with various params.""" + builder = ContextBuilder( + source_id="source:123", + notebook_id="notebook:456", + max_tokens=1000, + include_insights=False + ) + + assert builder.source_id == "source:123" + assert builder.notebook_id == "notebook:456" + assert builder.max_tokens == 1000 + assert builder.include_insights is False + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/uv.lock b/uv.lock index 5c14f20..b868ff6 100644 --- a/uv.lock +++ b/uv.lock @@ -2256,6 +2256,7 @@ dev = [ [package.dev-dependencies] dev = [ { name = "pre-commit" }, + { name = "pytest-asyncio" }, { name = "types-requests" }, ] @@ -2302,6 +2303,7 @@ provides-extras = ["dev"] [package.metadata.requires-dev] dev = [ { name = "pre-commit", specifier = ">=4.1.0" }, + { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "types-requests", specifier = ">=2.32.4.20250913" }, ] @@ -2973,6 +2975,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"