Phase 4-6: Dead code removal, performance, minor fixes

Phase 4:
- Remove legacy SessionRecord, _sessions, _msg_to_session from SessionStore
- Fix hardcoded provider in root endpoint (use settings.provider_type)
- Update session store tests

Phase 5:
- Use list-based string accumulation in ThinkingSegment, TextSegment, ToolCallSegment
- Cache MAX_MESSAGE_LOG_ENTRIES_PER_CHAT at SessionStore init
- Use iterative DFS in MessageTree.get_descendants

Phase 6:
- Add comment for abstract async generator workaround in BaseProvider
- Rename TELEGRAM_EDIT log tags to PLATFORM_EDIT in handler

Co-authored-by: Ali Khokhar <alishahryar2@gmail.com>
This commit is contained in:
Cursor Agent 2026-02-17 02:01:01 +00:00
parent 72b7e34999
commit bfc781e0ed
10 changed files with 75 additions and 128 deletions

View file

@ -100,7 +100,7 @@ async def root(settings: Settings = Depends(get_settings)):
"""Root endpoint."""
return {
"status": "ok",
"provider": "nvidia_nim",
"provider": settings.provider_type,
"model": settings.model,
}

View file

@ -479,7 +479,7 @@ class ClaudeMessageHandler:
return
if display and display != last_displayed_text:
logger.debug(
"TELEGRAM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d",
"PLATFORM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d",
node_id,
chat_id,
status_msg_id,
@ -488,13 +488,13 @@ class ClaudeMessageHandler:
len(display),
)
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
logger.debug("TELEGRAM_EDIT_TEXT:\n%s", display)
logger.debug("PLATFORM_EDIT_TEXT:\n%s", display)
else:
head = display[:500]
tail = display[-500:] if len(display) > 500 else ""
logger.debug("TELEGRAM_EDIT_PREVIEW_HEAD:\n%s", head)
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n%s", head)
if tail:
logger.debug("TELEGRAM_EDIT_PREVIEW_TAIL:\n%s", tail)
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n%s", tail)
last_displayed_text = display
try:
await self.platform.queue_edit_message(
@ -504,7 +504,9 @@ class ClaudeMessageHandler:
parse_mode=self._parse_mode(),
)
except Exception as e:
logger.warning(f"Failed to update Telegram for node {node_id}: {e}")
logger.warning(
f"Failed to update platform for node {node_id}: {e}"
)
try:
try:

View file

@ -9,24 +9,10 @@ import json
import os
from datetime import datetime, timezone
from typing import Optional, Dict, List, Any
from dataclasses import dataclass, asdict
import threading
from loguru import logger
@dataclass
class SessionRecord:
"""A single session record."""
session_id: str
chat_id: str
initial_msg_id: str
last_msg_id: str
platform: str
created_at: str
updated_at: str
class SessionStore:
"""
Persistent storage for message Claude session mappings and message trees.
@ -38,10 +24,6 @@ class SessionStore:
def __init__(self, storage_path: str = "sessions.json"):
self.storage_path = storage_path
self._lock = threading.Lock()
self._sessions: Dict[str, SessionRecord] = {}
self._msg_to_session: Dict[
str, str
] = {} # "platform:chat_id:msg_id" -> session_id
self._trees: Dict[str, dict] = {} # root_id -> tree data
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
# Per-chat message ID log used to support best-effort UI clearing (/clear).
@ -51,12 +33,15 @@ class SessionStore:
self._dirty = False
self._save_timer: Optional[threading.Timer] = None
self._save_debounce_secs = 0.5
cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
try:
self._message_log_cap: Optional[int] = (
int(cap_raw) if cap_raw else None
)
except ValueError:
self._message_log_cap = None
self._load()
def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str:
"""Create a unique key from platform, chat_id and msg_id."""
return f"{platform}:{chat_id}:{msg_id}"
def _make_chat_key(self, platform: str, chat_id: str) -> str:
return f"{platform}:{chat_id}"
@ -69,25 +54,6 @@ class SessionStore:
with open(self.storage_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Load sessions (legacy support)
for sid, record_data in data.get("sessions", {}).items():
if "platform" not in record_data:
record_data["platform"] = "telegram"
for field in ["chat_id", "initial_msg_id", "last_msg_id"]:
if isinstance(record_data.get(field), int):
record_data[field] = str(record_data[field])
record = SessionRecord(**record_data)
self._sessions[sid] = record
self._msg_to_session[
self._make_key(
record.platform, record.chat_id, record.initial_msg_id
)
] = sid
self._msg_to_session[
self._make_key(record.platform, record.chat_id, record.last_msg_id)
] = sid
# Load trees
self._trees = data.get("trees", {})
self._node_to_tree = data.get("node_to_tree", {})
@ -124,8 +90,8 @@ class SessionStore:
self._message_log_ids[chat_key] = seen
logger.info(
f"Loaded {len(self._sessions)} sessions, {len(self._trees)} trees, "
f"and {sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}"
f"Loaded {len(self._trees)} trees and "
f"{sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}"
)
except Exception as e:
logger.error(f"Failed to load sessions: {e}")
@ -134,9 +100,6 @@ class SessionStore:
"""Persist sessions and trees to disk. Caller must hold self._lock."""
try:
data = {
"sessions": {
sid: asdict(record) for sid, record in self._sessions.items()
},
"trees": self._trees,
"node_to_tree": self._node_to_tree,
"message_log": self._message_log,
@ -211,22 +174,14 @@ class SessionStore:
seen.add(mid)
# Optional cap to prevent unbounded growth if configured.
# Default is unlimited as requested.
try:
cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
if cap_raw:
cap = int(cap_raw)
if cap > 0:
items = self._message_log.get(chat_key, [])
if len(items) > cap:
# Drop oldest entries and rebuild seen set.
self._message_log[chat_key] = items[-cap:]
self._message_log_ids[chat_key] = {
str(x.get("message_id"))
for x in self._message_log[chat_key]
}
except Exception:
pass
if self._message_log_cap is not None and self._message_log_cap > 0:
items = self._message_log.get(chat_key, [])
if len(items) > self._message_log_cap:
self._message_log[chat_key] = items[-self._message_log_cap :]
self._message_log_ids[chat_key] = {
str(x.get("message_id"))
for x in self._message_log[chat_key]
}
self._schedule_save()
@ -244,8 +199,6 @@ class SessionStore:
def clear_all(self) -> None:
"""Clear all stored sessions/trees/mappings and persist an empty store."""
with self._lock:
self._sessions.clear()
self._msg_to_session.clear()
self._trees.clear()
self._node_to_tree.clear()
self._message_log.clear()

View file

@ -33,14 +33,17 @@ class Segment:
@dataclass
class ThinkingSegment(Segment):
text: str = ""
def __init__(self) -> None:
super().__init__(kind="thinking")
self._parts: List[str] = []
def append(self, t: str) -> None:
if t:
self.text += t
self._parts.append(t)
@property
def text(self) -> str:
return "".join(self._parts)
def render(self, ctx: "RenderCtx") -> str:
raw = self.text or ""
@ -52,14 +55,17 @@ class ThinkingSegment(Segment):
@dataclass
class TextSegment(Segment):
text: str = ""
def __init__(self) -> None:
super().__init__(kind="text")
self._parts: List[str] = []
def append(self, t: str) -> None:
if t:
self.text += t
self._parts.append(t)
@property
def text(self) -> str:
return "".join(self._parts)
def render(self, ctx: "RenderCtx") -> str:
raw = self.text or ""
@ -72,7 +78,6 @@ class TextSegment(Segment):
class ToolCallSegment(Segment):
tool_use_id: str
name: str
input_text: str = ""
closed: bool = False
indent_level: int = 0
@ -81,18 +86,23 @@ class ToolCallSegment(Segment):
self.tool_use_id = str(tool_use_id or "")
self.name = str(name or "tool")
self.indent_level = max(0, int(indent_level))
self._parts: List[str] = []
def set_initial_input(self, inp: Any) -> None:
if inp is None:
return
if isinstance(inp, str):
self.input_text = inp
self._parts = [inp]
else:
self.input_text = _safe_json_dumps(inp)
self._parts = [_safe_json_dumps(inp)]
def append_input_delta(self, partial: str) -> None:
if partial:
self.input_text += partial
self._parts.append(partial)
@property
def input_text(self) -> str:
return "".join(self._parts)
def render(self, ctx: "RenderCtx") -> str:
name = ctx.code_inline(self.name)

View file

@ -400,10 +400,14 @@ class MessageTree:
"""
if node_id not in self._nodes:
return []
result = [node_id]
node = self._nodes[node_id]
for child_id in node.children_ids:
result.extend(self.get_descendants(child_id))
result: List[str] = []
stack = [node_id]
while stack:
nid = stack.pop()
result.append(nid)
node = self._nodes.get(nid)
if node:
stack.extend(node.children_ids)
return result
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:

View file

@ -115,9 +115,7 @@ class TreeQueueProcessor:
node = tree.get_node(next_node_id)
if node:
tree.set_current_task(
asyncio.create_task(
self.process_node(tree, node, processor)
)
asyncio.create_task(self.process_node(tree, node, processor))
)
# Notify that this node has started processing and refresh queue positions.
@ -155,9 +153,7 @@ class TreeQueueProcessor:
node = tree.get_node(node_id)
if node:
tree.set_current_task(
asyncio.create_task(
self.process_node(tree, node, processor)
)
asyncio.create_task(self.process_node(tree, node, processor))
)
return False

View file

@ -38,4 +38,4 @@ class BaseProvider(ABC):
) -> AsyncIterator[str]:
"""Stream response in Anthropic SSE format."""
if False:
yield ""
yield "" # Required for ty/mypy to accept abstract async generator

View file

@ -62,7 +62,7 @@ class TestSessionStore:
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
assert store._sessions == {}
assert store._trees == {}
# --- Tree Tests ---
@ -127,22 +127,15 @@ class TestSessionStore:
# --- Persistence & Edge Cases ---
def test_load_existing_legacy_format(self, tmp_path):
"""Test loading legacy session format (int IDs) - backward compat."""
def test_load_existing_file_with_trees(self, tmp_path):
"""Test loading file with trees (legacy sessions ignored)."""
from messaging.session import SessionStore
data = {
"sessions": {
"s1": {
"session_id": "s1",
"chat_id": 123, # Legacy int
"initial_msg_id": 100, # Legacy int
"last_msg_id": 101, # Legacy int
"created_at": "2024-01-01",
"updated_at": "2024-01-01",
# platform missing -> should default to telegram
}
}
"sessions": {},
"trees": {"r1": {"root_id": "r1", "nodes": {"r1": {}}}},
"node_to_tree": {"r1": "r1"},
"message_log": {},
}
p = tmp_path / "sessions.json"
@ -150,11 +143,7 @@ class TestSessionStore:
json.dump(data, f)
store = SessionStore(storage_path=str(p))
# Legacy sessions are loaded for backward compat; verify conversion
assert "s1" in store._sessions
rec = store._sessions["s1"]
assert rec.chat_id == "123" # Converted to str
assert rec.platform == "telegram" # Defaulted
assert store.get_tree("r1") is not None
def test_load_corrupt_file(self, tmp_path):
"""Test loading corrupt/invalid json file."""
@ -166,7 +155,7 @@ class TestSessionStore:
# Should log error and start empty, avoiding crash
store = SessionStore(storage_path=str(p))
assert store._sessions == {}
assert store._trees == {}
def test_save_error_handling(self, tmp_path):
"""Test error during save."""

View file

@ -56,9 +56,7 @@ async def test_init(provider_config):
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config, nim_settings=NimSettings()
)
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
assert provider._api_key == "test_key"
assert provider._base_url == "https://test.api.nvidia.com/v1"
mock_openai.assert_called_once()

View file

@ -24,7 +24,6 @@ class TestSessionStoreLoadEdgeCases:
f.write("{invalid json")
store = SessionStore(storage_path=path)
assert len(store._sessions) == 0
assert len(store._trees) == 0
def test_load_truncated_json(self, tmp_path):
@ -34,7 +33,7 @@ class TestSessionStoreLoadEdgeCases:
f.write('{"sessions": {"s1": {"session_id": "s1"')
store = SessionStore(storage_path=path)
assert len(store._sessions) == 0
assert len(store._trees) == 0
def test_load_empty_file(self, tmp_path):
"""Empty file is handled gracefully."""
@ -43,16 +42,16 @@ class TestSessionStoreLoadEdgeCases:
f.write("")
store = SessionStore(storage_path=path)
assert len(store._sessions) == 0
assert len(store._trees) == 0
def test_load_nonexistent_file(self, tmp_path):
"""Non-existent file starts with empty state."""
path = str(tmp_path / "nonexistent.json")
store = SessionStore(storage_path=path)
assert len(store._sessions) == 0
assert len(store._trees) == 0
def test_load_legacy_int_fields(self, tmp_path):
"""Legacy format with int chat_id/msg_id is converted to string."""
def test_load_legacy_sessions_ignored(self, tmp_path):
"""Legacy sessions in file are ignored; trees and message_log load."""
path = str(tmp_path / "sessions.json")
data = {
"sessions": {
@ -66,17 +65,15 @@ class TestSessionStoreLoadEdgeCases:
"updated_at": "2025-01-01T00:00:00+00:00",
}
},
"trees": {},
"node_to_tree": {},
"trees": {"r1": {"root_id": "r1", "nodes": {"r1": {}}}},
"node_to_tree": {"r1": "r1"},
"message_log": {},
}
with open(path, "w") as f:
json.dump(data, f)
store = SessionStore(storage_path=path)
record = store._sessions["s1"]
assert record.chat_id == "12345"
assert record.initial_msg_id == "100"
assert record.last_msg_id == "200"
assert store.get_tree("r1") is not None
class TestSessionStoreSaveEdgeCases:
@ -131,13 +128,11 @@ class TestSessionStoreClearAll:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
assert data["sessions"] == {}
assert data["trees"] == {}
assert data["node_to_tree"] == {}
assert data["message_log"] == {}
store2 = SessionStore(storage_path=path)
assert len(store2._sessions) == 0
assert len(store2._trees) == 0
def test_message_log_persists_and_dedups(self, tmp_path):