# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import shutil import tempfile from pathlib import Path from unittest.mock import MagicMock, Mock, patch import pytest from app.agent.toolkit.rag_toolkit import RAGToolkit @pytest.fixture def temp_storage_path(): """Create a temporary storage path for tests.""" temp_dir = tempfile.mkdtemp() yield Path(temp_dir) shutil.rmtree(temp_dir, ignore_errors=True) @pytest.fixture def toolkit(temp_storage_path): """Create a RAGToolkit instance with mocked AutoRetriever.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task-123", storage_path=temp_storage_path, ) return toolkit def test_toolkit_initialization(temp_storage_path): """Test RAGToolkit initializes correctly.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever") as mock_ar: with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task-456", collection_name="test_collection", storage_path=temp_storage_path, ) assert toolkit.api_task_id == "test-task-456" assert toolkit._storage_path == temp_storage_path assert toolkit._collection_name == "test_collection" assert temp_storage_path.exists() mock_ar.assert_called_once() call_kwargs = mock_ar.call_args[1] assert ( str(temp_storage_path) in call_kwargs["vector_storage_local_path"] ) def test_toolkit_initialization_with_custom_agent(temp_storage_path): """Test RAGToolkit with custom agent name.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", agent_name="custom_agent", storage_path=temp_storage_path, ) assert toolkit.agent_name == "custom_agent" def test_list_knowledge_bases_empty(temp_storage_path): """Test list_knowledge_bases when no KBs exist.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", storage_path=temp_storage_path, ) result = toolkit.list_knowledge_bases() assert "No knowledge bases found" in result def test_list_knowledge_bases_with_tasks(temp_storage_path): """Test list_knowledge_bases when task directories exist.""" (temp_storage_path / "task_123").mkdir() (temp_storage_path / "task_456").mkdir() with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", storage_path=temp_storage_path, ) result = toolkit.list_knowledge_bases() assert "task_123" in result assert "task_456" in result def test_get_tools_returns_three_tools(toolkit): """Test get_tools returns RAG tools.""" tools = toolkit.get_tools() assert len(tools) == 3 tool_names = [t.func.__name__ for t in tools] assert "add_document" in tool_names assert "query_knowledge_base" in tool_names assert "information_retrieval" in tool_names def test_get_can_use_tools_returns_tools(temp_storage_path): """Test get_can_use_tools returns tools.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.object(RAGToolkit, "get_tools") as mock_get_tools: mock_get_tools.return_value = [Mock(), Mock()] tools = RAGToolkit.get_can_use_tools("test-task") assert len(tools) == 2 def test_get_can_use_tools_auto_derives_collection_name(temp_storage_path): """Test get_can_use_tools auto-derives collection_name from api_task_id.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.object( RAGToolkit, "__init__", return_value=None ) as mock_init: with patch.object(RAGToolkit, "get_tools", return_value=[]): RAGToolkit.get_can_use_tools("test-task-123") mock_init.assert_called_once_with( api_task_id="test-task-123", collection_name="task_test-task-123", ) def test_default_collection_name(temp_storage_path): """Test default collection_name when not provided.""" with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"): with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", storage_path=temp_storage_path, ) assert toolkit._collection_name == "default" @patch("app.agent.toolkit.rag_toolkit.AutoRetriever") def test_information_retrieval_success( mock_auto_retriever_class, temp_storage_path ): """Test successful information retrieval.""" mock_auto_retriever = MagicMock() mock_auto_retriever.run_vector_retriever.return_value = { "text": ["Relevant content about the query"] } mock_auto_retriever_class.return_value = mock_auto_retriever with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", storage_path=temp_storage_path, ) result = toolkit.information_retrieval( query="What is the content?", contents="/path/to/document.pdf", top_k=5, ) assert isinstance(result, str) mock_auto_retriever.run_vector_retriever.assert_called_once() @patch("app.agent.toolkit.rag_toolkit.AutoRetriever") def test_information_retrieval_with_error( mock_auto_retriever_class, temp_storage_path ): """Test information retrieval handles errors gracefully.""" mock_auto_retriever = MagicMock() mock_auto_retriever.run_vector_retriever.side_effect = Exception( "Test error" ) mock_auto_retriever_class.return_value = mock_auto_retriever with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", storage_path=temp_storage_path, ) result = toolkit.information_retrieval( query="What is the content?", contents="/path/to/document.pdf", ) assert "Error" in result assert "Test error" in result @patch("app.agent.toolkit.rag_toolkit.AutoRetriever") def test_information_retrieval_with_list_contents( mock_auto_retriever_class, temp_storage_path ): """Test information retrieval with multiple content sources.""" mock_auto_retriever = MagicMock() mock_auto_retriever.run_vector_retriever.return_value = { "text": ["Combined results from multiple sources"] } mock_auto_retriever_class.return_value = mock_auto_retriever with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): toolkit = RAGToolkit( api_task_id="test-task", storage_path=temp_storage_path, ) result = toolkit.information_retrieval( query="What is the content?", contents=["/path/to/doc1.pdf", "/path/to/doc2.pdf"], ) assert isinstance(result, str) mock_auto_retriever.run_vector_retriever.assert_called_once()