mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-22 11:15:42 +00:00
Merge f98ace6d28 into e911e5e03a
This commit is contained in:
commit
2ea6cb5aed
3 changed files with 209 additions and 5 deletions
28
.github/workflows/pytest.yml
vendored
Normal file
28
.github/workflows/pytest.yml
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
name: Pytest CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ feature/memory-saved-after-hook ]
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-asyncio pyyaml
|
||||
|
||||
- name: Run memory hook tests
|
||||
run: |
|
||||
pytest tests/test_memory_hook.py
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Sequence
|
||||
from langchain.storage import InMemoryByteStore, LocalFileStore
|
||||
|
|
@ -16,9 +18,8 @@ from langchain_community.docstore.in_memory import InMemoryDocstore
|
|||
from langchain_community.vectorstores.utils import (
|
||||
DistanceStrategy,
|
||||
)
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
import os, json
|
||||
import os
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -392,9 +393,34 @@ class Memory:
|
|||
self._save_db() # persist
|
||||
return rem_docs
|
||||
|
||||
async def insert_text(self, text, metadata: dict = {}):
|
||||
doc = Document(text, metadata=metadata)
|
||||
async def insert_text(self, text, metadata: dict | None = None):
|
||||
from helpers.extension import call_extensions_async
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
# memory_save_before: pass mutable object so extensions can edit or skip
|
||||
obj = {"text": text, "metadata": metadata, "memory_subdir": self.memory_subdir}
|
||||
await call_extensions_async(
|
||||
"memory_save_before",
|
||||
agent=getattr(self, "agent", None),
|
||||
object=obj,
|
||||
)
|
||||
|
||||
# If an extension set text to None, skip the save
|
||||
if obj["text"] is None:
|
||||
return None
|
||||
|
||||
doc = Document(obj["text"], metadata=obj["metadata"])
|
||||
ids = await self.insert_documents([doc])
|
||||
|
||||
# memory_save_after: notify extensions after successful persist
|
||||
obj["doc_id"] = ids[0]
|
||||
await call_extensions_async(
|
||||
"memory_save_after",
|
||||
agent=getattr(self, "agent", None),
|
||||
object=obj,
|
||||
)
|
||||
|
||||
return ids[0]
|
||||
|
||||
async def insert_documents(self, docs: list[Document]):
|
||||
|
|
@ -575,3 +601,4 @@ def get_knowledge_subdirs_by_memory_subdir(
|
|||
|
||||
default.append(get_project_meta(memory_subdir[9:], "knowledge"))
|
||||
return default
|
||||
|
||||
|
|
|
|||
149
tests/test_memory_hook.py
Normal file
149
tests/test_memory_hook.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
import sys
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock, call
|
||||
from types import SimpleNamespace
|
||||
|
||||
# Catch-all mock importer to avoid heavy Agent Zero dependencies locally
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
class MockLoader:
|
||||
def create_module(self, spec):
|
||||
if spec.name in sys.modules:
|
||||
return sys.modules[spec.name]
|
||||
mock = MagicMock()
|
||||
# Ensure submodules can be accessed via attributes
|
||||
mock.__path__ = []
|
||||
sys.modules[spec.name] = mock
|
||||
return mock
|
||||
|
||||
def exec_module(self, module):
|
||||
pass
|
||||
|
||||
class MockImporter:
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
catch_prefixes = [
|
||||
'langchain', 'faiss', 'simpleeval', 'webcolors', 'litellm',
|
||||
'openai', 'cryptography', 'nest_asyncio', 'whisper', 'git',
|
||||
'tiktoken', 'browser_use', 'docker', 'duckduckgo_search', 'bs4',
|
||||
'html2text', 'yaml', 'aiohttp', 'jinja2', 'markdown', 'requests',
|
||||
'sentence_transformers', 'regex', 'pydantic', 'rich', 'pymupdf',
|
||||
'playwright', 'pathspec', 'tenacity', 'dotenv'
|
||||
]
|
||||
if any(fullname.startswith(p) for p in catch_prefixes):
|
||||
return ModuleSpec(fullname, MockLoader(), is_package=True)
|
||||
return None
|
||||
|
||||
sys.meta_path.insert(0, MockImporter())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory():
|
||||
from plugins._memory.helpers.memory import Memory
|
||||
# Create a dummy Memory object bypassing init to avoid Faiss overhead
|
||||
mem = Memory.__new__(Memory)
|
||||
mem.memory_subdir = "test_subdir"
|
||||
mem.agent = SimpleNamespace(name="TestAgent")
|
||||
# Mock insert_documents since we only test the hook behavior
|
||||
mem.insert_documents = AsyncMock(return_value=["doc-123"])
|
||||
return mem
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_save_before_called_with_object(mock_memory):
|
||||
"""memory_save_before receives a mutable {object} dict."""
|
||||
text = "Hello world"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
|
||||
doc_id = await mock_memory.insert_text(text, metadata=metadata)
|
||||
|
||||
assert doc_id == "doc-123"
|
||||
# memory_save_before is the first call
|
||||
before_call = mock_call_ext.call_args_list[0]
|
||||
assert before_call == call(
|
||||
"memory_save_before",
|
||||
agent=mock_memory.agent,
|
||||
object={"text": text, "metadata": metadata, "memory_subdir": "test_subdir"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_save_after_called_with_doc_id(mock_memory):
|
||||
"""memory_save_after receives the object with doc_id after persist."""
|
||||
text = "Hello world"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
|
||||
doc_id = await mock_memory.insert_text(text, metadata=metadata)
|
||||
|
||||
assert doc_id == "doc-123"
|
||||
# memory_save_after is the second call
|
||||
after_call = mock_call_ext.call_args_list[1]
|
||||
assert after_call == call(
|
||||
"memory_save_after",
|
||||
agent=mock_memory.agent,
|
||||
object={
|
||||
"text": text,
|
||||
"metadata": metadata,
|
||||
"memory_subdir": "test_subdir",
|
||||
"doc_id": "doc-123",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_save_skipped_when_text_none(mock_memory):
|
||||
"""Save is skipped when memory_save_before sets object['text'] to None."""
|
||||
text = "Hello world"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
async def nullify_text(*args, **kwargs):
|
||||
# Simulate an extension setting text to None
|
||||
obj = kwargs.get("object")
|
||||
if obj is not None:
|
||||
obj["text"] = None
|
||||
|
||||
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
|
||||
mock_call_ext.side_effect = nullify_text
|
||||
doc_id = await mock_memory.insert_text(text, metadata=metadata)
|
||||
|
||||
# Save was skipped
|
||||
assert doc_id is None
|
||||
# insert_documents was never called
|
||||
mock_memory.insert_documents.assert_not_called()
|
||||
# Only memory_save_before was called (no after)
|
||||
assert mock_call_ext.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_save_before_can_modify_text(mock_memory):
|
||||
"""Extensions can modify the text via memory_save_before."""
|
||||
text = "Original"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
async def modify_text(*args, **kwargs):
|
||||
obj = kwargs.get("object")
|
||||
if obj is not None and obj.get("text") == "Original":
|
||||
obj["text"] = "Modified by extension"
|
||||
|
||||
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
|
||||
mock_call_ext.side_effect = modify_text
|
||||
doc_id = await mock_memory.insert_text(text, metadata=metadata)
|
||||
|
||||
assert doc_id == "doc-123"
|
||||
# Verify insert_documents was called with the modified text
|
||||
call_args = mock_memory.insert_documents.call_args
|
||||
doc = call_args[0][0][0]
|
||||
assert doc.page_content == "Modified by extension"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extension_exceptions_propagate(mock_memory):
|
||||
"""No try/catch — extension errors propagate to the caller."""
|
||||
text = "Hello world"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
|
||||
mock_call_ext.side_effect = RuntimeError("Extension crashed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Extension crashed"):
|
||||
await mock_memory.insert_text(text, metadata=metadata)
|
||||
Loading…
Add table
Add a link
Reference in a new issue