fix: capture query embedding before retriever callback and guard empty source_nodes

LangChain: Add on_chain_start to capture query embedding before
on_retriever_end fires in RetrievalQA flow. on_llm_start retained
as fallback for non-chain invocations.

LlamaIndex: Double guard on source_nodes — check both attribute
existence AND non-empty list before computing average ΔS, preventing
ZeroDivisionError on responses without retrieved nodes.

Addresses Codex automated review feedback
This commit is contained in:
Zaious 2026-03-10 09:14:53 +08:00
parent cf78311a25
commit b53de5acd2
2 changed files with 15 additions and 8 deletions

View file

@ -31,9 +31,15 @@ class WFGYSemanticFirewall(BaseCallbackHandler):
cos_theta = dot_product / (norm_a * norm_b)
return 1.0 - cos_theta
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Capture query embedding early — runs before on_retriever_end."""
query_text = inputs.get('query') or inputs.get('question') or inputs.get('input', '')
if query_text and not self.last_query_embedding:
self.last_query_embedding = self.embedding_model.embed_query(str(query_text))
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
# Cache query embedding for ΔS calculation later
if prompts:
# Fallback: capture query embedding if not yet set by on_chain_start
if prompts and not self.last_query_embedding:
self.last_query_embedding = self.embedding_model.embed_query(prompts[0])
def on_retriever_end(self, documents: List[Any], **kwargs: Any) -> Any:

View file

@ -50,19 +50,20 @@ class WFGYSemanticFirewallLlama:
print(f"[WFGY-LlamaIndex] Answer ΔS: {delta_s:.4f}")
# Analyze Source Nodes
if hasattr(response, 'source_nodes'):
if hasattr(response, 'source_nodes') and response.source_nodes:
source_delta_s = []
for node in response.source_nodes:
node_text = node.node.get_content()
node_embedding = self.embedding_model.get_text_embedding(node_text)
source_delta_s.append(self._calculate_delta_s(query_embedding, node_embedding))
avg_source_delta_s = sum(source_delta_s) / len(source_delta_s)
if self.verbose:
print(f"[WFGY-LlamaIndex] Avg Source ΔS: {avg_source_delta_s:.4f}")
if source_delta_s:
avg_source_delta_s = sum(source_delta_s) / len(source_delta_s)
if self.verbose:
print(f"[WFGY-LlamaIndex] Avg Source ΔS: {avg_source_delta_s:.4f}")
if avg_source_delta_s >= self.threshold_danger:
print(f"⚠️ [WFGY] DANGER ZONE: Source Drift (ΔS={avg_source_delta_s:.4f})")
if avg_source_delta_s >= self.threshold_danger:
print(f"⚠️ [WFGY] DANGER ZONE: Source Drift (ΔS={avg_source_delta_s:.4f})")
return response