mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-24 05:26:42 +00:00
231 lines
8.2 KiB
Python
231 lines
8.2 KiB
Python
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
|
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from app.agent.toolkit.rag_toolkit import RAGToolkit
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_storage_path():
|
|
"""Create a temporary storage path for tests."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
yield Path(temp_dir)
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def toolkit(temp_storage_path):
|
|
"""Create a RAGToolkit instance with mocked AutoRetriever."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task-123",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
return toolkit
|
|
|
|
|
|
def test_toolkit_initialization(temp_storage_path):
|
|
"""Test RAGToolkit initializes correctly."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever") as mock_ar:
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task-456",
|
|
collection_name="test_collection",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
assert toolkit.api_task_id == "test-task-456"
|
|
assert toolkit._storage_path == temp_storage_path
|
|
assert toolkit._collection_name == "test_collection"
|
|
assert temp_storage_path.exists()
|
|
mock_ar.assert_called_once()
|
|
call_kwargs = mock_ar.call_args[1]
|
|
assert (
|
|
str(temp_storage_path)
|
|
in call_kwargs["vector_storage_local_path"]
|
|
)
|
|
|
|
|
|
def test_toolkit_initialization_with_custom_agent(temp_storage_path):
|
|
"""Test RAGToolkit with custom agent name."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
agent_name="custom_agent",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
assert toolkit.agent_name == "custom_agent"
|
|
|
|
|
|
def test_list_knowledge_bases_empty(temp_storage_path):
|
|
"""Test list_knowledge_bases when no KBs exist."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
result = toolkit.list_knowledge_bases()
|
|
assert "No knowledge bases found" in result
|
|
|
|
|
|
def test_list_knowledge_bases_with_tasks(temp_storage_path):
|
|
"""Test list_knowledge_bases when task directories exist."""
|
|
(temp_storage_path / "task_123").mkdir()
|
|
(temp_storage_path / "task_456").mkdir()
|
|
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
result = toolkit.list_knowledge_bases()
|
|
assert "task_123" in result
|
|
assert "task_456" in result
|
|
|
|
|
|
def test_get_tools_returns_three_tools(toolkit):
|
|
"""Test get_tools returns RAG tools."""
|
|
tools = toolkit.get_tools()
|
|
|
|
assert len(tools) == 3
|
|
tool_names = [t.func.__name__ for t in tools]
|
|
assert "add_document" in tool_names
|
|
assert "query_knowledge_base" in tool_names
|
|
assert "information_retrieval" in tool_names
|
|
|
|
|
|
def test_get_can_use_tools_returns_tools(temp_storage_path):
|
|
"""Test get_can_use_tools returns tools."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.object(RAGToolkit, "get_tools") as mock_get_tools:
|
|
mock_get_tools.return_value = [Mock(), Mock()]
|
|
tools = RAGToolkit.get_can_use_tools("test-task")
|
|
assert len(tools) == 2
|
|
|
|
|
|
def test_get_can_use_tools_auto_derives_collection_name(temp_storage_path):
|
|
"""Test get_can_use_tools auto-derives collection_name from api_task_id."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.object(
|
|
RAGToolkit, "__init__", return_value=None
|
|
) as mock_init:
|
|
with patch.object(RAGToolkit, "get_tools", return_value=[]):
|
|
RAGToolkit.get_can_use_tools("test-task-123")
|
|
mock_init.assert_called_once_with(
|
|
api_task_id="test-task-123",
|
|
collection_name="task_test-task-123",
|
|
)
|
|
|
|
|
|
def test_default_collection_name(temp_storage_path):
|
|
"""Test default collection_name when not provided."""
|
|
with patch("app.agent.toolkit.rag_toolkit.AutoRetriever"):
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
assert toolkit._collection_name == "default"
|
|
|
|
|
|
@patch("app.agent.toolkit.rag_toolkit.AutoRetriever")
|
|
def test_information_retrieval_success(
|
|
mock_auto_retriever_class, temp_storage_path
|
|
):
|
|
"""Test successful information retrieval."""
|
|
mock_auto_retriever = MagicMock()
|
|
mock_auto_retriever.run_vector_retriever.return_value = {
|
|
"text": ["Relevant content about the query"]
|
|
}
|
|
mock_auto_retriever_class.return_value = mock_auto_retriever
|
|
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
result = toolkit.information_retrieval(
|
|
query="What is the content?",
|
|
contents="/path/to/document.pdf",
|
|
top_k=5,
|
|
)
|
|
|
|
assert isinstance(result, str)
|
|
mock_auto_retriever.run_vector_retriever.assert_called_once()
|
|
|
|
|
|
@patch("app.agent.toolkit.rag_toolkit.AutoRetriever")
|
|
def test_information_retrieval_with_error(
|
|
mock_auto_retriever_class, temp_storage_path
|
|
):
|
|
"""Test information retrieval handles errors gracefully."""
|
|
mock_auto_retriever = MagicMock()
|
|
mock_auto_retriever.run_vector_retriever.side_effect = Exception(
|
|
"Test error"
|
|
)
|
|
mock_auto_retriever_class.return_value = mock_auto_retriever
|
|
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
result = toolkit.information_retrieval(
|
|
query="What is the content?",
|
|
contents="/path/to/document.pdf",
|
|
)
|
|
|
|
assert "Error" in result
|
|
assert "Test error" in result
|
|
|
|
|
|
@patch("app.agent.toolkit.rag_toolkit.AutoRetriever")
|
|
def test_information_retrieval_with_list_contents(
|
|
mock_auto_retriever_class, temp_storage_path
|
|
):
|
|
"""Test information retrieval with multiple content sources."""
|
|
mock_auto_retriever = MagicMock()
|
|
mock_auto_retriever.run_vector_retriever.return_value = {
|
|
"text": ["Combined results from multiple sources"]
|
|
}
|
|
mock_auto_retriever_class.return_value = mock_auto_retriever
|
|
|
|
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
|
toolkit = RAGToolkit(
|
|
api_task_id="test-task",
|
|
storage_path=temp_storage_path,
|
|
)
|
|
|
|
result = toolkit.information_retrieval(
|
|
query="What is the content?",
|
|
contents=["/path/to/doc1.pdf", "/path/to/doc2.pdf"],
|
|
)
|
|
|
|
assert isinstance(result, str)
|
|
mock_auto_retriever.run_vector_retriever.assert_called_once()
|