mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Phase 7: Directory restructuring (messaging/ and tests/)
- Create messaging/platforms/ (base, discord, telegram, factory) - Create messaging/rendering/ (discord_markdown, telegram_markdown) - Create messaging/trees/ (data, repository, processor, queue_manager) - Organize tests/ into api/, providers/, messaging/, cli/, config/ - Add backward-compatible re-exports at old locations - Update handler.py and test_messaging_factory.py imports - Fix Telegram type hints for TELEGRAM_AVAILABLE=False case - Fix Python 3 except syntax in discord_markdown Co-authored-by: Ali Khokhar <alishahryar2@gmail.com>
This commit is contained in:
parent
38a7980546
commit
4b4f87515d
76 changed files with 3294 additions and 3124 deletions
|
|
@ -1,219 +1,9 @@
|
|||
"""Abstract base class for messaging platforms."""
|
||||
"""Backward-compatible re-export. Use messaging.platforms.base for new code."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Awaitable,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
runtime_checkable,
|
||||
AsyncGenerator,
|
||||
Any,
|
||||
Dict,
|
||||
from .platforms.base import (
|
||||
MessagingPlatform,
|
||||
SessionManagerInterface,
|
||||
CLISession,
|
||||
)
|
||||
from .models import IncomingMessage
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CLISession(Protocol):
|
||||
"""Protocol for CLI session - avoid circular import from cli package."""
|
||||
|
||||
def start_task(
|
||||
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
|
||||
) -> AsyncGenerator[Dict, Any]:
|
||||
"""Start a task in the CLI session."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if session is busy."""
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SessionManagerInterface(Protocol):
|
||||
"""
|
||||
Protocol for session managers to avoid tight coupling with cli package.
|
||||
|
||||
Implementations: CLISessionManager
|
||||
"""
|
||||
|
||||
async def get_or_create_session(
|
||||
self, session_id: Optional[str] = None
|
||||
) -> Tuple[CLISession, str, bool]:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Returns: Tuple of (session, session_id, is_new_session)
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_real_session_id(
|
||||
self, temp_id: str, real_session_id: str
|
||||
) -> bool:
|
||||
"""Register the real session ID from CLI output."""
|
||||
...
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all sessions."""
|
||||
...
|
||||
|
||||
async def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session from the manager."""
|
||||
...
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get session statistics."""
|
||||
...
|
||||
|
||||
|
||||
class MessagingPlatform(ABC):
|
||||
"""
|
||||
Base class for all messaging platform adapters.
|
||||
|
||||
Implement this to add support for Telegram, Discord, Slack, etc.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to the messaging platform."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Disconnect and cleanup resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
text: Message content
|
||||
reply_to: Optional message ID to reply to
|
||||
parse_mode: Optional formatting mode ("markdown", "html")
|
||||
|
||||
Returns:
|
||||
The message ID of the sent message
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Edit an existing message.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID
|
||||
message_id: The message ID to edit
|
||||
text: New message content
|
||||
parse_mode: Optional formatting mode
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a message from a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID
|
||||
message_id: The message ID to delete
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Enqueue a message to be sent.
|
||||
|
||||
If fire_and_forget is True, returns None immediately.
|
||||
Otherwise, waits for the rate limiter and returns message ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue a message edit.
|
||||
|
||||
If fire_and_forget is True, returns immediately.
|
||||
Otherwise, waits for the rate limiter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue a message deletion.
|
||||
|
||||
If fire_and_forget is True, returns immediately.
|
||||
Otherwise, waits for the rate limiter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""
|
||||
Register a message handler callback.
|
||||
|
||||
The handler will be called for each incoming message.
|
||||
|
||||
Args:
|
||||
handler: Async function that processes incoming messages
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fire_and_forget(self, task: Awaitable[Any]) -> None:
|
||||
"""Execute a coroutine without awaiting it."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the platform is connected."""
|
||||
return False
|
||||
__all__ = ["MessagingPlatform", "SessionManagerInterface", "CLISession"]
|
||||
|
|
|
|||
|
|
@ -1,394 +1,9 @@
|
|||
"""
|
||||
Discord Platform Adapter
|
||||
"""Backward-compatible re-export. Use messaging.platforms.discord for new code."""
|
||||
|
||||
Implements MessagingPlatform for Discord using discord.py.
|
||||
"""
|
||||
from .platforms.discord import (
|
||||
DiscordPlatform,
|
||||
DISCORD_AVAILABLE,
|
||||
DISCORD_MESSAGE_LIMIT,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Callable, Awaitable, Optional, Any, Set, cast
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from .models import IncomingMessage
|
||||
from .discord_markdown import format_status_discord
|
||||
|
||||
_discord_module: Any = None
|
||||
try:
|
||||
import discord as _discord_import
|
||||
|
||||
_discord_module = _discord_import
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
|
||||
DISCORD_MESSAGE_LIMIT = 2000
|
||||
|
||||
|
||||
def _get_discord() -> Any:
|
||||
"""Return the discord module. Raises if not available."""
|
||||
if not DISCORD_AVAILABLE or _discord_module is None:
|
||||
raise ImportError(
|
||||
"discord.py is required. Install with: pip install discord.py"
|
||||
)
|
||||
return _discord_module
|
||||
|
||||
|
||||
def _parse_allowed_channels(raw: Optional[str]) -> Set[str]:
|
||||
"""Parse comma-separated channel IDs into a set of strings."""
|
||||
if not raw or not raw.strip():
|
||||
return set()
|
||||
return {s.strip() for s in raw.split(",") if s.strip()}
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE and _discord_module is not None:
|
||||
_discord = _discord_module
|
||||
|
||||
class _DiscordClient(_discord.Client):
|
||||
"""Internal Discord client that forwards events to DiscordPlatform."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
platform: "DiscordPlatform",
|
||||
intents: _discord.Intents,
|
||||
) -> None:
|
||||
super().__init__(intents=intents)
|
||||
self._platform = platform
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Called when the bot is ready."""
|
||||
self._platform._connected = True
|
||||
logger.info("Discord platform connected")
|
||||
|
||||
async def on_message(self, message: Any) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
await self._platform._on_discord_message(message)
|
||||
else:
|
||||
_DiscordClient = None
|
||||
|
||||
|
||||
class DiscordPlatform(MessagingPlatform):
|
||||
"""
|
||||
Discord messaging platform adapter.
|
||||
|
||||
Uses discord.py for Discord access.
|
||||
Requires a Bot Token from Discord Developer Portal and message_content intent.
|
||||
"""
|
||||
|
||||
name = "discord"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: Optional[str] = None,
|
||||
allowed_channel_ids: Optional[str] = None,
|
||||
):
|
||||
if not DISCORD_AVAILABLE:
|
||||
raise ImportError(
|
||||
"discord.py is required. Install with: pip install discord.py"
|
||||
)
|
||||
|
||||
self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
|
||||
raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
|
||||
self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
|
||||
|
||||
if not self.bot_token:
|
||||
logger.warning("DISCORD_BOT_TOKEN not set")
|
||||
|
||||
discord = _get_discord()
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
assert _DiscordClient is not None
|
||||
self._client = _DiscordClient(self, intents)
|
||||
self._message_handler: Optional[
|
||||
Callable[[IncomingMessage], Awaitable[None]]
|
||||
] = None
|
||||
self._connected = False
|
||||
self._limiter: Optional[Any] = None
|
||||
self._start_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def _on_discord_message(self, message: Any) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
if message.author.bot:
|
||||
return
|
||||
if not message.content:
|
||||
return
|
||||
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
if not self.allowed_channel_ids or channel_id not in self.allowed_channel_ids:
|
||||
return
|
||||
|
||||
user_id = str(message.author.id)
|
||||
message_id = str(message.id)
|
||||
reply_to = (
|
||||
str(message.reference.message_id)
|
||||
if message.reference and message.reference.message_id
|
||||
else None
|
||||
)
|
||||
|
||||
text_preview = (message.content or "")[:80]
|
||||
if len(message.content or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"DISCORD_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
channel_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
text_preview,
|
||||
)
|
||||
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
incoming = IncomingMessage(
|
||||
text=message.content,
|
||||
chat_id=channel_id,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
platform="discord",
|
||||
reply_to_message_id=reply_to,
|
||||
username=message.author.display_name,
|
||||
raw_event=message,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
try:
|
||||
await self.send_message(
|
||||
channel_id,
|
||||
format_status_discord("Error:", str(e)[:200]),
|
||||
reply_to=message_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _truncate(self, text: str, limit: int = DISCORD_MESSAGE_LIMIT) -> str:
|
||||
"""Truncate text to Discord's message limit."""
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[: limit - 3] + "..."
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to Discord."""
|
||||
if not self.bot_token:
|
||||
raise ValueError("DISCORD_BOT_TOKEN is required")
|
||||
|
||||
from .limiter import MessagingRateLimiter
|
||||
|
||||
self._limiter = await MessagingRateLimiter.get_instance()
|
||||
|
||||
self._start_task = asyncio.create_task(
|
||||
self._client.start(self.bot_token),
|
||||
name="discord-client-start",
|
||||
)
|
||||
|
||||
max_wait = 30
|
||||
waited = 0
|
||||
while not self._connected and waited < max_wait:
|
||||
await asyncio.sleep(0.5)
|
||||
waited += 0.5
|
||||
|
||||
if not self._connected:
|
||||
raise RuntimeError("Discord client failed to connect within timeout")
|
||||
|
||||
logger.info("Discord platform started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the bot."""
|
||||
if self._client.is_closed():
|
||||
self._connected = False
|
||||
return
|
||||
|
||||
await self._client.close()
|
||||
if self._start_task and not self._start_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._start_task, timeout=5.0)
|
||||
except asyncio.TimeoutError, asyncio.CancelledError:
|
||||
self._start_task.cancel()
|
||||
try:
|
||||
await self._start_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._connected = False
|
||||
logger.info("Discord platform stopped")
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Send a message to a channel."""
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel or not hasattr(channel, "send"):
|
||||
raise RuntimeError(f"Channel {chat_id} not found")
|
||||
|
||||
text = self._truncate(text)
|
||||
channel = cast(Any, channel)
|
||||
|
||||
discord = _get_discord()
|
||||
if reply_to:
|
||||
ref = discord.MessageReference(
|
||||
message_id=int(reply_to),
|
||||
channel_id=int(chat_id),
|
||||
)
|
||||
msg = await channel.send(content=text, reference=ref)
|
||||
else:
|
||||
msg = await channel.send(content=text)
|
||||
|
||||
return str(msg.id)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Edit an existing message."""
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel or not hasattr(channel, "fetch_message"):
|
||||
raise RuntimeError(f"Channel {chat_id} not found")
|
||||
|
||||
discord = _get_discord()
|
||||
channel = cast(Any, channel)
|
||||
try:
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
except discord.NotFound:
|
||||
return
|
||||
|
||||
text = self._truncate(text)
|
||||
await msg.edit(content=text)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""Delete a message from a channel."""
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel or not hasattr(channel, "fetch_message"):
|
||||
return
|
||||
|
||||
discord = _get_discord()
|
||||
channel = cast(Any, channel)
|
||||
try:
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
await msg.delete()
|
||||
except discord.NotFound, discord.Forbidden:
|
||||
pass
|
||||
|
||||
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
|
||||
"""Delete multiple messages (best-effort)."""
|
||||
for mid in message_ids:
|
||||
await self.delete_message(chat_id, mid)
|
||||
|
||||
async def queue_send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""Enqueue a message to be sent."""
|
||||
if not self._limiter:
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
async def _send():
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_send)
|
||||
return None
|
||||
return await self._limiter.enqueue(_send)
|
||||
|
||||
async def queue_edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message edit."""
|
||||
if not self._limiter:
|
||||
await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
return
|
||||
|
||||
async def _edit():
|
||||
await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
|
||||
dedup_key = f"edit:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message delete."""
|
||||
if not self._limiter:
|
||||
await self.delete_message(chat_id, message_id)
|
||||
return
|
||||
|
||||
async def _delete():
|
||||
await self.delete_message(chat_id, message_id)
|
||||
|
||||
dedup_key = f"del:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_messages(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_ids: list[str],
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a bulk delete."""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
if not self._limiter:
|
||||
await self.delete_messages(chat_id, message_ids)
|
||||
return
|
||||
|
||||
async def _bulk():
|
||||
await self.delete_messages(chat_id, message_ids)
|
||||
|
||||
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
|
||||
|
||||
def fire_and_forget(self, task: Awaitable[Any]) -> None:
|
||||
"""Execute a coroutine without awaiting it."""
|
||||
if asyncio.iscoroutine(task):
|
||||
asyncio.create_task(task)
|
||||
else:
|
||||
asyncio.ensure_future(task)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Register a message handler callback."""
|
||||
self._message_handler = handler
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected."""
|
||||
return self._connected
|
||||
__all__ = ["DiscordPlatform", "DISCORD_AVAILABLE", "DISCORD_MESSAGE_LIMIT"]
|
||||
|
|
|
|||
|
|
@ -1,367 +1,14 @@
|
|||
"""Discord markdown utilities.
|
||||
|
||||
Discord uses standard markdown: **bold**, *italic*, `code`, ```code block```.
|
||||
Used by the message handler and Discord platform adapter.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
|
||||
# Discord escapes: \ * _ ` ~ | >
|
||||
DISCORD_SPECIAL = set("\\*_`~|>")
|
||||
|
||||
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
|
||||
_MD.enable("strikethrough")
|
||||
_MD.enable("table")
|
||||
|
||||
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
|
||||
_FENCE_RE = re.compile(r"^\s*```")
|
||||
|
||||
|
||||
def _is_gfm_table_header_line(line: str) -> bool:
|
||||
"""Check if line is a GFM table header."""
|
||||
if "|" not in line:
|
||||
return False
|
||||
if _TABLE_SEP_RE.match(line):
|
||||
return False
|
||||
stripped = line.strip()
|
||||
parts = [p.strip() for p in stripped.strip("|").split("|")]
|
||||
parts = [p for p in parts if p != ""]
|
||||
return len(parts) >= 2
|
||||
|
||||
|
||||
def _normalize_gfm_tables(text: str) -> str:
|
||||
"""Insert blank line before detected tables outside code blocks."""
|
||||
lines = text.splitlines()
|
||||
if len(lines) < 2:
|
||||
return text
|
||||
|
||||
out_lines: List[str] = []
|
||||
in_fence = False
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
if _FENCE_RE.match(line):
|
||||
in_fence = not in_fence
|
||||
out_lines.append(line)
|
||||
continue
|
||||
|
||||
if (
|
||||
not in_fence
|
||||
and idx + 1 < len(lines)
|
||||
and _is_gfm_table_header_line(line)
|
||||
and _TABLE_SEP_RE.match(lines[idx + 1])
|
||||
):
|
||||
if out_lines and out_lines[-1].strip() != "":
|
||||
m = re.match(r"^(\s*)", line)
|
||||
indent = m.group(1) if m else ""
|
||||
out_lines.append(indent)
|
||||
|
||||
out_lines.append(line)
|
||||
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
def escape_discord(text: str) -> str:
|
||||
"""Escape text for Discord markdown (bold, italic, etc.)."""
|
||||
return "".join(f"\\{ch}" if ch in DISCORD_SPECIAL else ch for ch in text)
|
||||
|
||||
|
||||
def escape_discord_code(text: str) -> str:
|
||||
"""Escape text for Discord code spans/blocks."""
|
||||
return text.replace("\\", "\\\\").replace("`", "\\`")
|
||||
|
||||
|
||||
def discord_bold(text: str) -> str:
|
||||
"""Format text as bold in Discord (uses **)."""
|
||||
return f"**{escape_discord(text)}**"
|
||||
|
||||
|
||||
def discord_code_inline(text: str) -> str:
|
||||
"""Format text as inline code in Discord."""
|
||||
return f"`{escape_discord_code(text)}`"
|
||||
|
||||
|
||||
def format_status_discord(label: str, suffix: Optional[str] = None) -> str:
|
||||
"""Format a status message for Discord (label in bold, optional suffix)."""
|
||||
base = discord_bold(label)
|
||||
if suffix:
|
||||
return f"{base} {escape_discord(suffix)}"
|
||||
return base
|
||||
|
||||
|
||||
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
|
||||
"""Format a status message with emoji for Discord (matches Telegram API)."""
|
||||
base = f"{emoji} {discord_bold(label)}"
|
||||
if suffix:
|
||||
return f"{base} {escape_discord(suffix)}"
|
||||
return base
|
||||
|
||||
|
||||
def render_markdown_to_discord(text: str) -> str:
|
||||
"""Render common Markdown into Discord-compatible format."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = _normalize_gfm_tables(text)
|
||||
tokens = _MD.parse(text)
|
||||
|
||||
def render_inline_table_plain(children) -> str:
|
||||
out: List[str] = []
|
||||
for tok in children:
|
||||
if tok.type == "text":
|
||||
out.append(tok.content)
|
||||
elif tok.type == "code_inline":
|
||||
out.append(tok.content)
|
||||
elif tok.type in {"softbreak", "hardbreak"}:
|
||||
out.append(" ")
|
||||
elif tok.type == "image":
|
||||
if tok.content:
|
||||
out.append(tok.content)
|
||||
return "".join(out)
|
||||
|
||||
def render_inline(children) -> str:
|
||||
out: List[str] = []
|
||||
i = 0
|
||||
while i < len(children):
|
||||
tok = children[i]
|
||||
t = tok.type
|
||||
if t == "text":
|
||||
out.append(escape_discord(tok.content))
|
||||
elif t in {"softbreak", "hardbreak"}:
|
||||
out.append("\n")
|
||||
elif t == "em_open":
|
||||
out.append("*")
|
||||
elif t == "em_close":
|
||||
out.append("*")
|
||||
elif t == "strong_open":
|
||||
out.append("**")
|
||||
elif t == "strong_close":
|
||||
out.append("**")
|
||||
elif t == "s_open":
|
||||
out.append("~~")
|
||||
elif t == "s_close":
|
||||
out.append("~~")
|
||||
elif t == "code_inline":
|
||||
out.append(f"`{escape_discord_code(tok.content)}`")
|
||||
elif t == "link_open":
|
||||
href = ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("href", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "href":
|
||||
href = val
|
||||
break
|
||||
inner_tokens = []
|
||||
i += 1
|
||||
while i < len(children) and children[i].type != "link_close":
|
||||
inner_tokens.append(children[i])
|
||||
i += 1
|
||||
link_text = ""
|
||||
for child in inner_tokens:
|
||||
if child.type == "text":
|
||||
link_text += child.content
|
||||
elif child.type == "code_inline":
|
||||
link_text += child.content
|
||||
out.append(f"[{escape_discord(link_text)}]({href})")
|
||||
elif t == "image":
|
||||
href = ""
|
||||
alt = tok.content or ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("src", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "src":
|
||||
href = val
|
||||
break
|
||||
if alt:
|
||||
out.append(f"{escape_discord(alt)} ({href})")
|
||||
else:
|
||||
out.append(href)
|
||||
else:
|
||||
out.append(escape_discord(tok.content or ""))
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
out: List[str] = []
|
||||
list_stack: List[dict] = []
|
||||
pending_prefix: Optional[str] = None
|
||||
blockquote_level = 0
|
||||
in_heading = False
|
||||
|
||||
def apply_blockquote(val: str) -> str:
|
||||
if blockquote_level <= 0:
|
||||
return val
|
||||
prefix = "> " * blockquote_level
|
||||
return prefix + val.replace("\n", "\n" + prefix)
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
t = tok.type
|
||||
if t == "paragraph_open":
|
||||
pass
|
||||
elif t == "paragraph_close":
|
||||
out.append("\n")
|
||||
elif t == "heading_open":
|
||||
in_heading = True
|
||||
elif t == "heading_close":
|
||||
in_heading = False
|
||||
out.append("\n")
|
||||
elif t == "bullet_list_open":
|
||||
list_stack.append({"type": "bullet", "index": 1})
|
||||
elif t == "bullet_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "ordered_list_open":
|
||||
start = 1
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
val = tok.attrs.get("start")
|
||||
if val is not None:
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "start":
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
break
|
||||
list_stack.append({"type": "ordered", "index": start})
|
||||
elif t == "ordered_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "list_item_open":
|
||||
if list_stack:
|
||||
top = list_stack[-1]
|
||||
if top["type"] == "bullet":
|
||||
pending_prefix = "- "
|
||||
else:
|
||||
pending_prefix = f"{top['index']}. "
|
||||
top["index"] += 1
|
||||
elif t == "list_item_close":
|
||||
out.append("\n")
|
||||
elif t == "blockquote_open":
|
||||
blockquote_level += 1
|
||||
elif t == "blockquote_close":
|
||||
blockquote_level = max(0, blockquote_level - 1)
|
||||
out.append("\n")
|
||||
elif t == "table_open":
|
||||
if pending_prefix:
|
||||
out.append(apply_blockquote(pending_prefix.rstrip()))
|
||||
out.append("\n")
|
||||
pending_prefix = None
|
||||
|
||||
rows: List[List[str]] = []
|
||||
row_is_header: List[bool] = []
|
||||
|
||||
j = i + 1
|
||||
in_thead = False
|
||||
in_row = False
|
||||
current_row: List[str] = []
|
||||
current_row_header = False
|
||||
|
||||
in_cell = False
|
||||
cell_parts: List[str] = []
|
||||
|
||||
while j < len(tokens):
|
||||
tt = tokens[j].type
|
||||
if tt == "thead_open":
|
||||
in_thead = True
|
||||
elif tt == "thead_close":
|
||||
in_thead = False
|
||||
elif tt == "tr_open":
|
||||
in_row = True
|
||||
current_row = []
|
||||
current_row_header = in_thead
|
||||
elif tt in {"th_open", "td_open"}:
|
||||
in_cell = True
|
||||
cell_parts = []
|
||||
elif tt == "inline" and in_cell:
|
||||
cell_parts.append(
|
||||
render_inline_table_plain(tokens[j].children or [])
|
||||
)
|
||||
elif tt in {"th_close", "td_close"} and in_cell:
|
||||
cell = " ".join(cell_parts).strip()
|
||||
current_row.append(cell)
|
||||
in_cell = False
|
||||
cell_parts = []
|
||||
elif tt == "tr_close" and in_row:
|
||||
rows.append(current_row)
|
||||
row_is_header.append(bool(current_row_header))
|
||||
in_row = False
|
||||
elif tt == "table_close":
|
||||
break
|
||||
j += 1
|
||||
|
||||
if rows:
|
||||
col_count = max((len(r) for r in rows), default=0)
|
||||
norm_rows: List[List[str]] = []
|
||||
for r in rows:
|
||||
if len(r) < col_count:
|
||||
r = r + [""] * (col_count - len(r))
|
||||
norm_rows.append(r)
|
||||
|
||||
widths: List[int] = []
|
||||
for c in range(col_count):
|
||||
w = max((len(r[c]) for r in norm_rows), default=0)
|
||||
widths.append(max(w, 3))
|
||||
|
||||
def fmt_row(r: List[str]) -> str:
|
||||
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
def fmt_sep() -> str:
|
||||
cells = ["-" * widths[c] for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
last_header_idx = -1
|
||||
for idx, is_h in enumerate(row_is_header):
|
||||
if is_h:
|
||||
last_header_idx = idx
|
||||
|
||||
lines: List[str] = []
|
||||
for idx, r in enumerate(norm_rows):
|
||||
lines.append(fmt_row(r))
|
||||
if idx == last_header_idx:
|
||||
lines.append(fmt_sep())
|
||||
|
||||
table_text = "\n".join(lines).rstrip()
|
||||
out.append(f"```\n{escape_discord_code(table_text)}\n```")
|
||||
out.append("\n")
|
||||
|
||||
i = j + 1
|
||||
continue
|
||||
elif t in {"code_block", "fence"}:
|
||||
code = escape_discord_code(tok.content.rstrip("\n"))
|
||||
out.append(f"```\n{code}\n```")
|
||||
out.append("\n")
|
||||
elif t == "inline":
|
||||
rendered = render_inline(tok.children or [])
|
||||
if in_heading:
|
||||
rendered = f"**{render_inline(tok.children or [])}**"
|
||||
if pending_prefix:
|
||||
rendered = pending_prefix + rendered
|
||||
pending_prefix = None
|
||||
rendered = apply_blockquote(rendered)
|
||||
out.append(rendered)
|
||||
else:
|
||||
if tok.content:
|
||||
out.append(escape_discord(tok.content))
|
||||
i += 1
|
||||
|
||||
return "".join(out).rstrip()
|
||||
"""Backward-compatible re-export. Use messaging.rendering.discord_markdown for new code."""
|
||||
|
||||
from .rendering.discord_markdown import (
|
||||
escape_discord,
|
||||
escape_discord_code,
|
||||
discord_bold,
|
||||
discord_code_inline,
|
||||
format_status,
|
||||
format_status_discord,
|
||||
render_markdown_to_discord,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"escape_discord",
|
||||
|
|
|
|||
|
|
@ -1,58 +1,5 @@
|
|||
"""Messaging platform factory.
|
||||
"""Backward-compatible re-export. Use messaging.platforms.factory for new code."""
|
||||
|
||||
Creates the appropriate messaging platform adapter based on configuration.
|
||||
To add a new platform (e.g. Discord, Slack):
|
||||
1. Create a new class implementing MessagingPlatform in messaging/
|
||||
2. Add a case to create_messaging_platform() below
|
||||
"""
|
||||
from .platforms.factory import create_messaging_platform
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .base import MessagingPlatform
|
||||
|
||||
|
||||
def create_messaging_platform(
|
||||
platform_type: str,
|
||||
**kwargs,
|
||||
) -> Optional[MessagingPlatform]:
|
||||
"""Create a messaging platform instance based on type.
|
||||
|
||||
Args:
|
||||
platform_type: Platform identifier ("telegram", "discord", etc.)
|
||||
**kwargs: Platform-specific configuration passed to the constructor.
|
||||
|
||||
Returns:
|
||||
Configured MessagingPlatform instance, or None if not configured.
|
||||
"""
|
||||
if platform_type == "telegram":
|
||||
bot_token = kwargs.get("bot_token")
|
||||
if not bot_token:
|
||||
logger.info("No Telegram bot token configured, skipping platform setup")
|
||||
return None
|
||||
|
||||
from .telegram import TelegramPlatform
|
||||
|
||||
return TelegramPlatform(
|
||||
bot_token=bot_token,
|
||||
allowed_user_id=kwargs.get("allowed_user_id"),
|
||||
)
|
||||
|
||||
if platform_type == "discord":
|
||||
bot_token = kwargs.get("discord_bot_token")
|
||||
if not bot_token:
|
||||
logger.info("No Discord bot token configured, skipping platform setup")
|
||||
return None
|
||||
|
||||
from .discord import DiscordPlatform
|
||||
|
||||
return DiscordPlatform(
|
||||
bot_token=bot_token,
|
||||
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram', 'discord'"
|
||||
)
|
||||
return None
|
||||
__all__ = ["create_messaging_platform"]
|
||||
|
|
|
|||
|
|
@ -11,13 +11,18 @@ import asyncio
|
|||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .base import MessagingPlatform, SessionManagerInterface
|
||||
from .platforms.base import MessagingPlatform, SessionManagerInterface
|
||||
from .models import IncomingMessage
|
||||
from .session import SessionStore
|
||||
from .tree_queue import TreeQueueManager, MessageNode, MessageState, MessageTree
|
||||
from .trees.queue_manager import (
|
||||
TreeQueueManager,
|
||||
MessageNode,
|
||||
MessageState,
|
||||
MessageTree,
|
||||
)
|
||||
from .event_parser import parse_cli_event
|
||||
from .transcript import TranscriptBuffer, RenderCtx
|
||||
from .telegram_markdown import (
|
||||
from .rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
mdv2_bold,
|
||||
|
|
@ -25,7 +30,7 @@ from .telegram_markdown import (
|
|||
format_status as format_status_telegram,
|
||||
render_markdown_to_mdv2,
|
||||
)
|
||||
from .discord_markdown import (
|
||||
from .rendering.discord_markdown import (
|
||||
escape_discord,
|
||||
escape_discord_code,
|
||||
discord_bold,
|
||||
|
|
|
|||
11
messaging/platforms/__init__.py
Normal file
11
messaging/platforms/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Messaging platform adapters (Telegram, Discord, etc.)."""
|
||||
|
||||
from .base import MessagingPlatform, SessionManagerInterface, CLISession
|
||||
from .factory import create_messaging_platform
|
||||
|
||||
__all__ = [
|
||||
"MessagingPlatform",
|
||||
"SessionManagerInterface",
|
||||
"CLISession",
|
||||
"create_messaging_platform",
|
||||
]
|
||||
219
messaging/platforms/base.py
Normal file
219
messaging/platforms/base.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""Abstract base class for messaging platforms."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Awaitable,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
runtime_checkable,
|
||||
AsyncGenerator,
|
||||
Any,
|
||||
Dict,
|
||||
)
|
||||
from ..models import IncomingMessage
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CLISession(Protocol):
|
||||
"""Protocol for CLI session - avoid circular import from cli package."""
|
||||
|
||||
def start_task(
|
||||
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
|
||||
) -> AsyncGenerator[Dict, Any]:
|
||||
"""Start a task in the CLI session."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if session is busy."""
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SessionManagerInterface(Protocol):
|
||||
"""
|
||||
Protocol for session managers to avoid tight coupling with cli package.
|
||||
|
||||
Implementations: CLISessionManager
|
||||
"""
|
||||
|
||||
async def get_or_create_session(
|
||||
self, session_id: Optional[str] = None
|
||||
) -> Tuple[CLISession, str, bool]:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Returns: Tuple of (session, session_id, is_new_session)
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_real_session_id(
|
||||
self, temp_id: str, real_session_id: str
|
||||
) -> bool:
|
||||
"""Register the real session ID from CLI output."""
|
||||
...
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all sessions."""
|
||||
...
|
||||
|
||||
async def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session from the manager."""
|
||||
...
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get session statistics."""
|
||||
...
|
||||
|
||||
|
||||
class MessagingPlatform(ABC):
|
||||
"""
|
||||
Base class for all messaging platform adapters.
|
||||
|
||||
Implement this to add support for Telegram, Discord, Slack, etc.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to the messaging platform."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Disconnect and cleanup resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
text: Message content
|
||||
reply_to: Optional message ID to reply to
|
||||
parse_mode: Optional formatting mode ("markdown", "html")
|
||||
|
||||
Returns:
|
||||
The message ID of the sent message
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Edit an existing message.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID
|
||||
message_id: The message ID to edit
|
||||
text: New message content
|
||||
parse_mode: Optional formatting mode
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a message from a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID
|
||||
message_id: The message ID to delete
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Enqueue a message to be sent.
|
||||
|
||||
If fire_and_forget is True, returns None immediately.
|
||||
Otherwise, waits for the rate limiter and returns message ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue a message edit.
|
||||
|
||||
If fire_and_forget is True, returns immediately.
|
||||
Otherwise, waits for the rate limiter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue a message deletion.
|
||||
|
||||
If fire_and_forget is True, returns immediately.
|
||||
Otherwise, waits for the rate limiter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""
|
||||
Register a message handler callback.
|
||||
|
||||
The handler will be called for each incoming message.
|
||||
|
||||
Args:
|
||||
handler: Async function that processes incoming messages
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fire_and_forget(self, task: Awaitable[Any]) -> None:
|
||||
"""Execute a coroutine without awaiting it."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the platform is connected."""
|
||||
return False
|
||||
394
messaging/platforms/discord.py
Normal file
394
messaging/platforms/discord.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
"""
|
||||
Discord Platform Adapter
|
||||
|
||||
Implements MessagingPlatform for Discord using discord.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Callable, Awaitable, Optional, Any, Set, cast
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from ..models import IncomingMessage
|
||||
from ..rendering.discord_markdown import format_status_discord
|
||||
|
||||
_discord_module: Any = None
|
||||
try:
|
||||
import discord as _discord_import
|
||||
|
||||
_discord_module = _discord_import
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
|
||||
DISCORD_MESSAGE_LIMIT = 2000
|
||||
|
||||
|
||||
def _get_discord() -> Any:
|
||||
"""Return the discord module. Raises if not available."""
|
||||
if not DISCORD_AVAILABLE or _discord_module is None:
|
||||
raise ImportError(
|
||||
"discord.py is required. Install with: pip install discord.py"
|
||||
)
|
||||
return _discord_module
|
||||
|
||||
|
||||
def _parse_allowed_channels(raw: Optional[str]) -> Set[str]:
|
||||
"""Parse comma-separated channel IDs into a set of strings."""
|
||||
if not raw or not raw.strip():
|
||||
return set()
|
||||
return {s.strip() for s in raw.split(",") if s.strip()}
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE and _discord_module is not None:
|
||||
_discord = _discord_module
|
||||
|
||||
class _DiscordClient(_discord.Client):
|
||||
"""Internal Discord client that forwards events to DiscordPlatform."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
platform: "DiscordPlatform",
|
||||
intents: _discord.Intents,
|
||||
) -> None:
|
||||
super().__init__(intents=intents)
|
||||
self._platform = platform
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Called when the bot is ready."""
|
||||
self._platform._connected = True
|
||||
logger.info("Discord platform connected")
|
||||
|
||||
async def on_message(self, message: Any) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
await self._platform._on_discord_message(message)
|
||||
else:
|
||||
_DiscordClient = None
|
||||
|
||||
|
||||
class DiscordPlatform(MessagingPlatform):
|
||||
"""
|
||||
Discord messaging platform adapter.
|
||||
|
||||
Uses discord.py for Discord access.
|
||||
Requires a Bot Token from Discord Developer Portal and message_content intent.
|
||||
"""
|
||||
|
||||
name = "discord"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: Optional[str] = None,
|
||||
allowed_channel_ids: Optional[str] = None,
|
||||
):
|
||||
if not DISCORD_AVAILABLE:
|
||||
raise ImportError(
|
||||
"discord.py is required. Install with: pip install discord.py"
|
||||
)
|
||||
|
||||
self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
|
||||
raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
|
||||
self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
|
||||
|
||||
if not self.bot_token:
|
||||
logger.warning("DISCORD_BOT_TOKEN not set")
|
||||
|
||||
discord = _get_discord()
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
assert _DiscordClient is not None
|
||||
self._client = _DiscordClient(self, intents)
|
||||
self._message_handler: Optional[
|
||||
Callable[[IncomingMessage], Awaitable[None]]
|
||||
] = None
|
||||
self._connected = False
|
||||
self._limiter: Optional[Any] = None
|
||||
self._start_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def _on_discord_message(self, message: Any) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
if message.author.bot:
|
||||
return
|
||||
if not message.content:
|
||||
return
|
||||
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
if not self.allowed_channel_ids or channel_id not in self.allowed_channel_ids:
|
||||
return
|
||||
|
||||
user_id = str(message.author.id)
|
||||
message_id = str(message.id)
|
||||
reply_to = (
|
||||
str(message.reference.message_id)
|
||||
if message.reference and message.reference.message_id
|
||||
else None
|
||||
)
|
||||
|
||||
text_preview = (message.content or "")[:80]
|
||||
if len(message.content or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"DISCORD_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
channel_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
text_preview,
|
||||
)
|
||||
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
incoming = IncomingMessage(
|
||||
text=message.content,
|
||||
chat_id=channel_id,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
platform="discord",
|
||||
reply_to_message_id=reply_to,
|
||||
username=message.author.display_name,
|
||||
raw_event=message,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
try:
|
||||
await self.send_message(
|
||||
channel_id,
|
||||
format_status_discord("Error:", str(e)[:200]),
|
||||
reply_to=message_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _truncate(self, text: str, limit: int = DISCORD_MESSAGE_LIMIT) -> str:
|
||||
"""Truncate text to Discord's message limit."""
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[: limit - 3] + "..."
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to Discord."""
|
||||
if not self.bot_token:
|
||||
raise ValueError("DISCORD_BOT_TOKEN is required")
|
||||
|
||||
from ..limiter import MessagingRateLimiter
|
||||
|
||||
self._limiter = await MessagingRateLimiter.get_instance()
|
||||
|
||||
self._start_task = asyncio.create_task(
|
||||
self._client.start(self.bot_token),
|
||||
name="discord-client-start",
|
||||
)
|
||||
|
||||
max_wait = 30
|
||||
waited = 0
|
||||
while not self._connected and waited < max_wait:
|
||||
await asyncio.sleep(0.5)
|
||||
waited += 0.5
|
||||
|
||||
if not self._connected:
|
||||
raise RuntimeError("Discord client failed to connect within timeout")
|
||||
|
||||
logger.info("Discord platform started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the bot."""
|
||||
if self._client.is_closed():
|
||||
self._connected = False
|
||||
return
|
||||
|
||||
await self._client.close()
|
||||
if self._start_task and not self._start_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._start_task, timeout=5.0)
|
||||
except asyncio.TimeoutError, asyncio.CancelledError:
|
||||
self._start_task.cancel()
|
||||
try:
|
||||
await self._start_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._connected = False
|
||||
logger.info("Discord platform stopped")
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Send a message to a channel."""
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel or not hasattr(channel, "send"):
|
||||
raise RuntimeError(f"Channel {chat_id} not found")
|
||||
|
||||
text = self._truncate(text)
|
||||
channel = cast(Any, channel)
|
||||
|
||||
discord = _get_discord()
|
||||
if reply_to:
|
||||
ref = discord.MessageReference(
|
||||
message_id=int(reply_to),
|
||||
channel_id=int(chat_id),
|
||||
)
|
||||
msg = await channel.send(content=text, reference=ref)
|
||||
else:
|
||||
msg = await channel.send(content=text)
|
||||
|
||||
return str(msg.id)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Edit an existing message."""
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel or not hasattr(channel, "fetch_message"):
|
||||
raise RuntimeError(f"Channel {chat_id} not found")
|
||||
|
||||
discord = _get_discord()
|
||||
channel = cast(Any, channel)
|
||||
try:
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
except discord.NotFound:
|
||||
return
|
||||
|
||||
text = self._truncate(text)
|
||||
await msg.edit(content=text)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""Delete a message from a channel."""
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel or not hasattr(channel, "fetch_message"):
|
||||
return
|
||||
|
||||
discord = _get_discord()
|
||||
channel = cast(Any, channel)
|
||||
try:
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
await msg.delete()
|
||||
except discord.NotFound, discord.Forbidden:
|
||||
pass
|
||||
|
||||
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
|
||||
"""Delete multiple messages (best-effort)."""
|
||||
for mid in message_ids:
|
||||
await self.delete_message(chat_id, mid)
|
||||
|
||||
async def queue_send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""Enqueue a message to be sent."""
|
||||
if not self._limiter:
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
async def _send():
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_send)
|
||||
return None
|
||||
return await self._limiter.enqueue(_send)
|
||||
|
||||
async def queue_edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message edit."""
|
||||
if not self._limiter:
|
||||
await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
return
|
||||
|
||||
async def _edit():
|
||||
await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
|
||||
dedup_key = f"edit:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message delete."""
|
||||
if not self._limiter:
|
||||
await self.delete_message(chat_id, message_id)
|
||||
return
|
||||
|
||||
async def _delete():
|
||||
await self.delete_message(chat_id, message_id)
|
||||
|
||||
dedup_key = f"del:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_messages(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_ids: list[str],
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a bulk delete."""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
if not self._limiter:
|
||||
await self.delete_messages(chat_id, message_ids)
|
||||
return
|
||||
|
||||
async def _bulk():
|
||||
await self.delete_messages(chat_id, message_ids)
|
||||
|
||||
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
|
||||
|
||||
def fire_and_forget(self, task: Awaitable[Any]) -> None:
|
||||
"""Execute a coroutine without awaiting it."""
|
||||
if asyncio.iscoroutine(task):
|
||||
asyncio.create_task(task)
|
||||
else:
|
||||
asyncio.ensure_future(task)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Register a message handler callback."""
|
||||
self._message_handler = handler
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected."""
|
||||
return self._connected
|
||||
58
messaging/platforms/factory.py
Normal file
58
messaging/platforms/factory.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Messaging platform factory.
|
||||
|
||||
Creates the appropriate messaging platform adapter based on configuration.
|
||||
To add a new platform (e.g. Discord, Slack):
|
||||
1. Create a new class implementing MessagingPlatform in messaging/platforms/
|
||||
2. Add a case to create_messaging_platform() below
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .base import MessagingPlatform
|
||||
|
||||
|
||||
def create_messaging_platform(
|
||||
platform_type: str,
|
||||
**kwargs,
|
||||
) -> Optional[MessagingPlatform]:
|
||||
"""Create a messaging platform instance based on type.
|
||||
|
||||
Args:
|
||||
platform_type: Platform identifier ("telegram", "discord", etc.)
|
||||
**kwargs: Platform-specific configuration passed to the constructor.
|
||||
|
||||
Returns:
|
||||
Configured MessagingPlatform instance, or None if not configured.
|
||||
"""
|
||||
if platform_type == "telegram":
|
||||
bot_token = kwargs.get("bot_token")
|
||||
if not bot_token:
|
||||
logger.info("No Telegram bot token configured, skipping platform setup")
|
||||
return None
|
||||
|
||||
from .telegram import TelegramPlatform
|
||||
|
||||
return TelegramPlatform(
|
||||
bot_token=bot_token,
|
||||
allowed_user_id=kwargs.get("allowed_user_id"),
|
||||
)
|
||||
|
||||
if platform_type == "discord":
|
||||
bot_token = kwargs.get("discord_bot_token")
|
||||
if not bot_token:
|
||||
logger.info("No Discord bot token configured, skipping platform setup")
|
||||
return None
|
||||
|
||||
from .discord import DiscordPlatform
|
||||
|
||||
return DiscordPlatform(
|
||||
bot_token=bot_token,
|
||||
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram', 'discord'"
|
||||
)
|
||||
return None
|
||||
491
messaging/platforms/telegram.py
Normal file
491
messaging/platforms/telegram.py
Normal file
|
|
@ -0,0 +1,491 @@
|
|||
"""
|
||||
Telegram Platform Adapter
|
||||
|
||||
Implements MessagingPlatform for Telegram using python-telegram-bot.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
# Opt-in to future behavior for python-telegram-bot (retry_after as timedelta)
|
||||
# This must be set BEFORE importing telegram.error
|
||||
os.environ["PTB_TIMEDELTA"] = "1"
|
||||
|
||||
from typing import Callable, Awaitable, Optional, Any, TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from ..models import IncomingMessage
|
||||
from ..rendering.telegram_markdown import escape_md_v2
|
||||
|
||||
|
||||
# Optional import - python-telegram-bot may not be installed
|
||||
try:
|
||||
from telegram import Update
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
MessageHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.error import TelegramError, RetryAfter, NetworkError
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
|
||||
|
||||
class TelegramPlatform(MessagingPlatform):
|
||||
"""
|
||||
Telegram messaging platform adapter.
|
||||
|
||||
Uses python-telegram-bot (BoT API) for Telegram access.
|
||||
Requires a Bot Token from @BotFather.
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: Optional[str] = None,
|
||||
allowed_user_id: Optional[str] = None,
|
||||
):
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
|
||||
)
|
||||
|
||||
self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
|
||||
|
||||
if not self.bot_token:
|
||||
# We don't raise here to allow instantiation for testing/conditional logic,
|
||||
# but start() will fail.
|
||||
logger.warning("TELEGRAM_BOT_TOKEN not set")
|
||||
|
||||
self._application: Optional[Application] = None
|
||||
self._message_handler: Optional[
|
||||
Callable[[IncomingMessage], Awaitable[None]]
|
||||
] = None
|
||||
self._connected = False
|
||||
self._limiter: Optional[Any] = None # Will be MessagingRateLimiter
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to Telegram."""
|
||||
if not self.bot_token:
|
||||
raise ValueError("TELEGRAM_BOT_TOKEN is required")
|
||||
|
||||
# Configure request with longer timeouts
|
||||
request = HTTPXRequest(
|
||||
connection_pool_size=8, connect_timeout=30.0, read_timeout=30.0
|
||||
)
|
||||
|
||||
# Build Application
|
||||
builder = Application.builder().token(self.bot_token).request(request)
|
||||
self._application = builder.build()
|
||||
|
||||
# Register Internal Handlers
|
||||
# We catch ALL text messages and commands to forward them
|
||||
self._application.add_handler(
|
||||
MessageHandler(filters.TEXT & (~filters.COMMAND), self._on_telegram_message)
|
||||
)
|
||||
self._application.add_handler(CommandHandler("start", self._on_start_command))
|
||||
# Catch-all for other commands if needed, or let them fall through
|
||||
self._application.add_handler(
|
||||
MessageHandler(filters.COMMAND, self._on_telegram_message)
|
||||
)
|
||||
|
||||
# Initialize internal components with retry logic
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self._application.initialize()
|
||||
await self._application.start()
|
||||
|
||||
# Start polling (non-blocking way for integration)
|
||||
if self._application.updater:
|
||||
await self._application.updater.start_polling(
|
||||
drop_pending_updates=False
|
||||
)
|
||||
|
||||
self._connected = True
|
||||
break
|
||||
except (NetworkError, Exception) as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 * (attempt + 1)
|
||||
logger.warning(
|
||||
f"Connection failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Failed to connect after {max_retries} attempts")
|
||||
raise
|
||||
|
||||
# Initialize rate limiter
|
||||
from ..limiter import MessagingRateLimiter
|
||||
|
||||
self._limiter = await MessagingRateLimiter.get_instance()
|
||||
|
||||
# Send startup notification
|
||||
try:
|
||||
target = self.allowed_user_id
|
||||
if target:
|
||||
startup_text = (
|
||||
f"🚀 *{escape_md_v2('Claude Code Proxy is online!')}* "
|
||||
f"{escape_md_v2('(Bot API)')}"
|
||||
)
|
||||
await self.send_message(
|
||||
target,
|
||||
startup_text,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not send startup message: {e}")
|
||||
|
||||
logger.info("Telegram platform started (Bot API)")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the bot."""
|
||||
if self._application and self._application.updater:
|
||||
await self._application.updater.stop()
|
||||
await self._application.stop()
|
||||
await self._application.shutdown()
|
||||
|
||||
self._connected = False
|
||||
logger.info("Telegram platform stopped")
|
||||
|
||||
async def _with_retry(
|
||||
self, func: Callable[..., Awaitable[Any]], *args, **kwargs
|
||||
) -> Any:
|
||||
"""Helper to execute a function with exponential backoff on network errors."""
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except (NetworkError, asyncio.TimeoutError) as e:
|
||||
if "Message is not modified" in str(e):
|
||||
return None
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt # 1s, 2s, 4s
|
||||
logger.warning(
|
||||
f"Telegram API network error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
f"Telegram API failed after {max_retries} attempts: {e}"
|
||||
)
|
||||
raise
|
||||
except RetryAfter as e:
|
||||
# Telegram explicitly tells us to wait (PTB_TIMEDELTA: retry_after is timedelta)
|
||||
from datetime import timedelta
|
||||
|
||||
retry_after = e.retry_after
|
||||
if isinstance(retry_after, timedelta):
|
||||
wait_secs = retry_after.total_seconds()
|
||||
else:
|
||||
wait_secs = float(retry_after)
|
||||
|
||||
logger.warning(f"Rate limited by Telegram, waiting {wait_secs}s...")
|
||||
await asyncio.sleep(wait_secs)
|
||||
# We don't increment attempt here, as this is a specific instruction
|
||||
return await func(*args, **kwargs)
|
||||
except TelegramError as e:
|
||||
# Non-network Telegram errors
|
||||
err_lower = str(e).lower()
|
||||
if "message is not modified" in err_lower:
|
||||
return None
|
||||
# Best-effort no-op cases (common during chat cleanup / /clear).
|
||||
if any(
|
||||
x in err_lower
|
||||
for x in [
|
||||
"message to edit not found",
|
||||
"message to delete not found",
|
||||
"message can't be deleted",
|
||||
"message can't be edited",
|
||||
"not enough rights to delete",
|
||||
]
|
||||
):
|
||||
return None
|
||||
if "Can't parse entities" in str(e) and kwargs.get("parse_mode"):
|
||||
logger.warning("Markdown failed, retrying without parse_mode")
|
||||
kwargs["parse_mode"] = None
|
||||
return await func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
) -> str:
|
||||
"""Send a message to a chat."""
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
async def _do_send(parse_mode=parse_mode):
|
||||
bot = app.bot
|
||||
msg = await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
return str(msg.message_id)
|
||||
|
||||
return await self._with_retry(_do_send, parse_mode=parse_mode)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
) -> None:
|
||||
"""Edit an existing message."""
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
async def _do_edit(parse_mode=parse_mode):
|
||||
bot = app.bot
|
||||
await bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=int(message_id),
|
||||
text=text,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
|
||||
await self._with_retry(_do_edit, parse_mode=parse_mode)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""Delete a message from a chat."""
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
async def _do_delete():
|
||||
bot = app.bot
|
||||
await bot.delete_message(chat_id=chat_id, message_id=int(message_id))
|
||||
|
||||
await self._with_retry(_do_delete)
|
||||
|
||||
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
|
||||
"""Delete multiple messages (best-effort)."""
|
||||
if not message_ids:
|
||||
return
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
# PTB supports bulk deletion via delete_messages; fall back to per-message.
|
||||
bot = app.bot
|
||||
if hasattr(bot, "delete_messages"):
|
||||
|
||||
async def _do_bulk():
|
||||
mids = []
|
||||
for mid in message_ids:
|
||||
try:
|
||||
mids.append(int(mid))
|
||||
except Exception:
|
||||
continue
|
||||
if not mids:
|
||||
return None
|
||||
# delete_messages accepts a sequence of ints (up to 100).
|
||||
await bot.delete_messages(chat_id=chat_id, message_ids=mids)
|
||||
|
||||
await self._with_retry(_do_bulk)
|
||||
return
|
||||
|
||||
for mid in message_ids:
|
||||
await self.delete_message(chat_id, mid)
|
||||
|
||||
async def queue_send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
fire_and_forget: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""Enqueue a message to be sent (using limiter)."""
|
||||
# Note: Bot API handles limits better, but we still use our limiter for nice queuing
|
||||
if not self._limiter:
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
async def _send():
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_send)
|
||||
return None
|
||||
else:
|
||||
return await self._limiter.enqueue(_send)
|
||||
|
||||
async def queue_edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message edit."""
|
||||
if not self._limiter:
|
||||
return await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
|
||||
async def _edit():
|
||||
return await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
|
||||
dedup_key = f"edit:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message delete."""
|
||||
if not self._limiter:
|
||||
return await self.delete_message(chat_id, message_id)
|
||||
|
||||
async def _delete():
|
||||
return await self.delete_message(chat_id, message_id)
|
||||
|
||||
dedup_key = f"del:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_messages(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_ids: list[str],
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a bulk delete (if supported) or a sequence of deletes."""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
if not self._limiter:
|
||||
return await self.delete_messages(chat_id, message_ids)
|
||||
|
||||
async def _bulk():
|
||||
return await self.delete_messages(chat_id, message_ids)
|
||||
|
||||
# Dedup by the chunk content; okay to be coarse here.
|
||||
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
|
||||
|
||||
def fire_and_forget(self, task: Awaitable[Any]) -> None:
|
||||
"""Execute a coroutine without awaiting it."""
|
||||
if asyncio.iscoroutine(task):
|
||||
asyncio.create_task(task)
|
||||
else:
|
||||
asyncio.ensure_future(task)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Register a message handler callback."""
|
||||
self._message_handler = handler
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected."""
|
||||
return self._connected
|
||||
|
||||
async def _on_start_command(
|
||||
self, update: "Update", context: "ContextTypes.DEFAULT_TYPE"
|
||||
) -> None:
|
||||
"""Handle /start command."""
|
||||
if update.message:
|
||||
await update.message.reply_text("👋 Hello! I am the Claude Code Proxy Bot.")
|
||||
# We can also treat this as a message if we want it to trigger something
|
||||
await self._on_telegram_message(update, context)
|
||||
|
||||
async def _on_telegram_message(
|
||||
self, update: "Update", context: "ContextTypes.DEFAULT_TYPE"
|
||||
) -> None:
|
||||
"""Handle incoming updates."""
|
||||
if (
|
||||
not update.message
|
||||
or not update.message.text
|
||||
or not update.effective_user
|
||||
or not update.effective_chat
|
||||
):
|
||||
return
|
||||
|
||||
user_id = str(update.effective_user.id)
|
||||
chat_id = str(update.effective_chat.id)
|
||||
|
||||
# Security check
|
||||
if self.allowed_user_id:
|
||||
if user_id != str(self.allowed_user_id).strip():
|
||||
logger.warning(f"Unauthorized access attempt from {user_id}")
|
||||
return
|
||||
|
||||
message_id = str(update.message.message_id)
|
||||
reply_to = (
|
||||
str(update.message.reply_to_message.message_id)
|
||||
if update.message.reply_to_message
|
||||
else None
|
||||
)
|
||||
text_preview = (update.message.text or "")[:80]
|
||||
if len(update.message.text or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"TELEGRAM_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
chat_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
text_preview,
|
||||
)
|
||||
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
incoming = IncomingMessage(
|
||||
text=update.message.text,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
platform="telegram",
|
||||
reply_to_message_id=reply_to,
|
||||
raw_event=update,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
try:
|
||||
await self.send_message(
|
||||
chat_id,
|
||||
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(str(e)[:200])}",
|
||||
reply_to=incoming.message_id,
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
37
messaging/rendering/__init__.py
Normal file
37
messaging/rendering/__init__.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Markdown rendering utilities for messaging platforms."""
|
||||
|
||||
from .discord_markdown import (
|
||||
escape_discord,
|
||||
escape_discord_code,
|
||||
discord_bold,
|
||||
discord_code_inline,
|
||||
format_status as format_status_discord_fn,
|
||||
format_status_discord,
|
||||
render_markdown_to_discord,
|
||||
)
|
||||
from .telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
escape_md_v2_link_url,
|
||||
mdv2_bold,
|
||||
mdv2_code_inline,
|
||||
format_status as format_status_telegram_fn,
|
||||
render_markdown_to_mdv2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"escape_discord",
|
||||
"escape_discord_code",
|
||||
"discord_bold",
|
||||
"discord_code_inline",
|
||||
"format_status_discord_fn",
|
||||
"format_status_discord",
|
||||
"render_markdown_to_discord",
|
||||
"escape_md_v2",
|
||||
"escape_md_v2_code",
|
||||
"escape_md_v2_link_url",
|
||||
"mdv2_bold",
|
||||
"mdv2_code_inline",
|
||||
"format_status_telegram_fn",
|
||||
"render_markdown_to_mdv2",
|
||||
]
|
||||
374
messaging/rendering/discord_markdown.py
Normal file
374
messaging/rendering/discord_markdown.py
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
"""Discord markdown utilities.
|
||||
|
||||
Discord uses standard markdown: **bold**, *italic*, `code`, ```code block```.
|
||||
Used by the message handler and Discord platform adapter.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
|
||||
# Discord escapes: \ * _ ` ~ | >
|
||||
DISCORD_SPECIAL = set("\\*_`~|>")
|
||||
|
||||
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
|
||||
_MD.enable("strikethrough")
|
||||
_MD.enable("table")
|
||||
|
||||
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
|
||||
_FENCE_RE = re.compile(r"^\s*```")
|
||||
|
||||
|
||||
def _is_gfm_table_header_line(line: str) -> bool:
|
||||
"""Check if line is a GFM table header."""
|
||||
if "|" not in line:
|
||||
return False
|
||||
if _TABLE_SEP_RE.match(line):
|
||||
return False
|
||||
stripped = line.strip()
|
||||
parts = [p.strip() for p in stripped.strip("|").split("|")]
|
||||
parts = [p for p in parts if p != ""]
|
||||
return len(parts) >= 2
|
||||
|
||||
|
||||
def _normalize_gfm_tables(text: str) -> str:
|
||||
"""Insert blank line before detected tables outside code blocks."""
|
||||
lines = text.splitlines()
|
||||
if len(lines) < 2:
|
||||
return text
|
||||
|
||||
out_lines: List[str] = []
|
||||
in_fence = False
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
if _FENCE_RE.match(line):
|
||||
in_fence = not in_fence
|
||||
out_lines.append(line)
|
||||
continue
|
||||
|
||||
if (
|
||||
not in_fence
|
||||
and idx + 1 < len(lines)
|
||||
and _is_gfm_table_header_line(line)
|
||||
and _TABLE_SEP_RE.match(lines[idx + 1])
|
||||
):
|
||||
if out_lines and out_lines[-1].strip() != "":
|
||||
m = re.match(r"^(\s*)", line)
|
||||
indent = m.group(1) if m else ""
|
||||
out_lines.append(indent)
|
||||
|
||||
out_lines.append(line)
|
||||
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
def escape_discord(text: str) -> str:
|
||||
"""Escape text for Discord markdown (bold, italic, etc.)."""
|
||||
return "".join(f"\\{ch}" if ch in DISCORD_SPECIAL else ch for ch in text)
|
||||
|
||||
|
||||
def escape_discord_code(text: str) -> str:
|
||||
"""Escape text for Discord code spans/blocks."""
|
||||
return text.replace("\\", "\\\\").replace("`", "\\`")
|
||||
|
||||
|
||||
def discord_bold(text: str) -> str:
|
||||
"""Format text as bold in Discord (uses **)."""
|
||||
return f"**{escape_discord(text)}**"
|
||||
|
||||
|
||||
def discord_code_inline(text: str) -> str:
|
||||
"""Format text as inline code in Discord."""
|
||||
return f"`{escape_discord_code(text)}`"
|
||||
|
||||
|
||||
def format_status_discord(label: str, suffix: Optional[str] = None) -> str:
|
||||
"""Format a status message for Discord (label in bold, optional suffix)."""
|
||||
base = discord_bold(label)
|
||||
if suffix:
|
||||
return f"{base} {escape_discord(suffix)}"
|
||||
return base
|
||||
|
||||
|
||||
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
|
||||
"""Format a status message with emoji for Discord (matches Telegram API)."""
|
||||
base = f"{emoji} {discord_bold(label)}"
|
||||
if suffix:
|
||||
return f"{base} {escape_discord(suffix)}"
|
||||
return base
|
||||
|
||||
|
||||
def render_markdown_to_discord(text: str) -> str:
|
||||
"""Render common Markdown into Discord-compatible format."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = _normalize_gfm_tables(text)
|
||||
tokens = _MD.parse(text)
|
||||
|
||||
def render_inline_table_plain(children) -> str:
|
||||
out: List[str] = []
|
||||
for tok in children:
|
||||
if tok.type == "text":
|
||||
out.append(tok.content)
|
||||
elif tok.type == "code_inline":
|
||||
out.append(tok.content)
|
||||
elif tok.type in {"softbreak", "hardbreak"}:
|
||||
out.append(" ")
|
||||
elif tok.type == "image":
|
||||
if tok.content:
|
||||
out.append(tok.content)
|
||||
return "".join(out)
|
||||
|
||||
def render_inline(children) -> str:
|
||||
out: List[str] = []
|
||||
i = 0
|
||||
while i < len(children):
|
||||
tok = children[i]
|
||||
t = tok.type
|
||||
if t == "text":
|
||||
out.append(escape_discord(tok.content))
|
||||
elif t in {"softbreak", "hardbreak"}:
|
||||
out.append("\n")
|
||||
elif t == "em_open":
|
||||
out.append("*")
|
||||
elif t == "em_close":
|
||||
out.append("*")
|
||||
elif t == "strong_open":
|
||||
out.append("**")
|
||||
elif t == "strong_close":
|
||||
out.append("**")
|
||||
elif t == "s_open":
|
||||
out.append("~~")
|
||||
elif t == "s_close":
|
||||
out.append("~~")
|
||||
elif t == "code_inline":
|
||||
out.append(f"`{escape_discord_code(tok.content)}`")
|
||||
elif t == "link_open":
|
||||
href = ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("href", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "href":
|
||||
href = val
|
||||
break
|
||||
inner_tokens = []
|
||||
i += 1
|
||||
while i < len(children) and children[i].type != "link_close":
|
||||
inner_tokens.append(children[i])
|
||||
i += 1
|
||||
link_text = ""
|
||||
for child in inner_tokens:
|
||||
if child.type == "text":
|
||||
link_text += child.content
|
||||
elif child.type == "code_inline":
|
||||
link_text += child.content
|
||||
out.append(f"[{escape_discord(link_text)}]({href})")
|
||||
elif t == "image":
|
||||
href = ""
|
||||
alt = tok.content or ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("src", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "src":
|
||||
href = val
|
||||
break
|
||||
if alt:
|
||||
out.append(f"{escape_discord(alt)} ({href})")
|
||||
else:
|
||||
out.append(href)
|
||||
else:
|
||||
out.append(escape_discord(tok.content or ""))
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
out: List[str] = []
|
||||
list_stack: List[dict] = []
|
||||
pending_prefix: Optional[str] = None
|
||||
blockquote_level = 0
|
||||
in_heading = False
|
||||
|
||||
def apply_blockquote(val: str) -> str:
|
||||
if blockquote_level <= 0:
|
||||
return val
|
||||
prefix = "> " * blockquote_level
|
||||
return prefix + val.replace("\n", "\n" + prefix)
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
t = tok.type
|
||||
if t == "paragraph_open":
|
||||
pass
|
||||
elif t == "paragraph_close":
|
||||
out.append("\n")
|
||||
elif t == "heading_open":
|
||||
in_heading = True
|
||||
elif t == "heading_close":
|
||||
in_heading = False
|
||||
out.append("\n")
|
||||
elif t == "bullet_list_open":
|
||||
list_stack.append({"type": "bullet", "index": 1})
|
||||
elif t == "bullet_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "ordered_list_open":
|
||||
start = 1
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
val = tok.attrs.get("start")
|
||||
if val is not None:
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "start":
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
break
|
||||
list_stack.append({"type": "ordered", "index": start})
|
||||
elif t == "ordered_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "list_item_open":
|
||||
if list_stack:
|
||||
top = list_stack[-1]
|
||||
if top["type"] == "bullet":
|
||||
pending_prefix = "- "
|
||||
else:
|
||||
pending_prefix = f"{top['index']}. "
|
||||
top["index"] += 1
|
||||
elif t == "list_item_close":
|
||||
out.append("\n")
|
||||
elif t == "blockquote_open":
|
||||
blockquote_level += 1
|
||||
elif t == "blockquote_close":
|
||||
blockquote_level = max(0, blockquote_level - 1)
|
||||
out.append("\n")
|
||||
elif t == "table_open":
|
||||
if pending_prefix:
|
||||
out.append(apply_blockquote(pending_prefix.rstrip()))
|
||||
out.append("\n")
|
||||
pending_prefix = None
|
||||
|
||||
rows: List[List[str]] = []
|
||||
row_is_header: List[bool] = []
|
||||
|
||||
j = i + 1
|
||||
in_thead = False
|
||||
in_row = False
|
||||
current_row: List[str] = []
|
||||
current_row_header = False
|
||||
|
||||
in_cell = False
|
||||
cell_parts: List[str] = []
|
||||
|
||||
while j < len(tokens):
|
||||
tt = tokens[j].type
|
||||
if tt == "thead_open":
|
||||
in_thead = True
|
||||
elif tt == "thead_close":
|
||||
in_thead = False
|
||||
elif tt == "tr_open":
|
||||
in_row = True
|
||||
current_row = []
|
||||
current_row_header = in_thead
|
||||
elif tt in {"th_open", "td_open"}:
|
||||
in_cell = True
|
||||
cell_parts = []
|
||||
elif tt == "inline" and in_cell:
|
||||
cell_parts.append(
|
||||
render_inline_table_plain(tokens[j].children or [])
|
||||
)
|
||||
elif tt in {"th_close", "td_close"} and in_cell:
|
||||
cell = " ".join(cell_parts).strip()
|
||||
current_row.append(cell)
|
||||
in_cell = False
|
||||
cell_parts = []
|
||||
elif tt == "tr_close" and in_row:
|
||||
rows.append(current_row)
|
||||
row_is_header.append(bool(current_row_header))
|
||||
in_row = False
|
||||
elif tt == "table_close":
|
||||
break
|
||||
j += 1
|
||||
|
||||
if rows:
|
||||
col_count = max((len(r) for r in rows), default=0)
|
||||
norm_rows: List[List[str]] = []
|
||||
for r in rows:
|
||||
if len(r) < col_count:
|
||||
r = r + [""] * (col_count - len(r))
|
||||
norm_rows.append(r)
|
||||
|
||||
widths: List[int] = []
|
||||
for c in range(col_count):
|
||||
w = max((len(r[c]) for r in norm_rows), default=0)
|
||||
widths.append(max(w, 3))
|
||||
|
||||
def fmt_row(r: List[str]) -> str:
|
||||
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
def fmt_sep() -> str:
|
||||
cells = ["-" * widths[c] for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
last_header_idx = -1
|
||||
for idx, is_h in enumerate(row_is_header):
|
||||
if is_h:
|
||||
last_header_idx = idx
|
||||
|
||||
lines: List[str] = []
|
||||
for idx, r in enumerate(norm_rows):
|
||||
lines.append(fmt_row(r))
|
||||
if idx == last_header_idx:
|
||||
lines.append(fmt_sep())
|
||||
|
||||
table_text = "\n".join(lines).rstrip()
|
||||
out.append(f"```\n{escape_discord_code(table_text)}\n```")
|
||||
out.append("\n")
|
||||
|
||||
i = j + 1
|
||||
continue
|
||||
elif t in {"code_block", "fence"}:
|
||||
code = escape_discord_code(tok.content.rstrip("\n"))
|
||||
out.append(f"```\n{code}\n```")
|
||||
out.append("\n")
|
||||
elif t == "inline":
|
||||
rendered = render_inline(tok.children or [])
|
||||
if in_heading:
|
||||
rendered = f"**{render_inline(tok.children or [])}**"
|
||||
if pending_prefix:
|
||||
rendered = pending_prefix + rendered
|
||||
pending_prefix = None
|
||||
rendered = apply_blockquote(rendered)
|
||||
out.append(rendered)
|
||||
else:
|
||||
if tok.content:
|
||||
out.append(escape_discord(tok.content))
|
||||
i += 1
|
||||
|
||||
return "".join(out).rstrip()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"escape_discord",
|
||||
"escape_discord_code",
|
||||
"discord_bold",
|
||||
"discord_code_inline",
|
||||
"format_status",
|
||||
"format_status_discord",
|
||||
"render_markdown_to_discord",
|
||||
]
|
||||
391
messaging/rendering/telegram_markdown.py
Normal file
391
messaging/rendering/telegram_markdown.py
Normal file
|
|
@ -0,0 +1,391 @@
|
|||
"""Telegram MarkdownV2 utilities.
|
||||
|
||||
Renders common Markdown into Telegram MarkdownV2 format.
|
||||
Used by the message handler and Telegram platform adapter.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
|
||||
MDV2_SPECIAL_CHARS = set("\\_*[]()~`>#+-=|{}.!")
|
||||
MDV2_LINK_ESCAPE = set("\\)")
|
||||
|
||||
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
|
||||
_MD.enable("strikethrough")
|
||||
_MD.enable("table")
|
||||
|
||||
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
|
||||
_FENCE_RE = re.compile(r"^\s*```")
|
||||
|
||||
|
||||
def _is_gfm_table_header_line(line: str) -> bool:
|
||||
"""Check if line is a GFM table header (pipe-delimited, not separator)."""
|
||||
if "|" not in line:
|
||||
return False
|
||||
if _TABLE_SEP_RE.match(line):
|
||||
return False
|
||||
stripped = line.strip()
|
||||
parts = [p.strip() for p in stripped.strip("|").split("|")]
|
||||
parts = [p for p in parts if p != ""]
|
||||
return len(parts) >= 2
|
||||
|
||||
|
||||
def _normalize_gfm_tables(text: str) -> str:
|
||||
"""
|
||||
Many LLMs emit tables immediately after a paragraph line (no blank line).
|
||||
Markdown-it will treat that as a softbreak within the paragraph, so the
|
||||
table extension won't trigger. Insert a blank line before detected tables.
|
||||
|
||||
We only do this outside fenced code blocks.
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
if len(lines) < 2:
|
||||
return text
|
||||
|
||||
out_lines: List[str] = []
|
||||
in_fence = False
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
if _FENCE_RE.match(line):
|
||||
in_fence = not in_fence
|
||||
out_lines.append(line)
|
||||
continue
|
||||
|
||||
if (
|
||||
not in_fence
|
||||
and idx + 1 < len(lines)
|
||||
and _is_gfm_table_header_line(line)
|
||||
and _TABLE_SEP_RE.match(lines[idx + 1])
|
||||
):
|
||||
if out_lines and out_lines[-1].strip() != "":
|
||||
m = re.match(r"^(\s*)", line)
|
||||
indent = m.group(1) if m else ""
|
||||
out_lines.append(indent)
|
||||
|
||||
out_lines.append(line)
|
||||
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
def escape_md_v2(text: str) -> str:
|
||||
"""Escape text for Telegram MarkdownV2."""
|
||||
return "".join(f"\\{ch}" if ch in MDV2_SPECIAL_CHARS else ch for ch in text)
|
||||
|
||||
|
||||
def escape_md_v2_code(text: str) -> str:
|
||||
"""Escape text for Telegram MarkdownV2 code spans/blocks."""
|
||||
return text.replace("\\", "\\\\").replace("`", "\\`")
|
||||
|
||||
|
||||
def escape_md_v2_link_url(text: str) -> str:
|
||||
"""Escape URL for Telegram MarkdownV2 link destination."""
|
||||
return "".join(f"\\{ch}" if ch in MDV2_LINK_ESCAPE else ch for ch in text)
|
||||
|
||||
|
||||
def mdv2_bold(text: str) -> str:
|
||||
"""Format text as bold in MarkdownV2."""
|
||||
return f"*{escape_md_v2(text)}*"
|
||||
|
||||
|
||||
def mdv2_code_inline(text: str) -> str:
|
||||
"""Format text as inline code in MarkdownV2."""
|
||||
return f"`{escape_md_v2_code(text)}`"
|
||||
|
||||
|
||||
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
|
||||
"""Format a status message with emoji and optional suffix."""
|
||||
base = f"{emoji} {mdv2_bold(label)}"
|
||||
if suffix:
|
||||
return f"{base} {escape_md_v2(suffix)}"
|
||||
return base
|
||||
|
||||
|
||||
def render_markdown_to_mdv2(text: str) -> str:
|
||||
"""Render common Markdown into Telegram MarkdownV2."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = _normalize_gfm_tables(text)
|
||||
tokens = _MD.parse(text)
|
||||
|
||||
def render_inline_table_plain(children) -> str:
|
||||
out: List[str] = []
|
||||
for tok in children:
|
||||
if tok.type == "text":
|
||||
out.append(tok.content)
|
||||
elif tok.type == "code_inline":
|
||||
out.append(tok.content)
|
||||
elif tok.type in {"softbreak", "hardbreak"}:
|
||||
out.append(" ")
|
||||
elif tok.type == "image":
|
||||
if tok.content:
|
||||
out.append(tok.content)
|
||||
return "".join(out)
|
||||
|
||||
def render_inline_plain(children) -> str:
|
||||
out: List[str] = []
|
||||
for tok in children:
|
||||
if tok.type == "text":
|
||||
out.append(escape_md_v2(tok.content))
|
||||
elif tok.type == "code_inline":
|
||||
out.append(escape_md_v2(tok.content))
|
||||
elif tok.type in {"softbreak", "hardbreak"}:
|
||||
out.append("\n")
|
||||
return "".join(out)
|
||||
|
||||
def render_inline(children) -> str:
|
||||
out: List[str] = []
|
||||
i = 0
|
||||
while i < len(children):
|
||||
tok = children[i]
|
||||
t = tok.type
|
||||
if t == "text":
|
||||
out.append(escape_md_v2(tok.content))
|
||||
elif t in {"softbreak", "hardbreak"}:
|
||||
out.append("\n")
|
||||
elif t == "em_open":
|
||||
out.append("_")
|
||||
elif t == "em_close":
|
||||
out.append("_")
|
||||
elif t == "strong_open":
|
||||
out.append("*")
|
||||
elif t == "strong_close":
|
||||
out.append("*")
|
||||
elif t == "s_open":
|
||||
out.append("~")
|
||||
elif t == "s_close":
|
||||
out.append("~")
|
||||
elif t == "code_inline":
|
||||
out.append(f"`{escape_md_v2_code(tok.content)}`")
|
||||
elif t == "link_open":
|
||||
href = ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("href", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "href":
|
||||
href = val
|
||||
break
|
||||
inner_tokens = []
|
||||
i += 1
|
||||
while i < len(children) and children[i].type != "link_close":
|
||||
inner_tokens.append(children[i])
|
||||
i += 1
|
||||
link_text = ""
|
||||
for child in inner_tokens:
|
||||
if child.type == "text":
|
||||
link_text += child.content
|
||||
elif child.type == "code_inline":
|
||||
link_text += child.content
|
||||
out.append(
|
||||
f"[{escape_md_v2(link_text)}]({escape_md_v2_link_url(href)})"
|
||||
)
|
||||
elif t == "image":
|
||||
href = ""
|
||||
alt = tok.content or ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("src", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "src":
|
||||
href = val
|
||||
break
|
||||
if alt:
|
||||
out.append(f"{escape_md_v2(alt)} ({escape_md_v2_link_url(href)})")
|
||||
else:
|
||||
out.append(escape_md_v2_link_url(href))
|
||||
else:
|
||||
out.append(escape_md_v2(tok.content or ""))
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
out: List[str] = []
|
||||
list_stack: List[dict] = []
|
||||
pending_prefix: Optional[str] = None
|
||||
blockquote_level = 0
|
||||
in_heading = False
|
||||
|
||||
def apply_blockquote(val: str) -> str:
|
||||
if blockquote_level <= 0:
|
||||
return val
|
||||
prefix = "> " * blockquote_level
|
||||
return prefix + val.replace("\n", "\n" + prefix)
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
t = tok.type
|
||||
if t == "paragraph_open":
|
||||
pass
|
||||
elif t == "paragraph_close":
|
||||
out.append("\n")
|
||||
elif t == "heading_open":
|
||||
in_heading = True
|
||||
elif t == "heading_close":
|
||||
in_heading = False
|
||||
out.append("\n")
|
||||
elif t == "bullet_list_open":
|
||||
list_stack.append({"type": "bullet", "index": 1})
|
||||
elif t == "bullet_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "ordered_list_open":
|
||||
start = 1
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
val = tok.attrs.get("start")
|
||||
if val is not None:
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "start":
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
break
|
||||
list_stack.append({"type": "ordered", "index": start})
|
||||
elif t == "ordered_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "list_item_open":
|
||||
if list_stack:
|
||||
top = list_stack[-1]
|
||||
if top["type"] == "bullet":
|
||||
pending_prefix = "\\- "
|
||||
else:
|
||||
pending_prefix = f"{top['index']}\\."
|
||||
top["index"] += 1
|
||||
pending_prefix += " "
|
||||
elif t == "list_item_close":
|
||||
out.append("\n")
|
||||
elif t == "blockquote_open":
|
||||
blockquote_level += 1
|
||||
elif t == "blockquote_close":
|
||||
blockquote_level = max(0, blockquote_level - 1)
|
||||
out.append("\n")
|
||||
elif t == "table_open":
|
||||
if pending_prefix:
|
||||
out.append(apply_blockquote(pending_prefix.rstrip()))
|
||||
out.append("\n")
|
||||
pending_prefix = None
|
||||
|
||||
rows: List[List[str]] = []
|
||||
row_is_header: List[bool] = []
|
||||
|
||||
j = i + 1
|
||||
in_thead = False
|
||||
in_row = False
|
||||
current_row: List[str] = []
|
||||
current_row_header = False
|
||||
|
||||
in_cell = False
|
||||
cell_parts: List[str] = []
|
||||
|
||||
while j < len(tokens):
|
||||
tt = tokens[j].type
|
||||
if tt == "thead_open":
|
||||
in_thead = True
|
||||
elif tt == "thead_close":
|
||||
in_thead = False
|
||||
elif tt == "tr_open":
|
||||
in_row = True
|
||||
current_row = []
|
||||
current_row_header = in_thead
|
||||
elif tt in {"th_open", "td_open"}:
|
||||
in_cell = True
|
||||
cell_parts = []
|
||||
elif tt == "inline" and in_cell:
|
||||
cell_parts.append(
|
||||
render_inline_table_plain(tokens[j].children or [])
|
||||
)
|
||||
elif tt in {"th_close", "td_close"} and in_cell:
|
||||
cell = " ".join(cell_parts).strip()
|
||||
current_row.append(cell)
|
||||
in_cell = False
|
||||
cell_parts = []
|
||||
elif tt == "tr_close" and in_row:
|
||||
rows.append(current_row)
|
||||
row_is_header.append(bool(current_row_header))
|
||||
in_row = False
|
||||
elif tt == "table_close":
|
||||
break
|
||||
j += 1
|
||||
|
||||
if rows:
|
||||
col_count = max((len(r) for r in rows), default=0)
|
||||
norm_rows: List[List[str]] = []
|
||||
for r in rows:
|
||||
if len(r) < col_count:
|
||||
r = r + [""] * (col_count - len(r))
|
||||
norm_rows.append(r)
|
||||
|
||||
widths: List[int] = []
|
||||
for c in range(col_count):
|
||||
w = max((len(r[c]) for r in norm_rows), default=0)
|
||||
widths.append(max(w, 3))
|
||||
|
||||
def fmt_row(r: List[str]) -> str:
|
||||
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
def fmt_sep() -> str:
|
||||
cells = ["-" * widths[c] for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
last_header_idx = -1
|
||||
for idx, is_h in enumerate(row_is_header):
|
||||
if is_h:
|
||||
last_header_idx = idx
|
||||
|
||||
lines: List[str] = []
|
||||
for idx, r in enumerate(norm_rows):
|
||||
lines.append(fmt_row(r))
|
||||
if idx == last_header_idx:
|
||||
lines.append(fmt_sep())
|
||||
|
||||
table_text = "\n".join(lines).rstrip()
|
||||
out.append(f"```\n{escape_md_v2_code(table_text)}\n```")
|
||||
out.append("\n")
|
||||
|
||||
i = j + 1
|
||||
continue
|
||||
elif t in {"code_block", "fence"}:
|
||||
code = escape_md_v2_code(tok.content.rstrip("\n"))
|
||||
out.append(f"```\n{code}\n```")
|
||||
out.append("\n")
|
||||
elif t == "inline":
|
||||
rendered = render_inline(tok.children or [])
|
||||
if in_heading:
|
||||
rendered = f"*{render_inline_plain(tok.children or [])}*"
|
||||
if pending_prefix:
|
||||
rendered = pending_prefix + rendered
|
||||
pending_prefix = None
|
||||
rendered = apply_blockquote(rendered)
|
||||
out.append(rendered)
|
||||
else:
|
||||
if tok.content:
|
||||
out.append(escape_md_v2(tok.content))
|
||||
i += 1
|
||||
|
||||
return "".join(out).rstrip()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"escape_md_v2",
|
||||
"escape_md_v2_code",
|
||||
"escape_md_v2_link_url",
|
||||
"mdv2_bold",
|
||||
"mdv2_code_inline",
|
||||
"format_status",
|
||||
"render_markdown_to_mdv2",
|
||||
]
|
||||
|
|
@ -1,487 +1,8 @@
|
|||
"""
|
||||
Telegram Platform Adapter
|
||||
"""Backward-compatible re-export. Use messaging.platforms.telegram for new code."""
|
||||
|
||||
Implements MessagingPlatform for Telegram using python-telegram-bot.
|
||||
"""
|
||||
from .platforms.telegram import (
|
||||
TelegramPlatform,
|
||||
TELEGRAM_AVAILABLE,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
# Opt-in to future behavior for python-telegram-bot (retry_after as timedelta)
|
||||
# This must be set BEFORE importing telegram.error
|
||||
os.environ["PTB_TIMEDELTA"] = "1"
|
||||
|
||||
from typing import Callable, Awaitable, Optional, Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from .models import IncomingMessage
|
||||
from .telegram_markdown import escape_md_v2
|
||||
|
||||
|
||||
# Optional import - python-telegram-bot may not be installed
|
||||
try:
|
||||
from telegram import Update
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
MessageHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.error import TelegramError, RetryAfter, NetworkError
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
|
||||
|
||||
class TelegramPlatform(MessagingPlatform):
|
||||
"""
|
||||
Telegram messaging platform adapter.
|
||||
|
||||
Uses python-telegram-bot (BoT API) for Telegram access.
|
||||
Requires a Bot Token from @BotFather.
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: Optional[str] = None,
|
||||
allowed_user_id: Optional[str] = None,
|
||||
):
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
|
||||
)
|
||||
|
||||
self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
|
||||
|
||||
if not self.bot_token:
|
||||
# We don't raise here to allow instantiation for testing/conditional logic,
|
||||
# but start() will fail.
|
||||
logger.warning("TELEGRAM_BOT_TOKEN not set")
|
||||
|
||||
self._application: Optional[Application] = None
|
||||
self._message_handler: Optional[
|
||||
Callable[[IncomingMessage], Awaitable[None]]
|
||||
] = None
|
||||
self._connected = False
|
||||
self._limiter: Optional[Any] = None # Will be MessagingRateLimiter
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to Telegram."""
|
||||
if not self.bot_token:
|
||||
raise ValueError("TELEGRAM_BOT_TOKEN is required")
|
||||
|
||||
# Configure request with longer timeouts
|
||||
request = HTTPXRequest(
|
||||
connection_pool_size=8, connect_timeout=30.0, read_timeout=30.0
|
||||
)
|
||||
|
||||
# Build Application
|
||||
builder = Application.builder().token(self.bot_token).request(request)
|
||||
self._application = builder.build()
|
||||
|
||||
# Register Internal Handlers
|
||||
# We catch ALL text messages and commands to forward them
|
||||
self._application.add_handler(
|
||||
MessageHandler(filters.TEXT & (~filters.COMMAND), self._on_telegram_message)
|
||||
)
|
||||
self._application.add_handler(CommandHandler("start", self._on_start_command))
|
||||
# Catch-all for other commands if needed, or let them fall through
|
||||
self._application.add_handler(
|
||||
MessageHandler(filters.COMMAND, self._on_telegram_message)
|
||||
)
|
||||
|
||||
# Initialize internal components with retry logic
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self._application.initialize()
|
||||
await self._application.start()
|
||||
|
||||
# Start polling (non-blocking way for integration)
|
||||
if self._application.updater:
|
||||
await self._application.updater.start_polling(
|
||||
drop_pending_updates=False
|
||||
)
|
||||
|
||||
self._connected = True
|
||||
break
|
||||
except (NetworkError, Exception) as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 * (attempt + 1)
|
||||
logger.warning(
|
||||
f"Connection failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Failed to connect after {max_retries} attempts")
|
||||
raise
|
||||
|
||||
# Initialize rate limiter
|
||||
from .limiter import MessagingRateLimiter
|
||||
|
||||
self._limiter = await MessagingRateLimiter.get_instance()
|
||||
|
||||
# Send startup notification
|
||||
try:
|
||||
target = self.allowed_user_id
|
||||
if target:
|
||||
startup_text = (
|
||||
f"🚀 *{escape_md_v2('Claude Code Proxy is online!')}* "
|
||||
f"{escape_md_v2('(Bot API)')}"
|
||||
)
|
||||
await self.send_message(
|
||||
target,
|
||||
startup_text,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not send startup message: {e}")
|
||||
|
||||
logger.info("Telegram platform started (Bot API)")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the bot."""
|
||||
if self._application and self._application.updater:
|
||||
await self._application.updater.stop()
|
||||
await self._application.stop()
|
||||
await self._application.shutdown()
|
||||
|
||||
self._connected = False
|
||||
logger.info("Telegram platform stopped")
|
||||
|
||||
async def _with_retry(
|
||||
self, func: Callable[..., Awaitable[Any]], *args, **kwargs
|
||||
) -> Any:
|
||||
"""Helper to execute a function with exponential backoff on network errors."""
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except (NetworkError, asyncio.TimeoutError) as e:
|
||||
if "Message is not modified" in str(e):
|
||||
return None
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt # 1s, 2s, 4s
|
||||
logger.warning(
|
||||
f"Telegram API network error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
f"Telegram API failed after {max_retries} attempts: {e}"
|
||||
)
|
||||
raise
|
||||
except RetryAfter as e:
|
||||
# Telegram explicitly tells us to wait (PTB_TIMEDELTA: retry_after is timedelta)
|
||||
from datetime import timedelta
|
||||
|
||||
retry_after = e.retry_after
|
||||
if isinstance(retry_after, timedelta):
|
||||
wait_secs = retry_after.total_seconds()
|
||||
else:
|
||||
wait_secs = float(retry_after)
|
||||
|
||||
logger.warning(f"Rate limited by Telegram, waiting {wait_secs}s...")
|
||||
await asyncio.sleep(wait_secs)
|
||||
# We don't increment attempt here, as this is a specific instruction
|
||||
return await func(*args, **kwargs)
|
||||
except TelegramError as e:
|
||||
# Non-network Telegram errors
|
||||
err_lower = str(e).lower()
|
||||
if "message is not modified" in err_lower:
|
||||
return None
|
||||
# Best-effort no-op cases (common during chat cleanup / /clear).
|
||||
if any(
|
||||
x in err_lower
|
||||
for x in [
|
||||
"message to edit not found",
|
||||
"message to delete not found",
|
||||
"message can't be deleted",
|
||||
"message can't be edited",
|
||||
"not enough rights to delete",
|
||||
]
|
||||
):
|
||||
return None
|
||||
if "Can't parse entities" in str(e) and kwargs.get("parse_mode"):
|
||||
logger.warning("Markdown failed, retrying without parse_mode")
|
||||
kwargs["parse_mode"] = None
|
||||
return await func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
) -> str:
|
||||
"""Send a message to a chat."""
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
async def _do_send(parse_mode=parse_mode):
|
||||
bot = app.bot
|
||||
msg = await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
return str(msg.message_id)
|
||||
|
||||
return await self._with_retry(_do_send, parse_mode=parse_mode)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
) -> None:
|
||||
"""Edit an existing message."""
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
async def _do_edit(parse_mode=parse_mode):
|
||||
bot = app.bot
|
||||
await bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=int(message_id),
|
||||
text=text,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
|
||||
await self._with_retry(_do_edit, parse_mode=parse_mode)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""Delete a message from a chat."""
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
async def _do_delete():
|
||||
bot = app.bot
|
||||
await bot.delete_message(chat_id=chat_id, message_id=int(message_id))
|
||||
|
||||
await self._with_retry(_do_delete)
|
||||
|
||||
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
|
||||
"""Delete multiple messages (best-effort)."""
|
||||
if not message_ids:
|
||||
return
|
||||
app = self._application
|
||||
if not app or not app.bot:
|
||||
raise RuntimeError("Telegram application or bot not initialized")
|
||||
|
||||
# PTB supports bulk deletion via delete_messages; fall back to per-message.
|
||||
bot = app.bot
|
||||
if hasattr(bot, "delete_messages"):
|
||||
|
||||
async def _do_bulk():
|
||||
mids = []
|
||||
for mid in message_ids:
|
||||
try:
|
||||
mids.append(int(mid))
|
||||
except Exception:
|
||||
continue
|
||||
if not mids:
|
||||
return None
|
||||
# delete_messages accepts a sequence of ints (up to 100).
|
||||
await bot.delete_messages(chat_id=chat_id, message_ids=mids)
|
||||
|
||||
await self._with_retry(_do_bulk)
|
||||
return
|
||||
|
||||
for mid in message_ids:
|
||||
await self.delete_message(chat_id, mid)
|
||||
|
||||
async def queue_send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
fire_and_forget: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""Enqueue a message to be sent (using limiter)."""
|
||||
# Note: Bot API handles limits better, but we still use our limiter for nice queuing
|
||||
if not self._limiter:
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
async def _send():
|
||||
return await self.send_message(chat_id, text, reply_to, parse_mode)
|
||||
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_send)
|
||||
return None
|
||||
else:
|
||||
return await self._limiter.enqueue(_send)
|
||||
|
||||
async def queue_edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = "MarkdownV2",
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message edit."""
|
||||
if not self._limiter:
|
||||
return await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
|
||||
async def _edit():
|
||||
return await self.edit_message(chat_id, message_id, text, parse_mode)
|
||||
|
||||
dedup_key = f"edit:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a message delete."""
|
||||
if not self._limiter:
|
||||
return await self.delete_message(chat_id, message_id)
|
||||
|
||||
async def _delete():
|
||||
return await self.delete_message(chat_id, message_id)
|
||||
|
||||
dedup_key = f"del:{chat_id}:{message_id}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
|
||||
|
||||
async def queue_delete_messages(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_ids: list[str],
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Enqueue a bulk delete (if supported) or a sequence of deletes."""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
if not self._limiter:
|
||||
return await self.delete_messages(chat_id, message_ids)
|
||||
|
||||
async def _bulk():
|
||||
return await self.delete_messages(chat_id, message_ids)
|
||||
|
||||
# Dedup by the chunk content; okay to be coarse here.
|
||||
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
|
||||
if fire_and_forget:
|
||||
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
|
||||
else:
|
||||
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
|
||||
|
||||
def fire_and_forget(self, task: Awaitable[Any]) -> None:
|
||||
"""Execute a coroutine without awaiting it."""
|
||||
if asyncio.iscoroutine(task):
|
||||
asyncio.create_task(task)
|
||||
else:
|
||||
asyncio.ensure_future(task)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Register a message handler callback."""
|
||||
self._message_handler = handler
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected."""
|
||||
return self._connected
|
||||
|
||||
async def _on_start_command(
|
||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
"""Handle /start command."""
|
||||
if update.message:
|
||||
await update.message.reply_text("👋 Hello! I am the Claude Code Proxy Bot.")
|
||||
# We can also treat this as a message if we want it to trigger something
|
||||
await self._on_telegram_message(update, context)
|
||||
|
||||
async def _on_telegram_message(
|
||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
"""Handle incoming updates."""
|
||||
if (
|
||||
not update.message
|
||||
or not update.message.text
|
||||
or not update.effective_user
|
||||
or not update.effective_chat
|
||||
):
|
||||
return
|
||||
|
||||
user_id = str(update.effective_user.id)
|
||||
chat_id = str(update.effective_chat.id)
|
||||
|
||||
# Security check
|
||||
if self.allowed_user_id:
|
||||
if user_id != str(self.allowed_user_id).strip():
|
||||
logger.warning(f"Unauthorized access attempt from {user_id}")
|
||||
return
|
||||
|
||||
message_id = str(update.message.message_id)
|
||||
reply_to = (
|
||||
str(update.message.reply_to_message.message_id)
|
||||
if update.message.reply_to_message
|
||||
else None
|
||||
)
|
||||
text_preview = (update.message.text or "")[:80]
|
||||
if len(update.message.text or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"TELEGRAM_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
chat_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
text_preview,
|
||||
)
|
||||
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
incoming = IncomingMessage(
|
||||
text=update.message.text,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
platform="telegram",
|
||||
reply_to_message_id=reply_to,
|
||||
raw_event=update,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
try:
|
||||
await self.send_message(
|
||||
chat_id,
|
||||
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(str(e)[:200])}",
|
||||
reply_to=incoming.message_id,
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
__all__ = ["TelegramPlatform", "TELEGRAM_AVAILABLE"]
|
||||
|
|
|
|||
|
|
@ -1,384 +1,14 @@
|
|||
"""Telegram MarkdownV2 utilities.
|
||||
|
||||
Renders common Markdown into Telegram MarkdownV2 format.
|
||||
Used by the message handler and Telegram platform adapter.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
|
||||
MDV2_SPECIAL_CHARS = set("\\_*[]()~`>#+-=|{}.!")
|
||||
MDV2_LINK_ESCAPE = set("\\)")
|
||||
|
||||
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
|
||||
_MD.enable("strikethrough")
|
||||
_MD.enable("table")
|
||||
|
||||
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
|
||||
_FENCE_RE = re.compile(r"^\s*```")
|
||||
|
||||
|
||||
def _is_gfm_table_header_line(line: str) -> bool:
|
||||
"""Check if line is a GFM table header (pipe-delimited, not separator)."""
|
||||
if "|" not in line:
|
||||
return False
|
||||
if _TABLE_SEP_RE.match(line):
|
||||
return False
|
||||
stripped = line.strip()
|
||||
parts = [p.strip() for p in stripped.strip("|").split("|")]
|
||||
parts = [p for p in parts if p != ""]
|
||||
return len(parts) >= 2
|
||||
|
||||
|
||||
def _normalize_gfm_tables(text: str) -> str:
|
||||
"""
|
||||
Many LLMs emit tables immediately after a paragraph line (no blank line).
|
||||
Markdown-it will treat that as a softbreak within the paragraph, so the
|
||||
table extension won't trigger. Insert a blank line before detected tables.
|
||||
|
||||
We only do this outside fenced code blocks.
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
if len(lines) < 2:
|
||||
return text
|
||||
|
||||
out_lines: List[str] = []
|
||||
in_fence = False
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
if _FENCE_RE.match(line):
|
||||
in_fence = not in_fence
|
||||
out_lines.append(line)
|
||||
continue
|
||||
|
||||
if (
|
||||
not in_fence
|
||||
and idx + 1 < len(lines)
|
||||
and _is_gfm_table_header_line(line)
|
||||
and _TABLE_SEP_RE.match(lines[idx + 1])
|
||||
):
|
||||
if out_lines and out_lines[-1].strip() != "":
|
||||
m = re.match(r"^(\s*)", line)
|
||||
indent = m.group(1) if m else ""
|
||||
out_lines.append(indent)
|
||||
|
||||
out_lines.append(line)
|
||||
|
||||
return "\n".join(out_lines)
|
||||
|
||||
|
||||
def escape_md_v2(text: str) -> str:
|
||||
"""Escape text for Telegram MarkdownV2."""
|
||||
return "".join(f"\\{ch}" if ch in MDV2_SPECIAL_CHARS else ch for ch in text)
|
||||
|
||||
|
||||
def escape_md_v2_code(text: str) -> str:
|
||||
"""Escape text for Telegram MarkdownV2 code spans/blocks."""
|
||||
return text.replace("\\", "\\\\").replace("`", "\\`")
|
||||
|
||||
|
||||
def escape_md_v2_link_url(text: str) -> str:
|
||||
"""Escape URL for Telegram MarkdownV2 link destination."""
|
||||
return "".join(f"\\{ch}" if ch in MDV2_LINK_ESCAPE else ch for ch in text)
|
||||
|
||||
|
||||
def mdv2_bold(text: str) -> str:
|
||||
"""Format text as bold in MarkdownV2."""
|
||||
return f"*{escape_md_v2(text)}*"
|
||||
|
||||
|
||||
def mdv2_code_inline(text: str) -> str:
|
||||
"""Format text as inline code in MarkdownV2."""
|
||||
return f"`{escape_md_v2_code(text)}`"
|
||||
|
||||
|
||||
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
|
||||
"""Format a status message with emoji and optional suffix."""
|
||||
base = f"{emoji} {mdv2_bold(label)}"
|
||||
if suffix:
|
||||
return f"{base} {escape_md_v2(suffix)}"
|
||||
return base
|
||||
|
||||
|
||||
def render_markdown_to_mdv2(text: str) -> str:
|
||||
"""Render common Markdown into Telegram MarkdownV2."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = _normalize_gfm_tables(text)
|
||||
tokens = _MD.parse(text)
|
||||
|
||||
def render_inline_table_plain(children) -> str:
|
||||
out: List[str] = []
|
||||
for tok in children:
|
||||
if tok.type == "text":
|
||||
out.append(tok.content)
|
||||
elif tok.type == "code_inline":
|
||||
out.append(tok.content)
|
||||
elif tok.type in {"softbreak", "hardbreak"}:
|
||||
out.append(" ")
|
||||
elif tok.type == "image":
|
||||
if tok.content:
|
||||
out.append(tok.content)
|
||||
return "".join(out)
|
||||
|
||||
def render_inline_plain(children) -> str:
|
||||
out: List[str] = []
|
||||
for tok in children:
|
||||
if tok.type == "text":
|
||||
out.append(escape_md_v2(tok.content))
|
||||
elif tok.type == "code_inline":
|
||||
out.append(escape_md_v2(tok.content))
|
||||
elif tok.type in {"softbreak", "hardbreak"}:
|
||||
out.append("\n")
|
||||
return "".join(out)
|
||||
|
||||
def render_inline(children) -> str:
|
||||
out: List[str] = []
|
||||
i = 0
|
||||
while i < len(children):
|
||||
tok = children[i]
|
||||
t = tok.type
|
||||
if t == "text":
|
||||
out.append(escape_md_v2(tok.content))
|
||||
elif t in {"softbreak", "hardbreak"}:
|
||||
out.append("\n")
|
||||
elif t == "em_open":
|
||||
out.append("_")
|
||||
elif t == "em_close":
|
||||
out.append("_")
|
||||
elif t == "strong_open":
|
||||
out.append("*")
|
||||
elif t == "strong_close":
|
||||
out.append("*")
|
||||
elif t == "s_open":
|
||||
out.append("~")
|
||||
elif t == "s_close":
|
||||
out.append("~")
|
||||
elif t == "code_inline":
|
||||
out.append(f"`{escape_md_v2_code(tok.content)}`")
|
||||
elif t == "link_open":
|
||||
href = ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("href", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "href":
|
||||
href = val
|
||||
break
|
||||
inner_tokens = []
|
||||
i += 1
|
||||
while i < len(children) and children[i].type != "link_close":
|
||||
inner_tokens.append(children[i])
|
||||
i += 1
|
||||
link_text = ""
|
||||
for child in inner_tokens:
|
||||
if child.type == "text":
|
||||
link_text += child.content
|
||||
elif child.type == "code_inline":
|
||||
link_text += child.content
|
||||
out.append(
|
||||
f"[{escape_md_v2(link_text)}]({escape_md_v2_link_url(href)})"
|
||||
)
|
||||
elif t == "image":
|
||||
href = ""
|
||||
alt = tok.content or ""
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
href = tok.attrs.get("src", "")
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "src":
|
||||
href = val
|
||||
break
|
||||
if alt:
|
||||
out.append(f"{escape_md_v2(alt)} ({escape_md_v2_link_url(href)})")
|
||||
else:
|
||||
out.append(escape_md_v2_link_url(href))
|
||||
else:
|
||||
out.append(escape_md_v2(tok.content or ""))
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
out: List[str] = []
|
||||
list_stack: List[dict] = []
|
||||
pending_prefix: Optional[str] = None
|
||||
blockquote_level = 0
|
||||
in_heading = False
|
||||
|
||||
def apply_blockquote(val: str) -> str:
|
||||
if blockquote_level <= 0:
|
||||
return val
|
||||
prefix = "> " * blockquote_level
|
||||
return prefix + val.replace("\n", "\n" + prefix)
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
t = tok.type
|
||||
if t == "paragraph_open":
|
||||
pass
|
||||
elif t == "paragraph_close":
|
||||
out.append("\n")
|
||||
elif t == "heading_open":
|
||||
in_heading = True
|
||||
elif t == "heading_close":
|
||||
in_heading = False
|
||||
out.append("\n")
|
||||
elif t == "bullet_list_open":
|
||||
list_stack.append({"type": "bullet", "index": 1})
|
||||
elif t == "bullet_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "ordered_list_open":
|
||||
start = 1
|
||||
if tok.attrs:
|
||||
if isinstance(tok.attrs, dict):
|
||||
val = tok.attrs.get("start")
|
||||
if val is not None:
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
else:
|
||||
for key, val in tok.attrs:
|
||||
if key == "start":
|
||||
try:
|
||||
start = int(val)
|
||||
except TypeError, ValueError:
|
||||
start = 1
|
||||
break
|
||||
list_stack.append({"type": "ordered", "index": start})
|
||||
elif t == "ordered_list_close":
|
||||
if list_stack:
|
||||
list_stack.pop()
|
||||
out.append("\n")
|
||||
elif t == "list_item_open":
|
||||
if list_stack:
|
||||
top = list_stack[-1]
|
||||
if top["type"] == "bullet":
|
||||
pending_prefix = "\\- "
|
||||
else:
|
||||
pending_prefix = f"{top['index']}\\."
|
||||
top["index"] += 1
|
||||
pending_prefix += " "
|
||||
elif t == "list_item_close":
|
||||
out.append("\n")
|
||||
elif t == "blockquote_open":
|
||||
blockquote_level += 1
|
||||
elif t == "blockquote_close":
|
||||
blockquote_level = max(0, blockquote_level - 1)
|
||||
out.append("\n")
|
||||
elif t == "table_open":
|
||||
if pending_prefix:
|
||||
out.append(apply_blockquote(pending_prefix.rstrip()))
|
||||
out.append("\n")
|
||||
pending_prefix = None
|
||||
|
||||
rows: List[List[str]] = []
|
||||
row_is_header: List[bool] = []
|
||||
|
||||
j = i + 1
|
||||
in_thead = False
|
||||
in_row = False
|
||||
current_row: List[str] = []
|
||||
current_row_header = False
|
||||
|
||||
in_cell = False
|
||||
cell_parts: List[str] = []
|
||||
|
||||
while j < len(tokens):
|
||||
tt = tokens[j].type
|
||||
if tt == "thead_open":
|
||||
in_thead = True
|
||||
elif tt == "thead_close":
|
||||
in_thead = False
|
||||
elif tt == "tr_open":
|
||||
in_row = True
|
||||
current_row = []
|
||||
current_row_header = in_thead
|
||||
elif tt in {"th_open", "td_open"}:
|
||||
in_cell = True
|
||||
cell_parts = []
|
||||
elif tt == "inline" and in_cell:
|
||||
cell_parts.append(
|
||||
render_inline_table_plain(tokens[j].children or [])
|
||||
)
|
||||
elif tt in {"th_close", "td_close"} and in_cell:
|
||||
cell = " ".join(cell_parts).strip()
|
||||
current_row.append(cell)
|
||||
in_cell = False
|
||||
cell_parts = []
|
||||
elif tt == "tr_close" and in_row:
|
||||
rows.append(current_row)
|
||||
row_is_header.append(bool(current_row_header))
|
||||
in_row = False
|
||||
elif tt == "table_close":
|
||||
break
|
||||
j += 1
|
||||
|
||||
if rows:
|
||||
col_count = max((len(r) for r in rows), default=0)
|
||||
norm_rows: List[List[str]] = []
|
||||
for r in rows:
|
||||
if len(r) < col_count:
|
||||
r = r + [""] * (col_count - len(r))
|
||||
norm_rows.append(r)
|
||||
|
||||
widths: List[int] = []
|
||||
for c in range(col_count):
|
||||
w = max((len(r[c]) for r in norm_rows), default=0)
|
||||
widths.append(max(w, 3))
|
||||
|
||||
def fmt_row(r: List[str]) -> str:
|
||||
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
def fmt_sep() -> str:
|
||||
cells = ["-" * widths[c] for c in range(col_count)]
|
||||
return "| " + " | ".join(cells) + " |"
|
||||
|
||||
last_header_idx = -1
|
||||
for idx, is_h in enumerate(row_is_header):
|
||||
if is_h:
|
||||
last_header_idx = idx
|
||||
|
||||
lines: List[str] = []
|
||||
for idx, r in enumerate(norm_rows):
|
||||
lines.append(fmt_row(r))
|
||||
if idx == last_header_idx:
|
||||
lines.append(fmt_sep())
|
||||
|
||||
table_text = "\n".join(lines).rstrip()
|
||||
out.append(f"```\n{escape_md_v2_code(table_text)}\n```")
|
||||
out.append("\n")
|
||||
|
||||
i = j + 1
|
||||
continue
|
||||
elif t in {"code_block", "fence"}:
|
||||
code = escape_md_v2_code(tok.content.rstrip("\n"))
|
||||
out.append(f"```\n{code}\n```")
|
||||
out.append("\n")
|
||||
elif t == "inline":
|
||||
rendered = render_inline(tok.children or [])
|
||||
if in_heading:
|
||||
rendered = f"*{render_inline_plain(tok.children or [])}*"
|
||||
if pending_prefix:
|
||||
rendered = pending_prefix + rendered
|
||||
pending_prefix = None
|
||||
rendered = apply_blockquote(rendered)
|
||||
out.append(rendered)
|
||||
else:
|
||||
if tok.content:
|
||||
out.append(escape_md_v2(tok.content))
|
||||
i += 1
|
||||
|
||||
return "".join(out).rstrip()
|
||||
"""Backward-compatible re-export. Use messaging.rendering.telegram_markdown for new code."""
|
||||
|
||||
from .rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
escape_md_v2_link_url,
|
||||
mdv2_bold,
|
||||
mdv2_code_inline,
|
||||
format_status,
|
||||
render_markdown_to_mdv2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"escape_md_v2",
|
||||
|
|
|
|||
|
|
@ -1,441 +1,5 @@
|
|||
"""Tree data structures for message queue.
|
||||
"""Backward-compatible re-export. Use messaging.trees.data for new code."""
|
||||
|
||||
Contains MessageState, MessageNode, and MessageTree classes.
|
||||
"""
|
||||
from .trees.data import MessageTree, MessageNode, MessageState
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, List, Any, cast
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .models import IncomingMessage
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageState(Enum):
|
||||
"""State of a message node in the tree."""
|
||||
|
||||
PENDING = "pending" # Queued, waiting to be processed
|
||||
IN_PROGRESS = "in_progress" # Currently being processed by Claude
|
||||
COMPLETED = "completed" # Processing finished successfully
|
||||
ERROR = "error" # Processing failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageNode:
|
||||
"""
|
||||
A node in the message tree.
|
||||
|
||||
Each node represents a single message and tracks:
|
||||
- Its relationship to parent/children
|
||||
- Its processing state
|
||||
- Claude session information
|
||||
"""
|
||||
|
||||
node_id: str # Unique ID (typically message_id)
|
||||
incoming: IncomingMessage # The original message
|
||||
status_message_id: str # Bot's status message ID
|
||||
state: MessageState = MessageState.PENDING
|
||||
parent_id: Optional[str] = None # Parent node ID (None for root)
|
||||
session_id: Optional[str] = None # Claude session ID (forked from parent)
|
||||
children_ids: List[str] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
context: Any = None # Additional context if needed
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"incoming": {
|
||||
"text": self.incoming.text,
|
||||
"chat_id": self.incoming.chat_id,
|
||||
"user_id": self.incoming.user_id,
|
||||
"message_id": self.incoming.message_id,
|
||||
"platform": self.incoming.platform,
|
||||
"reply_to_message_id": self.incoming.reply_to_message_id,
|
||||
"username": self.incoming.username,
|
||||
},
|
||||
"status_message_id": self.status_message_id,
|
||||
"state": self.state.value,
|
||||
"parent_id": self.parent_id,
|
||||
"session_id": self.session_id,
|
||||
"children_ids": self.children_ids,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"error_message": self.error_message,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "MessageNode":
|
||||
"""Create from dictionary (JSON deserialization)."""
|
||||
incoming_data = data["incoming"]
|
||||
incoming = IncomingMessage(
|
||||
text=incoming_data["text"],
|
||||
chat_id=incoming_data["chat_id"],
|
||||
user_id=incoming_data["user_id"],
|
||||
message_id=incoming_data["message_id"],
|
||||
platform=incoming_data["platform"],
|
||||
reply_to_message_id=incoming_data.get("reply_to_message_id"),
|
||||
username=incoming_data.get("username"),
|
||||
)
|
||||
return cls(
|
||||
node_id=data["node_id"],
|
||||
incoming=incoming,
|
||||
status_message_id=data["status_message_id"],
|
||||
state=MessageState(data["state"]),
|
||||
parent_id=data.get("parent_id"),
|
||||
session_id=data.get("session_id"),
|
||||
children_ids=data.get("children_ids", []),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
completed_at=datetime.fromisoformat(data["completed_at"])
|
||||
if data.get("completed_at")
|
||||
else None,
|
||||
error_message=data.get("error_message"),
|
||||
)
|
||||
|
||||
|
||||
class MessageTree:
|
||||
"""
|
||||
A tree of message nodes with queue functionality.
|
||||
|
||||
Provides:
|
||||
- O(1) node lookup via hashmap
|
||||
- Per-tree message queue
|
||||
- Thread-safe operations via asyncio.Lock
|
||||
"""
|
||||
|
||||
def __init__(self, root_node: MessageNode):
|
||||
"""
|
||||
Initialize tree with a root node.
|
||||
|
||||
Args:
|
||||
root_node: The root message node
|
||||
"""
|
||||
self.root_id = root_node.node_id
|
||||
self._nodes: Dict[str, MessageNode] = {root_node.node_id: root_node}
|
||||
self._status_to_node: Dict[str, str] = {
|
||||
root_node.status_message_id: root_node.node_id
|
||||
}
|
||||
self._queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._is_processing = False
|
||||
self._current_node_id: Optional[str] = None
|
||||
self._current_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.debug(f"Created MessageTree with root {self.root_id}")
|
||||
|
||||
def set_current_task(self, task: Optional[asyncio.Task]) -> None:
|
||||
"""Set the current processing task. Caller must hold lock."""
|
||||
self._current_task = task
|
||||
|
||||
@property
|
||||
def is_processing(self) -> bool:
|
||||
"""Check if tree is currently processing a message."""
|
||||
return self._is_processing
|
||||
|
||||
async def add_node(
|
||||
self,
|
||||
node_id: str,
|
||||
incoming: IncomingMessage,
|
||||
status_message_id: str,
|
||||
parent_id: str,
|
||||
) -> MessageNode:
|
||||
"""
|
||||
Add a child node to the tree.
|
||||
|
||||
Args:
|
||||
node_id: Unique ID for the new node
|
||||
incoming: The incoming message
|
||||
status_message_id: Bot's status message ID
|
||||
parent_id: Parent node ID
|
||||
|
||||
Returns:
|
||||
The created MessageNode
|
||||
"""
|
||||
async with self._lock:
|
||||
if parent_id not in self._nodes:
|
||||
raise ValueError(f"Parent node {parent_id} not found in tree")
|
||||
|
||||
node = MessageNode(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
status_message_id=status_message_id,
|
||||
parent_id=parent_id,
|
||||
state=MessageState.PENDING,
|
||||
)
|
||||
|
||||
self._nodes[node_id] = node
|
||||
self._status_to_node[status_message_id] = node_id
|
||||
self._nodes[parent_id].children_ids.append(node_id)
|
||||
|
||||
logger.debug(f"Added node {node_id} as child of {parent_id}")
|
||||
return node
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get a node by ID (O(1) lookup)."""
|
||||
return self._nodes.get(node_id)
|
||||
|
||||
def get_root(self) -> MessageNode:
|
||||
"""Get the root node."""
|
||||
return self._nodes[self.root_id]
|
||||
|
||||
def get_children(self, node_id: str) -> List[MessageNode]:
|
||||
"""Get all child nodes of a given node."""
|
||||
node = self._nodes.get(node_id)
|
||||
if not node:
|
||||
return []
|
||||
return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
|
||||
|
||||
def get_parent(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get the parent node."""
|
||||
node = self._nodes.get(node_id)
|
||||
if not node or not node.parent_id:
|
||||
return None
|
||||
return self._nodes.get(node.parent_id)
|
||||
|
||||
def get_parent_session_id(self, node_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the parent's session ID for forking.
|
||||
|
||||
Returns None for root nodes.
|
||||
"""
|
||||
parent = self.get_parent(node_id)
|
||||
return parent.session_id if parent else None
|
||||
|
||||
async def update_state(
|
||||
self,
|
||||
node_id: str,
|
||||
state: MessageState,
|
||||
session_id: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update a node's state."""
|
||||
async with self._lock:
|
||||
node = self._nodes.get(node_id)
|
||||
if not node:
|
||||
logger.warning(f"Node {node_id} not found for state update")
|
||||
return
|
||||
|
||||
node.state = state
|
||||
if session_id:
|
||||
node.session_id = session_id
|
||||
if error_message:
|
||||
node.error_message = error_message
|
||||
if state in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.debug(f"Node {node_id} state -> {state.value}")
|
||||
|
||||
async def enqueue(self, node_id: str) -> int:
|
||||
"""
|
||||
Add a node to the processing queue.
|
||||
|
||||
Returns:
|
||||
Queue position (1-indexed)
|
||||
"""
|
||||
async with self._lock:
|
||||
await self._queue.put(node_id)
|
||||
position = self._queue.qsize()
|
||||
logger.debug(f"Enqueued node {node_id}, position {position}")
|
||||
return position
|
||||
|
||||
async def dequeue(self) -> Optional[str]:
|
||||
"""
|
||||
Get the next node ID from the queue.
|
||||
|
||||
Returns None if queue is empty.
|
||||
"""
|
||||
try:
|
||||
return self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
async def get_queue_snapshot(self) -> List[str]:
|
||||
"""
|
||||
Get a snapshot of the current queue order.
|
||||
|
||||
Returns:
|
||||
List of node IDs in FIFO order.
|
||||
"""
|
||||
async with self._lock:
|
||||
# Read internal deque directly to avoid mutating queue state.
|
||||
# Drain/put approach would inflate _unfinished_tasks without task_done().
|
||||
queue_deque = cast(deque, getattr(self._queue, "_queue"))
|
||||
return list(queue_deque)
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get number of messages waiting in queue."""
|
||||
return self._queue.qsize()
|
||||
|
||||
def remove_from_queue(self, node_id: str) -> bool:
|
||||
"""
|
||||
Remove node_id from the internal queue if present.
|
||||
|
||||
Caller must hold the tree lock (e.g. via with_lock).
|
||||
Returns True if node was removed, False if not in queue.
|
||||
|
||||
Note: asyncio.Queue has no built-in remove; we filter via the internal
|
||||
deque. O(n) in queue size; acceptable for typical tree queue sizes.
|
||||
"""
|
||||
queue_deque = cast(deque, getattr(self._queue, "_queue"))
|
||||
if node_id not in queue_deque:
|
||||
return False
|
||||
object.__setattr__(
|
||||
self._queue, "_queue", deque(x for x in queue_deque if x != node_id)
|
||||
)
|
||||
return True
|
||||
|
||||
@asynccontextmanager
|
||||
async def with_lock(self):
|
||||
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None:
|
||||
"""Set processing state. Caller must hold lock for consistency with queue operations."""
|
||||
self._is_processing = is_processing
|
||||
self._current_node_id = node_id if is_processing else None
|
||||
|
||||
def clear_current_node(self) -> None:
|
||||
"""Clear the currently processing node ID. Caller must hold lock."""
|
||||
self._current_node_id = None
|
||||
|
||||
def is_current_node(self, node_id: str) -> bool:
|
||||
"""Check if node_id is the currently processing node."""
|
||||
return self._current_node_id == node_id
|
||||
|
||||
def put_queue_unlocked(self, node_id: str) -> None:
|
||||
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
|
||||
self._queue.put_nowait(node_id)
|
||||
|
||||
def cancel_current_task(self) -> bool:
|
||||
"""Cancel the currently running task. Returns True if a task was cancelled."""
|
||||
if self._current_task and not self._current_task.done():
|
||||
self._current_task.cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
def drain_queue_and_mark_cancelled(
|
||||
self, error_message: str = "Cancelled by user"
|
||||
) -> List["MessageNode"]:
|
||||
"""
|
||||
Drain the queue, mark each node as ERROR, and return affected nodes.
|
||||
Does not acquire lock; caller must ensure no concurrent queue access.
|
||||
"""
|
||||
nodes: List[MessageNode] = []
|
||||
while True:
|
||||
try:
|
||||
node_id = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = error_message
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def reset_processing_state(self) -> None:
|
||||
"""Reset processing flags after cancel/cleanup."""
|
||||
self._is_processing = False
|
||||
self._current_node_id = None
|
||||
|
||||
@property
|
||||
def current_node_id(self) -> Optional[str]:
|
||||
"""Get the ID of the node currently being processed."""
|
||||
return self._current_node_id
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize tree to dictionary."""
|
||||
return {
|
||||
"root_id": self.root_id,
|
||||
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "MessageTree":
|
||||
"""Deserialize tree from dictionary."""
|
||||
root_id = data["root_id"]
|
||||
nodes_data = data["nodes"]
|
||||
|
||||
# Create root node first
|
||||
root_node = MessageNode.from_dict(nodes_data[root_id])
|
||||
tree = cls(root_node)
|
||||
|
||||
# Add remaining nodes and build status->node index
|
||||
for node_id, node_data in nodes_data.items():
|
||||
if node_id != root_id:
|
||||
node = MessageNode.from_dict(node_data)
|
||||
tree._nodes[node_id] = node
|
||||
tree._status_to_node[node.status_message_id] = node_id
|
||||
|
||||
return tree
|
||||
|
||||
def all_nodes(self) -> List[MessageNode]:
|
||||
"""Get all nodes in the tree."""
|
||||
return list(self._nodes.values())
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node exists in this tree."""
|
||||
return node_id in self._nodes
|
||||
|
||||
def find_node_by_status_message(self, status_msg_id: str) -> Optional[MessageNode]:
|
||||
"""Find the node that has this status message ID (O(1) lookup)."""
|
||||
node_id = self._status_to_node.get(status_msg_id)
|
||||
return self._nodes.get(node_id) if node_id else None
|
||||
|
||||
def get_descendants(self, node_id: str) -> List[str]:
|
||||
"""
|
||||
Get node_id and all descendant IDs (subtree).
|
||||
|
||||
Returns:
|
||||
List of node IDs including the given node.
|
||||
"""
|
||||
if node_id not in self._nodes:
|
||||
return []
|
||||
result: List[str] = []
|
||||
stack = [node_id]
|
||||
while stack:
|
||||
nid = stack.pop()
|
||||
result.append(nid)
|
||||
node = self._nodes.get(nid)
|
||||
if node:
|
||||
stack.extend(node.children_ids)
|
||||
return result
|
||||
|
||||
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Remove a subtree (branch_root and all descendants) from the tree.
|
||||
|
||||
Updates parent's children_ids. Caller must hold lock for consistency.
|
||||
Does not acquire lock internally.
|
||||
|
||||
Returns:
|
||||
List of removed nodes.
|
||||
"""
|
||||
if branch_root_id not in self._nodes:
|
||||
return []
|
||||
|
||||
parent = self.get_parent(branch_root_id)
|
||||
removed = []
|
||||
for nid in self.get_descendants(branch_root_id):
|
||||
node = self._nodes.get(nid)
|
||||
if node:
|
||||
removed.append(node)
|
||||
del self._nodes[nid]
|
||||
del self._status_to_node[node.status_message_id]
|
||||
|
||||
if parent and branch_root_id in parent.children_ids:
|
||||
parent.children_ids = [
|
||||
c for c in parent.children_ids if c != branch_root_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
|
||||
return removed
|
||||
__all__ = ["MessageTree", "MessageNode", "MessageState"]
|
||||
|
|
|
|||
|
|
@ -1,162 +1,5 @@
|
|||
"""Async queue processor for message trees.
|
||||
"""Backward-compatible re-export. Use messaging.trees.processor for new code."""
|
||||
|
||||
Handles the async processing lifecycle of tree nodes.
|
||||
"""
|
||||
from .trees.processor import TreeQueueProcessor
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Awaitable, Optional
|
||||
|
||||
from .tree_data import MessageTree, MessageNode, MessageState
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TreeQueueProcessor:
|
||||
"""
|
||||
Handles async queue processing for a single tree.
|
||||
|
||||
Separates the async processing logic from the data management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_update_callback: Optional[
|
||||
Callable[[MessageTree], Awaitable[None]]
|
||||
] = None,
|
||||
node_started_callback: Optional[
|
||||
Callable[[MessageTree, str], Awaitable[None]]
|
||||
] = None,
|
||||
):
|
||||
self._queue_update_callback = queue_update_callback
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
def set_queue_update_callback(
|
||||
self,
|
||||
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Update the callback used to refresh queue positions."""
|
||||
self._queue_update_callback = queue_update_callback
|
||||
|
||||
def set_node_started_callback(
|
||||
self,
|
||||
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Update the callback used when a queued node starts processing."""
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
async def _notify_queue_updated(self, tree: MessageTree) -> None:
|
||||
"""Invoke queue update callback if set."""
|
||||
if not self._queue_update_callback:
|
||||
return
|
||||
try:
|
||||
await self._queue_update_callback(tree)
|
||||
except Exception as e:
|
||||
logger.warning(f"Queue update callback failed: {e}")
|
||||
|
||||
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
|
||||
"""Invoke node started callback if set."""
|
||||
if not self._node_started_callback:
|
||||
return
|
||||
try:
|
||||
await self._node_started_callback(tree, node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Node started callback failed: {e}")
|
||||
|
||||
async def process_node(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node: MessageNode,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a single node and then check the queue."""
|
||||
# Skip if already in terminal state (e.g. from error propagation)
|
||||
if node.state.value == MessageState.ERROR.value:
|
||||
logger.info(
|
||||
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
||||
)
|
||||
# Still need to check for next messages
|
||||
await self._process_next(tree, processor)
|
||||
return
|
||||
|
||||
try:
|
||||
await processor(node.node_id, node)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task for node {node.node_id} was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node.node_id}: {e}")
|
||||
await tree.update_state(
|
||||
node.node_id, MessageState.ERROR, error_message=str(e)
|
||||
)
|
||||
finally:
|
||||
tree.clear_current_node()
|
||||
# Check if there are more messages in the queue
|
||||
await self._process_next(tree, processor)
|
||||
|
||||
async def _process_next(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process the next message in queue, if any."""
|
||||
next_node_id = None
|
||||
node = None
|
||||
async with tree.with_lock():
|
||||
next_node_id = await tree.dequeue()
|
||||
|
||||
if not next_node_id:
|
||||
tree.set_processing_state(None, False)
|
||||
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
|
||||
return
|
||||
|
||||
tree.set_processing_state(next_node_id, True)
|
||||
logger.info(f"Processing next queued node {next_node_id}")
|
||||
|
||||
# Process next node (outside lock)
|
||||
node = tree.get_node(next_node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
|
||||
# Notify that this node has started processing and refresh queue positions.
|
||||
if next_node_id:
|
||||
await self._notify_node_started(tree, next_node_id)
|
||||
await self._notify_queue_updated(tree)
|
||||
|
||||
async def enqueue_and_start(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node_id: str,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Enqueue a node or start processing immediately.
|
||||
|
||||
Args:
|
||||
tree: The message tree
|
||||
node_id: Node to process
|
||||
processor: Async function to process the node
|
||||
|
||||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
async with tree.with_lock():
|
||||
if tree.is_processing:
|
||||
tree.put_queue_unlocked(node_id)
|
||||
queue_size = tree.get_queue_size()
|
||||
logger.info(f"Queued node {node_id}, position {queue_size}")
|
||||
return True
|
||||
else:
|
||||
tree.set_processing_state(node_id, True)
|
||||
|
||||
# Process outside the lock
|
||||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
return False
|
||||
|
||||
def cancel_current(self, tree: MessageTree) -> bool:
|
||||
"""Cancel the currently running task in a tree."""
|
||||
return tree.cancel_current_task()
|
||||
__all__ = ["TreeQueueProcessor"]
|
||||
|
|
|
|||
|
|
@ -1,461 +1,10 @@
|
|||
"""Tree-Based Message Queue Manager - Refactored.
|
||||
"""Backward-compatible re-export. Use messaging.trees.queue_manager for new code."""
|
||||
|
||||
Coordinates data access, async processing, and error handling.
|
||||
Uses TreeRepository for data, TreeQueueProcessor for async logic.
|
||||
"""
|
||||
from .trees.queue_manager import (
|
||||
TreeQueueManager,
|
||||
MessageTree,
|
||||
MessageNode,
|
||||
MessageState,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Awaitable, List, Optional
|
||||
|
||||
from .models import IncomingMessage
|
||||
from .tree_data import MessageState, MessageNode, MessageTree
|
||||
from .tree_repository import TreeRepository
|
||||
from .tree_processor import TreeQueueProcessor
|
||||
from loguru import logger
|
||||
|
||||
# Backward compatibility: re-export moved classes
|
||||
__all__ = [
|
||||
"TreeQueueManager",
|
||||
"MessageState",
|
||||
"MessageNode",
|
||||
"MessageTree",
|
||||
]
|
||||
|
||||
|
||||
class TreeQueueManager:
|
||||
"""
|
||||
Manages multiple message trees. Facade that coordinates components.
|
||||
|
||||
Each new conversation creates a new tree.
|
||||
Replies to existing messages add nodes to existing trees.
|
||||
|
||||
Components:
|
||||
- TreeRepository: Data access layer
|
||||
- TreeQueueProcessor: Async queue processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_update_callback: Optional[
|
||||
Callable[[MessageTree], Awaitable[None]]
|
||||
] = None,
|
||||
node_started_callback: Optional[
|
||||
Callable[[MessageTree, str], Awaitable[None]]
|
||||
] = None,
|
||||
):
|
||||
self._repository = TreeRepository()
|
||||
self._processor = TreeQueueProcessor(
|
||||
queue_update_callback=queue_update_callback,
|
||||
node_started_callback=node_started_callback,
|
||||
)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("TreeQueueManager initialized")
|
||||
|
||||
async def create_tree(
|
||||
self,
|
||||
node_id: str,
|
||||
incoming: IncomingMessage,
|
||||
status_message_id: str,
|
||||
) -> MessageTree:
|
||||
"""
|
||||
Create a new tree with a root node.
|
||||
|
||||
Args:
|
||||
node_id: ID for the root node
|
||||
incoming: The incoming message
|
||||
status_message_id: Bot's status message ID
|
||||
|
||||
Returns:
|
||||
The created MessageTree
|
||||
"""
|
||||
async with self._lock:
|
||||
root_node = MessageNode(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
status_message_id=status_message_id,
|
||||
state=MessageState.PENDING,
|
||||
)
|
||||
|
||||
tree = MessageTree(root_node)
|
||||
self._repository.add_tree(node_id, tree)
|
||||
|
||||
logger.info(f"Created new tree with root {node_id}")
|
||||
return tree
|
||||
|
||||
async def add_to_tree(
|
||||
self,
|
||||
parent_node_id: str,
|
||||
node_id: str,
|
||||
incoming: IncomingMessage,
|
||||
status_message_id: str,
|
||||
) -> tuple[MessageTree, MessageNode]:
|
||||
"""
|
||||
Add a reply as a child node to an existing tree.
|
||||
|
||||
Args:
|
||||
parent_node_id: ID of the parent message
|
||||
node_id: ID for the new node
|
||||
incoming: The incoming reply message
|
||||
status_message_id: Bot's status message ID
|
||||
|
||||
Returns:
|
||||
Tuple of (tree, new_node)
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._repository.has_node(parent_node_id):
|
||||
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
|
||||
|
||||
tree = self._repository.get_tree_for_node(parent_node_id)
|
||||
if not tree:
|
||||
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
|
||||
|
||||
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
|
||||
node = await tree.add_node(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
status_message_id=status_message_id,
|
||||
parent_id=parent_node_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._repository.register_node(node_id, tree.root_id)
|
||||
|
||||
logger.info(f"Added node {node_id} to tree {tree.root_id}")
|
||||
return tree, node
|
||||
|
||||
def get_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""Get a tree by its root ID."""
|
||||
return self._repository.get_tree(root_id)
|
||||
|
||||
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
|
||||
"""Get the tree containing a given node."""
|
||||
return self._repository.get_tree_for_node(node_id)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get a node from any tree."""
|
||||
return self._repository.get_node(node_id)
|
||||
|
||||
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
|
||||
"""Resolve a message ID to the actual parent node ID."""
|
||||
return self._repository.resolve_parent_node_id(msg_id)
|
||||
|
||||
def is_tree_busy(self, root_id: str) -> bool:
|
||||
"""Check if a tree is currently processing."""
|
||||
return self._repository.is_tree_busy(root_id)
|
||||
|
||||
def is_node_tree_busy(self, node_id: str) -> bool:
|
||||
"""Check if the tree containing a node is busy."""
|
||||
return self._repository.is_node_tree_busy(node_id)
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
node_id: str,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Enqueue a node for processing.
|
||||
|
||||
If the tree is not busy, processing starts immediately.
|
||||
If busy, the message is queued.
|
||||
|
||||
Args:
|
||||
node_id: Node to process
|
||||
processor: Async function to process the node
|
||||
|
||||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
logger.error(f"No tree found for node {node_id}")
|
||||
return False
|
||||
|
||||
return await self._processor.enqueue_and_start(tree, node_id, processor)
|
||||
|
||||
def get_queue_size(self, node_id: str) -> int:
|
||||
"""Get queue size for the tree containing a node."""
|
||||
return self._repository.get_queue_size(node_id)
|
||||
|
||||
def get_pending_children(self, node_id: str) -> List[MessageNode]:
|
||||
"""Get all pending child nodes (recursively) of a given node."""
|
||||
return self._repository.get_pending_children(node_id)
|
||||
|
||||
async def mark_node_error(
|
||||
self,
|
||||
node_id: str,
|
||||
error_message: str,
|
||||
propagate_to_children: bool = True,
|
||||
) -> List[MessageNode]:
|
||||
"""
|
||||
Mark a node as ERROR and optionally propagate to pending children.
|
||||
|
||||
Args:
|
||||
node_id: The node to mark as error
|
||||
error_message: Error description
|
||||
propagate_to_children: If True, also mark pending children as error
|
||||
|
||||
Returns:
|
||||
List of all nodes marked as error (including children)
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
affected = []
|
||||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
await tree.update_state(
|
||||
node_id, MessageState.ERROR, error_message=error_message
|
||||
)
|
||||
affected.append(node)
|
||||
|
||||
if propagate_to_children:
|
||||
pending_children = self._repository.get_pending_children(node_id)
|
||||
for child in pending_children:
|
||||
await tree.update_state(
|
||||
child.node_id,
|
||||
MessageState.ERROR,
|
||||
error_message=f"Parent failed: {error_message}",
|
||||
)
|
||||
affected.append(child)
|
||||
|
||||
return affected
|
||||
|
||||
def cancel_tree(self, root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all queued and in-progress messages in a tree.
|
||||
|
||||
Updates node states to ERROR and returns list of affected nodes
|
||||
that were actually active or in the current processing queue.
|
||||
"""
|
||||
tree = self._repository.get_tree(root_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
cancelled_nodes = []
|
||||
|
||||
# 1. Cancel running task
|
||||
if tree.cancel_current_task():
|
||||
current_id = tree.current_node_id
|
||||
if current_id:
|
||||
node = tree.get_node(current_id)
|
||||
if node and node.state not in (
|
||||
MessageState.COMPLETED,
|
||||
MessageState.ERROR,
|
||||
):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
cancelled_nodes.append(node)
|
||||
|
||||
# 2. Drain queue and mark nodes as cancelled
|
||||
queue_nodes = tree.drain_queue_and_mark_cancelled()
|
||||
cancelled_nodes.extend(queue_nodes)
|
||||
cancelled_ids = {n.node_id for n in cancelled_nodes}
|
||||
|
||||
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
|
||||
cleanup_count = 0
|
||||
for node in tree.all_nodes():
|
||||
if (
|
||||
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
||||
and node.node_id not in cancelled_ids
|
||||
):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Stale task cleaned up"
|
||||
cleanup_count += 1
|
||||
|
||||
tree.reset_processing_state()
|
||||
|
||||
if cancelled_nodes:
|
||||
logger.info(
|
||||
f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
|
||||
)
|
||||
if cleanup_count:
|
||||
logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")
|
||||
|
||||
return cancelled_nodes
|
||||
|
||||
async def cancel_node(self, node_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel a single node (queued or in-progress) without affecting other nodes.
|
||||
|
||||
- If the node is currently running, cancels the current asyncio task.
|
||||
- If the node is queued, removes it from the queue.
|
||||
- Marks the node as ERROR with "Cancelled by user".
|
||||
|
||||
Returns:
|
||||
List containing the cancelled node if it was cancellable, else empty list.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
async with tree.with_lock():
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
||||
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
return []
|
||||
|
||||
if tree.is_current_node(node_id):
|
||||
self._processor.cancel_current(tree)
|
||||
|
||||
try:
|
||||
tree.remove_from_queue(node_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to remove node from queue; will rely on state=ERROR"
|
||||
)
|
||||
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
return [node]
|
||||
|
||||
async def cancel_all(self) -> List[MessageNode]:
|
||||
"""Cancel all messages in all trees (async wrapper)."""
|
||||
async with self._lock:
|
||||
return self.cancel_all_sync()
|
||||
|
||||
def cancel_all_sync(self) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all messages in all trees (synchronous/locked version).
|
||||
NOTE: Must be called with self._lock held.
|
||||
"""
|
||||
all_cancelled = []
|
||||
for root_id in list(self._repository.tree_ids()):
|
||||
all_cancelled.extend(self.cancel_tree(root_id))
|
||||
return all_cancelled
|
||||
|
||||
def cleanup_stale_nodes(self) -> int:
|
||||
"""
|
||||
Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
|
||||
Used on startup to reconcile restored state.
|
||||
"""
|
||||
count = 0
|
||||
for tree in self._repository.all_trees():
|
||||
for node in tree.all_nodes():
|
||||
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Lost during server restart"
|
||||
count += 1
|
||||
if count:
|
||||
logger.info(f"Cleaned up {count} stale nodes during startup")
|
||||
return count
|
||||
|
||||
def get_tree_count(self) -> int:
|
||||
"""Get the number of active message trees."""
|
||||
return self._repository.tree_count()
|
||||
|
||||
def set_queue_update_callback(
|
||||
self,
|
||||
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Set callback for queue position updates."""
|
||||
self._processor.set_queue_update_callback(queue_update_callback)
|
||||
|
||||
def set_node_started_callback(
|
||||
self,
|
||||
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Set callback for when a queued node starts processing."""
|
||||
self._processor.set_node_started_callback(node_started_callback)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree (for external mapping)."""
|
||||
self._repository.register_node(node_id, root_id)
|
||||
|
||||
async def cancel_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
|
||||
|
||||
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
branch_ids = set(tree.get_descendants(branch_root_id))
|
||||
cancelled: List[MessageNode] = []
|
||||
|
||||
async with tree.with_lock():
|
||||
for nid in branch_ids:
|
||||
node = tree.get_node(nid)
|
||||
if not node or node.state in (
|
||||
MessageState.COMPLETED,
|
||||
MessageState.ERROR,
|
||||
):
|
||||
continue
|
||||
|
||||
if tree.is_current_node(nid):
|
||||
self._processor.cancel_current(tree)
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
cancelled.append(node)
|
||||
else:
|
||||
tree.remove_from_queue(nid)
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
cancelled.append(node)
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
|
||||
return cancelled
|
||||
|
||||
async def remove_branch(
|
||||
self, branch_root_id: str
|
||||
) -> tuple[List[MessageNode], str, bool]:
|
||||
"""
|
||||
Remove a branch (subtree) from the tree.
|
||||
|
||||
If branch_root is the tree root, removes the entire tree.
|
||||
|
||||
Returns:
|
||||
(removed_nodes, root_id, removed_entire_tree)
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return ([], "", False)
|
||||
|
||||
root_id = tree.root_id
|
||||
|
||||
if branch_root_id == root_id:
|
||||
cancelled = self.cancel_tree(root_id)
|
||||
removed_tree = self._repository.remove_tree(root_id)
|
||||
if removed_tree:
|
||||
return (removed_tree.all_nodes(), root_id, True)
|
||||
return (cancelled, root_id, True)
|
||||
|
||||
async with tree.with_lock():
|
||||
removed = tree.remove_branch(branch_root_id)
|
||||
|
||||
self._repository.unregister_nodes([n.node_id for n in removed])
|
||||
return (removed, root_id, False)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return self._repository.to_dict()
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: dict,
|
||||
queue_update_callback: Optional[
|
||||
Callable[[MessageTree], Awaitable[None]]
|
||||
] = None,
|
||||
node_started_callback: Optional[
|
||||
Callable[[MessageTree, str], Awaitable[None]]
|
||||
] = None,
|
||||
) -> "TreeQueueManager":
|
||||
"""Deserialize from dictionary."""
|
||||
manager = cls(
|
||||
queue_update_callback=queue_update_callback,
|
||||
node_started_callback=node_started_callback,
|
||||
)
|
||||
manager._repository = TreeRepository.from_dict(data)
|
||||
return manager
|
||||
__all__ = ["TreeQueueManager", "MessageTree", "MessageNode", "MessageState"]
|
||||
|
|
|
|||
|
|
@ -1,168 +1,5 @@
|
|||
"""Repository for message tree data access.
|
||||
"""Backward-compatible re-export. Use messaging.trees.repository for new code."""
|
||||
|
||||
Provides data access layer for managing trees and node mappings.
|
||||
"""
|
||||
from .trees.repository import TreeRepository
|
||||
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .tree_data import MessageTree, MessageNode, MessageState
|
||||
|
||||
|
||||
class TreeRepository:
|
||||
"""
|
||||
Repository for message tree data access.
|
||||
|
||||
Manages the storage and lookup of trees and node-to-tree mappings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._trees: Dict[str, MessageTree] = {} # root_id -> tree
|
||||
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
|
||||
|
||||
def get_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""Get a tree by its root ID."""
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
|
||||
"""Get the tree containing a given node."""
|
||||
root_id = self._node_to_tree.get(node_id)
|
||||
if not root_id:
|
||||
return None
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get a node from any tree."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_node(node_id) if tree else None
|
||||
|
||||
def add_tree(self, root_id: str, tree: MessageTree) -> None:
|
||||
"""Add a new tree to the repository."""
|
||||
self._trees[root_id] = tree
|
||||
self._node_to_tree[root_id] = root_id
|
||||
logger.debug("TREE_REPO: add_tree root_id=%s", root_id)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree."""
|
||||
self._node_to_tree[node_id] = root_id
|
||||
logger.debug("TREE_REPO: register_node node_id=%s root_id=%s", node_id, root_id)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node is registered in any tree."""
|
||||
return node_id in self._node_to_tree
|
||||
|
||||
def tree_count(self) -> int:
|
||||
"""Get the number of trees in the repository."""
|
||||
return len(self._trees)
|
||||
|
||||
def is_tree_busy(self, root_id: str) -> bool:
|
||||
"""Check if a tree is currently processing."""
|
||||
tree = self._trees.get(root_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def is_node_tree_busy(self, node_id: str) -> bool:
|
||||
"""Check if the tree containing a node is busy."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def get_queue_size(self, node_id: str) -> int:
|
||||
"""Get queue size for the tree containing a node."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_queue_size() if tree else 0
|
||||
|
||||
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a message ID to the actual parent node ID.
|
||||
|
||||
Handles the case where msg_id is a status message ID
|
||||
(which maps to the tree but isn't an actual node).
|
||||
|
||||
Returns:
|
||||
The node_id to use as parent, or None if not found
|
||||
"""
|
||||
tree = self.get_tree_for_node(msg_id)
|
||||
if not tree:
|
||||
return None
|
||||
|
||||
# Check if msg_id is an actual node
|
||||
if tree.has_node(msg_id):
|
||||
return msg_id
|
||||
|
||||
# Otherwise, it might be a status message - find the owning node
|
||||
node = tree.find_node_by_status_message(msg_id)
|
||||
if node:
|
||||
return node.node_id
|
||||
|
||||
return None
|
||||
|
||||
def get_pending_children(self, node_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Get all pending child nodes (recursively) of a given node.
|
||||
|
||||
Used for error propagation - when a node fails, its pending
|
||||
children should also be marked as failed.
|
||||
"""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
pending = []
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
||||
for child_id in node.children_ids:
|
||||
child = tree.get_node(child_id)
|
||||
if child and child.state == MessageState.PENDING:
|
||||
pending.append(child)
|
||||
# Recursively get children of pending children
|
||||
pending.extend(self.get_pending_children(child_id))
|
||||
|
||||
return pending
|
||||
|
||||
def all_trees(self) -> List[MessageTree]:
|
||||
"""Get all trees in the repository."""
|
||||
return list(self._trees.values())
|
||||
|
||||
def tree_ids(self) -> List[str]:
|
||||
"""Get all tree root IDs."""
|
||||
return list(self._trees.keys())
|
||||
|
||||
def unregister_nodes(self, node_ids: List[str]) -> None:
|
||||
"""Remove node IDs from the node-to-tree mapping."""
|
||||
for nid in node_ids:
|
||||
self._node_to_tree.pop(nid, None)
|
||||
|
||||
def remove_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""
|
||||
Remove a tree and all its node mappings from the repository.
|
||||
|
||||
Returns:
|
||||
The removed tree, or None if not found.
|
||||
"""
|
||||
tree = self._trees.pop(root_id, None)
|
||||
if not tree:
|
||||
return None
|
||||
for node in tree.all_nodes():
|
||||
self._node_to_tree.pop(node.node_id, None)
|
||||
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
|
||||
return tree
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return {
|
||||
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
|
||||
"node_to_tree": self._node_to_tree.copy(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "TreeRepository":
|
||||
"""Deserialize from dictionary."""
|
||||
from .tree_data import MessageTree
|
||||
|
||||
repo = cls()
|
||||
for root_id, tree_data in data.get("trees", {}).items():
|
||||
repo._trees[root_id] = MessageTree.from_dict(tree_data)
|
||||
repo._node_to_tree = data.get("node_to_tree", {})
|
||||
return repo
|
||||
__all__ = ["TreeRepository"]
|
||||
|
|
|
|||
11
messaging/trees/__init__.py
Normal file
11
messaging/trees/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Message tree data structures and queue management."""
|
||||
|
||||
from .data import MessageTree, MessageNode, MessageState
|
||||
from .queue_manager import TreeQueueManager
|
||||
|
||||
__all__ = [
|
||||
"TreeQueueManager",
|
||||
"MessageTree",
|
||||
"MessageNode",
|
||||
"MessageState",
|
||||
]
|
||||
441
messaging/trees/data.py
Normal file
441
messaging/trees/data.py
Normal file
|
|
@ -0,0 +1,441 @@
|
|||
"""Tree data structures for message queue.
|
||||
|
||||
Contains MessageState, MessageNode, and MessageTree classes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, List, Any, cast
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..models import IncomingMessage
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageState(Enum):
|
||||
"""State of a message node in the tree."""
|
||||
|
||||
PENDING = "pending" # Queued, waiting to be processed
|
||||
IN_PROGRESS = "in_progress" # Currently being processed by Claude
|
||||
COMPLETED = "completed" # Processing finished successfully
|
||||
ERROR = "error" # Processing failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageNode:
|
||||
"""
|
||||
A node in the message tree.
|
||||
|
||||
Each node represents a single message and tracks:
|
||||
- Its relationship to parent/children
|
||||
- Its processing state
|
||||
- Claude session information
|
||||
"""
|
||||
|
||||
node_id: str # Unique ID (typically message_id)
|
||||
incoming: IncomingMessage # The original message
|
||||
status_message_id: str # Bot's status message ID
|
||||
state: MessageState = MessageState.PENDING
|
||||
parent_id: Optional[str] = None # Parent node ID (None for root)
|
||||
session_id: Optional[str] = None # Claude session ID (forked from parent)
|
||||
children_ids: List[str] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
context: Any = None # Additional context if needed
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"incoming": {
|
||||
"text": self.incoming.text,
|
||||
"chat_id": self.incoming.chat_id,
|
||||
"user_id": self.incoming.user_id,
|
||||
"message_id": self.incoming.message_id,
|
||||
"platform": self.incoming.platform,
|
||||
"reply_to_message_id": self.incoming.reply_to_message_id,
|
||||
"username": self.incoming.username,
|
||||
},
|
||||
"status_message_id": self.status_message_id,
|
||||
"state": self.state.value,
|
||||
"parent_id": self.parent_id,
|
||||
"session_id": self.session_id,
|
||||
"children_ids": self.children_ids,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"error_message": self.error_message,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "MessageNode":
|
||||
"""Create from dictionary (JSON deserialization)."""
|
||||
incoming_data = data["incoming"]
|
||||
incoming = IncomingMessage(
|
||||
text=incoming_data["text"],
|
||||
chat_id=incoming_data["chat_id"],
|
||||
user_id=incoming_data["user_id"],
|
||||
message_id=incoming_data["message_id"],
|
||||
platform=incoming_data["platform"],
|
||||
reply_to_message_id=incoming_data.get("reply_to_message_id"),
|
||||
username=incoming_data.get("username"),
|
||||
)
|
||||
return cls(
|
||||
node_id=data["node_id"],
|
||||
incoming=incoming,
|
||||
status_message_id=data["status_message_id"],
|
||||
state=MessageState(data["state"]),
|
||||
parent_id=data.get("parent_id"),
|
||||
session_id=data.get("session_id"),
|
||||
children_ids=data.get("children_ids", []),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
completed_at=datetime.fromisoformat(data["completed_at"])
|
||||
if data.get("completed_at")
|
||||
else None,
|
||||
error_message=data.get("error_message"),
|
||||
)
|
||||
|
||||
|
||||
class MessageTree:
|
||||
"""
|
||||
A tree of message nodes with queue functionality.
|
||||
|
||||
Provides:
|
||||
- O(1) node lookup via hashmap
|
||||
- Per-tree message queue
|
||||
- Thread-safe operations via asyncio.Lock
|
||||
"""
|
||||
|
||||
def __init__(self, root_node: MessageNode):
|
||||
"""
|
||||
Initialize tree with a root node.
|
||||
|
||||
Args:
|
||||
root_node: The root message node
|
||||
"""
|
||||
self.root_id = root_node.node_id
|
||||
self._nodes: Dict[str, MessageNode] = {root_node.node_id: root_node}
|
||||
self._status_to_node: Dict[str, str] = {
|
||||
root_node.status_message_id: root_node.node_id
|
||||
}
|
||||
self._queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._is_processing = False
|
||||
self._current_node_id: Optional[str] = None
|
||||
self._current_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.debug(f"Created MessageTree with root {self.root_id}")
|
||||
|
||||
def set_current_task(self, task: Optional[asyncio.Task]) -> None:
|
||||
"""Set the current processing task. Caller must hold lock."""
|
||||
self._current_task = task
|
||||
|
||||
@property
|
||||
def is_processing(self) -> bool:
|
||||
"""Check if tree is currently processing a message."""
|
||||
return self._is_processing
|
||||
|
||||
async def add_node(
|
||||
self,
|
||||
node_id: str,
|
||||
incoming: IncomingMessage,
|
||||
status_message_id: str,
|
||||
parent_id: str,
|
||||
) -> MessageNode:
|
||||
"""
|
||||
Add a child node to the tree.
|
||||
|
||||
Args:
|
||||
node_id: Unique ID for the new node
|
||||
incoming: The incoming message
|
||||
status_message_id: Bot's status message ID
|
||||
parent_id: Parent node ID
|
||||
|
||||
Returns:
|
||||
The created MessageNode
|
||||
"""
|
||||
async with self._lock:
|
||||
if parent_id not in self._nodes:
|
||||
raise ValueError(f"Parent node {parent_id} not found in tree")
|
||||
|
||||
node = MessageNode(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
status_message_id=status_message_id,
|
||||
parent_id=parent_id,
|
||||
state=MessageState.PENDING,
|
||||
)
|
||||
|
||||
self._nodes[node_id] = node
|
||||
self._status_to_node[status_message_id] = node_id
|
||||
self._nodes[parent_id].children_ids.append(node_id)
|
||||
|
||||
logger.debug(f"Added node {node_id} as child of {parent_id}")
|
||||
return node
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get a node by ID (O(1) lookup)."""
|
||||
return self._nodes.get(node_id)
|
||||
|
||||
def get_root(self) -> MessageNode:
|
||||
"""Get the root node."""
|
||||
return self._nodes[self.root_id]
|
||||
|
||||
def get_children(self, node_id: str) -> List[MessageNode]:
|
||||
"""Get all child nodes of a given node."""
|
||||
node = self._nodes.get(node_id)
|
||||
if not node:
|
||||
return []
|
||||
return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
|
||||
|
||||
def get_parent(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get the parent node."""
|
||||
node = self._nodes.get(node_id)
|
||||
if not node or not node.parent_id:
|
||||
return None
|
||||
return self._nodes.get(node.parent_id)
|
||||
|
||||
def get_parent_session_id(self, node_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the parent's session ID for forking.
|
||||
|
||||
Returns None for root nodes.
|
||||
"""
|
||||
parent = self.get_parent(node_id)
|
||||
return parent.session_id if parent else None
|
||||
|
||||
async def update_state(
|
||||
self,
|
||||
node_id: str,
|
||||
state: MessageState,
|
||||
session_id: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update a node's state."""
|
||||
async with self._lock:
|
||||
node = self._nodes.get(node_id)
|
||||
if not node:
|
||||
logger.warning(f"Node {node_id} not found for state update")
|
||||
return
|
||||
|
||||
node.state = state
|
||||
if session_id:
|
||||
node.session_id = session_id
|
||||
if error_message:
|
||||
node.error_message = error_message
|
||||
if state in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.debug(f"Node {node_id} state -> {state.value}")
|
||||
|
||||
async def enqueue(self, node_id: str) -> int:
|
||||
"""
|
||||
Add a node to the processing queue.
|
||||
|
||||
Returns:
|
||||
Queue position (1-indexed)
|
||||
"""
|
||||
async with self._lock:
|
||||
await self._queue.put(node_id)
|
||||
position = self._queue.qsize()
|
||||
logger.debug(f"Enqueued node {node_id}, position {position}")
|
||||
return position
|
||||
|
||||
async def dequeue(self) -> Optional[str]:
|
||||
"""
|
||||
Get the next node ID from the queue.
|
||||
|
||||
Returns None if queue is empty.
|
||||
"""
|
||||
try:
|
||||
return self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
async def get_queue_snapshot(self) -> List[str]:
|
||||
"""
|
||||
Get a snapshot of the current queue order.
|
||||
|
||||
Returns:
|
||||
List of node IDs in FIFO order.
|
||||
"""
|
||||
async with self._lock:
|
||||
# Read internal deque directly to avoid mutating queue state.
|
||||
# Drain/put approach would inflate _unfinished_tasks without task_done().
|
||||
queue_deque = cast(deque, getattr(self._queue, "_queue"))
|
||||
return list(queue_deque)
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get number of messages waiting in queue."""
|
||||
return self._queue.qsize()
|
||||
|
||||
def remove_from_queue(self, node_id: str) -> bool:
|
||||
"""
|
||||
Remove node_id from the internal queue if present.
|
||||
|
||||
Caller must hold the tree lock (e.g. via with_lock).
|
||||
Returns True if node was removed, False if not in queue.
|
||||
|
||||
Note: asyncio.Queue has no built-in remove; we filter via the internal
|
||||
deque. O(n) in queue size; acceptable for typical tree queue sizes.
|
||||
"""
|
||||
queue_deque = cast(deque, getattr(self._queue, "_queue"))
|
||||
if node_id not in queue_deque:
|
||||
return False
|
||||
object.__setattr__(
|
||||
self._queue, "_queue", deque(x for x in queue_deque if x != node_id)
|
||||
)
|
||||
return True
|
||||
|
||||
@asynccontextmanager
|
||||
async def with_lock(self):
|
||||
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None:
|
||||
"""Set processing state. Caller must hold lock for consistency with queue operations."""
|
||||
self._is_processing = is_processing
|
||||
self._current_node_id = node_id if is_processing else None
|
||||
|
||||
def clear_current_node(self) -> None:
|
||||
"""Clear the currently processing node ID. Caller must hold lock."""
|
||||
self._current_node_id = None
|
||||
|
||||
def is_current_node(self, node_id: str) -> bool:
|
||||
"""Check if node_id is the currently processing node."""
|
||||
return self._current_node_id == node_id
|
||||
|
||||
def put_queue_unlocked(self, node_id: str) -> None:
|
||||
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
|
||||
self._queue.put_nowait(node_id)
|
||||
|
||||
def cancel_current_task(self) -> bool:
|
||||
"""Cancel the currently running task. Returns True if a task was cancelled."""
|
||||
if self._current_task and not self._current_task.done():
|
||||
self._current_task.cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
def drain_queue_and_mark_cancelled(
|
||||
self, error_message: str = "Cancelled by user"
|
||||
) -> List["MessageNode"]:
|
||||
"""
|
||||
Drain the queue, mark each node as ERROR, and return affected nodes.
|
||||
Does not acquire lock; caller must ensure no concurrent queue access.
|
||||
"""
|
||||
nodes: List[MessageNode] = []
|
||||
while True:
|
||||
try:
|
||||
node_id = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = error_message
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def reset_processing_state(self) -> None:
|
||||
"""Reset processing flags after cancel/cleanup."""
|
||||
self._is_processing = False
|
||||
self._current_node_id = None
|
||||
|
||||
@property
|
||||
def current_node_id(self) -> Optional[str]:
|
||||
"""Get the ID of the node currently being processed."""
|
||||
return self._current_node_id
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize tree to dictionary."""
|
||||
return {
|
||||
"root_id": self.root_id,
|
||||
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "MessageTree":
|
||||
"""Deserialize tree from dictionary."""
|
||||
root_id = data["root_id"]
|
||||
nodes_data = data["nodes"]
|
||||
|
||||
# Create root node first
|
||||
root_node = MessageNode.from_dict(nodes_data[root_id])
|
||||
tree = cls(root_node)
|
||||
|
||||
# Add remaining nodes and build status->node index
|
||||
for node_id, node_data in nodes_data.items():
|
||||
if node_id != root_id:
|
||||
node = MessageNode.from_dict(node_data)
|
||||
tree._nodes[node_id] = node
|
||||
tree._status_to_node[node.status_message_id] = node_id
|
||||
|
||||
return tree
|
||||
|
||||
def all_nodes(self) -> List[MessageNode]:
|
||||
"""Get all nodes in the tree."""
|
||||
return list(self._nodes.values())
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node exists in this tree."""
|
||||
return node_id in self._nodes
|
||||
|
||||
def find_node_by_status_message(self, status_msg_id: str) -> Optional[MessageNode]:
|
||||
"""Find the node that has this status message ID (O(1) lookup)."""
|
||||
node_id = self._status_to_node.get(status_msg_id)
|
||||
return self._nodes.get(node_id) if node_id else None
|
||||
|
||||
def get_descendants(self, node_id: str) -> List[str]:
|
||||
"""
|
||||
Get node_id and all descendant IDs (subtree).
|
||||
|
||||
Returns:
|
||||
List of node IDs including the given node.
|
||||
"""
|
||||
if node_id not in self._nodes:
|
||||
return []
|
||||
result: List[str] = []
|
||||
stack = [node_id]
|
||||
while stack:
|
||||
nid = stack.pop()
|
||||
result.append(nid)
|
||||
node = self._nodes.get(nid)
|
||||
if node:
|
||||
stack.extend(node.children_ids)
|
||||
return result
|
||||
|
||||
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Remove a subtree (branch_root and all descendants) from the tree.
|
||||
|
||||
Updates parent's children_ids. Caller must hold lock for consistency.
|
||||
Does not acquire lock internally.
|
||||
|
||||
Returns:
|
||||
List of removed nodes.
|
||||
"""
|
||||
if branch_root_id not in self._nodes:
|
||||
return []
|
||||
|
||||
parent = self.get_parent(branch_root_id)
|
||||
removed = []
|
||||
for nid in self.get_descendants(branch_root_id):
|
||||
node = self._nodes.get(nid)
|
||||
if node:
|
||||
removed.append(node)
|
||||
del self._nodes[nid]
|
||||
del self._status_to_node[node.status_message_id]
|
||||
|
||||
if parent and branch_root_id in parent.children_ids:
|
||||
parent.children_ids = [
|
||||
c for c in parent.children_ids if c != branch_root_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
|
||||
return removed
|
||||
162
messaging/trees/processor.py
Normal file
162
messaging/trees/processor.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Async queue processor for message trees.
|
||||
|
||||
Handles the async processing lifecycle of tree nodes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Awaitable, Optional
|
||||
|
||||
from .data import MessageTree, MessageNode, MessageState
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TreeQueueProcessor:
|
||||
"""
|
||||
Handles async queue processing for a single tree.
|
||||
|
||||
Separates the async processing logic from the data management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_update_callback: Optional[
|
||||
Callable[[MessageTree], Awaitable[None]]
|
||||
] = None,
|
||||
node_started_callback: Optional[
|
||||
Callable[[MessageTree, str], Awaitable[None]]
|
||||
] = None,
|
||||
):
|
||||
self._queue_update_callback = queue_update_callback
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
def set_queue_update_callback(
|
||||
self,
|
||||
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Update the callback used to refresh queue positions."""
|
||||
self._queue_update_callback = queue_update_callback
|
||||
|
||||
def set_node_started_callback(
|
||||
self,
|
||||
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Update the callback used when a queued node starts processing."""
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
async def _notify_queue_updated(self, tree: MessageTree) -> None:
|
||||
"""Invoke queue update callback if set."""
|
||||
if not self._queue_update_callback:
|
||||
return
|
||||
try:
|
||||
await self._queue_update_callback(tree)
|
||||
except Exception as e:
|
||||
logger.warning(f"Queue update callback failed: {e}")
|
||||
|
||||
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
|
||||
"""Invoke node started callback if set."""
|
||||
if not self._node_started_callback:
|
||||
return
|
||||
try:
|
||||
await self._node_started_callback(tree, node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Node started callback failed: {e}")
|
||||
|
||||
async def process_node(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node: MessageNode,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a single node and then check the queue."""
|
||||
# Skip if already in terminal state (e.g. from error propagation)
|
||||
if node.state.value == MessageState.ERROR.value:
|
||||
logger.info(
|
||||
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
||||
)
|
||||
# Still need to check for next messages
|
||||
await self._process_next(tree, processor)
|
||||
return
|
||||
|
||||
try:
|
||||
await processor(node.node_id, node)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task for node {node.node_id} was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node.node_id}: {e}")
|
||||
await tree.update_state(
|
||||
node.node_id, MessageState.ERROR, error_message=str(e)
|
||||
)
|
||||
finally:
|
||||
tree.clear_current_node()
|
||||
# Check if there are more messages in the queue
|
||||
await self._process_next(tree, processor)
|
||||
|
||||
async def _process_next(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process the next message in queue, if any."""
|
||||
next_node_id = None
|
||||
node = None
|
||||
async with tree.with_lock():
|
||||
next_node_id = await tree.dequeue()
|
||||
|
||||
if not next_node_id:
|
||||
tree.set_processing_state(None, False)
|
||||
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
|
||||
return
|
||||
|
||||
tree.set_processing_state(next_node_id, True)
|
||||
logger.info(f"Processing next queued node {next_node_id}")
|
||||
|
||||
# Process next node (outside lock)
|
||||
node = tree.get_node(next_node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
|
||||
# Notify that this node has started processing and refresh queue positions.
|
||||
if next_node_id:
|
||||
await self._notify_node_started(tree, next_node_id)
|
||||
await self._notify_queue_updated(tree)
|
||||
|
||||
async def enqueue_and_start(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node_id: str,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Enqueue a node or start processing immediately.
|
||||
|
||||
Args:
|
||||
tree: The message tree
|
||||
node_id: Node to process
|
||||
processor: Async function to process the node
|
||||
|
||||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
async with tree.with_lock():
|
||||
if tree.is_processing:
|
||||
tree.put_queue_unlocked(node_id)
|
||||
queue_size = tree.get_queue_size()
|
||||
logger.info(f"Queued node {node_id}, position {queue_size}")
|
||||
return True
|
||||
else:
|
||||
tree.set_processing_state(node_id, True)
|
||||
|
||||
# Process outside the lock
|
||||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
return False
|
||||
|
||||
def cancel_current(self, tree: MessageTree) -> bool:
|
||||
"""Cancel the currently running task in a tree."""
|
||||
return tree.cancel_current_task()
|
||||
461
messaging/trees/queue_manager.py
Normal file
461
messaging/trees/queue_manager.py
Normal file
|
|
@ -0,0 +1,461 @@
|
|||
"""Tree-Based Message Queue Manager - Refactored.
|
||||
|
||||
Coordinates data access, async processing, and error handling.
|
||||
Uses TreeRepository for data, TreeQueueProcessor for async logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Awaitable, List, Optional
|
||||
|
||||
from ..models import IncomingMessage
|
||||
from .data import MessageState, MessageNode, MessageTree
|
||||
from .repository import TreeRepository
|
||||
from .processor import TreeQueueProcessor
|
||||
from loguru import logger
|
||||
|
||||
# Backward compatibility: re-export moved classes
|
||||
__all__ = [
|
||||
"TreeQueueManager",
|
||||
"MessageState",
|
||||
"MessageNode",
|
||||
"MessageTree",
|
||||
]
|
||||
|
||||
|
||||
class TreeQueueManager:
|
||||
"""
|
||||
Manages multiple message trees. Facade that coordinates components.
|
||||
|
||||
Each new conversation creates a new tree.
|
||||
Replies to existing messages add nodes to existing trees.
|
||||
|
||||
Components:
|
||||
- TreeRepository: Data access layer
|
||||
- TreeQueueProcessor: Async queue processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_update_callback: Optional[
|
||||
Callable[[MessageTree], Awaitable[None]]
|
||||
] = None,
|
||||
node_started_callback: Optional[
|
||||
Callable[[MessageTree, str], Awaitable[None]]
|
||||
] = None,
|
||||
):
|
||||
self._repository = TreeRepository()
|
||||
self._processor = TreeQueueProcessor(
|
||||
queue_update_callback=queue_update_callback,
|
||||
node_started_callback=node_started_callback,
|
||||
)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("TreeQueueManager initialized")
|
||||
|
||||
async def create_tree(
|
||||
self,
|
||||
node_id: str,
|
||||
incoming: IncomingMessage,
|
||||
status_message_id: str,
|
||||
) -> MessageTree:
|
||||
"""
|
||||
Create a new tree with a root node.
|
||||
|
||||
Args:
|
||||
node_id: ID for the root node
|
||||
incoming: The incoming message
|
||||
status_message_id: Bot's status message ID
|
||||
|
||||
Returns:
|
||||
The created MessageTree
|
||||
"""
|
||||
async with self._lock:
|
||||
root_node = MessageNode(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
status_message_id=status_message_id,
|
||||
state=MessageState.PENDING,
|
||||
)
|
||||
|
||||
tree = MessageTree(root_node)
|
||||
self._repository.add_tree(node_id, tree)
|
||||
|
||||
logger.info(f"Created new tree with root {node_id}")
|
||||
return tree
|
||||
|
||||
async def add_to_tree(
|
||||
self,
|
||||
parent_node_id: str,
|
||||
node_id: str,
|
||||
incoming: IncomingMessage,
|
||||
status_message_id: str,
|
||||
) -> tuple[MessageTree, MessageNode]:
|
||||
"""
|
||||
Add a reply as a child node to an existing tree.
|
||||
|
||||
Args:
|
||||
parent_node_id: ID of the parent message
|
||||
node_id: ID for the new node
|
||||
incoming: The incoming reply message
|
||||
status_message_id: Bot's status message ID
|
||||
|
||||
Returns:
|
||||
Tuple of (tree, new_node)
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._repository.has_node(parent_node_id):
|
||||
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
|
||||
|
||||
tree = self._repository.get_tree_for_node(parent_node_id)
|
||||
if not tree:
|
||||
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
|
||||
|
||||
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
|
||||
node = await tree.add_node(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
status_message_id=status_message_id,
|
||||
parent_id=parent_node_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._repository.register_node(node_id, tree.root_id)
|
||||
|
||||
logger.info(f"Added node {node_id} to tree {tree.root_id}")
|
||||
return tree, node
|
||||
|
||||
def get_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""Get a tree by its root ID."""
|
||||
return self._repository.get_tree(root_id)
|
||||
|
||||
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
|
||||
"""Get the tree containing a given node."""
|
||||
return self._repository.get_tree_for_node(node_id)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get a node from any tree."""
|
||||
return self._repository.get_node(node_id)
|
||||
|
||||
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
|
||||
"""Resolve a message ID to the actual parent node ID."""
|
||||
return self._repository.resolve_parent_node_id(msg_id)
|
||||
|
||||
def is_tree_busy(self, root_id: str) -> bool:
|
||||
"""Check if a tree is currently processing."""
|
||||
return self._repository.is_tree_busy(root_id)
|
||||
|
||||
def is_node_tree_busy(self, node_id: str) -> bool:
|
||||
"""Check if the tree containing a node is busy."""
|
||||
return self._repository.is_node_tree_busy(node_id)
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
node_id: str,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Enqueue a node for processing.
|
||||
|
||||
If the tree is not busy, processing starts immediately.
|
||||
If busy, the message is queued.
|
||||
|
||||
Args:
|
||||
node_id: Node to process
|
||||
processor: Async function to process the node
|
||||
|
||||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
logger.error(f"No tree found for node {node_id}")
|
||||
return False
|
||||
|
||||
return await self._processor.enqueue_and_start(tree, node_id, processor)
|
||||
|
||||
def get_queue_size(self, node_id: str) -> int:
|
||||
"""Get queue size for the tree containing a node."""
|
||||
return self._repository.get_queue_size(node_id)
|
||||
|
||||
def get_pending_children(self, node_id: str) -> List[MessageNode]:
|
||||
"""Get all pending child nodes (recursively) of a given node."""
|
||||
return self._repository.get_pending_children(node_id)
|
||||
|
||||
async def mark_node_error(
|
||||
self,
|
||||
node_id: str,
|
||||
error_message: str,
|
||||
propagate_to_children: bool = True,
|
||||
) -> List[MessageNode]:
|
||||
"""
|
||||
Mark a node as ERROR and optionally propagate to pending children.
|
||||
|
||||
Args:
|
||||
node_id: The node to mark as error
|
||||
error_message: Error description
|
||||
propagate_to_children: If True, also mark pending children as error
|
||||
|
||||
Returns:
|
||||
List of all nodes marked as error (including children)
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
affected = []
|
||||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
await tree.update_state(
|
||||
node_id, MessageState.ERROR, error_message=error_message
|
||||
)
|
||||
affected.append(node)
|
||||
|
||||
if propagate_to_children:
|
||||
pending_children = self._repository.get_pending_children(node_id)
|
||||
for child in pending_children:
|
||||
await tree.update_state(
|
||||
child.node_id,
|
||||
MessageState.ERROR,
|
||||
error_message=f"Parent failed: {error_message}",
|
||||
)
|
||||
affected.append(child)
|
||||
|
||||
return affected
|
||||
|
||||
def cancel_tree(self, root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all queued and in-progress messages in a tree.
|
||||
|
||||
Updates node states to ERROR and returns list of affected nodes
|
||||
that were actually active or in the current processing queue.
|
||||
"""
|
||||
tree = self._repository.get_tree(root_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
cancelled_nodes = []
|
||||
|
||||
# 1. Cancel running task
|
||||
if tree.cancel_current_task():
|
||||
current_id = tree.current_node_id
|
||||
if current_id:
|
||||
node = tree.get_node(current_id)
|
||||
if node and node.state not in (
|
||||
MessageState.COMPLETED,
|
||||
MessageState.ERROR,
|
||||
):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
cancelled_nodes.append(node)
|
||||
|
||||
# 2. Drain queue and mark nodes as cancelled
|
||||
queue_nodes = tree.drain_queue_and_mark_cancelled()
|
||||
cancelled_nodes.extend(queue_nodes)
|
||||
cancelled_ids = {n.node_id for n in cancelled_nodes}
|
||||
|
||||
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
|
||||
cleanup_count = 0
|
||||
for node in tree.all_nodes():
|
||||
if (
|
||||
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
||||
and node.node_id not in cancelled_ids
|
||||
):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Stale task cleaned up"
|
||||
cleanup_count += 1
|
||||
|
||||
tree.reset_processing_state()
|
||||
|
||||
if cancelled_nodes:
|
||||
logger.info(
|
||||
f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
|
||||
)
|
||||
if cleanup_count:
|
||||
logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")
|
||||
|
||||
return cancelled_nodes
|
||||
|
||||
async def cancel_node(self, node_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel a single node (queued or in-progress) without affecting other nodes.
|
||||
|
||||
- If the node is currently running, cancels the current asyncio task.
|
||||
- If the node is queued, removes it from the queue.
|
||||
- Marks the node as ERROR with "Cancelled by user".
|
||||
|
||||
Returns:
|
||||
List containing the cancelled node if it was cancellable, else empty list.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
async with tree.with_lock():
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
||||
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
return []
|
||||
|
||||
if tree.is_current_node(node_id):
|
||||
self._processor.cancel_current(tree)
|
||||
|
||||
try:
|
||||
tree.remove_from_queue(node_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to remove node from queue; will rely on state=ERROR"
|
||||
)
|
||||
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
return [node]
|
||||
|
||||
async def cancel_all(self) -> List[MessageNode]:
|
||||
"""Cancel all messages in all trees (async wrapper)."""
|
||||
async with self._lock:
|
||||
return self.cancel_all_sync()
|
||||
|
||||
def cancel_all_sync(self) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all messages in all trees (synchronous/locked version).
|
||||
NOTE: Must be called with self._lock held.
|
||||
"""
|
||||
all_cancelled = []
|
||||
for root_id in list(self._repository.tree_ids()):
|
||||
all_cancelled.extend(self.cancel_tree(root_id))
|
||||
return all_cancelled
|
||||
|
||||
def cleanup_stale_nodes(self) -> int:
|
||||
"""
|
||||
Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
|
||||
Used on startup to reconcile restored state.
|
||||
"""
|
||||
count = 0
|
||||
for tree in self._repository.all_trees():
|
||||
for node in tree.all_nodes():
|
||||
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Lost during server restart"
|
||||
count += 1
|
||||
if count:
|
||||
logger.info(f"Cleaned up {count} stale nodes during startup")
|
||||
return count
|
||||
|
||||
def get_tree_count(self) -> int:
|
||||
"""Get the number of active message trees."""
|
||||
return self._repository.tree_count()
|
||||
|
||||
def set_queue_update_callback(
|
||||
self,
|
||||
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Set callback for queue position updates."""
|
||||
self._processor.set_queue_update_callback(queue_update_callback)
|
||||
|
||||
def set_node_started_callback(
|
||||
self,
|
||||
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
|
||||
) -> None:
|
||||
"""Set callback for when a queued node starts processing."""
|
||||
self._processor.set_node_started_callback(node_started_callback)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree (for external mapping)."""
|
||||
self._repository.register_node(node_id, root_id)
|
||||
|
||||
async def cancel_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
|
||||
|
||||
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
branch_ids = set(tree.get_descendants(branch_root_id))
|
||||
cancelled: List[MessageNode] = []
|
||||
|
||||
async with tree.with_lock():
|
||||
for nid in branch_ids:
|
||||
node = tree.get_node(nid)
|
||||
if not node or node.state in (
|
||||
MessageState.COMPLETED,
|
||||
MessageState.ERROR,
|
||||
):
|
||||
continue
|
||||
|
||||
if tree.is_current_node(nid):
|
||||
self._processor.cancel_current(tree)
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
cancelled.append(node)
|
||||
else:
|
||||
tree.remove_from_queue(nid)
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
cancelled.append(node)
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
|
||||
return cancelled
|
||||
|
||||
async def remove_branch(
|
||||
self, branch_root_id: str
|
||||
) -> tuple[List[MessageNode], str, bool]:
|
||||
"""
|
||||
Remove a branch (subtree) from the tree.
|
||||
|
||||
If branch_root is the tree root, removes the entire tree.
|
||||
|
||||
Returns:
|
||||
(removed_nodes, root_id, removed_entire_tree)
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return ([], "", False)
|
||||
|
||||
root_id = tree.root_id
|
||||
|
||||
if branch_root_id == root_id:
|
||||
cancelled = self.cancel_tree(root_id)
|
||||
removed_tree = self._repository.remove_tree(root_id)
|
||||
if removed_tree:
|
||||
return (removed_tree.all_nodes(), root_id, True)
|
||||
return (cancelled, root_id, True)
|
||||
|
||||
async with tree.with_lock():
|
||||
removed = tree.remove_branch(branch_root_id)
|
||||
|
||||
self._repository.unregister_nodes([n.node_id for n in removed])
|
||||
return (removed, root_id, False)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return self._repository.to_dict()
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: dict,
|
||||
queue_update_callback: Optional[
|
||||
Callable[[MessageTree], Awaitable[None]]
|
||||
] = None,
|
||||
node_started_callback: Optional[
|
||||
Callable[[MessageTree, str], Awaitable[None]]
|
||||
] = None,
|
||||
) -> "TreeQueueManager":
|
||||
"""Deserialize from dictionary."""
|
||||
manager = cls(
|
||||
queue_update_callback=queue_update_callback,
|
||||
node_started_callback=node_started_callback,
|
||||
)
|
||||
manager._repository = TreeRepository.from_dict(data)
|
||||
return manager
|
||||
168
messaging/trees/repository.py
Normal file
168
messaging/trees/repository.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Repository for message tree data access.
|
||||
|
||||
Provides data access layer for managing trees and node mappings.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .data import MessageTree, MessageNode, MessageState
|
||||
|
||||
|
||||
class TreeRepository:
|
||||
"""
|
||||
Repository for message tree data access.
|
||||
|
||||
Manages the storage and lookup of trees and node-to-tree mappings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._trees: Dict[str, MessageTree] = {} # root_id -> tree
|
||||
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
|
||||
|
||||
def get_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""Get a tree by its root ID."""
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
|
||||
"""Get the tree containing a given node."""
|
||||
root_id = self._node_to_tree.get(node_id)
|
||||
if not root_id:
|
||||
return None
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[MessageNode]:
|
||||
"""Get a node from any tree."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_node(node_id) if tree else None
|
||||
|
||||
def add_tree(self, root_id: str, tree: MessageTree) -> None:
|
||||
"""Add a new tree to the repository."""
|
||||
self._trees[root_id] = tree
|
||||
self._node_to_tree[root_id] = root_id
|
||||
logger.debug("TREE_REPO: add_tree root_id=%s", root_id)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree."""
|
||||
self._node_to_tree[node_id] = root_id
|
||||
logger.debug("TREE_REPO: register_node node_id=%s root_id=%s", node_id, root_id)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node is registered in any tree."""
|
||||
return node_id in self._node_to_tree
|
||||
|
||||
def tree_count(self) -> int:
|
||||
"""Get the number of trees in the repository."""
|
||||
return len(self._trees)
|
||||
|
||||
def is_tree_busy(self, root_id: str) -> bool:
|
||||
"""Check if a tree is currently processing."""
|
||||
tree = self._trees.get(root_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def is_node_tree_busy(self, node_id: str) -> bool:
|
||||
"""Check if the tree containing a node is busy."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def get_queue_size(self, node_id: str) -> int:
|
||||
"""Get queue size for the tree containing a node."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_queue_size() if tree else 0
|
||||
|
||||
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a message ID to the actual parent node ID.
|
||||
|
||||
Handles the case where msg_id is a status message ID
|
||||
(which maps to the tree but isn't an actual node).
|
||||
|
||||
Returns:
|
||||
The node_id to use as parent, or None if not found
|
||||
"""
|
||||
tree = self.get_tree_for_node(msg_id)
|
||||
if not tree:
|
||||
return None
|
||||
|
||||
# Check if msg_id is an actual node
|
||||
if tree.has_node(msg_id):
|
||||
return msg_id
|
||||
|
||||
# Otherwise, it might be a status message - find the owning node
|
||||
node = tree.find_node_by_status_message(msg_id)
|
||||
if node:
|
||||
return node.node_id
|
||||
|
||||
return None
|
||||
|
||||
def get_pending_children(self, node_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Get all pending child nodes (recursively) of a given node.
|
||||
|
||||
Used for error propagation - when a node fails, its pending
|
||||
children should also be marked as failed.
|
||||
"""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
pending = []
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
||||
for child_id in node.children_ids:
|
||||
child = tree.get_node(child_id)
|
||||
if child and child.state == MessageState.PENDING:
|
||||
pending.append(child)
|
||||
# Recursively get children of pending children
|
||||
pending.extend(self.get_pending_children(child_id))
|
||||
|
||||
return pending
|
||||
|
||||
def all_trees(self) -> List[MessageTree]:
|
||||
"""Get all trees in the repository."""
|
||||
return list(self._trees.values())
|
||||
|
||||
def tree_ids(self) -> List[str]:
|
||||
"""Get all tree root IDs."""
|
||||
return list(self._trees.keys())
|
||||
|
||||
def unregister_nodes(self, node_ids: List[str]) -> None:
|
||||
"""Remove node IDs from the node-to-tree mapping."""
|
||||
for nid in node_ids:
|
||||
self._node_to_tree.pop(nid, None)
|
||||
|
||||
def remove_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""
|
||||
Remove a tree and all its node mappings from the repository.
|
||||
|
||||
Returns:
|
||||
The removed tree, or None if not found.
|
||||
"""
|
||||
tree = self._trees.pop(root_id, None)
|
||||
if not tree:
|
||||
return None
|
||||
for node in tree.all_nodes():
|
||||
self._node_to_tree.pop(node.node_id, None)
|
||||
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
|
||||
return tree
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return {
|
||||
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
|
||||
"node_to_tree": self._node_to_tree.copy(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "TreeRepository":
|
||||
"""Deserialize from dictionary."""
|
||||
from .data import MessageTree
|
||||
|
||||
repo = cls()
|
||||
for root_id, tree_data in data.get("trees", {}).items():
|
||||
repo._trees[root_id] = MessageTree.from_dict(tree_data)
|
||||
repo._node_to_tree = data.get("node_to_tree", {})
|
||||
return repo
|
||||
|
|
@ -11,9 +11,10 @@ class TestCreateMessagingPlatform:
|
|||
def test_telegram_with_token(self):
|
||||
"""Create Telegram platform when bot_token is provided."""
|
||||
mock_platform = MagicMock()
|
||||
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
with patch(
|
||||
"messaging.telegram.TelegramPlatform", return_value=mock_platform
|
||||
"messaging.platforms.telegram.TelegramPlatform",
|
||||
return_value=mock_platform,
|
||||
):
|
||||
result = create_messaging_platform(
|
||||
"telegram",
|
||||
|
|
@ -36,8 +37,11 @@ class TestCreateMessagingPlatform:
|
|||
def test_discord_with_token(self):
|
||||
"""Create Discord platform when discord_bot_token is provided."""
|
||||
mock_platform = MagicMock()
|
||||
with patch("messaging.discord.DISCORD_AVAILABLE", True):
|
||||
with patch("messaging.discord.DiscordPlatform", return_value=mock_platform):
|
||||
with patch("messaging.platforms.discord.DISCORD_AVAILABLE", True):
|
||||
with patch(
|
||||
"messaging.platforms.discord.DiscordPlatform",
|
||||
return_value=mock_platform,
|
||||
):
|
||||
result = create_messaging_platform(
|
||||
"discord",
|
||||
discord_bot_token="test_token",
|
||||
Loading…
Add table
Add a link
Reference in a new issue