diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 000000000..732411249 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,28 @@ +name: Pytest CI + +on: + push: + branches: [ feature/memory-saved-after-hook ] + pull_request: + branches: [ development ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-asyncio pyyaml + + - name: Run memory hook tests + run: | + pytest tests/test_memory_hook.py diff --git a/plugins/_memory/helpers/memory.py b/plugins/_memory/helpers/memory.py index 43518338e..b7d8fd968 100644 --- a/plugins/_memory/helpers/memory.py +++ b/plugins/_memory/helpers/memory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Any, List, Sequence from langchain.storage import InMemoryByteStore, LocalFileStore @@ -16,9 +18,8 @@ from langchain_community.docstore.in_memory import InMemoryDocstore from langchain_community.vectorstores.utils import ( DistanceStrategy, ) -from langchain_core.embeddings import Embeddings - -import os, json +import os +import json import numpy as np @@ -392,9 +393,34 @@ class Memory: self._save_db() # persist return rem_docs - async def insert_text(self, text, metadata: dict = {}): - doc = Document(text, metadata=metadata) + async def insert_text(self, text, metadata: dict | None = None): + from helpers.extension import call_extensions_async + + metadata = metadata or {} + + # memory_save_before: pass mutable object so extensions can edit or skip + obj = {"text": text, "metadata": metadata, "memory_subdir": self.memory_subdir} + await call_extensions_async( + "memory_save_before", + agent=getattr(self, "agent", None), + object=obj, + ) + + # If an extension set text to None, skip the save + if obj["text"] is None: + return None + + doc = Document(obj["text"], metadata=obj["metadata"]) ids = await self.insert_documents([doc]) + + # memory_save_after: notify extensions after successful persist + obj["doc_id"] = ids[0] + await call_extensions_async( + "memory_save_after", + agent=getattr(self, "agent", None), + object=obj, + ) + return ids[0] async def insert_documents(self, docs: list[Document]): @@ -575,3 +601,4 @@ def get_knowledge_subdirs_by_memory_subdir( default.append(get_project_meta(memory_subdir[9:], "knowledge")) return default + diff --git a/tests/test_memory_hook.py b/tests/test_memory_hook.py new file mode 100644 index 000000000..def90ec22 --- /dev/null +++ b/tests/test_memory_hook.py @@ -0,0 +1,149 @@ +import sys +import pytest +from unittest.mock import patch, AsyncMock, MagicMock, call +from types import SimpleNamespace + +# Catch-all mock importer to avoid heavy Agent Zero dependencies locally +from importlib.machinery import ModuleSpec + +class MockLoader: + def create_module(self, spec): + if spec.name in sys.modules: + return sys.modules[spec.name] + mock = MagicMock() + # Ensure submodules can be accessed via attributes + mock.__path__ = [] + sys.modules[spec.name] = mock + return mock + + def exec_module(self, module): + pass + +class MockImporter: + def find_spec(self, fullname, path, target=None): + catch_prefixes = [ + 'langchain', 'faiss', 'simpleeval', 'webcolors', 'litellm', + 'openai', 'cryptography', 'nest_asyncio', 'whisper', 'git', + 'tiktoken', 'browser_use', 'docker', 'duckduckgo_search', 'bs4', + 'html2text', 'yaml', 'aiohttp', 'jinja2', 'markdown', 'requests', + 'sentence_transformers', 'regex', 'pydantic', 'rich', 'pymupdf', + 'playwright', 'pathspec', 'tenacity', 'dotenv' + ] + if any(fullname.startswith(p) for p in catch_prefixes): + return ModuleSpec(fullname, MockLoader(), is_package=True) + return None + +sys.meta_path.insert(0, MockImporter()) + +@pytest.fixture +def mock_memory(): + from plugins._memory.helpers.memory import Memory + # Create a dummy Memory object bypassing init to avoid Faiss overhead + mem = Memory.__new__(Memory) + mem.memory_subdir = "test_subdir" + mem.agent = SimpleNamespace(name="TestAgent") + # Mock insert_documents since we only test the hook behavior + mem.insert_documents = AsyncMock(return_value=["doc-123"]) + return mem + + +@pytest.mark.asyncio +async def test_memory_save_before_called_with_object(mock_memory): + """memory_save_before receives a mutable {object} dict.""" + text = "Hello world" + metadata = {"source": "test"} + + with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext: + doc_id = await mock_memory.insert_text(text, metadata=metadata) + + assert doc_id == "doc-123" + # memory_save_before is the first call + before_call = mock_call_ext.call_args_list[0] + assert before_call == call( + "memory_save_before", + agent=mock_memory.agent, + object={"text": text, "metadata": metadata, "memory_subdir": "test_subdir"}, + ) + + +@pytest.mark.asyncio +async def test_memory_save_after_called_with_doc_id(mock_memory): + """memory_save_after receives the object with doc_id after persist.""" + text = "Hello world" + metadata = {"source": "test"} + + with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext: + doc_id = await mock_memory.insert_text(text, metadata=metadata) + + assert doc_id == "doc-123" + # memory_save_after is the second call + after_call = mock_call_ext.call_args_list[1] + assert after_call == call( + "memory_save_after", + agent=mock_memory.agent, + object={ + "text": text, + "metadata": metadata, + "memory_subdir": "test_subdir", + "doc_id": "doc-123", + }, + ) + + +@pytest.mark.asyncio +async def test_memory_save_skipped_when_text_none(mock_memory): + """Save is skipped when memory_save_before sets object['text'] to None.""" + text = "Hello world" + metadata = {"source": "test"} + + async def nullify_text(*args, **kwargs): + # Simulate an extension setting text to None + obj = kwargs.get("object") + if obj is not None: + obj["text"] = None + + with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext: + mock_call_ext.side_effect = nullify_text + doc_id = await mock_memory.insert_text(text, metadata=metadata) + + # Save was skipped + assert doc_id is None + # insert_documents was never called + mock_memory.insert_documents.assert_not_called() + # Only memory_save_before was called (no after) + assert mock_call_ext.call_count == 1 + + +@pytest.mark.asyncio +async def test_memory_save_before_can_modify_text(mock_memory): + """Extensions can modify the text via memory_save_before.""" + text = "Original" + metadata = {"source": "test"} + + async def modify_text(*args, **kwargs): + obj = kwargs.get("object") + if obj is not None and obj.get("text") == "Original": + obj["text"] = "Modified by extension" + + with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext: + mock_call_ext.side_effect = modify_text + doc_id = await mock_memory.insert_text(text, metadata=metadata) + + assert doc_id == "doc-123" + # Verify insert_documents was called with the modified text + call_args = mock_memory.insert_documents.call_args + doc = call_args[0][0][0] + assert doc.page_content == "Modified by extension" + + +@pytest.mark.asyncio +async def test_extension_exceptions_propagate(mock_memory): + """No try/catch — extension errors propagate to the caller.""" + text = "Hello world" + metadata = {"source": "test"} + + with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext: + mock_call_ext.side_effect = RuntimeError("Extension crashed") + + with pytest.raises(RuntimeError, match="Extension crashed"): + await mock_memory.insert_text(text, metadata=metadata)