fix: minimize runtime intervention by scoping query capture to retriever

Replace over-broad on_chain_start with on_retriever_start to capture
the query embedding as late and as safely as possible in the RetrievalQA
flow, minimizing framework footprint.

Addresses maintainer review feedback for minimal & safe implementation.
This commit is contained in:
Zaious 2026-03-11 11:45:49 +08:00
parent b53de5acd2
commit 0d8a6437a0

View file

@ -31,14 +31,13 @@ class WFGYSemanticFirewall(BaseCallbackHandler):
cos_theta = dot_product / (norm_a * norm_b)
return 1.0 - cos_theta
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Capture query embedding early — runs before on_retriever_end."""
query_text = inputs.get('query') or inputs.get('question') or inputs.get('input', '')
if query_text and not self.last_query_embedding:
self.last_query_embedding = self.embedding_model.embed_query(str(query_text))
def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any:
# Capture query embedding right before retriever runs
if query and not self.last_query_embedding:
self.last_query_embedding = self.embedding_model.embed_query(query)
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
# Fallback: capture query embedding if not yet set by on_chain_start
# Fallback for non-retrieval chains
if prompts and not self.last_query_embedding:
self.last_query_embedding = self.embedding_model.embed_query(prompts[0])