diff --git a/pyproject.toml b/pyproject.toml
index b73082d..1114e44 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,7 +65,8 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"pre-commit>=4.1.0",
- "types-requests>=2.32.4.20250913"
+ "pytest-asyncio>=1.2.0",
+ "types-requests>=2.32.4.20250913",
]
[tool.isort]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..6e316e6
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,19 @@
+"""
+Pytest configuration file.
+
+This file ensures that the project root is in the Python path,
+allowing tests to import from the api and open_notebook modules.
+"""
+
+import os
+import sys
+from pathlib import Path
+
+# Add the project root to the Python path
+project_root = Path(__file__).parent.parent
+sys.path.insert(0, str(project_root))
+
+# Ensure password auth is disabled for tests
+# The PasswordAuthMiddleware skips auth when this env var is not set
+if "OPEN_NOTEBOOK_PASSWORD" in os.environ:
+ del os.environ["OPEN_NOTEBOOK_PASSWORD"]
diff --git a/tests/test_domain.py b/tests/test_domain.py
new file mode 100644
index 0000000..2ae8e18
--- /dev/null
+++ b/tests/test_domain.py
@@ -0,0 +1,308 @@
+"""
+Unit tests for the open_notebook.domain module.
+
+This test suite focuses on validation logic, business rules, and data structures
+that can be tested without database mocking.
+"""
+
+import pytest
+from pydantic import ValidationError
+
+from open_notebook.domain.base import RecordModel
+from open_notebook.domain.content_settings import ContentSettings
+from open_notebook.domain.models import ModelManager
+from open_notebook.domain.notebook import Note, Notebook, Source
+from open_notebook.domain.podcast import EpisodeProfile, SpeakerProfile
+from open_notebook.domain.transformation import Transformation
+from open_notebook.exceptions import InvalidInputError
+
+
+# ============================================================================
+# TEST SUITE 1: RecordModel Singleton Pattern
+# ============================================================================
+
+
+class TestRecordModelSingleton:
+ """Test suite for RecordModel singleton behavior."""
+
+ def test_recordmodel_singleton_behavior(self):
+ """Test that same instance is returned for same record_id."""
+
+ class TestRecord(RecordModel):
+ record_id = "test:singleton"
+ value: int = 0
+
+ # Clear any existing instance
+ TestRecord.clear_instance()
+
+ # Create first instance
+ instance1 = TestRecord(value=42)
+ assert instance1.value == 42
+
+ # Create second instance - should return same object
+ instance2 = TestRecord(value=99)
+ assert instance1 is instance2
+ assert instance2.value == 99 # Value was updated
+
+ # Cleanup
+ TestRecord.clear_instance()
+
+
+# ============================================================================
+# TEST SUITE 2: ModelManager Singleton
+# ============================================================================
+
+
+class TestModelManager:
+ """Test suite for ModelManager singleton pattern."""
+
+ def test_model_manager_singleton(self):
+ """Test ModelManager implements singleton pattern correctly."""
+ manager1 = ModelManager()
+ manager2 = ModelManager()
+
+ assert manager1 is manager2
+ assert id(manager1) == id(manager2)
+
+
+# ============================================================================
+# TEST SUITE 3: Notebook Domain Logic
+# ============================================================================
+
+
+class TestNotebookDomain:
+ """Test suite for Notebook validation and business rules."""
+
+ def test_notebook_name_validation(self):
+ """Test empty/whitespace names are rejected."""
+ # Empty name should raise error
+ with pytest.raises(InvalidInputError, match="Notebook name cannot be empty"):
+ Notebook(name="", description="Test")
+
+ # Whitespace-only name should raise error
+ with pytest.raises(InvalidInputError, match="Notebook name cannot be empty"):
+ Notebook(name=" ", description="Test")
+
+ # Valid name should work
+ notebook = Notebook(name="Valid Name", description="Test")
+ assert notebook.name == "Valid Name"
+
+ def test_notebook_archived_flag(self):
+ """Test archived flag defaults to False."""
+ notebook = Notebook(name="Test", description="Test")
+ assert notebook.archived is False
+
+ notebook_archived = Notebook(name="Test", description="Test", archived=True)
+ assert notebook_archived.archived is True
+
+
+# ============================================================================
+# TEST SUITE 4: Source Domain
+# ============================================================================
+
+
+class TestSourceDomain:
+ """Test suite for Source domain model."""
+
+ def test_source_command_field_parsing(self):
+ """Test RecordID parsing for command field."""
+ # Test with string command
+ source = Source(title="Test", command="command:123")
+ assert source.command is not None
+
+ # Test with None command
+ source2 = Source(title="Test", command=None)
+ assert source2.command is None
+
+ # Test command is included in save data prep
+ source3 = Source(id="source:123", title="Test", command="command:456")
+ save_data = source3._prepare_save_data()
+ assert "command" in save_data
+
+
+# ============================================================================
+# TEST SUITE 5: Note Domain
+# ============================================================================
+
+
+class TestNoteDomain:
+ """Test suite for Note validation."""
+
+ def test_note_content_validation(self):
+ """Test empty content is rejected."""
+ # None content is allowed
+ note = Note(title="Test", content=None)
+ assert note.content is None
+
+ # Non-empty content is valid
+ note2 = Note(title="Test", content="Valid content")
+ assert note2.content == "Valid content"
+
+ # Empty string should raise error
+ with pytest.raises(InvalidInputError, match="Note content cannot be empty"):
+ Note(title="Test", content="")
+
+ # Whitespace-only should raise error
+ with pytest.raises(InvalidInputError, match="Note content cannot be empty"):
+ Note(title="Test", content=" ")
+
+ def test_note_embedding_enabled(self):
+ """Test notes have embedding enabled by default."""
+ note = Note(title="Test", content="Test content")
+
+ assert note.needs_embedding() is True
+ assert note.get_embedding_content() == "Test content"
+
+ # Test with None content
+ note2 = Note(title="Test", content=None)
+ assert note2.get_embedding_content() is None
+
+
+# ============================================================================
+# TEST SUITE 6: Podcast Domain Validation
+# ============================================================================
+
+
+class TestPodcastDomain:
+ """Test suite for Podcast domain validation."""
+
+ def test_speaker_profile_validation(self):
+ """Test speaker profile validates count and required fields."""
+ # Test invalid - no speakers
+ with pytest.raises(ValidationError):
+ SpeakerProfile(
+ name="Test",
+ tts_provider="openai",
+ tts_model="tts-1",
+ speakers=[],
+ )
+
+ # Test invalid - too many speakers (> 4)
+ with pytest.raises(ValidationError):
+ SpeakerProfile(
+ name="Test",
+ tts_provider="openai",
+ tts_model="tts-1",
+ speakers=[{"name": f"Speaker{i}"} for i in range(5)],
+ )
+
+ # Test invalid - missing required fields
+ with pytest.raises(ValidationError):
+ SpeakerProfile(
+ name="Test",
+ tts_provider="openai",
+ tts_model="tts-1",
+ speakers=[{"name": "Speaker 1"}], # Missing voice_id, backstory, personality
+ )
+
+ # Test valid - single speaker with all fields
+ profile = SpeakerProfile(
+ name="Test",
+ tts_provider="openai",
+ tts_model="tts-1",
+ speakers=[
+ {
+ "name": "Host",
+ "voice_id": "voice123",
+ "backstory": "A friendly host",
+ "personality": "Enthusiastic and welcoming",
+ }
+ ],
+ )
+ assert len(profile.speakers) == 1
+ assert profile.speakers[0]["name"] == "Host"
+
+
+# ============================================================================
+# TEST SUITE 7: Transformation Domain
+# ============================================================================
+
+
+class TestTransformationDomain:
+ """Test suite for Transformation domain model."""
+
+ def test_transformation_creation(self):
+ """Test transformation model creation."""
+ transform = Transformation(
+ name="summarize",
+ title="Summarize Content",
+ description="Creates a summary",
+ prompt="Summarize the following text: {content}",
+ apply_default=True,
+ )
+
+ assert transform.name == "summarize"
+ assert transform.apply_default is True
+
+
+# ============================================================================
+# TEST SUITE 8: Content Settings
+# ============================================================================
+
+
+class TestContentSettings:
+ """Test suite for ContentSettings defaults."""
+
+ def test_content_settings_defaults(self):
+ """Test ContentSettings has proper defaults."""
+ settings = ContentSettings()
+
+ assert settings.record_id == "open_notebook:content_settings"
+ assert settings.default_content_processing_engine_doc == "auto"
+ assert settings.default_embedding_option == "ask"
+ assert settings.auto_delete_files == "yes"
+ assert len(settings.youtube_preferred_languages) > 0
+
+
+# ============================================================================
+# TEST SUITE 9: Episode Profile Validation
+# ============================================================================
+
+
+class TestEpisodeProfile:
+ """Test suite for EpisodeProfile validation."""
+
+ def test_episode_profile_segment_validation(self):
+ """Test segment count validation (3-20)."""
+ # Test invalid - too few segments
+ with pytest.raises(ValidationError, match="Number of segments must be between 3 and 20"):
+ EpisodeProfile(
+ name="Test",
+ speaker_config="default",
+ outline_provider="openai",
+ outline_model="gpt-4",
+ transcript_provider="openai",
+ transcript_model="gpt-4",
+ default_briefing="Test briefing",
+ num_segments=2,
+ )
+
+ # Test invalid - too many segments
+ with pytest.raises(ValidationError, match="Number of segments must be between 3 and 20"):
+ EpisodeProfile(
+ name="Test",
+ speaker_config="default",
+ outline_provider="openai",
+ outline_model="gpt-4",
+ transcript_provider="openai",
+ transcript_model="gpt-4",
+ default_briefing="Test briefing",
+ num_segments=21,
+ )
+
+ # Test valid segment count
+ profile = EpisodeProfile(
+ name="Test",
+ speaker_config="default",
+ outline_provider="openai",
+ outline_model="gpt-4",
+ transcript_provider="openai",
+ transcript_model="gpt-4",
+ default_briefing="Test briefing",
+ num_segments=5,
+ )
+ assert profile.num_segments == 5
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_graphs.py b/tests/test_graphs.py
new file mode 100644
index 0000000..d4b5a49
--- /dev/null
+++ b/tests/test_graphs.py
@@ -0,0 +1,155 @@
+"""
+Unit tests for the open_notebook.graphs module.
+
+This test suite focuses on testing graph structures, tools, and validation
+without heavy mocking of the actual processing logic.
+"""
+
+from datetime import datetime
+
+import pytest
+
+from open_notebook.graphs.prompt import PatternChainState, graph
+from open_notebook.graphs.tools import get_current_timestamp
+from open_notebook.graphs.transformation import (
+ TransformationState,
+ run_transformation,
+ graph as transformation_graph,
+)
+
+
+# ============================================================================
+# TEST SUITE 1: Graph Tools
+# ============================================================================
+
+
+class TestGraphTools:
+ """Test suite for graph tool definitions."""
+
+ def test_get_current_timestamp_format(self):
+ """Test timestamp tool returns correct format."""
+ timestamp = get_current_timestamp.func()
+
+ assert isinstance(timestamp, str)
+ assert len(timestamp) == 14 # YYYYMMDDHHmmss format
+ assert timestamp.isdigit()
+
+ def test_get_current_timestamp_validity(self):
+ """Test timestamp represents valid datetime."""
+ timestamp = get_current_timestamp.func()
+
+ # Parse it back to datetime to verify validity
+ year = int(timestamp[0:4])
+ month = int(timestamp[4:6])
+ day = int(timestamp[6:8])
+ hour = int(timestamp[8:10])
+ minute = int(timestamp[10:12])
+ second = int(timestamp[12:14])
+
+ # Should be valid date components
+ assert 2020 <= year <= 2100
+ assert 1 <= month <= 12
+ assert 1 <= day <= 31
+ assert 0 <= hour <= 23
+ assert 0 <= minute <= 59
+ assert 0 <= second <= 59
+
+ # Should parse as datetime
+ dt = datetime.strptime(timestamp, "%Y%m%d%H%M%S")
+ assert isinstance(dt, datetime)
+
+ def test_get_current_timestamp_is_tool(self):
+ """Test that function is properly decorated as a tool."""
+ # Check it has tool attributes
+ assert hasattr(get_current_timestamp, "name")
+ assert hasattr(get_current_timestamp, "description")
+
+
+# ============================================================================
+# TEST SUITE 2: Prompt Graph State
+# ============================================================================
+
+
+class TestPromptGraph:
+ """Test suite for prompt pattern chain graph."""
+
+ def test_pattern_chain_state_structure(self):
+ """Test PatternChainState structure and fields."""
+ state = PatternChainState(
+ prompt="Test prompt",
+ parser=None,
+ input_text="Test input",
+ output=""
+ )
+
+ assert state["prompt"] == "Test prompt"
+ assert state["parser"] is None
+ assert state["input_text"] == "Test input"
+ assert state["output"] == ""
+
+ def test_prompt_graph_compilation(self):
+ """Test that prompt graph compiles correctly."""
+ assert graph is not None
+
+ # Graph should have the expected structure
+ assert hasattr(graph, "invoke")
+ assert hasattr(graph, "ainvoke")
+
+
+# ============================================================================
+# TEST SUITE 3: Transformation Graph
+# ============================================================================
+
+
+class TestTransformationGraph:
+ """Test suite for transformation graph workflows."""
+
+ def test_transformation_state_structure(self):
+ """Test TransformationState structure and fields."""
+ from unittest.mock import MagicMock
+ from open_notebook.domain.notebook import Source
+ from open_notebook.domain.transformation import Transformation
+
+ mock_source = MagicMock(spec=Source)
+ mock_transformation = MagicMock(spec=Transformation)
+
+ state = TransformationState(
+ input_text="Test text",
+ source=mock_source,
+ transformation=mock_transformation,
+ output=""
+ )
+
+ assert state["input_text"] == "Test text"
+ assert state["source"] == mock_source
+ assert state["transformation"] == mock_transformation
+ assert state["output"] == ""
+
+ @pytest.mark.asyncio
+ async def test_run_transformation_assertion_no_content(self):
+ """Test transformation raises assertion with no content."""
+ from unittest.mock import MagicMock
+ from open_notebook.domain.transformation import Transformation
+
+ mock_transformation = MagicMock(spec=Transformation)
+
+ state = {
+ "input_text": None,
+ "transformation": mock_transformation,
+ "source": None
+ }
+
+ config = {"configurable": {"model_id": None}}
+
+ with pytest.raises(AssertionError, match="No content to transform"):
+ await run_transformation(state, config)
+
+ def test_transformation_graph_compilation(self):
+ """Test that transformation graph compiles correctly."""
+ assert transformation_graph is not None
+ assert hasattr(transformation_graph, "invoke")
+ assert hasattr(transformation_graph, "ainvoke")
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_source_chat.py b/tests/test_source_chat.py
deleted file mode 100644
index fb3e7d7..0000000
--- a/tests/test_source_chat.py
+++ /dev/null
@@ -1,296 +0,0 @@
-"""
-Integration tests for Source Chat Langgraph.
-
-These tests verify that the Source Chat Langgraph integrates correctly
-with the existing Open Notebook infrastructure.
-"""
-
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from langchain_core.messages import AIMessage, HumanMessage
-
-from open_notebook.domain.notebook import Source, SourceInsight
-from open_notebook.graphs.source_chat import (
- SourceChatState,
- _format_source_context,
- call_model_with_source_context,
- source_chat_graph,
-)
-
-
-@pytest.fixture
-def mock_source():
- """Create a mock Source object for testing."""
- source = MagicMock(spec=Source)
- source.id = "source:test123"
- source.title = "Test Source"
- source.topics = ["AI", "Machine Learning"]
- source.full_text = "This is test content for the source."
- source.model_dump.return_value = {
- "id": "source:test123",
- "title": "Test Source",
- "topics": ["AI", "Machine Learning"],
- "full_text": "This is test content for the source."
- }
- return source
-
-
-@pytest.fixture
-def mock_insight():
- """Create a mock SourceInsight object for testing."""
- insight = MagicMock(spec=SourceInsight)
- insight.id = "insight:test456"
- insight.insight_type = "summary"
- insight.content = "This is a test insight about the source."
- insight.model_dump.return_value = {
- "id": "insight:test456",
- "insight_type": "summary",
- "content": "This is a test insight about the source."
- }
- return insight
-
-
-@pytest.fixture
-def sample_state():
- """Create a sample SourceChatState for testing."""
- return SourceChatState(
- messages=[HumanMessage(content="What are the main topics in this source?")],
- source_id="source:test123",
- source=None,
- insights=None,
- context=None,
- model_override=None,
- context_indicators=None
- )
-
-
-@pytest.fixture
-def sample_config():
- """Create a sample configuration for testing."""
- return {
- "configurable": {
- "thread_id": "test_thread",
- "model_id": "test_model"
- }
- }
-
-
-class TestSourceChatState:
- """Test the SourceChatState TypedDict structure."""
-
- def test_source_chat_state_creation(self, sample_state):
- """Test that SourceChatState can be created with required fields."""
- assert sample_state["source_id"] == "source:test123"
- assert len(sample_state["messages"]) == 1
- assert sample_state["source"] is None
- assert sample_state["insights"] is None
-
-
-class TestContextFormatting:
- """Test the context formatting functionality."""
-
- def test_format_source_context_with_sources(self):
- """Test formatting context data containing sources."""
- context_data = {
- "sources": [{
- "id": "source:test123",
- "title": "Test Source",
- "full_text": "This is test content."
- }],
- "insights": [],
- "metadata": {
- "source_count": 1,
- "insight_count": 0
- },
- "total_tokens": 100
- }
-
- result = _format_source_context(context_data)
-
- assert "## SOURCE CONTENT" in result
- assert "source:test123" in result
- assert "Test Source" in result
- assert "This is test content." in result
- assert "## CONTEXT METADATA" in result
-
- def test_format_source_context_with_insights(self):
- """Test formatting context data containing insights."""
- context_data = {
- "sources": [],
- "insights": [{
- "id": "insight:test456",
- "insight_type": "summary",
- "content": "Test insight content."
- }],
- "metadata": {
- "source_count": 0,
- "insight_count": 1
- },
- "total_tokens": 50
- }
-
- result = _format_source_context(context_data)
-
- assert "## SOURCE INSIGHTS" in result
- assert "insight:test456" in result
- assert "summary" in result
- assert "Test insight content." in result
-
- def test_format_source_context_empty(self):
- """Test formatting empty context data."""
- context_data = {
- "sources": [],
- "insights": [],
- "metadata": {
- "source_count": 0,
- "insight_count": 0
- },
- "total_tokens": 0
- }
-
- result = _format_source_context(context_data)
-
- assert "## CONTEXT METADATA" in result
- assert "Source count: 0" in result
- assert "Insight count: 0" in result
-
-
-class TestSourceChatIntegration:
- """Test the integration of source chat components."""
-
- @patch('open_notebook.graphs.source_chat.ContextBuilder')
- @patch('open_notebook.graphs.source_chat.provision_langchain_model')
- @patch('open_notebook.graphs.source_chat.Prompter')
- async def test_call_model_with_source_context(
- self,
- mock_prompter,
- mock_provision_model,
- mock_context_builder,
- sample_state,
- sample_config,
- mock_source,
- mock_insight
- ):
- """Test the main model calling function with mocked dependencies."""
-
- # Mock the ContextBuilder
- mock_builder_instance = AsyncMock()
- mock_builder_instance.build.return_value = {
- "sources": [mock_source.model_dump()],
- "insights": [mock_insight.model_dump()],
- "metadata": {"source_count": 1, "insight_count": 1},
- "total_tokens": 150
- }
- mock_context_builder.return_value = mock_builder_instance
-
- # Mock the Prompter
- mock_prompter_instance = MagicMock()
- mock_prompter_instance.render.return_value = "Rendered prompt"
- mock_prompter.return_value = mock_prompter_instance
-
- # Mock the model
- mock_model = AsyncMock()
- mock_ai_message = AIMessage(content="Test response from AI")
- mock_model.invoke.return_value = mock_ai_message
- mock_provision_model.return_value = mock_model
-
- # Call the function
- result = await call_model_with_source_context(sample_state, sample_config) # type: ignore[misc]
-
- # Verify the result
- assert "messages" in result
- assert result["messages"] == mock_ai_message
- assert "source" in result
- assert "insights" in result
- assert "context" in result
- assert "context_indicators" in result
-
- # Verify mocks were called correctly
- mock_context_builder.assert_called_once()
- mock_builder_instance.build.assert_called_once()
- mock_prompter.assert_called_once_with(prompt_template="source_chat")
- mock_provision_model.assert_called_once()
-
- def test_source_chat_graph_structure(self):
- """Test that the source chat graph is properly structured."""
- # Verify the graph has the expected structure
- assert source_chat_graph is not None
-
- # Check that the graph has nodes
- nodes = source_chat_graph.get_graph().nodes
- assert "source_chat_agent" in [node for node in nodes]
-
- # Check that the graph has the checkpointer
- assert source_chat_graph.checkpointer is not None
-
- @pytest.mark.asyncio
- async def test_source_chat_state_validation(self):
- """Test that the source chat validates required state fields."""
- # Test with missing source_id
- invalid_state = SourceChatState(
- messages=[HumanMessage(content="Test")],
- source_id="", # Empty source_id should cause error
- source=None,
- insights=None,
- context=None,
- model_override=None,
- context_indicators=None
- )
-
- config = {"configurable": {"thread_id": "test"}}
-
- # This should raise an error due to missing source_id
- with pytest.raises(ValueError, match="source_id is required"):
- await call_model_with_source_context(invalid_state, config) # type: ignore[misc, arg-type]
-
-
-class TestSourceChatGraphExecution:
- """Test the execution of the source chat graph."""
-
- @patch('open_notebook.graphs.source_chat.Source')
- @patch('open_notebook.graphs.source_chat.ContextBuilder')
- @patch('open_notebook.graphs.source_chat.provision_langchain_model')
- @patch('open_notebook.graphs.source_chat.Prompter')
- @pytest.mark.asyncio
- async def test_graph_execution_flow(
- self,
- mock_prompter,
- mock_provision_model,
- mock_context_builder,
- mock_source_class,
- sample_state,
- sample_config
- ):
- """Test the complete graph execution flow with mocked dependencies."""
-
- # Setup mocks (similar to previous test but for full graph execution)
- mock_builder_instance = AsyncMock()
- mock_builder_instance.build.return_value = {
- "sources": [{"id": "source:test123", "title": "Test"}],
- "insights": [{"id": "insight:test456", "content": "Test insight"}],
- "metadata": {"source_count": 1, "insight_count": 1},
- "total_tokens": 100
- }
- mock_context_builder.return_value = mock_builder_instance
-
- mock_prompter_instance = MagicMock()
- mock_prompter_instance.render.return_value = "Test prompt"
- mock_prompter.return_value = mock_prompter_instance
-
- mock_model = AsyncMock()
- mock_model.invoke.return_value = AIMessage(content="AI response")
- mock_provision_model.return_value = mock_model
-
- # Execute the graph
- result = await source_chat_graph.ainvoke(sample_state, sample_config)
-
- # Verify the result structure
- assert "messages" in result
- assert "source_id" in result
- assert result["source_id"] == "source:test123"
-
-
-if __name__ == "__main__":
- # Run the tests
- pytest.main([__file__, "-v"])
\ No newline at end of file
diff --git a/tests/test_source_chat_api.py b/tests/test_source_chat_api.py
deleted file mode 100644
index 1621ac4..0000000
--- a/tests/test_source_chat_api.py
+++ /dev/null
@@ -1,223 +0,0 @@
-from unittest.mock import AsyncMock, patch
-
-import pytest
-from fastapi.testclient import TestClient
-
-from api.main import app
-
-client = TestClient(app)
-
-
-class TestSourceChatAPI:
- """Test suite for Source Chat API endpoints."""
-
- @pytest.fixture
- def sample_source_id(self):
- return "test_source_123"
-
- @pytest.fixture
- def sample_session_id(self):
- return "test_session_456"
-
- @patch('api.routers.source_chat.Source.get')
- @patch('api.routers.source_chat.ChatSession.save')
- @patch('api.routers.source_chat.ChatSession.relate')
- def test_create_source_chat_session(self, mock_relate, mock_save, mock_source_get, sample_source_id):
- """Test creating a new source chat session."""
- # Mock source exists
- mock_source = AsyncMock()
- mock_source.id = f"source:{sample_source_id}"
- mock_source_get.return_value = mock_source
-
- # Mock session save and relate
- mock_save.return_value = None
- mock_relate.return_value = None
-
- # Create session request
- request_data = {
- "source_id": sample_source_id,
- "title": "Test Chat Session",
- "model_override": "gpt-4"
- }
-
- response = client.post(
- f"/api/sources/{sample_source_id}/chat/sessions",
- json=request_data
- )
-
- assert response.status_code == 200
- data = response.json()
- assert data["title"] == "Test Chat Session"
- assert data["source_id"] == sample_source_id
- assert data["model_override"] == "gpt-4"
- assert "id" in data
- assert "created" in data
-
- @patch('api.routers.source_chat.Source.get')
- def test_create_session_source_not_found(self, mock_source_get, sample_source_id):
- """Test creating session for non-existent source."""
- mock_source_get.return_value = None
-
- request_data = {
- "source_id": sample_source_id,
- "title": "Test Chat Session"
- }
-
- response = client.post(
- f"/api/sources/{sample_source_id}/chat/sessions",
- json=request_data
- )
-
- assert response.status_code == 404
- assert "Source not found" in response.json()["detail"]
-
- @patch('api.routers.source_chat.Source.get')
- @patch('api.routers.source_chat.repo_query')
- def test_get_source_chat_sessions(self, mock_repo_query, mock_source_get, sample_source_id):
- """Test getting all chat sessions for a source."""
- # Mock source exists
- mock_source = AsyncMock()
- mock_source.id = f"source:{sample_source_id}"
- mock_source_get.return_value = mock_source
-
- # Mock query returns sessions
- mock_repo_query.return_value = [
- {"in": "chat_session:session1"},
- {"in": "chat_session:session2"}
- ]
-
- # Mock ChatSession.get for each session
- with patch('api.routers.source_chat.ChatSession.get') as mock_session_get:
- mock_session1 = AsyncMock()
- mock_session1.id = "chat_session:session1"
- mock_session1.title = "Session 1"
- mock_session1.created = "2024-01-01T00:00:00Z"
- mock_session1.updated = "2024-01-01T00:00:00Z"
- mock_session1.model_override = None
-
- mock_session2 = AsyncMock()
- mock_session2.id = "chat_session:session2"
- mock_session2.title = "Session 2"
- mock_session2.created = "2024-01-01T00:00:00Z"
- mock_session2.updated = "2024-01-01T00:00:00Z"
- mock_session2.model_override = "gpt-4"
-
- mock_session_get.side_effect = [mock_session1, mock_session2]
-
- response = client.get(f"/api/sources/{sample_source_id}/chat/sessions")
-
- assert response.status_code == 200
- data = response.json()
- assert len(data) == 2
- assert data[0]["title"] == "Session 1"
- assert data[1]["title"] == "Session 2"
- assert data[1]["model_override"] == "gpt-4"
-
- @patch('api.routers.source_chat.Source.get')
- @patch('api.routers.source_chat.ChatSession.get')
- @patch('api.routers.source_chat.repo_query')
- @patch('api.routers.source_chat.source_chat_graph.get_state')
- def test_get_source_chat_session_with_messages(
- self, mock_get_state, mock_repo_query, mock_session_get, mock_source_get,
- sample_source_id, sample_session_id
- ):
- """Test getting a specific chat session with messages."""
- # Mock source exists
- mock_source = AsyncMock()
- mock_source.id = f"source:{sample_source_id}"
- mock_source_get.return_value = mock_source
-
- # Mock session exists
- mock_session = AsyncMock()
- mock_session.id = f"chat_session:{sample_session_id}"
- mock_session.title = "Test Session"
- mock_session.created = "2024-01-01T00:00:00Z"
- mock_session.updated = "2024-01-01T00:00:00Z"
- mock_session.model_override = "gpt-4"
- mock_session_get.return_value = mock_session
-
- # Mock relation exists
- mock_repo_query.return_value = [{"relation": "exists"}]
-
- # Mock graph state with messages
- mock_message = AsyncMock()
- mock_message.type = "human"
- mock_message.content = "Hello"
- mock_message.id = "msg_1"
-
- mock_state = AsyncMock()
- mock_state.values = {
- "messages": [mock_message],
- "context_indicators": {"sources": ["source:123"], "insights": ["insight:456"], "notes": []}
- }
- mock_get_state.return_value = mock_state
-
- response = client.get(f"/api/sources/{sample_source_id}/chat/sessions/{sample_session_id}")
-
- assert response.status_code == 200
- data = response.json()
- assert data["title"] == "Test Session"
- assert data["model_override"] == "gpt-4"
- assert len(data["messages"]) == 1
- assert data["messages"][0]["content"] == "Hello"
- assert data["context_indicators"]["sources"] == ["source:123"]
-
- @patch('api.routers.source_chat.Source.get')
- @patch('api.routers.source_chat.ChatSession.get')
- @patch('api.routers.source_chat.repo_query')
- @patch('api.routers.source_chat.ChatSession.save')
- def test_update_source_chat_session(
- self, mock_save, mock_repo_query, mock_session_get, mock_source_get,
- sample_source_id, sample_session_id
- ):
- """Test updating a source chat session."""
- # Mock source exists
- mock_source = AsyncMock()
- mock_source.id = f"source:{sample_source_id}"
- mock_source_get.return_value = mock_source
-
- # Mock session exists
- mock_session = AsyncMock()
- mock_session.id = f"chat_session:{sample_session_id}"
- mock_session.title = "Old Title"
- mock_session.created = "2024-01-01T00:00:00Z"
- mock_session.updated = "2024-01-01T00:00:00Z"
- mock_session.model_override = None
- mock_session_get.return_value = mock_session
-
- # Mock relation exists
- mock_repo_query.return_value = [{"relation": "exists"}]
-
- # Mock save
- mock_save.return_value = None
-
- request_data = {
- "title": "New Title",
- "model_override": "gpt-4"
- }
-
- response = client.put(
- f"/api/sources/{sample_source_id}/chat/sessions/{sample_session_id}",
- json=request_data
- )
-
- assert response.status_code == 200
- data = response.json()
- assert data["title"] == "New Title"
- # Note: The mock will still return the original values unless we update them
- # In a real test, we'd want to verify the session was updated properly
-
- def test_api_endpoints_structure(self):
- """Test that all expected endpoints are properly structured."""
- # Test endpoint paths are correctly formed
- from api.routers.source_chat import router
-
- routes = [route.path for route in router.routes] # type: ignore[attr-defined]
- expected_routes = [
- "/sources/{source_id}/chat/sessions",
- "/sources/{source_id}/chat/sessions/{session_id}",
- "/sources/{source_id}/chat/sessions/{session_id}/messages"
- ]
-
- for expected_route in expected_routes:
- assert any(expected_route in route for route in routes), f"Route {expected_route} not found"
\ No newline at end of file
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..fb2b018
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,259 @@
+"""
+Unit tests for the open_notebook.utils module.
+
+This test suite focuses on testing utility functions that perform actual logic
+without heavy mocking - string processing, validation, and algorithms.
+"""
+
+import pytest
+
+from open_notebook.utils import (
+ clean_thinking_content,
+ compare_versions,
+ get_installed_version,
+ parse_thinking_content,
+ remove_non_ascii,
+ remove_non_printable,
+ split_text,
+ token_count,
+)
+from open_notebook.utils.context_builder import ContextBuilder, ContextConfig
+
+
+# ============================================================================
+# TEST SUITE 1: Text Utilities
+# ============================================================================
+
+
+class TestTextUtilities:
+ """Test suite for text utility functions."""
+
+ def test_split_text_empty_string(self):
+ """Test splitting empty or very short strings."""
+ assert split_text("") == []
+ assert split_text("short") == ["short"]
+
+ def test_remove_non_ascii(self):
+ """Test removal of non-ASCII characters."""
+ # Text with various non-ASCII characters
+ text_with_unicode = "Hello 世界 café naïve émoji 🎉"
+ result = remove_non_ascii(text_with_unicode)
+
+ # Should only contain ASCII characters
+ assert result == "Hello caf nave moji "
+ # All characters should be in ASCII range
+ assert all(ord(char) < 128 for char in result)
+
+ def test_remove_non_ascii_pure_ascii(self):
+ """Test that pure ASCII text is unchanged."""
+ text = "Hello World 123 !@#"
+ result = remove_non_ascii(text)
+ assert result == text
+
+ def test_remove_non_printable(self):
+ """Test removal of non-printable characters."""
+ # Text with various Unicode whitespace and control chars
+ text = "Hello\u2000World\u200B\u202FTest"
+ result = remove_non_printable(text)
+
+ # Should have regular spaces and printable chars only
+ assert "Hello" in result
+ assert "World" in result
+ assert "Test" in result
+
+ def test_remove_non_printable_preserves_newlines(self):
+ """Test that newlines and tabs are preserved."""
+ text = "Line1\nLine2\tTabbed"
+ result = remove_non_printable(text)
+ assert "\n" in result
+ assert "\t" in result
+
+ def test_parse_thinking_content_basic(self):
+ """Test parsing single thinking block."""
+ content = "This is my thinkingHere is my answer"
+ thinking, cleaned = parse_thinking_content(content)
+
+ assert thinking == "This is my thinking"
+ assert cleaned == "Here is my answer"
+
+ def test_parse_thinking_content_multiple_tags(self):
+ """Test parsing multiple thinking blocks."""
+ content = "First thoughtAnswerSecond thoughtMore"
+ thinking, cleaned = parse_thinking_content(content)
+
+ assert "First thought" in thinking
+ assert "Second thought" in thinking
+ assert "" not in cleaned
+ assert "Answer" in cleaned
+ assert "More" in cleaned
+
+ def test_parse_thinking_content_no_tags(self):
+ """Test parsing content without thinking tags."""
+ content = "Just regular content"
+ thinking, cleaned = parse_thinking_content(content)
+
+ assert thinking == ""
+ assert cleaned == "Just regular content"
+
+ def test_parse_thinking_content_invalid_input(self):
+ """Test parsing with invalid input types."""
+ # Non-string input
+ thinking, cleaned = parse_thinking_content(None)
+ assert thinking == ""
+ assert cleaned == ""
+
+ # Integer input
+ thinking, cleaned = parse_thinking_content(123)
+ assert thinking == ""
+ assert cleaned == "123"
+
+ def test_parse_thinking_content_large_content(self):
+ """Test that very large content is not processed."""
+ large_content = "x" * 200000 # > 100KB limit
+ thinking, cleaned = parse_thinking_content(large_content)
+
+ # Should return unchanged due to size limit
+ assert thinking == ""
+ assert cleaned == large_content
+
+ def test_clean_thinking_content(self):
+ """Test convenience function for cleaning thinking content."""
+ content = "Internal thoughtsPublic response"
+ result = clean_thinking_content(content)
+
+ assert "" not in result
+ assert "Public response" in result
+ assert "Internal thoughts" not in result
+
+
+# ============================================================================
+# TEST SUITE 2: Token Utilities
+# ============================================================================
+
+
+class TestTokenUtilities:
+ """Test suite for token counting fallback behavior."""
+
+ def test_token_count_fallback(self):
+ """Test fallback when tiktoken raises an error."""
+ from unittest.mock import patch
+
+ # Make tiktoken raise an ImportError to trigger fallback
+ with patch("tiktoken.get_encoding", side_effect=ImportError("tiktoken not available")):
+ text = "one two three four five"
+ count = token_count(text)
+
+ # Fallback uses word count * 1.3
+ # 5 words * 1.3 = 6.5 -> 6
+ assert isinstance(count, int)
+ assert count > 0
+
+
+# ============================================================================
+# TEST SUITE 3: Version Utilities
+# ============================================================================
+
+
+class TestVersionUtilities:
+ """Test suite for version management functions."""
+
+ def test_compare_versions_equal(self):
+ """Test comparing equal versions."""
+ result = compare_versions("1.0.0", "1.0.0")
+ assert result == 0
+
+ def test_compare_versions_less_than(self):
+ """Test comparing when first version is less."""
+ result = compare_versions("1.0.0", "2.0.0")
+ assert result == -1
+
+ result = compare_versions("1.0.0", "1.1.0")
+ assert result == -1
+
+ result = compare_versions("1.0.0", "1.0.1")
+ assert result == -1
+
+ def test_compare_versions_greater_than(self):
+ """Test comparing when first version is greater."""
+ result = compare_versions("2.0.0", "1.0.0")
+ assert result == 1
+
+ result = compare_versions("1.1.0", "1.0.0")
+ assert result == 1
+
+ result = compare_versions("1.0.1", "1.0.0")
+ assert result == 1
+
+ def test_compare_versions_prerelease(self):
+ """Test comparing versions with pre-release tags."""
+ result = compare_versions("1.0.0", "1.0.0-alpha")
+ assert result == 1 # Release > pre-release
+
+ result = compare_versions("1.0.0-beta", "1.0.0-alpha")
+ assert result == 1 # beta > alpha
+
+ def test_get_installed_version_success(self):
+ """Test getting installed package version."""
+ # Test with a known installed package
+ version = get_installed_version("pytest")
+ assert isinstance(version, str)
+ assert len(version) > 0
+ # Should look like a version (has dots)
+ assert "." in version
+
+ def test_get_installed_version_not_found(self):
+ """Test getting version of non-existent package."""
+ from importlib.metadata import PackageNotFoundError
+
+ with pytest.raises(PackageNotFoundError):
+ get_installed_version("this-package-does-not-exist-12345")
+
+ def test_get_version_from_github_invalid_url(self):
+ """Test GitHub version fetch with invalid URL."""
+ from open_notebook.utils.version_utils import get_version_from_github
+
+ with pytest.raises(ValueError, match="Not a GitHub URL"):
+ get_version_from_github("https://example.com/repo")
+
+ with pytest.raises(ValueError, match="Invalid GitHub repository URL"):
+ get_version_from_github("https://github.com/")
+
+
+# ============================================================================
+# TEST SUITE 4: Context Builder Configuration
+# ============================================================================
+
+
+class TestContextBuilder:
+ """Test suite for ContextBuilder initialization and configuration."""
+
+ def test_context_config_defaults(self):
+ """Test ContextConfig default values."""
+ config = ContextConfig()
+
+ assert config.sources == {}
+ assert config.notes == {}
+ assert config.include_insights is True
+ assert config.include_notes is True
+ assert config.priority_weights is not None
+ assert "source" in config.priority_weights
+ assert "note" in config.priority_weights
+ assert "insight" in config.priority_weights
+
+ def test_context_builder_initialization(self):
+ """Test ContextBuilder initialization with various params."""
+ builder = ContextBuilder(
+ source_id="source:123",
+ notebook_id="notebook:456",
+ max_tokens=1000,
+ include_insights=False
+ )
+
+ assert builder.source_id == "source:123"
+ assert builder.notebook_id == "notebook:456"
+ assert builder.max_tokens == 1000
+ assert builder.include_insights is False
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/uv.lock b/uv.lock
index 5c14f20..b868ff6 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2256,6 +2256,7 @@ dev = [
[package.dev-dependencies]
dev = [
{ name = "pre-commit" },
+ { name = "pytest-asyncio" },
{ name = "types-requests" },
]
@@ -2302,6 +2303,7 @@ provides-extras = ["dev"]
[package.metadata.requires-dev]
dev = [
{ name = "pre-commit", specifier = ">=4.1.0" },
+ { name = "pytest-asyncio", specifier = ">=1.2.0" },
{ name = "types-requests", specifier = ">=2.32.4.20250913" },
]
@@ -2973,6 +2975,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" },
]
+[[package]]
+name = "pytest-asyncio"
+version = "1.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pytest" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" },
+]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"