mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-04-28 03:30:23 +00:00
memory: harden FAISS integrity and consolidation scoring
- Add FAISS index integrity checks using a SHA-256 sidecar (`index.faiss.sha256`) and write hash on save.
- Harden `memory_load` filter evaluation with input validation (allowlist + length cap) and `simple_eval(..., functions={})`.
- Add score-preserving similarity search and use real relevance scores in consolidation (including best-score dedupe by memory id).
- Prevent utility-model context overflows by truncating memorize input history for fragments and solutions.
This commit is contained in:
parent
5dc589486a
commit
1cbecc241e
5 changed files with 111 additions and 63 deletions
|
|
@ -50,6 +50,10 @@ class MemorizeMemories(Extension):
|
|||
# get system message and chat history for util llm
|
||||
system = self.agent.read_prompt("memory.memories_sum.sys.md")
|
||||
msgs_text = self.agent.concat_messages(self.agent.history)
|
||||
# Keep only recent context to avoid utility-model context-window overflow.
|
||||
MAX_MSGS_CHARS = 80000
|
||||
if len(msgs_text) > MAX_MSGS_CHARS:
|
||||
msgs_text = msgs_text[-MAX_MSGS_CHARS:]
|
||||
|
||||
# # log query streamed by LLM
|
||||
# async def log_callback(content):
|
||||
|
|
|
|||
|
|
@ -51,6 +51,10 @@ class MemorizeSolutions(Extension):
|
|||
# get system message and chat history for util llm
|
||||
system = self.agent.read_prompt("memory.solutions_sum.sys.md")
|
||||
msgs_text = self.agent.concat_messages(self.agent.history)
|
||||
# Keep only recent context to avoid utility-model context-window overflow.
|
||||
MAX_MSGS_CHARS = 80000
|
||||
if len(msgs_text) > MAX_MSGS_CHARS:
|
||||
msgs_text = msgs_text[-MAX_MSGS_CHARS:]
|
||||
|
||||
# log query streamed by LLM
|
||||
# async def log_callback(content):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from langchain_community.vectorstores.utils import (
|
|||
)
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
import os, json
|
||||
import os, json, hashlib, re
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -178,14 +178,19 @@ class Memory:
|
|||
|
||||
# if db folder exists and is not empty:
|
||||
if os.path.exists(db_dir) and files.exists(db_dir, "index.faiss"):
|
||||
db = MyFaiss.load_local(
|
||||
folder_path=db_dir,
|
||||
embeddings=embedder,
|
||||
allow_dangerous_deserialization=True,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
# normalize_L2=True,
|
||||
relevance_score_fn=Memory._cosine_normalizer,
|
||||
) # type: ignore
|
||||
if not Memory._verify_index_hash(db_dir):
|
||||
PrintStyle(font_color="yellow").print(
|
||||
f"FAISS index hash mismatch in '{db_dir}' — index will be rebuilt."
|
||||
)
|
||||
else:
|
||||
db = MyFaiss.load_local(
|
||||
folder_path=db_dir,
|
||||
embeddings=embedder,
|
||||
allow_dangerous_deserialization=True,
|
||||
distance_strategy=DistanceStrategy.COSINE,
|
||||
# normalize_L2=True,
|
||||
relevance_score_fn=Memory._cosine_normalizer,
|
||||
) # type: ignore
|
||||
|
||||
# if there is a mismatch in embeddings used, re-index the whole DB
|
||||
emb_ok = False
|
||||
|
|
@ -345,6 +350,18 @@ class Memory:
|
|||
filter=comparator,
|
||||
)
|
||||
|
||||
async def search_similarity_threshold_with_scores(
|
||||
self, query: str, limit: int, threshold: float, filter: str = ""
|
||||
) -> list[tuple[Document, float]]:
|
||||
comparator = Memory._get_comparator(filter) if filter else None
|
||||
|
||||
return await self.db.asimilarity_search_with_relevance_scores(
|
||||
query,
|
||||
k=limit,
|
||||
score_threshold=threshold,
|
||||
filter=comparator,
|
||||
)
|
||||
|
||||
async def delete_documents_by_query(
|
||||
self, query: str, threshold: float, filter: str = ""
|
||||
):
|
||||
|
|
@ -432,12 +449,54 @@ class Memory:
|
|||
def _save_db_file(db: MyFaiss, memory_subdir: str):
|
||||
abs_dir = abs_db_dir(memory_subdir)
|
||||
db.save_local(folder_path=abs_dir)
|
||||
Memory._write_index_hash(abs_dir)
|
||||
|
||||
@staticmethod
|
||||
def _write_index_hash(abs_dir: str) -> None:
|
||||
faiss_path = os.path.join(abs_dir, "index.faiss")
|
||||
hash_path = os.path.join(abs_dir, "index.faiss.sha256")
|
||||
try:
|
||||
h = hashlib.sha256()
|
||||
with open(faiss_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
with open(hash_path, "w") as f:
|
||||
f.write(h.hexdigest())
|
||||
except Exception as e:
|
||||
PrintStyle(font_color="yellow").print(f"Warning: could not write FAISS hash: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _verify_index_hash(abs_dir: str) -> bool:
|
||||
faiss_path = os.path.join(abs_dir, "index.faiss")
|
||||
hash_path = os.path.join(abs_dir, "index.faiss.sha256")
|
||||
if not os.path.exists(hash_path):
|
||||
return True
|
||||
try:
|
||||
with open(hash_path, "r") as f:
|
||||
stored = f.read().strip()
|
||||
h = hashlib.sha256()
|
||||
with open(faiss_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest() == stored
|
||||
except Exception as e:
|
||||
PrintStyle(font_color="yellow").print(f"Warning: FAISS hash check failed: {e}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _get_comparator(condition: str):
|
||||
_FILTER_SAFE = re.compile(
|
||||
r"^[a-zA-Z0-9_\-\.\ \t'\"=<>!()\[\],:\+]+$"
|
||||
)
|
||||
if len(condition) > 512 or not _FILTER_SAFE.match(condition):
|
||||
PrintStyle.error(
|
||||
f"Memory filter rejected (unsafe characters or too long): {condition!r}"
|
||||
)
|
||||
return lambda _data: False
|
||||
|
||||
def comparator(data: dict[str, Any]):
|
||||
try:
|
||||
result = simple_eval(condition, names=data)
|
||||
result = simple_eval(condition, names=data, functions={})
|
||||
return result
|
||||
except Exception as e:
|
||||
PrintStyle.error(f"Error evaluating condition: {e}")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class ConsolidationConfig:
|
|||
keyword_extraction_msg_prompt: str = "memory.keyword_extraction.msg.md"
|
||||
processing_timeout_seconds: int = 60
|
||||
# Add safety threshold for REPLACE actions
|
||||
replace_similarity_threshold: float = 0.9 # Higher threshold for replacement safety
|
||||
replace_similarity_threshold: float = 0.75 # Threshold tuned for real cosine similarity scores
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -336,73 +336,52 @@ class MemoryConsolidator:
|
|||
|
||||
all_similar = []
|
||||
|
||||
# Step 2: Semantic similarity search with scores
|
||||
semantic_similar = await db.search_similarity_threshold(
|
||||
# Step 2: Semantic similarity search with real scores
|
||||
semantic_results = await db.search_similarity_threshold_with_scores(
|
||||
query=new_memory,
|
||||
limit=self.config.max_similar_memories,
|
||||
threshold=self.config.similarity_threshold,
|
||||
filter=f"area == '{area}'"
|
||||
)
|
||||
all_similar.extend(semantic_similar)
|
||||
for doc, score in semantic_results:
|
||||
doc.metadata['_consolidation_similarity'] = score
|
||||
all_similar.append(doc)
|
||||
|
||||
# Step 3: Keyword-based searches
|
||||
# Step 3: Keyword-based searches with real scores
|
||||
for query in search_queries:
|
||||
if query.strip():
|
||||
# Fix division by zero: ensure len(search_queries) > 0
|
||||
queries_count = max(1, len(search_queries)) # Prevent division by zero
|
||||
keyword_similar = await db.search_similarity_threshold(
|
||||
queries_count = max(1, len(search_queries))
|
||||
keyword_results = await db.search_similarity_threshold_with_scores(
|
||||
query=query.strip(),
|
||||
limit=max(3, self.config.max_similar_memories // queries_count),
|
||||
threshold=self.config.similarity_threshold,
|
||||
filter=f"area == '{area}'"
|
||||
)
|
||||
all_similar.extend(keyword_similar)
|
||||
for doc, score in keyword_results:
|
||||
doc.metadata['_consolidation_similarity'] = score
|
||||
all_similar.append(doc)
|
||||
|
||||
# Step 4: Deduplicate by document ID and store similarity info
|
||||
seen_ids = set()
|
||||
unique_similar = []
|
||||
# Step 4: Deduplicate by document ID, keep highest score per memory ID
|
||||
best_by_id: Dict[str, Document] = {}
|
||||
for doc in all_similar:
|
||||
doc_id = doc.metadata.get('id')
|
||||
if doc_id and doc_id not in seen_ids:
|
||||
seen_ids.add(doc_id)
|
||||
unique_similar.append(doc)
|
||||
|
||||
# Step 5: Calculate similarity scores for replacement validation
|
||||
# Since FAISS doesn't directly expose similarity scores, use ranking-based estimation
|
||||
# CRITICAL: All documents must have similarity >= search_threshold since FAISS returned them
|
||||
# FIXED: Use conservative scoring that keeps all scores in safe consolidation range
|
||||
similarity_scores = {}
|
||||
total_docs = len(unique_similar)
|
||||
search_threshold = self.config.similarity_threshold
|
||||
safety_threshold = self.config.replace_similarity_threshold
|
||||
|
||||
for i, doc in enumerate(unique_similar):
|
||||
doc_id = doc.metadata.get('id')
|
||||
if doc_id:
|
||||
# Convert ranking to similarity score with conservative distribution
|
||||
if total_docs == 1:
|
||||
ranking_similarity = 1.0 # Single document gets perfect score
|
||||
else:
|
||||
# Use conservative scoring: distribute between safety_threshold and 1.0
|
||||
# This ensures all scores are suitable for consolidation
|
||||
# First document gets 1.0, last gets safety_threshold (0.9 by default)
|
||||
ranking_factor = 1.0 - (i / (total_docs - 1))
|
||||
score_range = 1.0 - safety_threshold # e.g., 1.0 - 0.9 = 0.1
|
||||
ranking_similarity = safety_threshold + (score_range * ranking_factor)
|
||||
existing = best_by_id.get(doc_id)
|
||||
if (
|
||||
existing is None
|
||||
or doc.metadata.get('_consolidation_similarity', 0)
|
||||
> existing.metadata.get('_consolidation_similarity', 0)
|
||||
):
|
||||
best_by_id[doc_id] = doc
|
||||
unique_similar = list(best_by_id.values())
|
||||
|
||||
# Ensure minimum score is search_threshold for logical consistency
|
||||
ranking_similarity = max(ranking_similarity, search_threshold)
|
||||
# Step 5: Sort by similarity score descending
|
||||
unique_similar.sort(
|
||||
key=lambda d: d.metadata.get('_consolidation_similarity', 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
similarity_scores[doc_id] = ranking_similarity
|
||||
|
||||
# Step 6: Add similarity score to document metadata for LLM analysis
|
||||
for doc in unique_similar:
|
||||
doc_id = doc.metadata.get('id')
|
||||
estimated_similarity = similarity_scores.get(doc_id, 0.7)
|
||||
# Store for later validation
|
||||
doc.metadata['_consolidation_similarity'] = estimated_similarity
|
||||
|
||||
# Step 7: Limit to max context for LLM
|
||||
# Step 6: Limit to max context for LLM
|
||||
limited_similar = unique_similar[:self.config.max_llm_context_memories]
|
||||
|
||||
return limited_similar
|
||||
|
|
@ -782,7 +761,7 @@ def create_memory_consolidator(agent: Agent, **config_overrides) -> MemoryConsol
|
|||
|
||||
Available configuration options:
|
||||
- similarity_threshold: Discovery threshold for finding related memories (default 0.7)
|
||||
- replace_similarity_threshold: Safety threshold for REPLACE actions (default 0.9)
|
||||
- replace_similarity_threshold: Safety threshold for REPLACE actions (default 0.75)
|
||||
- max_similar_memories: Maximum memories to discover (default 10)
|
||||
- max_llm_context_memories: Maximum memories to send to LLM (default 5)
|
||||
- processing_timeout_seconds: Timeout for consolidation processing (default 30)
|
||||
|
|
|
|||
|
|
@ -4,14 +4,16 @@ use when durable recall or storage is useful
|
|||
- `memory_save`: args `text`, optional `area` and metadata kwargs
|
||||
- `memory_delete`: arg `ids` comma-separated ids
|
||||
- `memory_forget`: args `query`, optional `threshold`, `filter`
|
||||
|
||||
notes:
|
||||
- `threshold` is similarity from `0` to `1`
|
||||
- `filter` is a python expression over metadata
|
||||
- verify destructive memory changes if accuracy matters
|
||||
- `filter` is a metadata expression (e.g. `area=='main'`)
|
||||
- confirm destructive changes when accuracy matters
|
||||
|
||||
example:
|
||||
~~~json
|
||||
{
|
||||
"thoughts": ["I should search memory for the relevant prior guidance."],
|
||||
"thoughts": ["I should search memory for relevant prior guidance."],
|
||||
"headline": "Loading related memories",
|
||||
"tool_name": "memory_load",
|
||||
"tool_args": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue