From 3dceaca64ec49d4149b9ae91a0e8e66a4692d132 Mon Sep 17 00:00:00 2001 From: frdel <38891707+frdel@users.noreply.github.com> Date: Wed, 11 Feb 2026 23:04:32 +0100 Subject: [PATCH] caching and ctx window optimizations --- agent.py | 2 + models.py | 14 ++- prompts/fw.topic_summary.sys.md | 3 +- .../_50_recall_memories.py | 6 +- python/helpers/document_query.py | 3 +- python/helpers/history.py | 91 ++++++++++++------- python/helpers/memory_consolidation.py | 2 +- 7 files changed, 82 insertions(+), 39 deletions(-) diff --git a/agent.py b/agent.py index 6cb96e3a7..cffacf6e1 100644 --- a/agent.py +++ b/agent.py @@ -798,6 +798,7 @@ class Agent: response_callback: Callable[[str, str], Awaitable[None]] | None = None, reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None, background: bool = False, + explicit_caching: bool = True, ): response = "" @@ -812,6 +813,7 @@ class Agent: rate_limiter_callback=( self.rate_limiter_callback if not background else None ), + explicit_caching=explicit_caching, ) return response, reasoning diff --git a/models.py b/models.py index c8c3af9fc..bf584f763 100644 --- a/models.py +++ b/models.py @@ -316,7 +316,7 @@ class LiteLLMChatWrapper(SimpleChatModel): def _llm_type(self) -> str: return "litellm-chat" - def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]: + def _convert_messages(self, messages: List[BaseMessage], explicit_caching: bool = False) -> List[dict]: result = [] # Map LangChain message types to LiteLLM roles role_mapping = { @@ -362,6 +362,15 @@ class LiteLLMChatWrapper(SimpleChatModel): message_dict["tool_call_id"] = tool_call_id result.append(message_dict) + + if explicit_caching and result: + if result[0]["role"] == "system": + result[0]["cache_control"] = {"type": "ephemeral"} + for i in range(len(result) - 1, -1, -1): + if result[i]["role"] == "assistant": + result[i]["cache_control"] = {"type": "ephemeral"} + break + return result def _call( @@ -464,6 +473,7 @@ class LiteLLMChatWrapper(SimpleChatModel): rate_limiter_callback: ( Callable[[str, str, int, int], Awaitable[bool]] | None ) = None, + explicit_caching: bool = False, **kwargs: Any, ) -> Tuple[str, str]: @@ -478,7 +488,7 @@ class LiteLLMChatWrapper(SimpleChatModel): messages.append(HumanMessage(content=user_message)) # convert to litellm format - msgs_conv = self._convert_messages(messages) + msgs_conv = self._convert_messages(messages, explicit_caching=explicit_caching) # Apply rate limiting if configured limiter = await apply_rate_limiter( diff --git a/prompts/fw.topic_summary.sys.md b/prompts/fw.topic_summary.sys.md index 46004b15b..b8d676647 100644 --- a/prompts/fw.topic_summary.sys.md +++ b/prompts/fw.topic_summary.sys.md @@ -6,7 +6,8 @@ You must return a single summary of all records # Expected output Your output will be a text of the summary -Length of the text should be one paragraph, approximately 100 words +Summary must be shorter than original messages +Length of the text should be maximum one paragraph, approximately 100 words, shorter if original is shorter No intro No conclusion No formatting diff --git a/python/extensions/message_loop_prompts_after/_50_recall_memories.py b/python/extensions/message_loop_prompts_after/_50_recall_memories.py index 9e2beb9f8..e7fd509f8 100644 --- a/python/extensions/message_loop_prompts_after/_50_recall_memories.py +++ b/python/extensions/message_loop_prompts_after/_50_recall_memories.py @@ -8,6 +8,7 @@ from python.helpers import dirty_json, errors, settings, log DATA_NAME_TASK = "_recall_memories_task" DATA_NAME_ITER = "_recall_memories_iter" +SEARCH_TIMEOUT = 30 class RecallMemories(Extension): @@ -38,7 +39,10 @@ class RecallMemories(Extension): ) task = asyncio.create_task( - self.search_memories(loop_data=loop_data, log_item=log_item, **kwargs) + asyncio.wait_for( + self.search_memories(loop_data=loop_data, log_item=log_item, **kwargs), + timeout=SEARCH_TIMEOUT, + ) ) else: task = None diff --git a/python/helpers/document_query.py b/python/helpers/document_query.py index ff5575ec0..4ba82b159 100644 --- a/python/helpers/document_query.py +++ b/python/helpers/document_query.py @@ -433,7 +433,8 @@ class DocumentQueryHelper: messages=[ SystemMessage(content=qa_system_message), HumanMessage(content=qa_user_message), - ] + ], + explicit_caching=False, ) self.progress_callback(f"Q&A process completed") diff --git a/python/helpers/history.py b/python/helpers/history.py index d8b7d6c5f..0604bb4d9 100644 --- a/python/helpers/history.py +++ b/python/helpers/history.py @@ -10,13 +10,16 @@ from enum import Enum from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage BULK_MERGE_COUNT = 3 -TOPICS_KEEP_COUNT = 3 +TOPICS_MERGE_COUNT = 3 CURRENT_TOPIC_RATIO = 0.5 HISTORY_TOPIC_RATIO = 0.3 HISTORY_BULK_RATIO = 0.2 -TOPIC_COMPRESS_RATIO = 0.65 -LARGE_MESSAGE_TO_TOPIC_RATIO = 0.5 +CURRENT_TOPIC_ATTENTION_COMPRESSION = 0.65 # compress current topic's attention window to 65% of size +HISTORY_TOPIC_ATTENTION_COMPRESSION = 0 # compress history topic's attention window to 0% of size - only request and response remain intact +LARGE_MESSAGE_TO_CURRENT_TOPIC_RATIO = 0.5 +LARGE_MESSAGE_TO_HISTORY_TOPIC_RATIO = 0.2 RAW_MESSAGE_OUTPUT_TEXT_TRIM = 100 +COMPRESSION_TARGET_RATIO = 0.8 class RawMessage(TypedDict): @@ -155,13 +158,12 @@ class Topic(Record): self.summary = await self.summarize_messages(self.messages) return self.summary - async def compress_large_messages(self) -> bool: + def compress_large_messages(self, message_ratio: float = CURRENT_TOPIC_RATIO * LARGE_MESSAGE_TO_CURRENT_TOPIC_RATIO) -> bool: set = settings.get_settings() msg_max_size = ( set["chat_model_ctx_length"] * set["chat_model_ctx_history"] - * CURRENT_TOPIC_RATIO - * LARGE_MESSAGE_TO_TOPIC_RATIO + * message_ratio ) large_msgs = [] for m in (m for m in self.messages if not m.summary): @@ -195,27 +197,29 @@ class Topic(Record): return False async def compress(self) -> bool: - compress = await self.compress_large_messages() + compress = self.compress_large_messages() if not compress: compress = await self.compress_attention() return compress - async def compress_attention(self) -> bool: + async def compress_attention(self, ratio: float = CURRENT_TOPIC_ATTENTION_COMPRESSION) -> bool: - if len(self.messages) > 2: - cnt_to_sum = math.ceil((len(self.messages) - 2) * TOPIC_COMPRESS_RATIO) - msg_to_sum = self.messages[1 : cnt_to_sum + 1] - summary = await self.summarize_messages(msg_to_sum) - sum_msg_content = self.history.agent.parse_prompt( - "fw.msg_summary.md", summary=summary - ) - sum_msg = Message(False, sum_msg_content) - self.messages[1 : cnt_to_sum + 1] = [sum_msg] - return True - return False + middle = len(self.messages) - 2 + if middle < 2: + return False + cnt_to_sum = middle - math.floor(middle * ratio) + if cnt_to_sum < 1: + return False + msg_to_sum = self.messages[1 : cnt_to_sum + 1] + summary = await self.summarize_messages(msg_to_sum) + sum_msg_content = self.history.agent.parse_prompt( + "fw.msg_summary.md", summary=summary + ) + sum_msg = Message(False, sum_msg_content) + self.messages[1 : cnt_to_sum + 1] = [sum_msg] + return True async def summarize_messages(self, messages: list[Message]): - # FIXME: vision bytes are sent to utility LLM, send summary instead msg_txt = [m.output_text() for m in messages] summary = await self.history.agent.call_utility_model( system=self.history.agent.read_prompt("fw.topic_summary.sys.md"), @@ -363,22 +367,38 @@ class History(Record): async def compress(self): compressed = False + total = _get_ctx_size_for_history() + curr, hist, bulk = ( + self.get_current_topic_tokens(), + self.get_topics_tokens(), + self.get_bulks_tokens(), + ) + if (curr + hist + bulk) <= total: + return False + + target = total * COMPRESSION_TARGET_RATIO + prev_total = curr + hist + bulk while True: curr, hist, bulk = ( self.get_current_topic_tokens(), self.get_topics_tokens(), self.get_bulks_tokens(), ) - total = _get_ctx_size_for_history() + + # safeguard against infinite loop in case LLM bloats the summary for some reason + if (curr + hist + bulk) >= prev_total: + break + prev_total = curr + hist + bulk + ratios = [ (curr, CURRENT_TOPIC_RATIO, "current_topic"), (hist, HISTORY_TOPIC_RATIO, "history_topic"), (bulk, HISTORY_BULK_RATIO, "history_bulk"), ] - ratios = sorted(ratios, key=lambda x: (x[0] / total) / x[1], reverse=True) + ratios = sorted(ratios, key=lambda x: (x[0] / target) / x[1], reverse=True) compressed_part = False for ratio in ratios: - if ratio[0] > ratio[1] * total: + if ratio[0] > ratio[1] * target: over_part = ratio[2] if over_part == "current_topic": compressed_part = await self.current.compress() @@ -394,24 +414,29 @@ class History(Record): continue else: return compressed + return compressed async def compress_topics(self) -> bool: - # summarize topics one by one + + # 1. first identify large messages and compress them cheaply for topic in self.topics: - if not topic.summary: - await topic.summarize() + if topic.compress_large_messages(HISTORY_TOPIC_RATIO*LARGE_MESSAGE_TO_HISTORY_TOPIC_RATIO): return True - # move oldest topic to bulks and summarize + # 2. summarize topics attention window one by one for topic in self.topics: + if await topic.compress_attention(HISTORY_TOPIC_ATTENTION_COMPRESSION): + return True + + # 3. move oldest topics to bulks in chunks + if self.topics: + count = TOPICS_MERGE_COUNT if len(self.topics) >= TOPICS_MERGE_COUNT else 1 + chunk = self.topics[:count] bulk = Bulk(history=self) - bulk.records.append(topic) - if topic.summary: - bulk.summary = topic.summary - else: - await bulk.summarize() + bulk.records.extend(chunk) + await bulk.summarize() self.bulks.append(bulk) - self.topics.remove(topic) + self.topics[:count] = [] return True return False diff --git a/python/helpers/memory_consolidation.py b/python/helpers/memory_consolidation.py index 244ebbc8d..86a5cf598 100644 --- a/python/helpers/memory_consolidation.py +++ b/python/helpers/memory_consolidation.py @@ -82,7 +82,7 @@ class MemoryConsolidator: Args: new_memory: The new memory content to process - area: Memory area (MAIN, FRAGMENTS, SOLUTIONS, INSTRUMENTS) + area: Memory area (MAIN, FRAGMENTS, SOLUTIONS) metadata: Initial metadata for the memory log_item: Optional log item for progress tracking