eigent/backend/tests/app/agent/toolkit/test_rag_toolkit.py
Dream ba47db8a84
refactor: move toolkit from utils to agent module (#1045) (#1171)
Co-authored-by: bytecii <994513625@qq.com>
2026-02-06 15:22:21 -08:00

231 lines
8.2 KiB
Python

# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import pytest
from app.agent.toolkit.rag_toolkit import RAGToolkit
@pytest.fixture
def temp_storage_path():
"""Create a temporary storage path for tests."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def toolkit(temp_storage_path):
"""Create a RAGToolkit instance with mocked AutoRetriever."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task-123",
storage_path=temp_storage_path,
)
return toolkit
def test_toolkit_initialization(temp_storage_path):
"""Test RAGToolkit initializes correctly."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever") as mock_ar:
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task-456",
collection_name="test_collection",
storage_path=temp_storage_path,
)
assert toolkit.api_task_id == "test-task-456"
assert toolkit._storage_path == temp_storage_path
assert toolkit._collection_name == "test_collection"
assert temp_storage_path.exists()
mock_ar.assert_called_once()
call_kwargs = mock_ar.call_args[1]
assert (
str(temp_storage_path)
in call_kwargs["vector_storage_local_path"]
)
def test_toolkit_initialization_with_custom_agent(temp_storage_path):
"""Test RAGToolkit with custom agent name."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
agent_name="custom_agent",
storage_path=temp_storage_path,
)
assert toolkit.agent_name == "custom_agent"
def test_list_knowledge_bases_empty(temp_storage_path):
"""Test list_knowledge_bases when no KBs exist."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
storage_path=temp_storage_path,
)
result = toolkit.list_knowledge_bases()
assert "No knowledge bases found" in result
def test_list_knowledge_bases_with_tasks(temp_storage_path):
"""Test list_knowledge_bases when task directories exist."""
(temp_storage_path / "task_123").mkdir()
(temp_storage_path / "task_456").mkdir()
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
storage_path=temp_storage_path,
)
result = toolkit.list_knowledge_bases()
assert "task_123" in result
assert "task_456" in result
def test_get_tools_returns_three_tools(toolkit):
"""Test get_tools returns RAG tools."""
tools = toolkit.get_tools()
assert len(tools) == 3
tool_names = [t.func.__name__ for t in tools]
assert "add_document" in tool_names
assert "query_knowledge_base" in tool_names
assert "information_retrieval" in tool_names
def test_get_can_use_tools_returns_tools(temp_storage_path):
"""Test get_can_use_tools returns tools."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.object(RAGToolkit, "get_tools") as mock_get_tools:
mock_get_tools.return_value = [Mock(), Mock()]
tools = RAGToolkit.get_can_use_tools("test-task")
assert len(tools) == 2
def test_get_can_use_tools_auto_derives_collection_name(temp_storage_path):
"""Test get_can_use_tools auto-derives collection_name from api_task_id."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.object(
RAGToolkit, "__init__", return_value=None
) as mock_init:
with patch.object(RAGToolkit, "get_tools", return_value=[]):
RAGToolkit.get_can_use_tools("test-task-123")
mock_init.assert_called_once_with(
api_task_id="test-task-123",
collection_name="task_test-task-123",
)
def test_default_collection_name(temp_storage_path):
"""Test default collection_name when not provided."""
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
storage_path=temp_storage_path,
)
assert toolkit._collection_name == "default"
@patch("app.agent.toolkit.rag_toolkit.AutoRetriever")
def test_information_retrieval_success(
mock_auto_retriever_class, temp_storage_path
):
"""Test successful information retrieval."""
mock_auto_retriever = MagicMock()
mock_auto_retriever.run_vector_retriever.return_value = {
"text": ["Relevant content about the query"]
}
mock_auto_retriever_class.return_value = mock_auto_retriever
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
storage_path=temp_storage_path,
)
result = toolkit.information_retrieval(
query="What is the content?",
contents="/path/to/document.pdf",
top_k=5,
)
assert isinstance(result, str)
mock_auto_retriever.run_vector_retriever.assert_called_once()
@patch("app.agent.toolkit.rag_toolkit.AutoRetriever")
def test_information_retrieval_with_error(
mock_auto_retriever_class, temp_storage_path
):
"""Test information retrieval handles errors gracefully."""
mock_auto_retriever = MagicMock()
mock_auto_retriever.run_vector_retriever.side_effect = Exception(
"Test error"
)
mock_auto_retriever_class.return_value = mock_auto_retriever
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
storage_path=temp_storage_path,
)
result = toolkit.information_retrieval(
query="What is the content?",
contents="/path/to/document.pdf",
)
assert "Error" in result
assert "Test error" in result
@patch("app.agent.toolkit.rag_toolkit.AutoRetriever")
def test_information_retrieval_with_list_contents(
mock_auto_retriever_class, temp_storage_path
):
"""Test information retrieval with multiple content sources."""
mock_auto_retriever = MagicMock()
mock_auto_retriever.run_vector_retriever.return_value = {
"text": ["Combined results from multiple sources"]
}
mock_auto_retriever_class.return_value = mock_auto_retriever
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
toolkit = RAGToolkit(
api_task_id="test-task",
storage_path=temp_storage_path,
)
result = toolkit.information_retrieval(
query="What is the content?",
contents=["/path/to/doc1.pdf", "/path/to/doc2.pdf"],
)
assert isinstance(result, str)
mock_auto_retriever.run_vector_retriever.assert_called_once()