mirror of
https://github.com/onestardao/WFGY.git
synced 2026-04-28 03:29:51 +00:00
fix: capture query embedding before retriever callback and guard empty source_nodes
LangChain: Add on_chain_start to capture query embedding before on_retriever_end fires in RetrievalQA flow. on_llm_start retained as fallback for non-chain invocations. LlamaIndex: Double guard on source_nodes — check both attribute existence AND non-empty list before computing average ΔS, preventing ZeroDivisionError on responses without retrieved nodes. Addresses Codex automated review feedback
This commit is contained in:
parent
cf78311a25
commit
b53de5acd2
2 changed files with 15 additions and 8 deletions
|
|
@ -31,9 +31,15 @@ class WFGYSemanticFirewall(BaseCallbackHandler):
|
|||
cos_theta = dot_product / (norm_a * norm_b)
|
||||
return 1.0 - cos_theta
|
||||
|
||||
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""Capture query embedding early — runs before on_retriever_end."""
|
||||
query_text = inputs.get('query') or inputs.get('question') or inputs.get('input', '')
|
||||
if query_text and not self.last_query_embedding:
|
||||
self.last_query_embedding = self.embedding_model.embed_query(str(query_text))
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
|
||||
# Cache query embedding for ΔS calculation later
|
||||
if prompts:
|
||||
# Fallback: capture query embedding if not yet set by on_chain_start
|
||||
if prompts and not self.last_query_embedding:
|
||||
self.last_query_embedding = self.embedding_model.embed_query(prompts[0])
|
||||
|
||||
def on_retriever_end(self, documents: List[Any], **kwargs: Any) -> Any:
|
||||
|
|
|
|||
|
|
@ -50,19 +50,20 @@ class WFGYSemanticFirewallLlama:
|
|||
print(f"[WFGY-LlamaIndex] Answer ΔS: {delta_s:.4f}")
|
||||
|
||||
# Analyze Source Nodes
|
||||
if hasattr(response, 'source_nodes'):
|
||||
if hasattr(response, 'source_nodes') and response.source_nodes:
|
||||
source_delta_s = []
|
||||
for node in response.source_nodes:
|
||||
node_text = node.node.get_content()
|
||||
node_embedding = self.embedding_model.get_text_embedding(node_text)
|
||||
source_delta_s.append(self._calculate_delta_s(query_embedding, node_embedding))
|
||||
|
||||
avg_source_delta_s = sum(source_delta_s) / len(source_delta_s)
|
||||
if self.verbose:
|
||||
print(f"[WFGY-LlamaIndex] Avg Source ΔS: {avg_source_delta_s:.4f}")
|
||||
if source_delta_s:
|
||||
avg_source_delta_s = sum(source_delta_s) / len(source_delta_s)
|
||||
if self.verbose:
|
||||
print(f"[WFGY-LlamaIndex] Avg Source ΔS: {avg_source_delta_s:.4f}")
|
||||
|
||||
if avg_source_delta_s >= self.threshold_danger:
|
||||
print(f"⚠️ [WFGY] DANGER ZONE: Source Drift (ΔS={avg_source_delta_s:.4f})")
|
||||
if avg_source_delta_s >= self.threshold_danger:
|
||||
print(f"⚠️ [WFGY] DANGER ZONE: Source Drift (ΔS={avg_source_delta_s:.4f})")
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue