mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Phase 4-6: Dead code removal, performance, minor fixes
Phase 4: - Remove legacy SessionRecord, _sessions, _msg_to_session from SessionStore - Fix hardcoded provider in root endpoint (use settings.provider_type) - Update session store tests Phase 5: - Use list-based string accumulation in ThinkingSegment, TextSegment, ToolCallSegment - Cache MAX_MESSAGE_LOG_ENTRIES_PER_CHAT at SessionStore init - Use iterative DFS in MessageTree.get_descendants Phase 6: - Add comment for abstract async generator workaround in BaseProvider - Rename TELEGRAM_EDIT log tags to PLATFORM_EDIT in handler Co-authored-by: Ali Khokhar <alishahryar2@gmail.com>
This commit is contained in:
parent
72b7e34999
commit
bfc781e0ed
10 changed files with 75 additions and 128 deletions
|
|
@ -100,7 +100,7 @@ async def root(settings: Settings = Depends(get_settings)):
|
|||
"""Root endpoint."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"provider": "nvidia_nim",
|
||||
"provider": settings.provider_type,
|
||||
"model": settings.model,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ class ClaudeMessageHandler:
|
|||
return
|
||||
if display and display != last_displayed_text:
|
||||
logger.debug(
|
||||
"TELEGRAM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d",
|
||||
"PLATFORM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d",
|
||||
node_id,
|
||||
chat_id,
|
||||
status_msg_id,
|
||||
|
|
@ -488,13 +488,13 @@ class ClaudeMessageHandler:
|
|||
len(display),
|
||||
)
|
||||
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
|
||||
logger.debug("TELEGRAM_EDIT_TEXT:\n%s", display)
|
||||
logger.debug("PLATFORM_EDIT_TEXT:\n%s", display)
|
||||
else:
|
||||
head = display[:500]
|
||||
tail = display[-500:] if len(display) > 500 else ""
|
||||
logger.debug("TELEGRAM_EDIT_PREVIEW_HEAD:\n%s", head)
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n%s", head)
|
||||
if tail:
|
||||
logger.debug("TELEGRAM_EDIT_PREVIEW_TAIL:\n%s", tail)
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n%s", tail)
|
||||
last_displayed_text = display
|
||||
try:
|
||||
await self.platform.queue_edit_message(
|
||||
|
|
@ -504,7 +504,9 @@ class ClaudeMessageHandler:
|
|||
parse_mode=self._parse_mode(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update Telegram for node {node_id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to update platform for node {node_id}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -9,24 +9,10 @@ import json
|
|||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, List, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
import threading
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionRecord:
|
||||
"""A single session record."""
|
||||
|
||||
session_id: str
|
||||
chat_id: str
|
||||
initial_msg_id: str
|
||||
last_msg_id: str
|
||||
platform: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
Persistent storage for message ↔ Claude session mappings and message trees.
|
||||
|
|
@ -38,10 +24,6 @@ class SessionStore:
|
|||
def __init__(self, storage_path: str = "sessions.json"):
|
||||
self.storage_path = storage_path
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: Dict[str, SessionRecord] = {}
|
||||
self._msg_to_session: Dict[
|
||||
str, str
|
||||
] = {} # "platform:chat_id:msg_id" -> session_id
|
||||
self._trees: Dict[str, dict] = {} # root_id -> tree data
|
||||
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
|
||||
# Per-chat message ID log used to support best-effort UI clearing (/clear).
|
||||
|
|
@ -51,12 +33,15 @@ class SessionStore:
|
|||
self._dirty = False
|
||||
self._save_timer: Optional[threading.Timer] = None
|
||||
self._save_debounce_secs = 0.5
|
||||
cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
|
||||
try:
|
||||
self._message_log_cap: Optional[int] = (
|
||||
int(cap_raw) if cap_raw else None
|
||||
)
|
||||
except ValueError:
|
||||
self._message_log_cap = None
|
||||
self._load()
|
||||
|
||||
def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str:
|
||||
"""Create a unique key from platform, chat_id and msg_id."""
|
||||
return f"{platform}:{chat_id}:{msg_id}"
|
||||
|
||||
def _make_chat_key(self, platform: str, chat_id: str) -> str:
|
||||
return f"{platform}:{chat_id}"
|
||||
|
||||
|
|
@ -69,25 +54,6 @@ class SessionStore:
|
|||
with open(self.storage_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Load sessions (legacy support)
|
||||
for sid, record_data in data.get("sessions", {}).items():
|
||||
if "platform" not in record_data:
|
||||
record_data["platform"] = "telegram"
|
||||
for field in ["chat_id", "initial_msg_id", "last_msg_id"]:
|
||||
if isinstance(record_data.get(field), int):
|
||||
record_data[field] = str(record_data[field])
|
||||
|
||||
record = SessionRecord(**record_data)
|
||||
self._sessions[sid] = record
|
||||
self._msg_to_session[
|
||||
self._make_key(
|
||||
record.platform, record.chat_id, record.initial_msg_id
|
||||
)
|
||||
] = sid
|
||||
self._msg_to_session[
|
||||
self._make_key(record.platform, record.chat_id, record.last_msg_id)
|
||||
] = sid
|
||||
|
||||
# Load trees
|
||||
self._trees = data.get("trees", {})
|
||||
self._node_to_tree = data.get("node_to_tree", {})
|
||||
|
|
@ -124,8 +90,8 @@ class SessionStore:
|
|||
self._message_log_ids[chat_key] = seen
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(self._sessions)} sessions, {len(self._trees)} trees, "
|
||||
f"and {sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}"
|
||||
f"Loaded {len(self._trees)} trees and "
|
||||
f"{sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load sessions: {e}")
|
||||
|
|
@ -134,9 +100,6 @@ class SessionStore:
|
|||
"""Persist sessions and trees to disk. Caller must hold self._lock."""
|
||||
try:
|
||||
data = {
|
||||
"sessions": {
|
||||
sid: asdict(record) for sid, record in self._sessions.items()
|
||||
},
|
||||
"trees": self._trees,
|
||||
"node_to_tree": self._node_to_tree,
|
||||
"message_log": self._message_log,
|
||||
|
|
@ -211,22 +174,14 @@ class SessionStore:
|
|||
seen.add(mid)
|
||||
|
||||
# Optional cap to prevent unbounded growth if configured.
|
||||
# Default is unlimited as requested.
|
||||
try:
|
||||
cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
|
||||
if cap_raw:
|
||||
cap = int(cap_raw)
|
||||
if cap > 0:
|
||||
items = self._message_log.get(chat_key, [])
|
||||
if len(items) > cap:
|
||||
# Drop oldest entries and rebuild seen set.
|
||||
self._message_log[chat_key] = items[-cap:]
|
||||
self._message_log_ids[chat_key] = {
|
||||
str(x.get("message_id"))
|
||||
for x in self._message_log[chat_key]
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
if self._message_log_cap is not None and self._message_log_cap > 0:
|
||||
items = self._message_log.get(chat_key, [])
|
||||
if len(items) > self._message_log_cap:
|
||||
self._message_log[chat_key] = items[-self._message_log_cap :]
|
||||
self._message_log_ids[chat_key] = {
|
||||
str(x.get("message_id"))
|
||||
for x in self._message_log[chat_key]
|
||||
}
|
||||
|
||||
self._schedule_save()
|
||||
|
||||
|
|
@ -244,8 +199,6 @@ class SessionStore:
|
|||
def clear_all(self) -> None:
|
||||
"""Clear all stored sessions/trees/mappings and persist an empty store."""
|
||||
with self._lock:
|
||||
self._sessions.clear()
|
||||
self._msg_to_session.clear()
|
||||
self._trees.clear()
|
||||
self._node_to_tree.clear()
|
||||
self._message_log.clear()
|
||||
|
|
|
|||
|
|
@ -33,14 +33,17 @@ class Segment:
|
|||
|
||||
@dataclass
|
||||
class ThinkingSegment(Segment):
|
||||
text: str = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kind="thinking")
|
||||
self._parts: List[str] = []
|
||||
|
||||
def append(self, t: str) -> None:
|
||||
if t:
|
||||
self.text += t
|
||||
self._parts.append(t)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return "".join(self._parts)
|
||||
|
||||
def render(self, ctx: "RenderCtx") -> str:
|
||||
raw = self.text or ""
|
||||
|
|
@ -52,14 +55,17 @@ class ThinkingSegment(Segment):
|
|||
|
||||
@dataclass
|
||||
class TextSegment(Segment):
|
||||
text: str = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kind="text")
|
||||
self._parts: List[str] = []
|
||||
|
||||
def append(self, t: str) -> None:
|
||||
if t:
|
||||
self.text += t
|
||||
self._parts.append(t)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return "".join(self._parts)
|
||||
|
||||
def render(self, ctx: "RenderCtx") -> str:
|
||||
raw = self.text or ""
|
||||
|
|
@ -72,7 +78,6 @@ class TextSegment(Segment):
|
|||
class ToolCallSegment(Segment):
|
||||
tool_use_id: str
|
||||
name: str
|
||||
input_text: str = ""
|
||||
closed: bool = False
|
||||
indent_level: int = 0
|
||||
|
||||
|
|
@ -81,18 +86,23 @@ class ToolCallSegment(Segment):
|
|||
self.tool_use_id = str(tool_use_id or "")
|
||||
self.name = str(name or "tool")
|
||||
self.indent_level = max(0, int(indent_level))
|
||||
self._parts: List[str] = []
|
||||
|
||||
def set_initial_input(self, inp: Any) -> None:
|
||||
if inp is None:
|
||||
return
|
||||
if isinstance(inp, str):
|
||||
self.input_text = inp
|
||||
self._parts = [inp]
|
||||
else:
|
||||
self.input_text = _safe_json_dumps(inp)
|
||||
self._parts = [_safe_json_dumps(inp)]
|
||||
|
||||
def append_input_delta(self, partial: str) -> None:
|
||||
if partial:
|
||||
self.input_text += partial
|
||||
self._parts.append(partial)
|
||||
|
||||
@property
|
||||
def input_text(self) -> str:
|
||||
return "".join(self._parts)
|
||||
|
||||
def render(self, ctx: "RenderCtx") -> str:
|
||||
name = ctx.code_inline(self.name)
|
||||
|
|
|
|||
|
|
@ -400,10 +400,14 @@ class MessageTree:
|
|||
"""
|
||||
if node_id not in self._nodes:
|
||||
return []
|
||||
result = [node_id]
|
||||
node = self._nodes[node_id]
|
||||
for child_id in node.children_ids:
|
||||
result.extend(self.get_descendants(child_id))
|
||||
result: List[str] = []
|
||||
stack = [node_id]
|
||||
while stack:
|
||||
nid = stack.pop()
|
||||
result.append(nid)
|
||||
node = self._nodes.get(nid)
|
||||
if node:
|
||||
stack.extend(node.children_ids)
|
||||
return result
|
||||
|
||||
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
|
|
|
|||
|
|
@ -115,9 +115,7 @@ class TreeQueueProcessor:
|
|||
node = tree.get_node(next_node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(
|
||||
self.process_node(tree, node, processor)
|
||||
)
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
|
||||
# Notify that this node has started processing and refresh queue positions.
|
||||
|
|
@ -155,9 +153,7 @@ class TreeQueueProcessor:
|
|||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(
|
||||
self.process_node(tree, node, processor)
|
||||
)
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -38,4 +38,4 @@ class BaseProvider(ABC):
|
|||
) -> AsyncIterator[str]:
|
||||
"""Stream response in Anthropic SSE format."""
|
||||
if False:
|
||||
yield ""
|
||||
yield "" # Required for ty/mypy to accept abstract async generator
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class TestSessionStore:
|
|||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
assert store._sessions == {}
|
||||
assert store._trees == {}
|
||||
|
||||
# --- Tree Tests ---
|
||||
|
||||
|
|
@ -127,22 +127,15 @@ class TestSessionStore:
|
|||
|
||||
# --- Persistence & Edge Cases ---
|
||||
|
||||
def test_load_existing_legacy_format(self, tmp_path):
|
||||
"""Test loading legacy session format (int IDs) - backward compat."""
|
||||
def test_load_existing_file_with_trees(self, tmp_path):
|
||||
"""Test loading file with trees (legacy sessions ignored)."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
data = {
|
||||
"sessions": {
|
||||
"s1": {
|
||||
"session_id": "s1",
|
||||
"chat_id": 123, # Legacy int
|
||||
"initial_msg_id": 100, # Legacy int
|
||||
"last_msg_id": 101, # Legacy int
|
||||
"created_at": "2024-01-01",
|
||||
"updated_at": "2024-01-01",
|
||||
# platform missing -> should default to telegram
|
||||
}
|
||||
}
|
||||
"sessions": {},
|
||||
"trees": {"r1": {"root_id": "r1", "nodes": {"r1": {}}}},
|
||||
"node_to_tree": {"r1": "r1"},
|
||||
"message_log": {},
|
||||
}
|
||||
|
||||
p = tmp_path / "sessions.json"
|
||||
|
|
@ -150,11 +143,7 @@ class TestSessionStore:
|
|||
json.dump(data, f)
|
||||
|
||||
store = SessionStore(storage_path=str(p))
|
||||
# Legacy sessions are loaded for backward compat; verify conversion
|
||||
assert "s1" in store._sessions
|
||||
rec = store._sessions["s1"]
|
||||
assert rec.chat_id == "123" # Converted to str
|
||||
assert rec.platform == "telegram" # Defaulted
|
||||
assert store.get_tree("r1") is not None
|
||||
|
||||
def test_load_corrupt_file(self, tmp_path):
|
||||
"""Test loading corrupt/invalid json file."""
|
||||
|
|
@ -166,7 +155,7 @@ class TestSessionStore:
|
|||
|
||||
# Should log error and start empty, avoiding crash
|
||||
store = SessionStore(storage_path=str(p))
|
||||
assert store._sessions == {}
|
||||
assert store._trees == {}
|
||||
|
||||
def test_save_error_handling(self, tmp_path):
|
||||
"""Test error during save."""
|
||||
|
|
|
|||
|
|
@ -56,9 +56,7 @@ async def test_init(provider_config):
|
|||
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config, nim_settings=NimSettings()
|
||||
)
|
||||
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
assert provider._api_key == "test_key"
|
||||
assert provider._base_url == "https://test.api.nvidia.com/v1"
|
||||
mock_openai.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ class TestSessionStoreLoadEdgeCases:
|
|||
f.write("{invalid json")
|
||||
|
||||
store = SessionStore(storage_path=path)
|
||||
assert len(store._sessions) == 0
|
||||
assert len(store._trees) == 0
|
||||
|
||||
def test_load_truncated_json(self, tmp_path):
|
||||
|
|
@ -34,7 +33,7 @@ class TestSessionStoreLoadEdgeCases:
|
|||
f.write('{"sessions": {"s1": {"session_id": "s1"')
|
||||
|
||||
store = SessionStore(storage_path=path)
|
||||
assert len(store._sessions) == 0
|
||||
assert len(store._trees) == 0
|
||||
|
||||
def test_load_empty_file(self, tmp_path):
|
||||
"""Empty file is handled gracefully."""
|
||||
|
|
@ -43,16 +42,16 @@ class TestSessionStoreLoadEdgeCases:
|
|||
f.write("")
|
||||
|
||||
store = SessionStore(storage_path=path)
|
||||
assert len(store._sessions) == 0
|
||||
assert len(store._trees) == 0
|
||||
|
||||
def test_load_nonexistent_file(self, tmp_path):
|
||||
"""Non-existent file starts with empty state."""
|
||||
path = str(tmp_path / "nonexistent.json")
|
||||
store = SessionStore(storage_path=path)
|
||||
assert len(store._sessions) == 0
|
||||
assert len(store._trees) == 0
|
||||
|
||||
def test_load_legacy_int_fields(self, tmp_path):
|
||||
"""Legacy format with int chat_id/msg_id is converted to string."""
|
||||
def test_load_legacy_sessions_ignored(self, tmp_path):
|
||||
"""Legacy sessions in file are ignored; trees and message_log load."""
|
||||
path = str(tmp_path / "sessions.json")
|
||||
data = {
|
||||
"sessions": {
|
||||
|
|
@ -66,17 +65,15 @@ class TestSessionStoreLoadEdgeCases:
|
|||
"updated_at": "2025-01-01T00:00:00+00:00",
|
||||
}
|
||||
},
|
||||
"trees": {},
|
||||
"node_to_tree": {},
|
||||
"trees": {"r1": {"root_id": "r1", "nodes": {"r1": {}}}},
|
||||
"node_to_tree": {"r1": "r1"},
|
||||
"message_log": {},
|
||||
}
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
store = SessionStore(storage_path=path)
|
||||
record = store._sessions["s1"]
|
||||
assert record.chat_id == "12345"
|
||||
assert record.initial_msg_id == "100"
|
||||
assert record.last_msg_id == "200"
|
||||
assert store.get_tree("r1") is not None
|
||||
|
||||
|
||||
class TestSessionStoreSaveEdgeCases:
|
||||
|
|
@ -131,13 +128,11 @@ class TestSessionStoreClearAll:
|
|||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
assert data["sessions"] == {}
|
||||
assert data["trees"] == {}
|
||||
assert data["node_to_tree"] == {}
|
||||
assert data["message_log"] == {}
|
||||
|
||||
store2 = SessionStore(storage_path=path)
|
||||
assert len(store2._sessions) == 0
|
||||
assert len(store2._trees) == 0
|
||||
|
||||
def test_message_log_persists_and_dedups(self, tmp_path):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue