mirror of
https://github.com/onestardao/WFGY.git
synced 2026-04-26 10:40:55 +00:00
fix: minimize runtime intervention by scoping query capture to retriever
Replace over-broad on_chain_start with on_retriever_start to capture the query embedding as late and as safely as possible in the RetrievalQA flow, minimizing framework footprint. Addresses maintainer review feedback for minimal & safe implementation.
This commit is contained in:
parent
b53de5acd2
commit
0d8a6437a0
1 changed files with 5 additions and 6 deletions
|
|
@ -31,14 +31,13 @@ class WFGYSemanticFirewall(BaseCallbackHandler):
|
|||
cos_theta = dot_product / (norm_a * norm_b)
|
||||
return 1.0 - cos_theta
|
||||
|
||||
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""Capture query embedding early — runs before on_retriever_end."""
|
||||
query_text = inputs.get('query') or inputs.get('question') or inputs.get('input', '')
|
||||
if query_text and not self.last_query_embedding:
|
||||
self.last_query_embedding = self.embedding_model.embed_query(str(query_text))
|
||||
def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any:
|
||||
# Capture query embedding right before retriever runs
|
||||
if query and not self.last_query_embedding:
|
||||
self.last_query_embedding = self.embedding_model.embed_query(query)
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
|
||||
# Fallback: capture query embedding if not yet set by on_chain_start
|
||||
# Fallback for non-retrieval chains
|
||||
if prompts and not self.last_query_embedding:
|
||||
self.last_query_embedding = self.embedding_model.embed_query(prompts[0])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue