diff --git a/api/dependencies.py b/api/dependencies.py index 5ae85d9..5eac1a6 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -33,6 +33,8 @@ def get_provider() -> NvidiaNimProvider: async def cleanup_provider(): """Cleanup provider resources.""" global _provider - if _provider and hasattr(_provider, "_client"): - await _provider._client.aclose() + if _provider: + client = getattr(_provider, "_client", None) + if client and hasattr(client, "aclose"): + await client.aclose() _provider = None diff --git a/api/request_utils.py b/api/request_utils.py index 31f83e7..5c5e579 100644 --- a/api/request_utils.py +++ b/api/request_utils.py @@ -40,7 +40,8 @@ def is_quota_check_request(request_data: MessagesRequest) -> bool: # Check list content elif isinstance(content, list): for block in content: - if hasattr(block, "text") and "quota" in block.text.lower(): + text = getattr(block, "text", "") + if text and isinstance(text, str) and "quota" in text.lower(): return True return False @@ -66,7 +67,8 @@ def is_title_generation_request(request_data: MessagesRequest) -> bool: # Check list content elif isinstance(content, list): for block in content: - if hasattr(block, "text") and target_phrase in block.text.lower(): + text = getattr(block, "text", "") + if text and isinstance(text, str) and target_phrase in text.lower(): return True return False @@ -159,8 +161,9 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> Tuple[bool, st content = msg.content elif isinstance(msg.content, list): for block in msg.content: - if hasattr(block, "text"): - content += block.text + text = getattr(block, "text", "") + if text and isinstance(text, str): + content += text if "" in content and "Command:" in content: try: diff --git a/cli/parser.py b/cli/parser.py index b7f6bf1..29baa0c 100644 --- a/cli/parser.py +++ b/cli/parser.py @@ -1,7 +1,7 @@ """CLI event parser for Claude Code CLI output.""" import logging -from typing import Dict, List +from typing import Dict, List, Any logger = logging.getLogger(__name__) @@ -10,7 +10,7 @@ class CLIParser: """Helper to structure raw CLI events.""" @staticmethod - def parse_event(event: Dict) -> List[Dict]: + def parse_event(event: Any) -> List[Dict]: """ Parse a CLI event and return a structured result. diff --git a/cli/session.py b/cli/session.py index 812e19e..8f64400 100644 --- a/cli/session.py +++ b/cli/session.py @@ -4,7 +4,7 @@ import asyncio import os import json import logging -from typing import AsyncGenerator, Optional, Dict, List +from typing import AsyncGenerator, Optional, Dict, List, Any logger = logging.getLogger(__name__) @@ -186,7 +186,7 @@ class CLISession: logger.debug(f"Non-JSON output: {line_str[:100]}") yield {"type": "raw", "content": line_str} - def _extract_session_id(self, event: Dict) -> Optional[str]: + def _extract_session_id(self, event: Any) -> Optional[str]: """Extract session ID from CLI event.""" if not isinstance(event, dict): return None diff --git a/messaging/base.py b/messaging/base.py index f147bf7..1b12a5b 100644 --- a/messaging/base.py +++ b/messaging/base.py @@ -1,17 +1,29 @@ """Abstract base class for messaging platforms.""" from abc import ABC, abstractmethod -from typing import Callable, Awaitable, Optional, Protocol, Tuple, runtime_checkable +from typing import ( + Callable, + Awaitable, + Optional, + Protocol, + Tuple, + runtime_checkable, + AsyncGenerator, + Any, + Dict, +) from .models import IncomingMessage -class CLISession(ABC): - """Abstract base for CLI session - avoid circular import from cli package.""" +@runtime_checkable +class CLISession(Protocol): + """Protocol for CLI session - avoid circular import from cli package.""" - @abstractmethod - async def start_task(self, prompt: str, session_id: Optional[str] = None): + def start_task( + self, prompt: str, session_id: Optional[str] = None + ) -> AsyncGenerator[Dict, Any]: """Start a task in the CLI session.""" - pass + ... @property @abstractmethod @@ -162,6 +174,11 @@ class MessagingPlatform(ABC): """ pass + @abstractmethod + def fire_and_forget(self, task: Awaitable[Any]) -> None: + """Execute a coroutine without awaiting it.""" + pass + @property def is_connected(self) -> bool: """Check if the platform is connected.""" diff --git a/messaging/event_parser.py b/messaging/event_parser.py index b609561..5de5e34 100644 --- a/messaging/event_parser.py +++ b/messaging/event_parser.py @@ -4,12 +4,12 @@ Extracted from cli.parser to avoid tight coupling between messaging and cli pack """ import logging -from typing import Dict, List +from typing import Dict, List, Any logger = logging.getLogger(__name__) -def parse_cli_event(event: Dict) -> List[Dict]: +def parse_cli_event(event: Any) -> List[Dict]: """ Parse a CLI event and return a structured result. diff --git a/messaging/handler.py b/messaging/handler.py index d7b2c50..3c375ee 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -72,14 +72,13 @@ class ClaudeMessageHandler: parent_node_id = None tree = None - if incoming.is_reply(): + if incoming.is_reply() and incoming.reply_to_message_id: # Look up if the replied-to message is in any tree (could be a node or status message) - tree = self.tree_queue.get_tree_for_node(incoming.reply_to_message_id) + reply_id = incoming.reply_to_message_id + tree = self.tree_queue.get_tree_for_node(reply_id) if tree: # Resolve to actual node ID (handles status message replies) - parent_node_id = self.tree_queue.resolve_parent_node_id( - incoming.reply_to_message_id - ) + parent_node_id = self.tree_queue.resolve_parent_node_id(reply_id) if parent_node_id: logger.info(f"Found tree for reply, parent node: {parent_node_id}") else: @@ -106,7 +105,7 @@ class ClaudeMessageHandler: ) # Create or extend tree - if parent_node_id and tree: + if parent_node_id and tree and status_msg_id: # Reply to existing node - add as child tree, node = await self.tree_queue.add_to_tree( parent_node_id=parent_node_id, @@ -118,7 +117,7 @@ class ClaudeMessageHandler: self.tree_queue.register_node(status_msg_id, tree.root_id) self.session_store.register_node(status_msg_id, tree.root_id) self.session_store.register_node(node_id, tree.root_id) - else: + elif status_msg_id: # New conversation - create new tree tree = await self.tree_queue.create_tree( node_id=node_id, @@ -131,7 +130,8 @@ class ClaudeMessageHandler: self.session_store.register_node(status_msg_id, tree.root_id) # Persist tree - self.session_store.save_tree(tree.root_id, tree.to_dict()) + if tree: + self.session_store.save_tree(tree.root_id, tree.to_dict()) # Enqueue for processing was_queued = await self.tree_queue.enqueue( @@ -139,7 +139,7 @@ class ClaudeMessageHandler: processor=self._process_node, ) - if was_queued: + if was_queued and status_msg_id: # Update status to show queue position queue_size = self.tree_queue.get_queue_size(node_id) await self.platform.queue_edit_message( @@ -454,7 +454,7 @@ class ClaudeMessageHandler: parent_node_id: Optional[str], ) -> str: """Get initial status message text.""" - if tree: + if tree and parent_node_id: # Reply to existing tree if self.tree_queue.is_node_tree_busy(parent_node_id): queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1 diff --git a/messaging/limiter.py b/messaging/limiter.py index 727b4b6..985f4df 100644 --- a/messaging/limiter.py +++ b/messaging/limiter.py @@ -115,8 +115,13 @@ class GlobalRateLimiter: logger.error( f"FloodWait detected! Pausing worker for {seconds}s" ) + wait_secs = ( + float(seconds) + if isinstance(seconds, (int, float, str)) + else 30.0 + ) self._paused_until = ( - asyncio.get_event_loop().time() + seconds + asyncio.get_event_loop().time() + wait_secs ) else: logger.error( diff --git a/messaging/models.py b/messaging/models.py index 5337b3d..d1aa11c 100644 --- a/messaging/models.py +++ b/messaging/models.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Optional, Any -from datetime import datetime +from datetime import datetime, timezone @dataclass @@ -22,7 +22,7 @@ class IncomingMessage: # Optional fields reply_to_message_id: Optional[str] = None username: Optional[str] = None - timestamp: datetime = field(default_factory=datetime.utcnow) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) # Platform-specific raw event for edge cases raw_event: Any = None diff --git a/messaging/session.py b/messaging/session.py index 1bb7075..9e82af3 100644 --- a/messaging/session.py +++ b/messaging/session.py @@ -8,7 +8,7 @@ and message trees for conversation continuation. import json import os import logging -from datetime import datetime +from datetime import datetime, timezone from typing import Optional, Dict from dataclasses import dataclass, asdict import threading @@ -116,7 +116,7 @@ class SessionStore: ) -> None: """Save a new session mapping.""" with self._lock: - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() record = SessionRecord( session_id=session_id, chat_id=str(chat_id), @@ -152,7 +152,7 @@ class SessionStore: record = self._sessions[session_id] record.last_msg_id = str(msg_id) - record.updated_at = datetime.utcnow().isoformat() + record.updated_at = datetime.now(timezone.utc).isoformat() new_key = self._make_key(record.platform, record.chat_id, str(msg_id)) self._msg_to_session[new_key] = session_id self._save() @@ -167,7 +167,7 @@ class SessionStore: record = self._sessions.pop(old_id) record.session_id = new_id - record.updated_at = datetime.utcnow().isoformat() + record.updated_at = datetime.now(timezone.utc).isoformat() self._sessions[new_id] = record items_to_update = [ @@ -190,7 +190,7 @@ class SessionStore: def cleanup_old_sessions(self, max_age_days: int = 30) -> int: """Remove sessions older than max_age_days.""" with self._lock: - cutoff = datetime.utcnow() + cutoff = datetime.now(timezone.utc) removed = 0 to_remove = [] @@ -284,7 +284,7 @@ class SessionStore: def cleanup_old_trees(self, max_age_days: int = 30) -> int: """Remove trees older than max_age_days.""" with self._lock: - cutoff = datetime.utcnow() + cutoff = datetime.now(timezone.utc) removed = 0 to_remove = [] diff --git a/messaging/telegram.py b/messaging/telegram.py index e987a51..cc40921 100644 --- a/messaging/telegram.py +++ b/messaging/telegram.py @@ -105,9 +105,10 @@ class TelegramPlatform(MessagingPlatform): await self._application.start() # Start polling (non-blocking way for integration) - await self._application.updater.start_polling( - drop_pending_updates=False - ) + if self._application.updater: + await self._application.updater.start_polling( + drop_pending_updates=False + ) self._connected = True break @@ -141,7 +142,7 @@ class TelegramPlatform(MessagingPlatform): async def stop(self) -> None: """Stop the bot.""" - if self._application: + if self._application and self._application.updater: await self._application.updater.stop() await self._application.stop() await self._application.shutdown() @@ -175,9 +176,11 @@ class TelegramPlatform(MessagingPlatform): raise except RetryAfter as e: # Telegram explicitly tells us to wait - wait_secs = e.retry_after - if hasattr(wait_secs, "total_seconds"): - wait_secs = wait_secs.total_seconds() + retry_after = e.retry_after + if hasattr(retry_after, "total_seconds"): + wait_secs = float(retry_after.total_seconds()) # type: ignore + else: + wait_secs = float(retry_after) logger.warning(f"Rate limited by Telegram, waiting {wait_secs}s...") await asyncio.sleep(wait_secs) @@ -201,11 +204,12 @@ class TelegramPlatform(MessagingPlatform): parse_mode: Optional[str] = "Markdown", ) -> str: """Send a message to a chat.""" - if not self._application: - raise RuntimeError("Telegram application not initialized") + if not self._application or not self._application.bot: + raise RuntimeError("Telegram application or bot not initialized") async def _do_send(mode=parse_mode): - msg = await self._application.bot.send_message( + bot = self._application.bot # type: ignore + msg = await bot.send_message( chat_id=chat_id, text=text, reply_to_message_id=int(reply_to) if reply_to else None, @@ -223,11 +227,12 @@ class TelegramPlatform(MessagingPlatform): parse_mode: Optional[str] = "Markdown", ) -> None: """Edit an existing message.""" - if not self._application: - raise RuntimeError("Telegram application not initialized") + if not self._application or not self._application.bot: + raise RuntimeError("Telegram application or bot not initialized") async def _do_edit(mode=parse_mode): - await self._application.bot.edit_message_text( + bot = self._application.bot # type: ignore + await bot.edit_message_text( chat_id=chat_id, message_id=int(message_id), text=text, @@ -281,7 +286,10 @@ class TelegramPlatform(MessagingPlatform): def fire_and_forget(self, task: Awaitable[Any]) -> None: """Execute a coroutine without awaiting it.""" - asyncio.create_task(task) + if asyncio.iscoroutine(task): + asyncio.create_task(task) # type: ignore + else: + asyncio.ensure_future(task) def on_message( self, @@ -299,7 +307,8 @@ class TelegramPlatform(MessagingPlatform): self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: """Handle /start command.""" - await update.message.reply_text("👋 Hello! I am the Claude Code Proxy Bot.") + if update.message: + await update.message.reply_text("👋 Hello! I am the Claude Code Proxy Bot.") # We can also treat this as a message if we want it to trigger something await self._on_telegram_message(update, context) @@ -307,7 +316,12 @@ class TelegramPlatform(MessagingPlatform): self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: """Handle incoming updates.""" - if not update.message or not update.message.text: + if ( + not update.message + or not update.message.text + or not update.effective_user + or not update.effective_chat + ): return user_id = str(update.effective_user.id) diff --git a/messaging/tree_data.py b/messaging/tree_data.py index a51d6fa..e51dc9e 100644 --- a/messaging/tree_data.py +++ b/messaging/tree_data.py @@ -6,7 +6,7 @@ Contains MessageState, MessageNode, and MessageTree classes. import asyncio import logging from enum import Enum -from datetime import datetime +from datetime import datetime, timezone from typing import Dict, Optional, List, Any from dataclasses import dataclass, field @@ -42,7 +42,7 @@ class MessageNode: parent_id: Optional[str] = None # Parent node ID (None for root) session_id: Optional[str] = None # Claude session ID (forked from parent) children_ids: List[str] = field(default_factory=list) - created_at: datetime = field(default_factory=datetime.utcnow) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) completed_at: Optional[datetime] = None error_message: Optional[str] = None context: Any = None # Additional context if needed @@ -221,7 +221,7 @@ class MessageTree: if error_message: node.error_message = error_message if state in (MessageState.COMPLETED, MessageState.ERROR): - node.completed_at = datetime.utcnow() + node.completed_at = datetime.now(timezone.utc) logger.debug(f"Node {node_id} state -> {state.value}") diff --git a/providers/base.py b/providers/base.py index fe2b610..550eec9 100644 --- a/providers/base.py +++ b/providers/base.py @@ -30,7 +30,8 @@ class BaseProvider(ABC): self, request: Any, input_tokens: int = 0 ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" - pass + if False: + yield "" @abstractmethod def convert_response(self, response_json: dict, original_request: Any) -> Any: diff --git a/providers/logging_utils.py b/providers/logging_utils.py index 906aba3..348c859 100644 --- a/providers/logging_utils.py +++ b/providers/logging_utils.py @@ -8,7 +8,7 @@ import hashlib import json import logging import os -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) @@ -133,7 +133,7 @@ def log_full_payload(request_id: str, payload: Dict[str, Any]) -> None: try: handler = _get_debug_handler() record = { - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(timezone.utc).isoformat(), "request_id": request_id, "payload": payload, } diff --git a/tests/conftest.py b/tests/conftest.py index f4ffb15..3d5b090 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,6 +92,10 @@ def incoming_message_factory(): "platform": "telegram", } defaults.update(kwargs) - return IncomingMessage(**defaults) + if "timestamp" in defaults and isinstance(defaults["timestamp"], str): + from datetime import datetime + + defaults["timestamp"] = datetime.fromisoformat(defaults["timestamp"]) + return IncomingMessage(**defaults) # type: ignore return _create diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 5f10bd6..8cd8711 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -3,7 +3,7 @@ import pytest import json import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest.mock import patch # --- Existing Tests --- @@ -125,6 +125,7 @@ class TestSessionStore: # Verify record updated record = store.get_session_record("sess_1") + assert record is not None assert record.last_msg_id == "msg_2" def test_update_last_message_unknown_session(self, tmp_path): @@ -194,7 +195,7 @@ class TestSessionStore: store = SessionStore(storage_path=str(tmp_path / "sessions.json")) # Create an old session manually - old_date = (datetime.utcnow() - timedelta(days=40)).isoformat() + old_date = (datetime.now(timezone.utc) - timedelta(days=40)).isoformat() store.save_session("old_sess", "c_old", "m_old") # Manipulate the created_at directly store._sessions["old_sess"].created_at = old_date @@ -261,6 +262,7 @@ class TestSessionStore: store.update_tree_node("r1", "n2", {"data": "test"}) tree = store.get_tree("r1") + assert tree is not None assert "n2" in tree["nodes"] assert tree["nodes"]["n2"]["data"] == "test" assert store.get_tree_root_for_node("n2") == "r1" @@ -286,7 +288,7 @@ class TestSessionStore: store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - old_date = (datetime.utcnow() - timedelta(days=40)).isoformat() + old_date = (datetime.now(timezone.utc) - timedelta(days=40)).isoformat() # Old tree store.save_tree( @@ -296,7 +298,11 @@ class TestSessionStore: # New tree store.save_tree( "new_root", - {"nodes": {"new_root": {"created_at": datetime.utcnow().isoformat()}}}, + { + "nodes": { + "new_root": {"created_at": datetime.now(timezone.utc).isoformat()} + } + }, ) removed = store.cleanup_old_trees(30) @@ -334,6 +340,7 @@ class TestSessionStore: store = SessionStore(storage_path=str(p)) rec = store.get_session_record("s1") + assert rec is not None assert rec.chat_id == "123" # Converted to str assert rec.platform == "telegram" # Defaulted diff --git a/tests/test_tree_queue.py b/tests/test_tree_queue.py index 7556ce5..2d6194b 100644 --- a/tests/test_tree_queue.py +++ b/tests/test_tree_queue.py @@ -155,7 +155,9 @@ class TestMessageTree: assert child.node_id == "child" assert child.parent_id == "root" assert "child" in tree.get_root().children_ids - assert tree.get_parent("child").node_id == "root" + parent = tree.get_parent("child") + assert parent is not None + assert parent.node_id == "root" @pytest.mark.asyncio async def test_update_state(self): @@ -171,10 +173,13 @@ class TestMessageTree: tree = MessageTree(root) await tree.update_state("m1", MessageState.IN_PROGRESS) - assert tree.get_node("m1").state == MessageState.IN_PROGRESS + node = tree.get_node("m1") + assert node is not None + assert node.state == MessageState.IN_PROGRESS await tree.update_state("m1", MessageState.COMPLETED, session_id="sess_abc") node = tree.get_node("m1") + assert node is not None assert node.state == MessageState.COMPLETED assert node.session_id == "sess_abc" assert node.completed_at is not None @@ -224,7 +229,9 @@ class TestMessageTree: restored = MessageTree.from_dict(data) assert restored.root_id == "m1" - assert restored.get_node("m1").session_id == "sess_1" + node = restored.get_node("m1") + assert node is not None + assert node.session_id == "sess_1" class TestTreeQueueManager: diff --git a/tests/test_tree_repository.py b/tests/test_tree_repository.py index 048d8b4..3fc0b8b 100644 --- a/tests/test_tree_repository.py +++ b/tests/test_tree_repository.py @@ -129,8 +129,9 @@ def test_to_from_dict(repository, sample_tree): assert data["node_to_tree"]["root_id"] == "root_id" new_repo = TreeRepository.from_dict(data) - assert new_repo.get_tree("root_id") is not None - assert new_repo.get_tree("root_id").root_id == "root_id" + tree = new_repo.get_tree("root_id") + assert tree is not None + assert tree.root_id == "root_id" assert new_repo._node_to_tree["root_id"] == "root_id"