This commit is contained in:
George Freeney Jr. 2026-05-14 08:51:35 +08:00 committed by GitHub
commit 2ea6cb5aed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 209 additions and 5 deletions

28
.github/workflows/pytest.yml vendored Normal file
View file

@ -0,0 +1,28 @@
name: Pytest CI
on:
push:
branches: [ feature/memory-saved-after-hook ]
pull_request:
branches: [ development ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-asyncio pyyaml
- name: Run memory hook tests
run: |
pytest tests/test_memory_hook.py

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, List, Sequence
from langchain.storage import InMemoryByteStore, LocalFileStore
@ -16,9 +18,8 @@ from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores.utils import (
DistanceStrategy,
)
from langchain_core.embeddings import Embeddings
import os, json
import os
import json
import numpy as np
@ -392,9 +393,34 @@ class Memory:
self._save_db() # persist
return rem_docs
async def insert_text(self, text, metadata: dict = {}):
doc = Document(text, metadata=metadata)
async def insert_text(self, text, metadata: dict | None = None):
from helpers.extension import call_extensions_async
metadata = metadata or {}
# memory_save_before: pass mutable object so extensions can edit or skip
obj = {"text": text, "metadata": metadata, "memory_subdir": self.memory_subdir}
await call_extensions_async(
"memory_save_before",
agent=getattr(self, "agent", None),
object=obj,
)
# If an extension set text to None, skip the save
if obj["text"] is None:
return None
doc = Document(obj["text"], metadata=obj["metadata"])
ids = await self.insert_documents([doc])
# memory_save_after: notify extensions after successful persist
obj["doc_id"] = ids[0]
await call_extensions_async(
"memory_save_after",
agent=getattr(self, "agent", None),
object=obj,
)
return ids[0]
async def insert_documents(self, docs: list[Document]):
@ -575,3 +601,4 @@ def get_knowledge_subdirs_by_memory_subdir(
default.append(get_project_meta(memory_subdir[9:], "knowledge"))
return default

149
tests/test_memory_hook.py Normal file
View file

@ -0,0 +1,149 @@
import sys
import pytest
from unittest.mock import patch, AsyncMock, MagicMock, call
from types import SimpleNamespace
# Catch-all mock importer to avoid heavy Agent Zero dependencies locally
from importlib.machinery import ModuleSpec
class MockLoader:
def create_module(self, spec):
if spec.name in sys.modules:
return sys.modules[spec.name]
mock = MagicMock()
# Ensure submodules can be accessed via attributes
mock.__path__ = []
sys.modules[spec.name] = mock
return mock
def exec_module(self, module):
pass
class MockImporter:
def find_spec(self, fullname, path, target=None):
catch_prefixes = [
'langchain', 'faiss', 'simpleeval', 'webcolors', 'litellm',
'openai', 'cryptography', 'nest_asyncio', 'whisper', 'git',
'tiktoken', 'browser_use', 'docker', 'duckduckgo_search', 'bs4',
'html2text', 'yaml', 'aiohttp', 'jinja2', 'markdown', 'requests',
'sentence_transformers', 'regex', 'pydantic', 'rich', 'pymupdf',
'playwright', 'pathspec', 'tenacity', 'dotenv'
]
if any(fullname.startswith(p) for p in catch_prefixes):
return ModuleSpec(fullname, MockLoader(), is_package=True)
return None
sys.meta_path.insert(0, MockImporter())
@pytest.fixture
def mock_memory():
from plugins._memory.helpers.memory import Memory
# Create a dummy Memory object bypassing init to avoid Faiss overhead
mem = Memory.__new__(Memory)
mem.memory_subdir = "test_subdir"
mem.agent = SimpleNamespace(name="TestAgent")
# Mock insert_documents since we only test the hook behavior
mem.insert_documents = AsyncMock(return_value=["doc-123"])
return mem
@pytest.mark.asyncio
async def test_memory_save_before_called_with_object(mock_memory):
"""memory_save_before receives a mutable {object} dict."""
text = "Hello world"
metadata = {"source": "test"}
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
doc_id = await mock_memory.insert_text(text, metadata=metadata)
assert doc_id == "doc-123"
# memory_save_before is the first call
before_call = mock_call_ext.call_args_list[0]
assert before_call == call(
"memory_save_before",
agent=mock_memory.agent,
object={"text": text, "metadata": metadata, "memory_subdir": "test_subdir"},
)
@pytest.mark.asyncio
async def test_memory_save_after_called_with_doc_id(mock_memory):
"""memory_save_after receives the object with doc_id after persist."""
text = "Hello world"
metadata = {"source": "test"}
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
doc_id = await mock_memory.insert_text(text, metadata=metadata)
assert doc_id == "doc-123"
# memory_save_after is the second call
after_call = mock_call_ext.call_args_list[1]
assert after_call == call(
"memory_save_after",
agent=mock_memory.agent,
object={
"text": text,
"metadata": metadata,
"memory_subdir": "test_subdir",
"doc_id": "doc-123",
},
)
@pytest.mark.asyncio
async def test_memory_save_skipped_when_text_none(mock_memory):
"""Save is skipped when memory_save_before sets object['text'] to None."""
text = "Hello world"
metadata = {"source": "test"}
async def nullify_text(*args, **kwargs):
# Simulate an extension setting text to None
obj = kwargs.get("object")
if obj is not None:
obj["text"] = None
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
mock_call_ext.side_effect = nullify_text
doc_id = await mock_memory.insert_text(text, metadata=metadata)
# Save was skipped
assert doc_id is None
# insert_documents was never called
mock_memory.insert_documents.assert_not_called()
# Only memory_save_before was called (no after)
assert mock_call_ext.call_count == 1
@pytest.mark.asyncio
async def test_memory_save_before_can_modify_text(mock_memory):
"""Extensions can modify the text via memory_save_before."""
text = "Original"
metadata = {"source": "test"}
async def modify_text(*args, **kwargs):
obj = kwargs.get("object")
if obj is not None and obj.get("text") == "Original":
obj["text"] = "Modified by extension"
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
mock_call_ext.side_effect = modify_text
doc_id = await mock_memory.insert_text(text, metadata=metadata)
assert doc_id == "doc-123"
# Verify insert_documents was called with the modified text
call_args = mock_memory.insert_documents.call_args
doc = call_args[0][0][0]
assert doc.page_content == "Modified by extension"
@pytest.mark.asyncio
async def test_extension_exceptions_propagate(mock_memory):
"""No try/catch — extension errors propagate to the caller."""
text = "Hello world"
metadata = {"source": "test"}
with patch("helpers.extension.call_extensions", new_callable=AsyncMock) as mock_call_ext:
mock_call_ext.side_effect = RuntimeError("Extension crashed")
with pytest.raises(RuntimeError, match="Extension crashed"):
await mock_memory.insert_text(text, metadata=metadata)