diff --git a/adapters/langchain/firewall.py b/adapters/langchain/firewall.py index 458a5677..acd93346 100644 --- a/adapters/langchain/firewall.py +++ b/adapters/langchain/firewall.py @@ -31,14 +31,13 @@ class WFGYSemanticFirewall(BaseCallbackHandler): cos_theta = dot_product / (norm_a * norm_b) return 1.0 - cos_theta - def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any: - """Capture query embedding early — runs before on_retriever_end.""" - query_text = inputs.get('query') or inputs.get('question') or inputs.get('input', '') - if query_text and not self.last_query_embedding: - self.last_query_embedding = self.embedding_model.embed_query(str(query_text)) + def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any: + # Capture query embedding right before retriever runs + if query and not self.last_query_embedding: + self.last_query_embedding = self.embedding_model.embed_query(query) def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any: - # Fallback: capture query embedding if not yet set by on_chain_start + # Fallback for non-retrieval chains if prompts and not self.last_query_embedding: self.last_query_embedding = self.embedding_model.embed_query(prompts[0])