diff --git a/api/routes.py b/api/routes.py index cff45d5..fcf7d4c 100644 --- a/api/routes.py +++ b/api/routes.py @@ -100,7 +100,7 @@ async def root(settings: Settings = Depends(get_settings)): """Root endpoint.""" return { "status": "ok", - "provider": "nvidia_nim", + "provider": settings.provider_type, "model": settings.model, } diff --git a/messaging/handler.py b/messaging/handler.py index fb37c44..fd608d4 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -479,7 +479,7 @@ class ClaudeMessageHandler: return if display and display != last_displayed_text: logger.debug( - "TELEGRAM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d", + "PLATFORM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d", node_id, chat_id, status_msg_id, @@ -488,13 +488,13 @@ class ClaudeMessageHandler: len(display), ) if os.getenv("DEBUG_TELEGRAM_EDITS") == "1": - logger.debug("TELEGRAM_EDIT_TEXT:\n%s", display) + logger.debug("PLATFORM_EDIT_TEXT:\n%s", display) else: head = display[:500] tail = display[-500:] if len(display) > 500 else "" - logger.debug("TELEGRAM_EDIT_PREVIEW_HEAD:\n%s", head) + logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n%s", head) if tail: - logger.debug("TELEGRAM_EDIT_PREVIEW_TAIL:\n%s", tail) + logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n%s", tail) last_displayed_text = display try: await self.platform.queue_edit_message( @@ -504,7 +504,9 @@ class ClaudeMessageHandler: parse_mode=self._parse_mode(), ) except Exception as e: - logger.warning(f"Failed to update Telegram for node {node_id}: {e}") + logger.warning( + f"Failed to update platform for node {node_id}: {e}" + ) try: try: diff --git a/messaging/session.py b/messaging/session.py index 407faf9..de25020 100644 --- a/messaging/session.py +++ b/messaging/session.py @@ -9,24 +9,10 @@ import json import os from datetime import datetime, timezone from typing import Optional, Dict, List, Any -from dataclasses import dataclass, asdict import threading from loguru import logger -@dataclass -class SessionRecord: - """A single session record.""" - - session_id: str - chat_id: str - initial_msg_id: str - last_msg_id: str - platform: str - created_at: str - updated_at: str - - class SessionStore: """ Persistent storage for message ↔ Claude session mappings and message trees. @@ -38,10 +24,6 @@ class SessionStore: def __init__(self, storage_path: str = "sessions.json"): self.storage_path = storage_path self._lock = threading.Lock() - self._sessions: Dict[str, SessionRecord] = {} - self._msg_to_session: Dict[ - str, str - ] = {} # "platform:chat_id:msg_id" -> session_id self._trees: Dict[str, dict] = {} # root_id -> tree data self._node_to_tree: Dict[str, str] = {} # node_id -> root_id # Per-chat message ID log used to support best-effort UI clearing (/clear). @@ -51,12 +33,15 @@ class SessionStore: self._dirty = False self._save_timer: Optional[threading.Timer] = None self._save_debounce_secs = 0.5 + cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip() + try: + self._message_log_cap: Optional[int] = ( + int(cap_raw) if cap_raw else None + ) + except ValueError: + self._message_log_cap = None self._load() - def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str: - """Create a unique key from platform, chat_id and msg_id.""" - return f"{platform}:{chat_id}:{msg_id}" - def _make_chat_key(self, platform: str, chat_id: str) -> str: return f"{platform}:{chat_id}" @@ -69,25 +54,6 @@ class SessionStore: with open(self.storage_path, "r", encoding="utf-8") as f: data = json.load(f) - # Load sessions (legacy support) - for sid, record_data in data.get("sessions", {}).items(): - if "platform" not in record_data: - record_data["platform"] = "telegram" - for field in ["chat_id", "initial_msg_id", "last_msg_id"]: - if isinstance(record_data.get(field), int): - record_data[field] = str(record_data[field]) - - record = SessionRecord(**record_data) - self._sessions[sid] = record - self._msg_to_session[ - self._make_key( - record.platform, record.chat_id, record.initial_msg_id - ) - ] = sid - self._msg_to_session[ - self._make_key(record.platform, record.chat_id, record.last_msg_id) - ] = sid - # Load trees self._trees = data.get("trees", {}) self._node_to_tree = data.get("node_to_tree", {}) @@ -124,8 +90,8 @@ class SessionStore: self._message_log_ids[chat_key] = seen logger.info( - f"Loaded {len(self._sessions)} sessions, {len(self._trees)} trees, " - f"and {sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}" + f"Loaded {len(self._trees)} trees and " + f"{sum(len(v) for v in self._message_log.values())} msg_ids from {self.storage_path}" ) except Exception as e: logger.error(f"Failed to load sessions: {e}") @@ -134,9 +100,6 @@ class SessionStore: """Persist sessions and trees to disk. Caller must hold self._lock.""" try: data = { - "sessions": { - sid: asdict(record) for sid, record in self._sessions.items() - }, "trees": self._trees, "node_to_tree": self._node_to_tree, "message_log": self._message_log, @@ -211,22 +174,14 @@ class SessionStore: seen.add(mid) # Optional cap to prevent unbounded growth if configured. - # Default is unlimited as requested. - try: - cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip() - if cap_raw: - cap = int(cap_raw) - if cap > 0: - items = self._message_log.get(chat_key, []) - if len(items) > cap: - # Drop oldest entries and rebuild seen set. - self._message_log[chat_key] = items[-cap:] - self._message_log_ids[chat_key] = { - str(x.get("message_id")) - for x in self._message_log[chat_key] - } - except Exception: - pass + if self._message_log_cap is not None and self._message_log_cap > 0: + items = self._message_log.get(chat_key, []) + if len(items) > self._message_log_cap: + self._message_log[chat_key] = items[-self._message_log_cap :] + self._message_log_ids[chat_key] = { + str(x.get("message_id")) + for x in self._message_log[chat_key] + } self._schedule_save() @@ -244,8 +199,6 @@ class SessionStore: def clear_all(self) -> None: """Clear all stored sessions/trees/mappings and persist an empty store.""" with self._lock: - self._sessions.clear() - self._msg_to_session.clear() self._trees.clear() self._node_to_tree.clear() self._message_log.clear() diff --git a/messaging/transcript.py b/messaging/transcript.py index b89a408..e69c76e 100644 --- a/messaging/transcript.py +++ b/messaging/transcript.py @@ -33,14 +33,17 @@ class Segment: @dataclass class ThinkingSegment(Segment): - text: str = "" - def __init__(self) -> None: super().__init__(kind="thinking") + self._parts: List[str] = [] def append(self, t: str) -> None: if t: - self.text += t + self._parts.append(t) + + @property + def text(self) -> str: + return "".join(self._parts) def render(self, ctx: "RenderCtx") -> str: raw = self.text or "" @@ -52,14 +55,17 @@ class ThinkingSegment(Segment): @dataclass class TextSegment(Segment): - text: str = "" - def __init__(self) -> None: super().__init__(kind="text") + self._parts: List[str] = [] def append(self, t: str) -> None: if t: - self.text += t + self._parts.append(t) + + @property + def text(self) -> str: + return "".join(self._parts) def render(self, ctx: "RenderCtx") -> str: raw = self.text or "" @@ -72,7 +78,6 @@ class TextSegment(Segment): class ToolCallSegment(Segment): tool_use_id: str name: str - input_text: str = "" closed: bool = False indent_level: int = 0 @@ -81,18 +86,23 @@ class ToolCallSegment(Segment): self.tool_use_id = str(tool_use_id or "") self.name = str(name or "tool") self.indent_level = max(0, int(indent_level)) + self._parts: List[str] = [] def set_initial_input(self, inp: Any) -> None: if inp is None: return if isinstance(inp, str): - self.input_text = inp + self._parts = [inp] else: - self.input_text = _safe_json_dumps(inp) + self._parts = [_safe_json_dumps(inp)] def append_input_delta(self, partial: str) -> None: if partial: - self.input_text += partial + self._parts.append(partial) + + @property + def input_text(self) -> str: + return "".join(self._parts) def render(self, ctx: "RenderCtx") -> str: name = ctx.code_inline(self.name) diff --git a/messaging/tree_data.py b/messaging/tree_data.py index 233e063..9438df2 100644 --- a/messaging/tree_data.py +++ b/messaging/tree_data.py @@ -400,10 +400,14 @@ class MessageTree: """ if node_id not in self._nodes: return [] - result = [node_id] - node = self._nodes[node_id] - for child_id in node.children_ids: - result.extend(self.get_descendants(child_id)) + result: List[str] = [] + stack = [node_id] + while stack: + nid = stack.pop() + result.append(nid) + node = self._nodes.get(nid) + if node: + stack.extend(node.children_ids) return result def remove_branch(self, branch_root_id: str) -> List[MessageNode]: diff --git a/messaging/tree_processor.py b/messaging/tree_processor.py index e4fa79b..b5a8828 100644 --- a/messaging/tree_processor.py +++ b/messaging/tree_processor.py @@ -115,9 +115,7 @@ class TreeQueueProcessor: node = tree.get_node(next_node_id) if node: tree.set_current_task( - asyncio.create_task( - self.process_node(tree, node, processor) - ) + asyncio.create_task(self.process_node(tree, node, processor)) ) # Notify that this node has started processing and refresh queue positions. @@ -155,9 +153,7 @@ class TreeQueueProcessor: node = tree.get_node(node_id) if node: tree.set_current_task( - asyncio.create_task( - self.process_node(tree, node, processor) - ) + asyncio.create_task(self.process_node(tree, node, processor)) ) return False diff --git a/providers/base.py b/providers/base.py index 8d4cd57..58625df 100644 --- a/providers/base.py +++ b/providers/base.py @@ -38,4 +38,4 @@ class BaseProvider(ABC): ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" if False: - yield "" + yield "" # Required for ty/mypy to accept abstract async generator diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 1378232..cd50eba 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -62,7 +62,7 @@ class TestSessionStore: from messaging.session import SessionStore store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - assert store._sessions == {} + assert store._trees == {} # --- Tree Tests --- @@ -127,22 +127,15 @@ class TestSessionStore: # --- Persistence & Edge Cases --- - def test_load_existing_legacy_format(self, tmp_path): - """Test loading legacy session format (int IDs) - backward compat.""" + def test_load_existing_file_with_trees(self, tmp_path): + """Test loading file with trees (legacy sessions ignored).""" from messaging.session import SessionStore data = { - "sessions": { - "s1": { - "session_id": "s1", - "chat_id": 123, # Legacy int - "initial_msg_id": 100, # Legacy int - "last_msg_id": 101, # Legacy int - "created_at": "2024-01-01", - "updated_at": "2024-01-01", - # platform missing -> should default to telegram - } - } + "sessions": {}, + "trees": {"r1": {"root_id": "r1", "nodes": {"r1": {}}}}, + "node_to_tree": {"r1": "r1"}, + "message_log": {}, } p = tmp_path / "sessions.json" @@ -150,11 +143,7 @@ class TestSessionStore: json.dump(data, f) store = SessionStore(storage_path=str(p)) - # Legacy sessions are loaded for backward compat; verify conversion - assert "s1" in store._sessions - rec = store._sessions["s1"] - assert rec.chat_id == "123" # Converted to str - assert rec.platform == "telegram" # Defaulted + assert store.get_tree("r1") is not None def test_load_corrupt_file(self, tmp_path): """Test loading corrupt/invalid json file.""" @@ -166,7 +155,7 @@ class TestSessionStore: # Should log error and start empty, avoiding crash store = SessionStore(storage_path=str(p)) - assert store._sessions == {} + assert store._trees == {} def test_save_error_handling(self, tmp_path): """Test error during save.""" diff --git a/tests/test_nvidia_nim.py b/tests/test_nvidia_nim.py index 355781f..253b388 100644 --- a/tests/test_nvidia_nim.py +++ b/tests/test_nvidia_nim.py @@ -56,9 +56,7 @@ async def test_init(provider_config): with patch("providers.openai_compat.AsyncOpenAI") as mock_openai: from config.nim import NimSettings - provider = NvidiaNimProvider( - provider_config, nim_settings=NimSettings() - ) + provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) assert provider._api_key == "test_key" assert provider._base_url == "https://test.api.nvidia.com/v1" mock_openai.assert_called_once() diff --git a/tests/test_session_store_edge_cases.py b/tests/test_session_store_edge_cases.py index ab01332..05962d3 100644 --- a/tests/test_session_store_edge_cases.py +++ b/tests/test_session_store_edge_cases.py @@ -24,7 +24,6 @@ class TestSessionStoreLoadEdgeCases: f.write("{invalid json") store = SessionStore(storage_path=path) - assert len(store._sessions) == 0 assert len(store._trees) == 0 def test_load_truncated_json(self, tmp_path): @@ -34,7 +33,7 @@ class TestSessionStoreLoadEdgeCases: f.write('{"sessions": {"s1": {"session_id": "s1"') store = SessionStore(storage_path=path) - assert len(store._sessions) == 0 + assert len(store._trees) == 0 def test_load_empty_file(self, tmp_path): """Empty file is handled gracefully.""" @@ -43,16 +42,16 @@ class TestSessionStoreLoadEdgeCases: f.write("") store = SessionStore(storage_path=path) - assert len(store._sessions) == 0 + assert len(store._trees) == 0 def test_load_nonexistent_file(self, tmp_path): """Non-existent file starts with empty state.""" path = str(tmp_path / "nonexistent.json") store = SessionStore(storage_path=path) - assert len(store._sessions) == 0 + assert len(store._trees) == 0 - def test_load_legacy_int_fields(self, tmp_path): - """Legacy format with int chat_id/msg_id is converted to string.""" + def test_load_legacy_sessions_ignored(self, tmp_path): + """Legacy sessions in file are ignored; trees and message_log load.""" path = str(tmp_path / "sessions.json") data = { "sessions": { @@ -66,17 +65,15 @@ class TestSessionStoreLoadEdgeCases: "updated_at": "2025-01-01T00:00:00+00:00", } }, - "trees": {}, - "node_to_tree": {}, + "trees": {"r1": {"root_id": "r1", "nodes": {"r1": {}}}}, + "node_to_tree": {"r1": "r1"}, + "message_log": {}, } with open(path, "w") as f: json.dump(data, f) store = SessionStore(storage_path=path) - record = store._sessions["s1"] - assert record.chat_id == "12345" - assert record.initial_msg_id == "100" - assert record.last_msg_id == "200" + assert store.get_tree("r1") is not None class TestSessionStoreSaveEdgeCases: @@ -131,13 +128,11 @@ class TestSessionStoreClearAll: with open(path, "r", encoding="utf-8") as f: data = json.load(f) - assert data["sessions"] == {} assert data["trees"] == {} assert data["node_to_tree"] == {} assert data["message_log"] == {} store2 = SessionStore(storage_path=path) - assert len(store2._sessions) == 0 assert len(store2._trees) == 0 def test_message_log_persists_and_dedups(self, tmp_path):