From 0d8a6437a068a2490962f3347a77726461ecb9aa Mon Sep 17 00:00:00 2001 From: Zaious Date: Wed, 11 Mar 2026 11:45:49 +0800 Subject: [PATCH] fix: minimize runtime intervention by scoping query capture to retriever Replace over-broad on_chain_start with on_retriever_start to capture the query embedding as late and as safely as possible in the RetrievalQA flow, minimizing framework footprint. Addresses maintainer review feedback for minimal & safe implementation. --- adapters/langchain/firewall.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/adapters/langchain/firewall.py b/adapters/langchain/firewall.py index 458a5677..acd93346 100644 --- a/adapters/langchain/firewall.py +++ b/adapters/langchain/firewall.py @@ -31,14 +31,13 @@ class WFGYSemanticFirewall(BaseCallbackHandler): cos_theta = dot_product / (norm_a * norm_b) return 1.0 - cos_theta - def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any: - """Capture query embedding early — runs before on_retriever_end.""" - query_text = inputs.get('query') or inputs.get('question') or inputs.get('input', '') - if query_text and not self.last_query_embedding: - self.last_query_embedding = self.embedding_model.embed_query(str(query_text)) + def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any: + # Capture query embedding right before retriever runs + if query and not self.last_query_embedding: + self.last_query_embedding = self.embedding_model.embed_query(query) def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any: - # Fallback: capture query embedding if not yet set by on_chain_start + # Fallback for non-retrieval chains if prompts and not self.last_query_embedding: self.last_query_embedding = self.embedding_model.embed_query(prompts[0])