Phase 7: Directory restructuring (messaging/ and tests/)

- Create messaging/platforms/ (base, discord, telegram, factory)
- Create messaging/rendering/ (discord_markdown, telegram_markdown)
- Create messaging/trees/ (data, repository, processor, queue_manager)
- Organize tests/ into api/, providers/, messaging/, cli/, config/
- Add backward-compatible re-exports at old locations
- Update handler.py and test_messaging_factory.py imports
- Fix Telegram type hints for TELEGRAM_AVAILABLE=False case
- Fix Python 3 except syntax in discord_markdown

Co-authored-by: Ali Khokhar <alishahryar2@gmail.com>
This commit is contained in:
Cursor Agent 2026-02-17 02:25:42 +00:00
parent 38a7980546
commit 4b4f87515d
76 changed files with 3294 additions and 3124 deletions

View file

@ -1,219 +1,9 @@
"""Abstract base class for messaging platforms."""
"""Backward-compatible re-export. Use messaging.platforms.base for new code."""
from abc import ABC, abstractmethod
from typing import (
Callable,
Awaitable,
Optional,
Protocol,
Tuple,
runtime_checkable,
AsyncGenerator,
Any,
Dict,
from .platforms.base import (
MessagingPlatform,
SessionManagerInterface,
CLISession,
)
from .models import IncomingMessage
@runtime_checkable
class CLISession(Protocol):
"""Protocol for CLI session - avoid circular import from cli package."""
def start_task(
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
) -> AsyncGenerator[Dict, Any]:
"""Start a task in the CLI session."""
...
@property
@abstractmethod
def is_busy(self) -> bool:
"""Check if session is busy."""
pass
@runtime_checkable
class SessionManagerInterface(Protocol):
"""
Protocol for session managers to avoid tight coupling with cli package.
Implementations: CLISessionManager
"""
async def get_or_create_session(
self, session_id: Optional[str] = None
) -> Tuple[CLISession, str, bool]:
"""
Get an existing session or create a new one.
Returns: Tuple of (session, session_id, is_new_session)
"""
...
async def register_real_session_id(
self, temp_id: str, real_session_id: str
) -> bool:
"""Register the real session ID from CLI output."""
...
async def stop_all(self) -> None:
"""Stop all sessions."""
...
async def remove_session(self, session_id: str) -> bool:
"""Remove a session from the manager."""
...
def get_stats(self) -> dict:
"""Get session statistics."""
...
class MessagingPlatform(ABC):
"""
Base class for all messaging platform adapters.
Implement this to add support for Telegram, Discord, Slack, etc.
"""
name: str = "base"
@abstractmethod
async def start(self) -> None:
"""Initialize and connect to the messaging platform."""
pass
@abstractmethod
async def stop(self) -> None:
"""Disconnect and cleanup resources."""
pass
@abstractmethod
async def send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
) -> str:
"""
Send a message to a chat.
Args:
chat_id: The chat/channel ID to send to
text: Message content
reply_to: Optional message ID to reply to
parse_mode: Optional formatting mode ("markdown", "html")
Returns:
The message ID of the sent message
"""
pass
@abstractmethod
async def edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
) -> None:
"""
Edit an existing message.
Args:
chat_id: The chat/channel ID
message_id: The message ID to edit
text: New message content
parse_mode: Optional formatting mode
"""
pass
@abstractmethod
async def delete_message(
self,
chat_id: str,
message_id: str,
) -> None:
"""
Delete a message from a chat.
Args:
chat_id: The chat/channel ID
message_id: The message ID to delete
"""
pass
@abstractmethod
async def queue_send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> Optional[str]:
"""
Enqueue a message to be sent.
If fire_and_forget is True, returns None immediately.
Otherwise, waits for the rate limiter and returns message ID.
"""
pass
@abstractmethod
async def queue_edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> None:
"""
Enqueue a message edit.
If fire_and_forget is True, returns immediately.
Otherwise, waits for the rate limiter.
"""
pass
@abstractmethod
async def queue_delete_message(
self,
chat_id: str,
message_id: str,
fire_and_forget: bool = True,
) -> None:
"""
Enqueue a message deletion.
If fire_and_forget is True, returns immediately.
Otherwise, waits for the rate limiter.
"""
pass
@abstractmethod
def on_message(
self,
handler: Callable[[IncomingMessage], Awaitable[None]],
) -> None:
"""
Register a message handler callback.
The handler will be called for each incoming message.
Args:
handler: Async function that processes incoming messages
"""
pass
@abstractmethod
def fire_and_forget(self, task: Awaitable[Any]) -> None:
"""Execute a coroutine without awaiting it."""
pass
@property
def is_connected(self) -> bool:
"""Check if the platform is connected."""
return False
__all__ = ["MessagingPlatform", "SessionManagerInterface", "CLISession"]

View file

@ -1,394 +1,9 @@
"""
Discord Platform Adapter
"""Backward-compatible re-export. Use messaging.platforms.discord for new code."""
Implements MessagingPlatform for Discord using discord.py.
"""
from .platforms.discord import (
DiscordPlatform,
DISCORD_AVAILABLE,
DISCORD_MESSAGE_LIMIT,
)
import asyncio
import os
from typing import Callable, Awaitable, Optional, Any, Set, cast
from loguru import logger
from .base import MessagingPlatform
from .models import IncomingMessage
from .discord_markdown import format_status_discord
_discord_module: Any = None
try:
import discord as _discord_import
_discord_module = _discord_import
DISCORD_AVAILABLE = True
except ImportError:
DISCORD_AVAILABLE = False
DISCORD_MESSAGE_LIMIT = 2000
def _get_discord() -> Any:
"""Return the discord module. Raises if not available."""
if not DISCORD_AVAILABLE or _discord_module is None:
raise ImportError(
"discord.py is required. Install with: pip install discord.py"
)
return _discord_module
def _parse_allowed_channels(raw: Optional[str]) -> Set[str]:
"""Parse comma-separated channel IDs into a set of strings."""
if not raw or not raw.strip():
return set()
return {s.strip() for s in raw.split(",") if s.strip()}
if DISCORD_AVAILABLE and _discord_module is not None:
_discord = _discord_module
class _DiscordClient(_discord.Client):
"""Internal Discord client that forwards events to DiscordPlatform."""
def __init__(
self,
platform: "DiscordPlatform",
intents: _discord.Intents,
) -> None:
super().__init__(intents=intents)
self._platform = platform
async def on_ready(self) -> None:
"""Called when the bot is ready."""
self._platform._connected = True
logger.info("Discord platform connected")
async def on_message(self, message: Any) -> None:
"""Handle incoming Discord messages."""
await self._platform._on_discord_message(message)
else:
_DiscordClient = None
class DiscordPlatform(MessagingPlatform):
"""
Discord messaging platform adapter.
Uses discord.py for Discord access.
Requires a Bot Token from Discord Developer Portal and message_content intent.
"""
name = "discord"
def __init__(
self,
bot_token: Optional[str] = None,
allowed_channel_ids: Optional[str] = None,
):
if not DISCORD_AVAILABLE:
raise ImportError(
"discord.py is required. Install with: pip install discord.py"
)
self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
if not self.bot_token:
logger.warning("DISCORD_BOT_TOKEN not set")
discord = _get_discord()
intents = discord.Intents.default()
intents.message_content = True
assert _DiscordClient is not None
self._client = _DiscordClient(self, intents)
self._message_handler: Optional[
Callable[[IncomingMessage], Awaitable[None]]
] = None
self._connected = False
self._limiter: Optional[Any] = None
self._start_task: Optional[asyncio.Task] = None
async def _on_discord_message(self, message: Any) -> None:
"""Handle incoming Discord messages."""
if message.author.bot:
return
if not message.content:
return
channel_id = str(message.channel.id)
if not self.allowed_channel_ids or channel_id not in self.allowed_channel_ids:
return
user_id = str(message.author.id)
message_id = str(message.id)
reply_to = (
str(message.reference.message_id)
if message.reference and message.reference.message_id
else None
)
text_preview = (message.content or "")[:80]
if len(message.content or "") > 80:
text_preview += "..."
logger.info(
"DISCORD_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
channel_id,
message_id,
reply_to,
text_preview,
)
if not self._message_handler:
return
incoming = IncomingMessage(
text=message.content,
chat_id=channel_id,
user_id=user_id,
message_id=message_id,
platform="discord",
reply_to_message_id=reply_to,
username=message.author.display_name,
raw_event=message,
)
try:
await self._message_handler(incoming)
except Exception as e:
logger.error(f"Error handling message: {e}")
try:
await self.send_message(
channel_id,
format_status_discord("Error:", str(e)[:200]),
reply_to=message_id,
)
except Exception:
pass
def _truncate(self, text: str, limit: int = DISCORD_MESSAGE_LIMIT) -> str:
"""Truncate text to Discord's message limit."""
if len(text) <= limit:
return text
return text[: limit - 3] + "..."
async def start(self) -> None:
"""Initialize and connect to Discord."""
if not self.bot_token:
raise ValueError("DISCORD_BOT_TOKEN is required")
from .limiter import MessagingRateLimiter
self._limiter = await MessagingRateLimiter.get_instance()
self._start_task = asyncio.create_task(
self._client.start(self.bot_token),
name="discord-client-start",
)
max_wait = 30
waited = 0
while not self._connected and waited < max_wait:
await asyncio.sleep(0.5)
waited += 0.5
if not self._connected:
raise RuntimeError("Discord client failed to connect within timeout")
logger.info("Discord platform started")
async def stop(self) -> None:
"""Stop the bot."""
if self._client.is_closed():
self._connected = False
return
await self._client.close()
if self._start_task and not self._start_task.done():
try:
await asyncio.wait_for(self._start_task, timeout=5.0)
except asyncio.TimeoutError, asyncio.CancelledError:
self._start_task.cancel()
try:
await self._start_task
except asyncio.CancelledError:
pass
self._connected = False
logger.info("Discord platform stopped")
async def send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
) -> str:
"""Send a message to a channel."""
channel = self._client.get_channel(int(chat_id))
if not channel or not hasattr(channel, "send"):
raise RuntimeError(f"Channel {chat_id} not found")
text = self._truncate(text)
channel = cast(Any, channel)
discord = _get_discord()
if reply_to:
ref = discord.MessageReference(
message_id=int(reply_to),
channel_id=int(chat_id),
)
msg = await channel.send(content=text, reference=ref)
else:
msg = await channel.send(content=text)
return str(msg.id)
async def edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
) -> None:
"""Edit an existing message."""
channel = self._client.get_channel(int(chat_id))
if not channel or not hasattr(channel, "fetch_message"):
raise RuntimeError(f"Channel {chat_id} not found")
discord = _get_discord()
channel = cast(Any, channel)
try:
msg = await channel.fetch_message(int(message_id))
except discord.NotFound:
return
text = self._truncate(text)
await msg.edit(content=text)
async def delete_message(
self,
chat_id: str,
message_id: str,
) -> None:
"""Delete a message from a channel."""
channel = self._client.get_channel(int(chat_id))
if not channel or not hasattr(channel, "fetch_message"):
return
discord = _get_discord()
channel = cast(Any, channel)
try:
msg = await channel.fetch_message(int(message_id))
await msg.delete()
except discord.NotFound, discord.Forbidden:
pass
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
"""Delete multiple messages (best-effort)."""
for mid in message_ids:
await self.delete_message(chat_id, mid)
async def queue_send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> Optional[str]:
"""Enqueue a message to be sent."""
if not self._limiter:
return await self.send_message(chat_id, text, reply_to, parse_mode)
async def _send():
return await self.send_message(chat_id, text, reply_to, parse_mode)
if fire_and_forget:
self._limiter.fire_and_forget(_send)
return None
return await self._limiter.enqueue(_send)
async def queue_edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message edit."""
if not self._limiter:
await self.edit_message(chat_id, message_id, text, parse_mode)
return
async def _edit():
await self.edit_message(chat_id, message_id, text, parse_mode)
dedup_key = f"edit:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
async def queue_delete_message(
self,
chat_id: str,
message_id: str,
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message delete."""
if not self._limiter:
await self.delete_message(chat_id, message_id)
return
async def _delete():
await self.delete_message(chat_id, message_id)
dedup_key = f"del:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
async def queue_delete_messages(
self,
chat_id: str,
message_ids: list[str],
fire_and_forget: bool = True,
) -> None:
"""Enqueue a bulk delete."""
if not message_ids:
return
if not self._limiter:
await self.delete_messages(chat_id, message_ids)
return
async def _bulk():
await self.delete_messages(chat_id, message_ids)
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
if fire_and_forget:
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
def fire_and_forget(self, task: Awaitable[Any]) -> None:
"""Execute a coroutine without awaiting it."""
if asyncio.iscoroutine(task):
asyncio.create_task(task)
else:
asyncio.ensure_future(task)
def on_message(
self,
handler: Callable[[IncomingMessage], Awaitable[None]],
) -> None:
"""Register a message handler callback."""
self._message_handler = handler
@property
def is_connected(self) -> bool:
"""Check if connected."""
return self._connected
__all__ = ["DiscordPlatform", "DISCORD_AVAILABLE", "DISCORD_MESSAGE_LIMIT"]

View file

@ -1,367 +1,14 @@
"""Discord markdown utilities.
Discord uses standard markdown: **bold**, *italic*, `code`, ```code block```.
Used by the message handler and Discord platform adapter.
"""
import re
from typing import List, Optional
from markdown_it import MarkdownIt
# Discord escapes: \ * _ ` ~ | >
DISCORD_SPECIAL = set("\\*_`~|>")
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
_MD.enable("strikethrough")
_MD.enable("table")
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
_FENCE_RE = re.compile(r"^\s*```")
def _is_gfm_table_header_line(line: str) -> bool:
"""Check if line is a GFM table header."""
if "|" not in line:
return False
if _TABLE_SEP_RE.match(line):
return False
stripped = line.strip()
parts = [p.strip() for p in stripped.strip("|").split("|")]
parts = [p for p in parts if p != ""]
return len(parts) >= 2
def _normalize_gfm_tables(text: str) -> str:
"""Insert blank line before detected tables outside code blocks."""
lines = text.splitlines()
if len(lines) < 2:
return text
out_lines: List[str] = []
in_fence = False
for idx, line in enumerate(lines):
if _FENCE_RE.match(line):
in_fence = not in_fence
out_lines.append(line)
continue
if (
not in_fence
and idx + 1 < len(lines)
and _is_gfm_table_header_line(line)
and _TABLE_SEP_RE.match(lines[idx + 1])
):
if out_lines and out_lines[-1].strip() != "":
m = re.match(r"^(\s*)", line)
indent = m.group(1) if m else ""
out_lines.append(indent)
out_lines.append(line)
return "\n".join(out_lines)
def escape_discord(text: str) -> str:
"""Escape text for Discord markdown (bold, italic, etc.)."""
return "".join(f"\\{ch}" if ch in DISCORD_SPECIAL else ch for ch in text)
def escape_discord_code(text: str) -> str:
"""Escape text for Discord code spans/blocks."""
return text.replace("\\", "\\\\").replace("`", "\\`")
def discord_bold(text: str) -> str:
"""Format text as bold in Discord (uses **)."""
return f"**{escape_discord(text)}**"
def discord_code_inline(text: str) -> str:
"""Format text as inline code in Discord."""
return f"`{escape_discord_code(text)}`"
def format_status_discord(label: str, suffix: Optional[str] = None) -> str:
"""Format a status message for Discord (label in bold, optional suffix)."""
base = discord_bold(label)
if suffix:
return f"{base} {escape_discord(suffix)}"
return base
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
"""Format a status message with emoji for Discord (matches Telegram API)."""
base = f"{emoji} {discord_bold(label)}"
if suffix:
return f"{base} {escape_discord(suffix)}"
return base
def render_markdown_to_discord(text: str) -> str:
"""Render common Markdown into Discord-compatible format."""
if not text:
return ""
text = _normalize_gfm_tables(text)
tokens = _MD.parse(text)
def render_inline_table_plain(children) -> str:
out: List[str] = []
for tok in children:
if tok.type == "text":
out.append(tok.content)
elif tok.type == "code_inline":
out.append(tok.content)
elif tok.type in {"softbreak", "hardbreak"}:
out.append(" ")
elif tok.type == "image":
if tok.content:
out.append(tok.content)
return "".join(out)
def render_inline(children) -> str:
out: List[str] = []
i = 0
while i < len(children):
tok = children[i]
t = tok.type
if t == "text":
out.append(escape_discord(tok.content))
elif t in {"softbreak", "hardbreak"}:
out.append("\n")
elif t == "em_open":
out.append("*")
elif t == "em_close":
out.append("*")
elif t == "strong_open":
out.append("**")
elif t == "strong_close":
out.append("**")
elif t == "s_open":
out.append("~~")
elif t == "s_close":
out.append("~~")
elif t == "code_inline":
out.append(f"`{escape_discord_code(tok.content)}`")
elif t == "link_open":
href = ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("href", "")
else:
for key, val in tok.attrs:
if key == "href":
href = val
break
inner_tokens = []
i += 1
while i < len(children) and children[i].type != "link_close":
inner_tokens.append(children[i])
i += 1
link_text = ""
for child in inner_tokens:
if child.type == "text":
link_text += child.content
elif child.type == "code_inline":
link_text += child.content
out.append(f"[{escape_discord(link_text)}]({href})")
elif t == "image":
href = ""
alt = tok.content or ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("src", "")
else:
for key, val in tok.attrs:
if key == "src":
href = val
break
if alt:
out.append(f"{escape_discord(alt)} ({href})")
else:
out.append(href)
else:
out.append(escape_discord(tok.content or ""))
i += 1
return "".join(out)
out: List[str] = []
list_stack: List[dict] = []
pending_prefix: Optional[str] = None
blockquote_level = 0
in_heading = False
def apply_blockquote(val: str) -> str:
if blockquote_level <= 0:
return val
prefix = "> " * blockquote_level
return prefix + val.replace("\n", "\n" + prefix)
i = 0
while i < len(tokens):
tok = tokens[i]
t = tok.type
if t == "paragraph_open":
pass
elif t == "paragraph_close":
out.append("\n")
elif t == "heading_open":
in_heading = True
elif t == "heading_close":
in_heading = False
out.append("\n")
elif t == "bullet_list_open":
list_stack.append({"type": "bullet", "index": 1})
elif t == "bullet_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "ordered_list_open":
start = 1
if tok.attrs:
if isinstance(tok.attrs, dict):
val = tok.attrs.get("start")
if val is not None:
try:
start = int(val)
except TypeError, ValueError:
start = 1
else:
for key, val in tok.attrs:
if key == "start":
try:
start = int(val)
except TypeError, ValueError:
start = 1
break
list_stack.append({"type": "ordered", "index": start})
elif t == "ordered_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "list_item_open":
if list_stack:
top = list_stack[-1]
if top["type"] == "bullet":
pending_prefix = "- "
else:
pending_prefix = f"{top['index']}. "
top["index"] += 1
elif t == "list_item_close":
out.append("\n")
elif t == "blockquote_open":
blockquote_level += 1
elif t == "blockquote_close":
blockquote_level = max(0, blockquote_level - 1)
out.append("\n")
elif t == "table_open":
if pending_prefix:
out.append(apply_blockquote(pending_prefix.rstrip()))
out.append("\n")
pending_prefix = None
rows: List[List[str]] = []
row_is_header: List[bool] = []
j = i + 1
in_thead = False
in_row = False
current_row: List[str] = []
current_row_header = False
in_cell = False
cell_parts: List[str] = []
while j < len(tokens):
tt = tokens[j].type
if tt == "thead_open":
in_thead = True
elif tt == "thead_close":
in_thead = False
elif tt == "tr_open":
in_row = True
current_row = []
current_row_header = in_thead
elif tt in {"th_open", "td_open"}:
in_cell = True
cell_parts = []
elif tt == "inline" and in_cell:
cell_parts.append(
render_inline_table_plain(tokens[j].children or [])
)
elif tt in {"th_close", "td_close"} and in_cell:
cell = " ".join(cell_parts).strip()
current_row.append(cell)
in_cell = False
cell_parts = []
elif tt == "tr_close" and in_row:
rows.append(current_row)
row_is_header.append(bool(current_row_header))
in_row = False
elif tt == "table_close":
break
j += 1
if rows:
col_count = max((len(r) for r in rows), default=0)
norm_rows: List[List[str]] = []
for r in rows:
if len(r) < col_count:
r = r + [""] * (col_count - len(r))
norm_rows.append(r)
widths: List[int] = []
for c in range(col_count):
w = max((len(r[c]) for r in norm_rows), default=0)
widths.append(max(w, 3))
def fmt_row(r: List[str]) -> str:
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
def fmt_sep() -> str:
cells = ["-" * widths[c] for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
last_header_idx = -1
for idx, is_h in enumerate(row_is_header):
if is_h:
last_header_idx = idx
lines: List[str] = []
for idx, r in enumerate(norm_rows):
lines.append(fmt_row(r))
if idx == last_header_idx:
lines.append(fmt_sep())
table_text = "\n".join(lines).rstrip()
out.append(f"```\n{escape_discord_code(table_text)}\n```")
out.append("\n")
i = j + 1
continue
elif t in {"code_block", "fence"}:
code = escape_discord_code(tok.content.rstrip("\n"))
out.append(f"```\n{code}\n```")
out.append("\n")
elif t == "inline":
rendered = render_inline(tok.children or [])
if in_heading:
rendered = f"**{render_inline(tok.children or [])}**"
if pending_prefix:
rendered = pending_prefix + rendered
pending_prefix = None
rendered = apply_blockquote(rendered)
out.append(rendered)
else:
if tok.content:
out.append(escape_discord(tok.content))
i += 1
return "".join(out).rstrip()
"""Backward-compatible re-export. Use messaging.rendering.discord_markdown for new code."""
from .rendering.discord_markdown import (
escape_discord,
escape_discord_code,
discord_bold,
discord_code_inline,
format_status,
format_status_discord,
render_markdown_to_discord,
)
__all__ = [
"escape_discord",

View file

@ -1,58 +1,5 @@
"""Messaging platform factory.
"""Backward-compatible re-export. Use messaging.platforms.factory for new code."""
Creates the appropriate messaging platform adapter based on configuration.
To add a new platform (e.g. Discord, Slack):
1. Create a new class implementing MessagingPlatform in messaging/
2. Add a case to create_messaging_platform() below
"""
from .platforms.factory import create_messaging_platform
from typing import Optional
from loguru import logger
from .base import MessagingPlatform
def create_messaging_platform(
platform_type: str,
**kwargs,
) -> Optional[MessagingPlatform]:
"""Create a messaging platform instance based on type.
Args:
platform_type: Platform identifier ("telegram", "discord", etc.)
**kwargs: Platform-specific configuration passed to the constructor.
Returns:
Configured MessagingPlatform instance, or None if not configured.
"""
if platform_type == "telegram":
bot_token = kwargs.get("bot_token")
if not bot_token:
logger.info("No Telegram bot token configured, skipping platform setup")
return None
from .telegram import TelegramPlatform
return TelegramPlatform(
bot_token=bot_token,
allowed_user_id=kwargs.get("allowed_user_id"),
)
if platform_type == "discord":
bot_token = kwargs.get("discord_bot_token")
if not bot_token:
logger.info("No Discord bot token configured, skipping platform setup")
return None
from .discord import DiscordPlatform
return DiscordPlatform(
bot_token=bot_token,
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
)
logger.warning(
f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram', 'discord'"
)
return None
__all__ = ["create_messaging_platform"]

View file

@ -11,13 +11,18 @@ import asyncio
import os
from typing import List, Optional, Tuple
from .base import MessagingPlatform, SessionManagerInterface
from .platforms.base import MessagingPlatform, SessionManagerInterface
from .models import IncomingMessage
from .session import SessionStore
from .tree_queue import TreeQueueManager, MessageNode, MessageState, MessageTree
from .trees.queue_manager import (
TreeQueueManager,
MessageNode,
MessageState,
MessageTree,
)
from .event_parser import parse_cli_event
from .transcript import TranscriptBuffer, RenderCtx
from .telegram_markdown import (
from .rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
mdv2_bold,
@ -25,7 +30,7 @@ from .telegram_markdown import (
format_status as format_status_telegram,
render_markdown_to_mdv2,
)
from .discord_markdown import (
from .rendering.discord_markdown import (
escape_discord,
escape_discord_code,
discord_bold,

View file

@ -0,0 +1,11 @@
"""Messaging platform adapters (Telegram, Discord, etc.)."""
from .base import MessagingPlatform, SessionManagerInterface, CLISession
from .factory import create_messaging_platform
__all__ = [
"MessagingPlatform",
"SessionManagerInterface",
"CLISession",
"create_messaging_platform",
]

219
messaging/platforms/base.py Normal file
View file

@ -0,0 +1,219 @@
"""Abstract base class for messaging platforms."""
from abc import ABC, abstractmethod
from typing import (
Callable,
Awaitable,
Optional,
Protocol,
Tuple,
runtime_checkable,
AsyncGenerator,
Any,
Dict,
)
from ..models import IncomingMessage
@runtime_checkable
class CLISession(Protocol):
"""Protocol for CLI session - avoid circular import from cli package."""
def start_task(
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
) -> AsyncGenerator[Dict, Any]:
"""Start a task in the CLI session."""
...
@property
@abstractmethod
def is_busy(self) -> bool:
"""Check if session is busy."""
pass
@runtime_checkable
class SessionManagerInterface(Protocol):
"""
Protocol for session managers to avoid tight coupling with cli package.
Implementations: CLISessionManager
"""
async def get_or_create_session(
self, session_id: Optional[str] = None
) -> Tuple[CLISession, str, bool]:
"""
Get an existing session or create a new one.
Returns: Tuple of (session, session_id, is_new_session)
"""
...
async def register_real_session_id(
self, temp_id: str, real_session_id: str
) -> bool:
"""Register the real session ID from CLI output."""
...
async def stop_all(self) -> None:
"""Stop all sessions."""
...
async def remove_session(self, session_id: str) -> bool:
"""Remove a session from the manager."""
...
def get_stats(self) -> dict:
"""Get session statistics."""
...
class MessagingPlatform(ABC):
"""
Base class for all messaging platform adapters.
Implement this to add support for Telegram, Discord, Slack, etc.
"""
name: str = "base"
@abstractmethod
async def start(self) -> None:
"""Initialize and connect to the messaging platform."""
pass
@abstractmethod
async def stop(self) -> None:
"""Disconnect and cleanup resources."""
pass
@abstractmethod
async def send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
) -> str:
"""
Send a message to a chat.
Args:
chat_id: The chat/channel ID to send to
text: Message content
reply_to: Optional message ID to reply to
parse_mode: Optional formatting mode ("markdown", "html")
Returns:
The message ID of the sent message
"""
pass
@abstractmethod
async def edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
) -> None:
"""
Edit an existing message.
Args:
chat_id: The chat/channel ID
message_id: The message ID to edit
text: New message content
parse_mode: Optional formatting mode
"""
pass
@abstractmethod
async def delete_message(
self,
chat_id: str,
message_id: str,
) -> None:
"""
Delete a message from a chat.
Args:
chat_id: The chat/channel ID
message_id: The message ID to delete
"""
pass
@abstractmethod
async def queue_send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> Optional[str]:
"""
Enqueue a message to be sent.
If fire_and_forget is True, returns None immediately.
Otherwise, waits for the rate limiter and returns message ID.
"""
pass
@abstractmethod
async def queue_edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> None:
"""
Enqueue a message edit.
If fire_and_forget is True, returns immediately.
Otherwise, waits for the rate limiter.
"""
pass
@abstractmethod
async def queue_delete_message(
self,
chat_id: str,
message_id: str,
fire_and_forget: bool = True,
) -> None:
"""
Enqueue a message deletion.
If fire_and_forget is True, returns immediately.
Otherwise, waits for the rate limiter.
"""
pass
@abstractmethod
def on_message(
self,
handler: Callable[[IncomingMessage], Awaitable[None]],
) -> None:
"""
Register a message handler callback.
The handler will be called for each incoming message.
Args:
handler: Async function that processes incoming messages
"""
pass
@abstractmethod
def fire_and_forget(self, task: Awaitable[Any]) -> None:
"""Execute a coroutine without awaiting it."""
pass
@property
def is_connected(self) -> bool:
"""Check if the platform is connected."""
return False

View file

@ -0,0 +1,394 @@
"""
Discord Platform Adapter
Implements MessagingPlatform for Discord using discord.py.
"""
import asyncio
import os
from typing import Callable, Awaitable, Optional, Any, Set, cast
from loguru import logger
from .base import MessagingPlatform
from ..models import IncomingMessage
from ..rendering.discord_markdown import format_status_discord
_discord_module: Any = None
try:
import discord as _discord_import
_discord_module = _discord_import
DISCORD_AVAILABLE = True
except ImportError:
DISCORD_AVAILABLE = False
DISCORD_MESSAGE_LIMIT = 2000
def _get_discord() -> Any:
"""Return the discord module. Raises if not available."""
if not DISCORD_AVAILABLE or _discord_module is None:
raise ImportError(
"discord.py is required. Install with: pip install discord.py"
)
return _discord_module
def _parse_allowed_channels(raw: Optional[str]) -> Set[str]:
"""Parse comma-separated channel IDs into a set of strings."""
if not raw or not raw.strip():
return set()
return {s.strip() for s in raw.split(",") if s.strip()}
if DISCORD_AVAILABLE and _discord_module is not None:
_discord = _discord_module
class _DiscordClient(_discord.Client):
"""Internal Discord client that forwards events to DiscordPlatform."""
def __init__(
self,
platform: "DiscordPlatform",
intents: _discord.Intents,
) -> None:
super().__init__(intents=intents)
self._platform = platform
async def on_ready(self) -> None:
"""Called when the bot is ready."""
self._platform._connected = True
logger.info("Discord platform connected")
async def on_message(self, message: Any) -> None:
"""Handle incoming Discord messages."""
await self._platform._on_discord_message(message)
else:
_DiscordClient = None
class DiscordPlatform(MessagingPlatform):
"""
Discord messaging platform adapter.
Uses discord.py for Discord access.
Requires a Bot Token from Discord Developer Portal and message_content intent.
"""
name = "discord"
def __init__(
self,
bot_token: Optional[str] = None,
allowed_channel_ids: Optional[str] = None,
):
if not DISCORD_AVAILABLE:
raise ImportError(
"discord.py is required. Install with: pip install discord.py"
)
self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
if not self.bot_token:
logger.warning("DISCORD_BOT_TOKEN not set")
discord = _get_discord()
intents = discord.Intents.default()
intents.message_content = True
assert _DiscordClient is not None
self._client = _DiscordClient(self, intents)
self._message_handler: Optional[
Callable[[IncomingMessage], Awaitable[None]]
] = None
self._connected = False
self._limiter: Optional[Any] = None
self._start_task: Optional[asyncio.Task] = None
async def _on_discord_message(self, message: Any) -> None:
"""Handle incoming Discord messages."""
if message.author.bot:
return
if not message.content:
return
channel_id = str(message.channel.id)
if not self.allowed_channel_ids or channel_id not in self.allowed_channel_ids:
return
user_id = str(message.author.id)
message_id = str(message.id)
reply_to = (
str(message.reference.message_id)
if message.reference and message.reference.message_id
else None
)
text_preview = (message.content or "")[:80]
if len(message.content or "") > 80:
text_preview += "..."
logger.info(
"DISCORD_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
channel_id,
message_id,
reply_to,
text_preview,
)
if not self._message_handler:
return
incoming = IncomingMessage(
text=message.content,
chat_id=channel_id,
user_id=user_id,
message_id=message_id,
platform="discord",
reply_to_message_id=reply_to,
username=message.author.display_name,
raw_event=message,
)
try:
await self._message_handler(incoming)
except Exception as e:
logger.error(f"Error handling message: {e}")
try:
await self.send_message(
channel_id,
format_status_discord("Error:", str(e)[:200]),
reply_to=message_id,
)
except Exception:
pass
def _truncate(self, text: str, limit: int = DISCORD_MESSAGE_LIMIT) -> str:
"""Truncate text to Discord's message limit."""
if len(text) <= limit:
return text
return text[: limit - 3] + "..."
async def start(self) -> None:
"""Initialize and connect to Discord."""
if not self.bot_token:
raise ValueError("DISCORD_BOT_TOKEN is required")
from ..limiter import MessagingRateLimiter
self._limiter = await MessagingRateLimiter.get_instance()
self._start_task = asyncio.create_task(
self._client.start(self.bot_token),
name="discord-client-start",
)
max_wait = 30
waited = 0
while not self._connected and waited < max_wait:
await asyncio.sleep(0.5)
waited += 0.5
if not self._connected:
raise RuntimeError("Discord client failed to connect within timeout")
logger.info("Discord platform started")
async def stop(self) -> None:
"""Stop the bot."""
if self._client.is_closed():
self._connected = False
return
await self._client.close()
if self._start_task and not self._start_task.done():
try:
await asyncio.wait_for(self._start_task, timeout=5.0)
except asyncio.TimeoutError, asyncio.CancelledError:
self._start_task.cancel()
try:
await self._start_task
except asyncio.CancelledError:
pass
self._connected = False
logger.info("Discord platform stopped")
async def send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
) -> str:
"""Send a message to a channel."""
channel = self._client.get_channel(int(chat_id))
if not channel or not hasattr(channel, "send"):
raise RuntimeError(f"Channel {chat_id} not found")
text = self._truncate(text)
channel = cast(Any, channel)
discord = _get_discord()
if reply_to:
ref = discord.MessageReference(
message_id=int(reply_to),
channel_id=int(chat_id),
)
msg = await channel.send(content=text, reference=ref)
else:
msg = await channel.send(content=text)
return str(msg.id)
async def edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
) -> None:
"""Edit an existing message."""
channel = self._client.get_channel(int(chat_id))
if not channel or not hasattr(channel, "fetch_message"):
raise RuntimeError(f"Channel {chat_id} not found")
discord = _get_discord()
channel = cast(Any, channel)
try:
msg = await channel.fetch_message(int(message_id))
except discord.NotFound:
return
text = self._truncate(text)
await msg.edit(content=text)
async def delete_message(
self,
chat_id: str,
message_id: str,
) -> None:
"""Delete a message from a channel."""
channel = self._client.get_channel(int(chat_id))
if not channel or not hasattr(channel, "fetch_message"):
return
discord = _get_discord()
channel = cast(Any, channel)
try:
msg = await channel.fetch_message(int(message_id))
await msg.delete()
except discord.NotFound, discord.Forbidden:
pass
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
"""Delete multiple messages (best-effort)."""
for mid in message_ids:
await self.delete_message(chat_id, mid)
async def queue_send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> Optional[str]:
"""Enqueue a message to be sent."""
if not self._limiter:
return await self.send_message(chat_id, text, reply_to, parse_mode)
async def _send():
return await self.send_message(chat_id, text, reply_to, parse_mode)
if fire_and_forget:
self._limiter.fire_and_forget(_send)
return None
return await self._limiter.enqueue(_send)
async def queue_edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = None,
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message edit."""
if not self._limiter:
await self.edit_message(chat_id, message_id, text, parse_mode)
return
async def _edit():
await self.edit_message(chat_id, message_id, text, parse_mode)
dedup_key = f"edit:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
async def queue_delete_message(
self,
chat_id: str,
message_id: str,
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message delete."""
if not self._limiter:
await self.delete_message(chat_id, message_id)
return
async def _delete():
await self.delete_message(chat_id, message_id)
dedup_key = f"del:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
async def queue_delete_messages(
self,
chat_id: str,
message_ids: list[str],
fire_and_forget: bool = True,
) -> None:
"""Enqueue a bulk delete."""
if not message_ids:
return
if not self._limiter:
await self.delete_messages(chat_id, message_ids)
return
async def _bulk():
await self.delete_messages(chat_id, message_ids)
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
if fire_and_forget:
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
def fire_and_forget(self, task: Awaitable[Any]) -> None:
"""Execute a coroutine without awaiting it."""
if asyncio.iscoroutine(task):
asyncio.create_task(task)
else:
asyncio.ensure_future(task)
def on_message(
self,
handler: Callable[[IncomingMessage], Awaitable[None]],
) -> None:
"""Register a message handler callback."""
self._message_handler = handler
@property
def is_connected(self) -> bool:
"""Check if connected."""
return self._connected

View file

@ -0,0 +1,58 @@
"""Messaging platform factory.
Creates the appropriate messaging platform adapter based on configuration.
To add a new platform (e.g. Discord, Slack):
1. Create a new class implementing MessagingPlatform in messaging/platforms/
2. Add a case to create_messaging_platform() below
"""
from typing import Optional
from loguru import logger
from .base import MessagingPlatform
def create_messaging_platform(
platform_type: str,
**kwargs,
) -> Optional[MessagingPlatform]:
"""Create a messaging platform instance based on type.
Args:
platform_type: Platform identifier ("telegram", "discord", etc.)
**kwargs: Platform-specific configuration passed to the constructor.
Returns:
Configured MessagingPlatform instance, or None if not configured.
"""
if platform_type == "telegram":
bot_token = kwargs.get("bot_token")
if not bot_token:
logger.info("No Telegram bot token configured, skipping platform setup")
return None
from .telegram import TelegramPlatform
return TelegramPlatform(
bot_token=bot_token,
allowed_user_id=kwargs.get("allowed_user_id"),
)
if platform_type == "discord":
bot_token = kwargs.get("discord_bot_token")
if not bot_token:
logger.info("No Discord bot token configured, skipping platform setup")
return None
from .discord import DiscordPlatform
return DiscordPlatform(
bot_token=bot_token,
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
)
logger.warning(
f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram', 'discord'"
)
return None

View file

@ -0,0 +1,491 @@
"""
Telegram Platform Adapter
Implements MessagingPlatform for Telegram using python-telegram-bot.
"""
import asyncio
import os
# Opt-in to future behavior for python-telegram-bot (retry_after as timedelta)
# This must be set BEFORE importing telegram.error
os.environ["PTB_TIMEDELTA"] = "1"
from typing import Callable, Awaitable, Optional, Any, TYPE_CHECKING
from loguru import logger
if TYPE_CHECKING:
from telegram import Update
from telegram.ext import ContextTypes
from .base import MessagingPlatform
from ..models import IncomingMessage
from ..rendering.telegram_markdown import escape_md_v2
# Optional import - python-telegram-bot may not be installed
try:
from telegram import Update
from telegram.ext import (
Application,
CommandHandler,
MessageHandler,
ContextTypes,
filters,
)
from telegram.error import TelegramError, RetryAfter, NetworkError
from telegram.request import HTTPXRequest
TELEGRAM_AVAILABLE = True
except ImportError:
TELEGRAM_AVAILABLE = False
class TelegramPlatform(MessagingPlatform):
"""
Telegram messaging platform adapter.
Uses python-telegram-bot (BoT API) for Telegram access.
Requires a Bot Token from @BotFather.
"""
name = "telegram"
def __init__(
self,
bot_token: Optional[str] = None,
allowed_user_id: Optional[str] = None,
):
if not TELEGRAM_AVAILABLE:
raise ImportError(
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
)
self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
if not self.bot_token:
# We don't raise here to allow instantiation for testing/conditional logic,
# but start() will fail.
logger.warning("TELEGRAM_BOT_TOKEN not set")
self._application: Optional[Application] = None
self._message_handler: Optional[
Callable[[IncomingMessage], Awaitable[None]]
] = None
self._connected = False
self._limiter: Optional[Any] = None # Will be MessagingRateLimiter
async def start(self) -> None:
"""Initialize and connect to Telegram."""
if not self.bot_token:
raise ValueError("TELEGRAM_BOT_TOKEN is required")
# Configure request with longer timeouts
request = HTTPXRequest(
connection_pool_size=8, connect_timeout=30.0, read_timeout=30.0
)
# Build Application
builder = Application.builder().token(self.bot_token).request(request)
self._application = builder.build()
# Register Internal Handlers
# We catch ALL text messages and commands to forward them
self._application.add_handler(
MessageHandler(filters.TEXT & (~filters.COMMAND), self._on_telegram_message)
)
self._application.add_handler(CommandHandler("start", self._on_start_command))
# Catch-all for other commands if needed, or let them fall through
self._application.add_handler(
MessageHandler(filters.COMMAND, self._on_telegram_message)
)
# Initialize internal components with retry logic
max_retries = 3
for attempt in range(max_retries):
try:
await self._application.initialize()
await self._application.start()
# Start polling (non-blocking way for integration)
if self._application.updater:
await self._application.updater.start_polling(
drop_pending_updates=False
)
self._connected = True
break
except (NetworkError, Exception) as e:
if attempt < max_retries - 1:
wait_time = 2 * (attempt + 1)
logger.warning(
f"Connection failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
else:
logger.error(f"Failed to connect after {max_retries} attempts")
raise
# Initialize rate limiter
from ..limiter import MessagingRateLimiter
self._limiter = await MessagingRateLimiter.get_instance()
# Send startup notification
try:
target = self.allowed_user_id
if target:
startup_text = (
f"🚀 *{escape_md_v2('Claude Code Proxy is online!')}* "
f"{escape_md_v2('(Bot API)')}"
)
await self.send_message(
target,
startup_text,
)
except Exception as e:
logger.warning(f"Could not send startup message: {e}")
logger.info("Telegram platform started (Bot API)")
async def stop(self) -> None:
"""Stop the bot."""
if self._application and self._application.updater:
await self._application.updater.stop()
await self._application.stop()
await self._application.shutdown()
self._connected = False
logger.info("Telegram platform stopped")
async def _with_retry(
self, func: Callable[..., Awaitable[Any]], *args, **kwargs
) -> Any:
"""Helper to execute a function with exponential backoff on network errors."""
max_retries = 3
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except (NetworkError, asyncio.TimeoutError) as e:
if "Message is not modified" in str(e):
return None
if attempt < max_retries - 1:
wait_time = 2**attempt # 1s, 2s, 4s
logger.warning(
f"Telegram API network error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Telegram API failed after {max_retries} attempts: {e}"
)
raise
except RetryAfter as e:
# Telegram explicitly tells us to wait (PTB_TIMEDELTA: retry_after is timedelta)
from datetime import timedelta
retry_after = e.retry_after
if isinstance(retry_after, timedelta):
wait_secs = retry_after.total_seconds()
else:
wait_secs = float(retry_after)
logger.warning(f"Rate limited by Telegram, waiting {wait_secs}s...")
await asyncio.sleep(wait_secs)
# We don't increment attempt here, as this is a specific instruction
return await func(*args, **kwargs)
except TelegramError as e:
# Non-network Telegram errors
err_lower = str(e).lower()
if "message is not modified" in err_lower:
return None
# Best-effort no-op cases (common during chat cleanup / /clear).
if any(
x in err_lower
for x in [
"message to edit not found",
"message to delete not found",
"message can't be deleted",
"message can't be edited",
"not enough rights to delete",
]
):
return None
if "Can't parse entities" in str(e) and kwargs.get("parse_mode"):
logger.warning("Markdown failed, retrying without parse_mode")
kwargs["parse_mode"] = None
return await func(*args, **kwargs)
raise
async def send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = "MarkdownV2",
) -> str:
"""Send a message to a chat."""
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
async def _do_send(parse_mode=parse_mode):
bot = app.bot
msg = await bot.send_message(
chat_id=chat_id,
text=text,
reply_to_message_id=int(reply_to) if reply_to else None,
parse_mode=parse_mode,
)
return str(msg.message_id)
return await self._with_retry(_do_send, parse_mode=parse_mode)
async def edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = "MarkdownV2",
) -> None:
"""Edit an existing message."""
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
async def _do_edit(parse_mode=parse_mode):
bot = app.bot
await bot.edit_message_text(
chat_id=chat_id,
message_id=int(message_id),
text=text,
parse_mode=parse_mode,
)
await self._with_retry(_do_edit, parse_mode=parse_mode)
async def delete_message(
self,
chat_id: str,
message_id: str,
) -> None:
"""Delete a message from a chat."""
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
async def _do_delete():
bot = app.bot
await bot.delete_message(chat_id=chat_id, message_id=int(message_id))
await self._with_retry(_do_delete)
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
"""Delete multiple messages (best-effort)."""
if not message_ids:
return
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
# PTB supports bulk deletion via delete_messages; fall back to per-message.
bot = app.bot
if hasattr(bot, "delete_messages"):
async def _do_bulk():
mids = []
for mid in message_ids:
try:
mids.append(int(mid))
except Exception:
continue
if not mids:
return None
# delete_messages accepts a sequence of ints (up to 100).
await bot.delete_messages(chat_id=chat_id, message_ids=mids)
await self._with_retry(_do_bulk)
return
for mid in message_ids:
await self.delete_message(chat_id, mid)
async def queue_send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = "MarkdownV2",
fire_and_forget: bool = True,
) -> Optional[str]:
"""Enqueue a message to be sent (using limiter)."""
# Note: Bot API handles limits better, but we still use our limiter for nice queuing
if not self._limiter:
return await self.send_message(chat_id, text, reply_to, parse_mode)
async def _send():
return await self.send_message(chat_id, text, reply_to, parse_mode)
if fire_and_forget:
self._limiter.fire_and_forget(_send)
return None
else:
return await self._limiter.enqueue(_send)
async def queue_edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = "MarkdownV2",
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message edit."""
if not self._limiter:
return await self.edit_message(chat_id, message_id, text, parse_mode)
async def _edit():
return await self.edit_message(chat_id, message_id, text, parse_mode)
dedup_key = f"edit:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
async def queue_delete_message(
self,
chat_id: str,
message_id: str,
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message delete."""
if not self._limiter:
return await self.delete_message(chat_id, message_id)
async def _delete():
return await self.delete_message(chat_id, message_id)
dedup_key = f"del:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
async def queue_delete_messages(
self,
chat_id: str,
message_ids: list[str],
fire_and_forget: bool = True,
) -> None:
"""Enqueue a bulk delete (if supported) or a sequence of deletes."""
if not message_ids:
return
if not self._limiter:
return await self.delete_messages(chat_id, message_ids)
async def _bulk():
return await self.delete_messages(chat_id, message_ids)
# Dedup by the chunk content; okay to be coarse here.
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
if fire_and_forget:
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
def fire_and_forget(self, task: Awaitable[Any]) -> None:
"""Execute a coroutine without awaiting it."""
if asyncio.iscoroutine(task):
asyncio.create_task(task)
else:
asyncio.ensure_future(task)
def on_message(
self,
handler: Callable[[IncomingMessage], Awaitable[None]],
) -> None:
"""Register a message handler callback."""
self._message_handler = handler
@property
def is_connected(self) -> bool:
"""Check if connected."""
return self._connected
async def _on_start_command(
self, update: "Update", context: "ContextTypes.DEFAULT_TYPE"
) -> None:
"""Handle /start command."""
if update.message:
await update.message.reply_text("👋 Hello! I am the Claude Code Proxy Bot.")
# We can also treat this as a message if we want it to trigger something
await self._on_telegram_message(update, context)
async def _on_telegram_message(
self, update: "Update", context: "ContextTypes.DEFAULT_TYPE"
) -> None:
"""Handle incoming updates."""
if (
not update.message
or not update.message.text
or not update.effective_user
or not update.effective_chat
):
return
user_id = str(update.effective_user.id)
chat_id = str(update.effective_chat.id)
# Security check
if self.allowed_user_id:
if user_id != str(self.allowed_user_id).strip():
logger.warning(f"Unauthorized access attempt from {user_id}")
return
message_id = str(update.message.message_id)
reply_to = (
str(update.message.reply_to_message.message_id)
if update.message.reply_to_message
else None
)
text_preview = (update.message.text or "")[:80]
if len(update.message.text or "") > 80:
text_preview += "..."
logger.info(
"TELEGRAM_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
chat_id,
message_id,
reply_to,
text_preview,
)
if not self._message_handler:
return
incoming = IncomingMessage(
text=update.message.text,
chat_id=chat_id,
user_id=user_id,
message_id=message_id,
platform="telegram",
reply_to_message_id=reply_to,
raw_event=update,
)
try:
await self._message_handler(incoming)
except Exception as e:
logger.error(f"Error handling message: {e}")
try:
await self.send_message(
chat_id,
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(str(e)[:200])}",
reply_to=incoming.message_id,
parse_mode="MarkdownV2",
)
except Exception:
pass

View file

@ -0,0 +1,37 @@
"""Markdown rendering utilities for messaging platforms."""
from .discord_markdown import (
escape_discord,
escape_discord_code,
discord_bold,
discord_code_inline,
format_status as format_status_discord_fn,
format_status_discord,
render_markdown_to_discord,
)
from .telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
escape_md_v2_link_url,
mdv2_bold,
mdv2_code_inline,
format_status as format_status_telegram_fn,
render_markdown_to_mdv2,
)
__all__ = [
"escape_discord",
"escape_discord_code",
"discord_bold",
"discord_code_inline",
"format_status_discord_fn",
"format_status_discord",
"render_markdown_to_discord",
"escape_md_v2",
"escape_md_v2_code",
"escape_md_v2_link_url",
"mdv2_bold",
"mdv2_code_inline",
"format_status_telegram_fn",
"render_markdown_to_mdv2",
]

View file

@ -0,0 +1,374 @@
"""Discord markdown utilities.
Discord uses standard markdown: **bold**, *italic*, `code`, ```code block```.
Used by the message handler and Discord platform adapter.
"""
import re
from typing import List, Optional
from markdown_it import MarkdownIt
# Discord escapes: \ * _ ` ~ | >
DISCORD_SPECIAL = set("\\*_`~|>")
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
_MD.enable("strikethrough")
_MD.enable("table")
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
_FENCE_RE = re.compile(r"^\s*```")
def _is_gfm_table_header_line(line: str) -> bool:
"""Check if line is a GFM table header."""
if "|" not in line:
return False
if _TABLE_SEP_RE.match(line):
return False
stripped = line.strip()
parts = [p.strip() for p in stripped.strip("|").split("|")]
parts = [p for p in parts if p != ""]
return len(parts) >= 2
def _normalize_gfm_tables(text: str) -> str:
"""Insert blank line before detected tables outside code blocks."""
lines = text.splitlines()
if len(lines) < 2:
return text
out_lines: List[str] = []
in_fence = False
for idx, line in enumerate(lines):
if _FENCE_RE.match(line):
in_fence = not in_fence
out_lines.append(line)
continue
if (
not in_fence
and idx + 1 < len(lines)
and _is_gfm_table_header_line(line)
and _TABLE_SEP_RE.match(lines[idx + 1])
):
if out_lines and out_lines[-1].strip() != "":
m = re.match(r"^(\s*)", line)
indent = m.group(1) if m else ""
out_lines.append(indent)
out_lines.append(line)
return "\n".join(out_lines)
def escape_discord(text: str) -> str:
"""Escape text for Discord markdown (bold, italic, etc.)."""
return "".join(f"\\{ch}" if ch in DISCORD_SPECIAL else ch for ch in text)
def escape_discord_code(text: str) -> str:
"""Escape text for Discord code spans/blocks."""
return text.replace("\\", "\\\\").replace("`", "\\`")
def discord_bold(text: str) -> str:
"""Format text as bold in Discord (uses **)."""
return f"**{escape_discord(text)}**"
def discord_code_inline(text: str) -> str:
"""Format text as inline code in Discord."""
return f"`{escape_discord_code(text)}`"
def format_status_discord(label: str, suffix: Optional[str] = None) -> str:
"""Format a status message for Discord (label in bold, optional suffix)."""
base = discord_bold(label)
if suffix:
return f"{base} {escape_discord(suffix)}"
return base
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
"""Format a status message with emoji for Discord (matches Telegram API)."""
base = f"{emoji} {discord_bold(label)}"
if suffix:
return f"{base} {escape_discord(suffix)}"
return base
def render_markdown_to_discord(text: str) -> str:
"""Render common Markdown into Discord-compatible format."""
if not text:
return ""
text = _normalize_gfm_tables(text)
tokens = _MD.parse(text)
def render_inline_table_plain(children) -> str:
out: List[str] = []
for tok in children:
if tok.type == "text":
out.append(tok.content)
elif tok.type == "code_inline":
out.append(tok.content)
elif tok.type in {"softbreak", "hardbreak"}:
out.append(" ")
elif tok.type == "image":
if tok.content:
out.append(tok.content)
return "".join(out)
def render_inline(children) -> str:
out: List[str] = []
i = 0
while i < len(children):
tok = children[i]
t = tok.type
if t == "text":
out.append(escape_discord(tok.content))
elif t in {"softbreak", "hardbreak"}:
out.append("\n")
elif t == "em_open":
out.append("*")
elif t == "em_close":
out.append("*")
elif t == "strong_open":
out.append("**")
elif t == "strong_close":
out.append("**")
elif t == "s_open":
out.append("~~")
elif t == "s_close":
out.append("~~")
elif t == "code_inline":
out.append(f"`{escape_discord_code(tok.content)}`")
elif t == "link_open":
href = ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("href", "")
else:
for key, val in tok.attrs:
if key == "href":
href = val
break
inner_tokens = []
i += 1
while i < len(children) and children[i].type != "link_close":
inner_tokens.append(children[i])
i += 1
link_text = ""
for child in inner_tokens:
if child.type == "text":
link_text += child.content
elif child.type == "code_inline":
link_text += child.content
out.append(f"[{escape_discord(link_text)}]({href})")
elif t == "image":
href = ""
alt = tok.content or ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("src", "")
else:
for key, val in tok.attrs:
if key == "src":
href = val
break
if alt:
out.append(f"{escape_discord(alt)} ({href})")
else:
out.append(href)
else:
out.append(escape_discord(tok.content or ""))
i += 1
return "".join(out)
out: List[str] = []
list_stack: List[dict] = []
pending_prefix: Optional[str] = None
blockquote_level = 0
in_heading = False
def apply_blockquote(val: str) -> str:
if blockquote_level <= 0:
return val
prefix = "> " * blockquote_level
return prefix + val.replace("\n", "\n" + prefix)
i = 0
while i < len(tokens):
tok = tokens[i]
t = tok.type
if t == "paragraph_open":
pass
elif t == "paragraph_close":
out.append("\n")
elif t == "heading_open":
in_heading = True
elif t == "heading_close":
in_heading = False
out.append("\n")
elif t == "bullet_list_open":
list_stack.append({"type": "bullet", "index": 1})
elif t == "bullet_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "ordered_list_open":
start = 1
if tok.attrs:
if isinstance(tok.attrs, dict):
val = tok.attrs.get("start")
if val is not None:
try:
start = int(val)
except TypeError, ValueError:
start = 1
else:
for key, val in tok.attrs:
if key == "start":
try:
start = int(val)
except TypeError, ValueError:
start = 1
break
list_stack.append({"type": "ordered", "index": start})
elif t == "ordered_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "list_item_open":
if list_stack:
top = list_stack[-1]
if top["type"] == "bullet":
pending_prefix = "- "
else:
pending_prefix = f"{top['index']}. "
top["index"] += 1
elif t == "list_item_close":
out.append("\n")
elif t == "blockquote_open":
blockquote_level += 1
elif t == "blockquote_close":
blockquote_level = max(0, blockquote_level - 1)
out.append("\n")
elif t == "table_open":
if pending_prefix:
out.append(apply_blockquote(pending_prefix.rstrip()))
out.append("\n")
pending_prefix = None
rows: List[List[str]] = []
row_is_header: List[bool] = []
j = i + 1
in_thead = False
in_row = False
current_row: List[str] = []
current_row_header = False
in_cell = False
cell_parts: List[str] = []
while j < len(tokens):
tt = tokens[j].type
if tt == "thead_open":
in_thead = True
elif tt == "thead_close":
in_thead = False
elif tt == "tr_open":
in_row = True
current_row = []
current_row_header = in_thead
elif tt in {"th_open", "td_open"}:
in_cell = True
cell_parts = []
elif tt == "inline" and in_cell:
cell_parts.append(
render_inline_table_plain(tokens[j].children or [])
)
elif tt in {"th_close", "td_close"} and in_cell:
cell = " ".join(cell_parts).strip()
current_row.append(cell)
in_cell = False
cell_parts = []
elif tt == "tr_close" and in_row:
rows.append(current_row)
row_is_header.append(bool(current_row_header))
in_row = False
elif tt == "table_close":
break
j += 1
if rows:
col_count = max((len(r) for r in rows), default=0)
norm_rows: List[List[str]] = []
for r in rows:
if len(r) < col_count:
r = r + [""] * (col_count - len(r))
norm_rows.append(r)
widths: List[int] = []
for c in range(col_count):
w = max((len(r[c]) for r in norm_rows), default=0)
widths.append(max(w, 3))
def fmt_row(r: List[str]) -> str:
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
def fmt_sep() -> str:
cells = ["-" * widths[c] for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
last_header_idx = -1
for idx, is_h in enumerate(row_is_header):
if is_h:
last_header_idx = idx
lines: List[str] = []
for idx, r in enumerate(norm_rows):
lines.append(fmt_row(r))
if idx == last_header_idx:
lines.append(fmt_sep())
table_text = "\n".join(lines).rstrip()
out.append(f"```\n{escape_discord_code(table_text)}\n```")
out.append("\n")
i = j + 1
continue
elif t in {"code_block", "fence"}:
code = escape_discord_code(tok.content.rstrip("\n"))
out.append(f"```\n{code}\n```")
out.append("\n")
elif t == "inline":
rendered = render_inline(tok.children or [])
if in_heading:
rendered = f"**{render_inline(tok.children or [])}**"
if pending_prefix:
rendered = pending_prefix + rendered
pending_prefix = None
rendered = apply_blockquote(rendered)
out.append(rendered)
else:
if tok.content:
out.append(escape_discord(tok.content))
i += 1
return "".join(out).rstrip()
__all__ = [
"escape_discord",
"escape_discord_code",
"discord_bold",
"discord_code_inline",
"format_status",
"format_status_discord",
"render_markdown_to_discord",
]

View file

@ -0,0 +1,391 @@
"""Telegram MarkdownV2 utilities.
Renders common Markdown into Telegram MarkdownV2 format.
Used by the message handler and Telegram platform adapter.
"""
import re
from typing import List, Optional
from markdown_it import MarkdownIt
MDV2_SPECIAL_CHARS = set("\\_*[]()~`>#+-=|{}.!")
MDV2_LINK_ESCAPE = set("\\)")
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
_MD.enable("strikethrough")
_MD.enable("table")
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
_FENCE_RE = re.compile(r"^\s*```")
def _is_gfm_table_header_line(line: str) -> bool:
"""Check if line is a GFM table header (pipe-delimited, not separator)."""
if "|" not in line:
return False
if _TABLE_SEP_RE.match(line):
return False
stripped = line.strip()
parts = [p.strip() for p in stripped.strip("|").split("|")]
parts = [p for p in parts if p != ""]
return len(parts) >= 2
def _normalize_gfm_tables(text: str) -> str:
"""
Many LLMs emit tables immediately after a paragraph line (no blank line).
Markdown-it will treat that as a softbreak within the paragraph, so the
table extension won't trigger. Insert a blank line before detected tables.
We only do this outside fenced code blocks.
"""
lines = text.splitlines()
if len(lines) < 2:
return text
out_lines: List[str] = []
in_fence = False
for idx, line in enumerate(lines):
if _FENCE_RE.match(line):
in_fence = not in_fence
out_lines.append(line)
continue
if (
not in_fence
and idx + 1 < len(lines)
and _is_gfm_table_header_line(line)
and _TABLE_SEP_RE.match(lines[idx + 1])
):
if out_lines and out_lines[-1].strip() != "":
m = re.match(r"^(\s*)", line)
indent = m.group(1) if m else ""
out_lines.append(indent)
out_lines.append(line)
return "\n".join(out_lines)
def escape_md_v2(text: str) -> str:
"""Escape text for Telegram MarkdownV2."""
return "".join(f"\\{ch}" if ch in MDV2_SPECIAL_CHARS else ch for ch in text)
def escape_md_v2_code(text: str) -> str:
"""Escape text for Telegram MarkdownV2 code spans/blocks."""
return text.replace("\\", "\\\\").replace("`", "\\`")
def escape_md_v2_link_url(text: str) -> str:
"""Escape URL for Telegram MarkdownV2 link destination."""
return "".join(f"\\{ch}" if ch in MDV2_LINK_ESCAPE else ch for ch in text)
def mdv2_bold(text: str) -> str:
"""Format text as bold in MarkdownV2."""
return f"*{escape_md_v2(text)}*"
def mdv2_code_inline(text: str) -> str:
"""Format text as inline code in MarkdownV2."""
return f"`{escape_md_v2_code(text)}`"
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
"""Format a status message with emoji and optional suffix."""
base = f"{emoji} {mdv2_bold(label)}"
if suffix:
return f"{base} {escape_md_v2(suffix)}"
return base
def render_markdown_to_mdv2(text: str) -> str:
"""Render common Markdown into Telegram MarkdownV2."""
if not text:
return ""
text = _normalize_gfm_tables(text)
tokens = _MD.parse(text)
def render_inline_table_plain(children) -> str:
out: List[str] = []
for tok in children:
if tok.type == "text":
out.append(tok.content)
elif tok.type == "code_inline":
out.append(tok.content)
elif tok.type in {"softbreak", "hardbreak"}:
out.append(" ")
elif tok.type == "image":
if tok.content:
out.append(tok.content)
return "".join(out)
def render_inline_plain(children) -> str:
out: List[str] = []
for tok in children:
if tok.type == "text":
out.append(escape_md_v2(tok.content))
elif tok.type == "code_inline":
out.append(escape_md_v2(tok.content))
elif tok.type in {"softbreak", "hardbreak"}:
out.append("\n")
return "".join(out)
def render_inline(children) -> str:
out: List[str] = []
i = 0
while i < len(children):
tok = children[i]
t = tok.type
if t == "text":
out.append(escape_md_v2(tok.content))
elif t in {"softbreak", "hardbreak"}:
out.append("\n")
elif t == "em_open":
out.append("_")
elif t == "em_close":
out.append("_")
elif t == "strong_open":
out.append("*")
elif t == "strong_close":
out.append("*")
elif t == "s_open":
out.append("~")
elif t == "s_close":
out.append("~")
elif t == "code_inline":
out.append(f"`{escape_md_v2_code(tok.content)}`")
elif t == "link_open":
href = ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("href", "")
else:
for key, val in tok.attrs:
if key == "href":
href = val
break
inner_tokens = []
i += 1
while i < len(children) and children[i].type != "link_close":
inner_tokens.append(children[i])
i += 1
link_text = ""
for child in inner_tokens:
if child.type == "text":
link_text += child.content
elif child.type == "code_inline":
link_text += child.content
out.append(
f"[{escape_md_v2(link_text)}]({escape_md_v2_link_url(href)})"
)
elif t == "image":
href = ""
alt = tok.content or ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("src", "")
else:
for key, val in tok.attrs:
if key == "src":
href = val
break
if alt:
out.append(f"{escape_md_v2(alt)} ({escape_md_v2_link_url(href)})")
else:
out.append(escape_md_v2_link_url(href))
else:
out.append(escape_md_v2(tok.content or ""))
i += 1
return "".join(out)
out: List[str] = []
list_stack: List[dict] = []
pending_prefix: Optional[str] = None
blockquote_level = 0
in_heading = False
def apply_blockquote(val: str) -> str:
if blockquote_level <= 0:
return val
prefix = "> " * blockquote_level
return prefix + val.replace("\n", "\n" + prefix)
i = 0
while i < len(tokens):
tok = tokens[i]
t = tok.type
if t == "paragraph_open":
pass
elif t == "paragraph_close":
out.append("\n")
elif t == "heading_open":
in_heading = True
elif t == "heading_close":
in_heading = False
out.append("\n")
elif t == "bullet_list_open":
list_stack.append({"type": "bullet", "index": 1})
elif t == "bullet_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "ordered_list_open":
start = 1
if tok.attrs:
if isinstance(tok.attrs, dict):
val = tok.attrs.get("start")
if val is not None:
try:
start = int(val)
except TypeError, ValueError:
start = 1
else:
for key, val in tok.attrs:
if key == "start":
try:
start = int(val)
except TypeError, ValueError:
start = 1
break
list_stack.append({"type": "ordered", "index": start})
elif t == "ordered_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "list_item_open":
if list_stack:
top = list_stack[-1]
if top["type"] == "bullet":
pending_prefix = "\\- "
else:
pending_prefix = f"{top['index']}\\."
top["index"] += 1
pending_prefix += " "
elif t == "list_item_close":
out.append("\n")
elif t == "blockquote_open":
blockquote_level += 1
elif t == "blockquote_close":
blockquote_level = max(0, blockquote_level - 1)
out.append("\n")
elif t == "table_open":
if pending_prefix:
out.append(apply_blockquote(pending_prefix.rstrip()))
out.append("\n")
pending_prefix = None
rows: List[List[str]] = []
row_is_header: List[bool] = []
j = i + 1
in_thead = False
in_row = False
current_row: List[str] = []
current_row_header = False
in_cell = False
cell_parts: List[str] = []
while j < len(tokens):
tt = tokens[j].type
if tt == "thead_open":
in_thead = True
elif tt == "thead_close":
in_thead = False
elif tt == "tr_open":
in_row = True
current_row = []
current_row_header = in_thead
elif tt in {"th_open", "td_open"}:
in_cell = True
cell_parts = []
elif tt == "inline" and in_cell:
cell_parts.append(
render_inline_table_plain(tokens[j].children or [])
)
elif tt in {"th_close", "td_close"} and in_cell:
cell = " ".join(cell_parts).strip()
current_row.append(cell)
in_cell = False
cell_parts = []
elif tt == "tr_close" and in_row:
rows.append(current_row)
row_is_header.append(bool(current_row_header))
in_row = False
elif tt == "table_close":
break
j += 1
if rows:
col_count = max((len(r) for r in rows), default=0)
norm_rows: List[List[str]] = []
for r in rows:
if len(r) < col_count:
r = r + [""] * (col_count - len(r))
norm_rows.append(r)
widths: List[int] = []
for c in range(col_count):
w = max((len(r[c]) for r in norm_rows), default=0)
widths.append(max(w, 3))
def fmt_row(r: List[str]) -> str:
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
def fmt_sep() -> str:
cells = ["-" * widths[c] for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
last_header_idx = -1
for idx, is_h in enumerate(row_is_header):
if is_h:
last_header_idx = idx
lines: List[str] = []
for idx, r in enumerate(norm_rows):
lines.append(fmt_row(r))
if idx == last_header_idx:
lines.append(fmt_sep())
table_text = "\n".join(lines).rstrip()
out.append(f"```\n{escape_md_v2_code(table_text)}\n```")
out.append("\n")
i = j + 1
continue
elif t in {"code_block", "fence"}:
code = escape_md_v2_code(tok.content.rstrip("\n"))
out.append(f"```\n{code}\n```")
out.append("\n")
elif t == "inline":
rendered = render_inline(tok.children or [])
if in_heading:
rendered = f"*{render_inline_plain(tok.children or [])}*"
if pending_prefix:
rendered = pending_prefix + rendered
pending_prefix = None
rendered = apply_blockquote(rendered)
out.append(rendered)
else:
if tok.content:
out.append(escape_md_v2(tok.content))
i += 1
return "".join(out).rstrip()
__all__ = [
"escape_md_v2",
"escape_md_v2_code",
"escape_md_v2_link_url",
"mdv2_bold",
"mdv2_code_inline",
"format_status",
"render_markdown_to_mdv2",
]

View file

@ -1,487 +1,8 @@
"""
Telegram Platform Adapter
"""Backward-compatible re-export. Use messaging.platforms.telegram for new code."""
Implements MessagingPlatform for Telegram using python-telegram-bot.
"""
from .platforms.telegram import (
TelegramPlatform,
TELEGRAM_AVAILABLE,
)
import asyncio
import os
# Opt-in to future behavior for python-telegram-bot (retry_after as timedelta)
# This must be set BEFORE importing telegram.error
os.environ["PTB_TIMEDELTA"] = "1"
from typing import Callable, Awaitable, Optional, Any
from loguru import logger
from .base import MessagingPlatform
from .models import IncomingMessage
from .telegram_markdown import escape_md_v2
# Optional import - python-telegram-bot may not be installed
try:
from telegram import Update
from telegram.ext import (
Application,
CommandHandler,
MessageHandler,
ContextTypes,
filters,
)
from telegram.error import TelegramError, RetryAfter, NetworkError
from telegram.request import HTTPXRequest
TELEGRAM_AVAILABLE = True
except ImportError:
TELEGRAM_AVAILABLE = False
class TelegramPlatform(MessagingPlatform):
"""
Telegram messaging platform adapter.
Uses python-telegram-bot (BoT API) for Telegram access.
Requires a Bot Token from @BotFather.
"""
name = "telegram"
def __init__(
self,
bot_token: Optional[str] = None,
allowed_user_id: Optional[str] = None,
):
if not TELEGRAM_AVAILABLE:
raise ImportError(
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
)
self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
if not self.bot_token:
# We don't raise here to allow instantiation for testing/conditional logic,
# but start() will fail.
logger.warning("TELEGRAM_BOT_TOKEN not set")
self._application: Optional[Application] = None
self._message_handler: Optional[
Callable[[IncomingMessage], Awaitable[None]]
] = None
self._connected = False
self._limiter: Optional[Any] = None # Will be MessagingRateLimiter
async def start(self) -> None:
"""Initialize and connect to Telegram."""
if not self.bot_token:
raise ValueError("TELEGRAM_BOT_TOKEN is required")
# Configure request with longer timeouts
request = HTTPXRequest(
connection_pool_size=8, connect_timeout=30.0, read_timeout=30.0
)
# Build Application
builder = Application.builder().token(self.bot_token).request(request)
self._application = builder.build()
# Register Internal Handlers
# We catch ALL text messages and commands to forward them
self._application.add_handler(
MessageHandler(filters.TEXT & (~filters.COMMAND), self._on_telegram_message)
)
self._application.add_handler(CommandHandler("start", self._on_start_command))
# Catch-all for other commands if needed, or let them fall through
self._application.add_handler(
MessageHandler(filters.COMMAND, self._on_telegram_message)
)
# Initialize internal components with retry logic
max_retries = 3
for attempt in range(max_retries):
try:
await self._application.initialize()
await self._application.start()
# Start polling (non-blocking way for integration)
if self._application.updater:
await self._application.updater.start_polling(
drop_pending_updates=False
)
self._connected = True
break
except (NetworkError, Exception) as e:
if attempt < max_retries - 1:
wait_time = 2 * (attempt + 1)
logger.warning(
f"Connection failed (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
else:
logger.error(f"Failed to connect after {max_retries} attempts")
raise
# Initialize rate limiter
from .limiter import MessagingRateLimiter
self._limiter = await MessagingRateLimiter.get_instance()
# Send startup notification
try:
target = self.allowed_user_id
if target:
startup_text = (
f"🚀 *{escape_md_v2('Claude Code Proxy is online!')}* "
f"{escape_md_v2('(Bot API)')}"
)
await self.send_message(
target,
startup_text,
)
except Exception as e:
logger.warning(f"Could not send startup message: {e}")
logger.info("Telegram platform started (Bot API)")
async def stop(self) -> None:
"""Stop the bot."""
if self._application and self._application.updater:
await self._application.updater.stop()
await self._application.stop()
await self._application.shutdown()
self._connected = False
logger.info("Telegram platform stopped")
async def _with_retry(
self, func: Callable[..., Awaitable[Any]], *args, **kwargs
) -> Any:
"""Helper to execute a function with exponential backoff on network errors."""
max_retries = 3
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except (NetworkError, asyncio.TimeoutError) as e:
if "Message is not modified" in str(e):
return None
if attempt < max_retries - 1:
wait_time = 2**attempt # 1s, 2s, 4s
logger.warning(
f"Telegram API network error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Telegram API failed after {max_retries} attempts: {e}"
)
raise
except RetryAfter as e:
# Telegram explicitly tells us to wait (PTB_TIMEDELTA: retry_after is timedelta)
from datetime import timedelta
retry_after = e.retry_after
if isinstance(retry_after, timedelta):
wait_secs = retry_after.total_seconds()
else:
wait_secs = float(retry_after)
logger.warning(f"Rate limited by Telegram, waiting {wait_secs}s...")
await asyncio.sleep(wait_secs)
# We don't increment attempt here, as this is a specific instruction
return await func(*args, **kwargs)
except TelegramError as e:
# Non-network Telegram errors
err_lower = str(e).lower()
if "message is not modified" in err_lower:
return None
# Best-effort no-op cases (common during chat cleanup / /clear).
if any(
x in err_lower
for x in [
"message to edit not found",
"message to delete not found",
"message can't be deleted",
"message can't be edited",
"not enough rights to delete",
]
):
return None
if "Can't parse entities" in str(e) and kwargs.get("parse_mode"):
logger.warning("Markdown failed, retrying without parse_mode")
kwargs["parse_mode"] = None
return await func(*args, **kwargs)
raise
async def send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = "MarkdownV2",
) -> str:
"""Send a message to a chat."""
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
async def _do_send(parse_mode=parse_mode):
bot = app.bot
msg = await bot.send_message(
chat_id=chat_id,
text=text,
reply_to_message_id=int(reply_to) if reply_to else None,
parse_mode=parse_mode,
)
return str(msg.message_id)
return await self._with_retry(_do_send, parse_mode=parse_mode)
async def edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = "MarkdownV2",
) -> None:
"""Edit an existing message."""
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
async def _do_edit(parse_mode=parse_mode):
bot = app.bot
await bot.edit_message_text(
chat_id=chat_id,
message_id=int(message_id),
text=text,
parse_mode=parse_mode,
)
await self._with_retry(_do_edit, parse_mode=parse_mode)
async def delete_message(
self,
chat_id: str,
message_id: str,
) -> None:
"""Delete a message from a chat."""
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
async def _do_delete():
bot = app.bot
await bot.delete_message(chat_id=chat_id, message_id=int(message_id))
await self._with_retry(_do_delete)
async def delete_messages(self, chat_id: str, message_ids: list[str]) -> None:
"""Delete multiple messages (best-effort)."""
if not message_ids:
return
app = self._application
if not app or not app.bot:
raise RuntimeError("Telegram application or bot not initialized")
# PTB supports bulk deletion via delete_messages; fall back to per-message.
bot = app.bot
if hasattr(bot, "delete_messages"):
async def _do_bulk():
mids = []
for mid in message_ids:
try:
mids.append(int(mid))
except Exception:
continue
if not mids:
return None
# delete_messages accepts a sequence of ints (up to 100).
await bot.delete_messages(chat_id=chat_id, message_ids=mids)
await self._with_retry(_do_bulk)
return
for mid in message_ids:
await self.delete_message(chat_id, mid)
async def queue_send_message(
self,
chat_id: str,
text: str,
reply_to: Optional[str] = None,
parse_mode: Optional[str] = "MarkdownV2",
fire_and_forget: bool = True,
) -> Optional[str]:
"""Enqueue a message to be sent (using limiter)."""
# Note: Bot API handles limits better, but we still use our limiter for nice queuing
if not self._limiter:
return await self.send_message(chat_id, text, reply_to, parse_mode)
async def _send():
return await self.send_message(chat_id, text, reply_to, parse_mode)
if fire_and_forget:
self._limiter.fire_and_forget(_send)
return None
else:
return await self._limiter.enqueue(_send)
async def queue_edit_message(
self,
chat_id: str,
message_id: str,
text: str,
parse_mode: Optional[str] = "MarkdownV2",
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message edit."""
if not self._limiter:
return await self.edit_message(chat_id, message_id, text, parse_mode)
async def _edit():
return await self.edit_message(chat_id, message_id, text, parse_mode)
dedup_key = f"edit:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_edit, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_edit, dedup_key=dedup_key)
async def queue_delete_message(
self,
chat_id: str,
message_id: str,
fire_and_forget: bool = True,
) -> None:
"""Enqueue a message delete."""
if not self._limiter:
return await self.delete_message(chat_id, message_id)
async def _delete():
return await self.delete_message(chat_id, message_id)
dedup_key = f"del:{chat_id}:{message_id}"
if fire_and_forget:
self._limiter.fire_and_forget(_delete, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_delete, dedup_key=dedup_key)
async def queue_delete_messages(
self,
chat_id: str,
message_ids: list[str],
fire_and_forget: bool = True,
) -> None:
"""Enqueue a bulk delete (if supported) or a sequence of deletes."""
if not message_ids:
return
if not self._limiter:
return await self.delete_messages(chat_id, message_ids)
async def _bulk():
return await self.delete_messages(chat_id, message_ids)
# Dedup by the chunk content; okay to be coarse here.
dedup_key = f"del_bulk:{chat_id}:{hash(tuple(message_ids))}"
if fire_and_forget:
self._limiter.fire_and_forget(_bulk, dedup_key=dedup_key)
else:
await self._limiter.enqueue(_bulk, dedup_key=dedup_key)
def fire_and_forget(self, task: Awaitable[Any]) -> None:
"""Execute a coroutine without awaiting it."""
if asyncio.iscoroutine(task):
asyncio.create_task(task)
else:
asyncio.ensure_future(task)
def on_message(
self,
handler: Callable[[IncomingMessage], Awaitable[None]],
) -> None:
"""Register a message handler callback."""
self._message_handler = handler
@property
def is_connected(self) -> bool:
"""Check if connected."""
return self._connected
async def _on_start_command(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Handle /start command."""
if update.message:
await update.message.reply_text("👋 Hello! I am the Claude Code Proxy Bot.")
# We can also treat this as a message if we want it to trigger something
await self._on_telegram_message(update, context)
async def _on_telegram_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Handle incoming updates."""
if (
not update.message
or not update.message.text
or not update.effective_user
or not update.effective_chat
):
return
user_id = str(update.effective_user.id)
chat_id = str(update.effective_chat.id)
# Security check
if self.allowed_user_id:
if user_id != str(self.allowed_user_id).strip():
logger.warning(f"Unauthorized access attempt from {user_id}")
return
message_id = str(update.message.message_id)
reply_to = (
str(update.message.reply_to_message.message_id)
if update.message.reply_to_message
else None
)
text_preview = (update.message.text or "")[:80]
if len(update.message.text or "") > 80:
text_preview += "..."
logger.info(
"TELEGRAM_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
chat_id,
message_id,
reply_to,
text_preview,
)
if not self._message_handler:
return
incoming = IncomingMessage(
text=update.message.text,
chat_id=chat_id,
user_id=user_id,
message_id=message_id,
platform="telegram",
reply_to_message_id=reply_to,
raw_event=update,
)
try:
await self._message_handler(incoming)
except Exception as e:
logger.error(f"Error handling message: {e}")
try:
await self.send_message(
chat_id,
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(str(e)[:200])}",
reply_to=incoming.message_id,
parse_mode="MarkdownV2",
)
except Exception:
pass
__all__ = ["TelegramPlatform", "TELEGRAM_AVAILABLE"]

View file

@ -1,384 +1,14 @@
"""Telegram MarkdownV2 utilities.
Renders common Markdown into Telegram MarkdownV2 format.
Used by the message handler and Telegram platform adapter.
"""
import re
from typing import List, Optional
from markdown_it import MarkdownIt
MDV2_SPECIAL_CHARS = set("\\_*[]()~`>#+-=|{}.!")
MDV2_LINK_ESCAPE = set("\\)")
_MD = MarkdownIt("commonmark", {"html": False, "breaks": False})
_MD.enable("strikethrough")
_MD.enable("table")
_TABLE_SEP_RE = re.compile(r"^\s*\|?\s*:?-{3,}:?\s*(\|\s*:?-{3,}:?\s*)+\|?\s*$")
_FENCE_RE = re.compile(r"^\s*```")
def _is_gfm_table_header_line(line: str) -> bool:
"""Check if line is a GFM table header (pipe-delimited, not separator)."""
if "|" not in line:
return False
if _TABLE_SEP_RE.match(line):
return False
stripped = line.strip()
parts = [p.strip() for p in stripped.strip("|").split("|")]
parts = [p for p in parts if p != ""]
return len(parts) >= 2
def _normalize_gfm_tables(text: str) -> str:
"""
Many LLMs emit tables immediately after a paragraph line (no blank line).
Markdown-it will treat that as a softbreak within the paragraph, so the
table extension won't trigger. Insert a blank line before detected tables.
We only do this outside fenced code blocks.
"""
lines = text.splitlines()
if len(lines) < 2:
return text
out_lines: List[str] = []
in_fence = False
for idx, line in enumerate(lines):
if _FENCE_RE.match(line):
in_fence = not in_fence
out_lines.append(line)
continue
if (
not in_fence
and idx + 1 < len(lines)
and _is_gfm_table_header_line(line)
and _TABLE_SEP_RE.match(lines[idx + 1])
):
if out_lines and out_lines[-1].strip() != "":
m = re.match(r"^(\s*)", line)
indent = m.group(1) if m else ""
out_lines.append(indent)
out_lines.append(line)
return "\n".join(out_lines)
def escape_md_v2(text: str) -> str:
"""Escape text for Telegram MarkdownV2."""
return "".join(f"\\{ch}" if ch in MDV2_SPECIAL_CHARS else ch for ch in text)
def escape_md_v2_code(text: str) -> str:
"""Escape text for Telegram MarkdownV2 code spans/blocks."""
return text.replace("\\", "\\\\").replace("`", "\\`")
def escape_md_v2_link_url(text: str) -> str:
"""Escape URL for Telegram MarkdownV2 link destination."""
return "".join(f"\\{ch}" if ch in MDV2_LINK_ESCAPE else ch for ch in text)
def mdv2_bold(text: str) -> str:
"""Format text as bold in MarkdownV2."""
return f"*{escape_md_v2(text)}*"
def mdv2_code_inline(text: str) -> str:
"""Format text as inline code in MarkdownV2."""
return f"`{escape_md_v2_code(text)}`"
def format_status(emoji: str, label: str, suffix: Optional[str] = None) -> str:
"""Format a status message with emoji and optional suffix."""
base = f"{emoji} {mdv2_bold(label)}"
if suffix:
return f"{base} {escape_md_v2(suffix)}"
return base
def render_markdown_to_mdv2(text: str) -> str:
"""Render common Markdown into Telegram MarkdownV2."""
if not text:
return ""
text = _normalize_gfm_tables(text)
tokens = _MD.parse(text)
def render_inline_table_plain(children) -> str:
out: List[str] = []
for tok in children:
if tok.type == "text":
out.append(tok.content)
elif tok.type == "code_inline":
out.append(tok.content)
elif tok.type in {"softbreak", "hardbreak"}:
out.append(" ")
elif tok.type == "image":
if tok.content:
out.append(tok.content)
return "".join(out)
def render_inline_plain(children) -> str:
out: List[str] = []
for tok in children:
if tok.type == "text":
out.append(escape_md_v2(tok.content))
elif tok.type == "code_inline":
out.append(escape_md_v2(tok.content))
elif tok.type in {"softbreak", "hardbreak"}:
out.append("\n")
return "".join(out)
def render_inline(children) -> str:
out: List[str] = []
i = 0
while i < len(children):
tok = children[i]
t = tok.type
if t == "text":
out.append(escape_md_v2(tok.content))
elif t in {"softbreak", "hardbreak"}:
out.append("\n")
elif t == "em_open":
out.append("_")
elif t == "em_close":
out.append("_")
elif t == "strong_open":
out.append("*")
elif t == "strong_close":
out.append("*")
elif t == "s_open":
out.append("~")
elif t == "s_close":
out.append("~")
elif t == "code_inline":
out.append(f"`{escape_md_v2_code(tok.content)}`")
elif t == "link_open":
href = ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("href", "")
else:
for key, val in tok.attrs:
if key == "href":
href = val
break
inner_tokens = []
i += 1
while i < len(children) and children[i].type != "link_close":
inner_tokens.append(children[i])
i += 1
link_text = ""
for child in inner_tokens:
if child.type == "text":
link_text += child.content
elif child.type == "code_inline":
link_text += child.content
out.append(
f"[{escape_md_v2(link_text)}]({escape_md_v2_link_url(href)})"
)
elif t == "image":
href = ""
alt = tok.content or ""
if tok.attrs:
if isinstance(tok.attrs, dict):
href = tok.attrs.get("src", "")
else:
for key, val in tok.attrs:
if key == "src":
href = val
break
if alt:
out.append(f"{escape_md_v2(alt)} ({escape_md_v2_link_url(href)})")
else:
out.append(escape_md_v2_link_url(href))
else:
out.append(escape_md_v2(tok.content or ""))
i += 1
return "".join(out)
out: List[str] = []
list_stack: List[dict] = []
pending_prefix: Optional[str] = None
blockquote_level = 0
in_heading = False
def apply_blockquote(val: str) -> str:
if blockquote_level <= 0:
return val
prefix = "> " * blockquote_level
return prefix + val.replace("\n", "\n" + prefix)
i = 0
while i < len(tokens):
tok = tokens[i]
t = tok.type
if t == "paragraph_open":
pass
elif t == "paragraph_close":
out.append("\n")
elif t == "heading_open":
in_heading = True
elif t == "heading_close":
in_heading = False
out.append("\n")
elif t == "bullet_list_open":
list_stack.append({"type": "bullet", "index": 1})
elif t == "bullet_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "ordered_list_open":
start = 1
if tok.attrs:
if isinstance(tok.attrs, dict):
val = tok.attrs.get("start")
if val is not None:
try:
start = int(val)
except TypeError, ValueError:
start = 1
else:
for key, val in tok.attrs:
if key == "start":
try:
start = int(val)
except TypeError, ValueError:
start = 1
break
list_stack.append({"type": "ordered", "index": start})
elif t == "ordered_list_close":
if list_stack:
list_stack.pop()
out.append("\n")
elif t == "list_item_open":
if list_stack:
top = list_stack[-1]
if top["type"] == "bullet":
pending_prefix = "\\- "
else:
pending_prefix = f"{top['index']}\\."
top["index"] += 1
pending_prefix += " "
elif t == "list_item_close":
out.append("\n")
elif t == "blockquote_open":
blockquote_level += 1
elif t == "blockquote_close":
blockquote_level = max(0, blockquote_level - 1)
out.append("\n")
elif t == "table_open":
if pending_prefix:
out.append(apply_blockquote(pending_prefix.rstrip()))
out.append("\n")
pending_prefix = None
rows: List[List[str]] = []
row_is_header: List[bool] = []
j = i + 1
in_thead = False
in_row = False
current_row: List[str] = []
current_row_header = False
in_cell = False
cell_parts: List[str] = []
while j < len(tokens):
tt = tokens[j].type
if tt == "thead_open":
in_thead = True
elif tt == "thead_close":
in_thead = False
elif tt == "tr_open":
in_row = True
current_row = []
current_row_header = in_thead
elif tt in {"th_open", "td_open"}:
in_cell = True
cell_parts = []
elif tt == "inline" and in_cell:
cell_parts.append(
render_inline_table_plain(tokens[j].children or [])
)
elif tt in {"th_close", "td_close"} and in_cell:
cell = " ".join(cell_parts).strip()
current_row.append(cell)
in_cell = False
cell_parts = []
elif tt == "tr_close" and in_row:
rows.append(current_row)
row_is_header.append(bool(current_row_header))
in_row = False
elif tt == "table_close":
break
j += 1
if rows:
col_count = max((len(r) for r in rows), default=0)
norm_rows: List[List[str]] = []
for r in rows:
if len(r) < col_count:
r = r + [""] * (col_count - len(r))
norm_rows.append(r)
widths: List[int] = []
for c in range(col_count):
w = max((len(r[c]) for r in norm_rows), default=0)
widths.append(max(w, 3))
def fmt_row(r: List[str]) -> str:
cells = [r[c].ljust(widths[c]) for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
def fmt_sep() -> str:
cells = ["-" * widths[c] for c in range(col_count)]
return "| " + " | ".join(cells) + " |"
last_header_idx = -1
for idx, is_h in enumerate(row_is_header):
if is_h:
last_header_idx = idx
lines: List[str] = []
for idx, r in enumerate(norm_rows):
lines.append(fmt_row(r))
if idx == last_header_idx:
lines.append(fmt_sep())
table_text = "\n".join(lines).rstrip()
out.append(f"```\n{escape_md_v2_code(table_text)}\n```")
out.append("\n")
i = j + 1
continue
elif t in {"code_block", "fence"}:
code = escape_md_v2_code(tok.content.rstrip("\n"))
out.append(f"```\n{code}\n```")
out.append("\n")
elif t == "inline":
rendered = render_inline(tok.children or [])
if in_heading:
rendered = f"*{render_inline_plain(tok.children or [])}*"
if pending_prefix:
rendered = pending_prefix + rendered
pending_prefix = None
rendered = apply_blockquote(rendered)
out.append(rendered)
else:
if tok.content:
out.append(escape_md_v2(tok.content))
i += 1
return "".join(out).rstrip()
"""Backward-compatible re-export. Use messaging.rendering.telegram_markdown for new code."""
from .rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
escape_md_v2_link_url,
mdv2_bold,
mdv2_code_inline,
format_status,
render_markdown_to_mdv2,
)
__all__ = [
"escape_md_v2",

View file

@ -1,441 +1,5 @@
"""Tree data structures for message queue.
"""Backward-compatible re-export. Use messaging.trees.data for new code."""
Contains MessageState, MessageNode, and MessageTree classes.
"""
from .trees.data import MessageTree, MessageNode, MessageState
import asyncio
from collections import deque
from contextlib import asynccontextmanager
from enum import Enum
from datetime import datetime, timezone
from typing import Dict, Optional, List, Any, cast
from dataclasses import dataclass, field
from .models import IncomingMessage
from loguru import logger
class MessageState(Enum):
"""State of a message node in the tree."""
PENDING = "pending" # Queued, waiting to be processed
IN_PROGRESS = "in_progress" # Currently being processed by Claude
COMPLETED = "completed" # Processing finished successfully
ERROR = "error" # Processing failed
@dataclass
class MessageNode:
"""
A node in the message tree.
Each node represents a single message and tracks:
- Its relationship to parent/children
- Its processing state
- Claude session information
"""
node_id: str # Unique ID (typically message_id)
incoming: IncomingMessage # The original message
status_message_id: str # Bot's status message ID
state: MessageState = MessageState.PENDING
parent_id: Optional[str] = None # Parent node ID (None for root)
session_id: Optional[str] = None # Claude session ID (forked from parent)
children_ids: List[str] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
context: Any = None # Additional context if needed
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"node_id": self.node_id,
"incoming": {
"text": self.incoming.text,
"chat_id": self.incoming.chat_id,
"user_id": self.incoming.user_id,
"message_id": self.incoming.message_id,
"platform": self.incoming.platform,
"reply_to_message_id": self.incoming.reply_to_message_id,
"username": self.incoming.username,
},
"status_message_id": self.status_message_id,
"state": self.state.value,
"parent_id": self.parent_id,
"session_id": self.session_id,
"children_ids": self.children_ids,
"created_at": self.created_at.isoformat(),
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"error_message": self.error_message,
}
@classmethod
def from_dict(cls, data: dict) -> "MessageNode":
"""Create from dictionary (JSON deserialization)."""
incoming_data = data["incoming"]
incoming = IncomingMessage(
text=incoming_data["text"],
chat_id=incoming_data["chat_id"],
user_id=incoming_data["user_id"],
message_id=incoming_data["message_id"],
platform=incoming_data["platform"],
reply_to_message_id=incoming_data.get("reply_to_message_id"),
username=incoming_data.get("username"),
)
return cls(
node_id=data["node_id"],
incoming=incoming,
status_message_id=data["status_message_id"],
state=MessageState(data["state"]),
parent_id=data.get("parent_id"),
session_id=data.get("session_id"),
children_ids=data.get("children_ids", []),
created_at=datetime.fromisoformat(data["created_at"]),
completed_at=datetime.fromisoformat(data["completed_at"])
if data.get("completed_at")
else None,
error_message=data.get("error_message"),
)
class MessageTree:
"""
A tree of message nodes with queue functionality.
Provides:
- O(1) node lookup via hashmap
- Per-tree message queue
- Thread-safe operations via asyncio.Lock
"""
def __init__(self, root_node: MessageNode):
"""
Initialize tree with a root node.
Args:
root_node: The root message node
"""
self.root_id = root_node.node_id
self._nodes: Dict[str, MessageNode] = {root_node.node_id: root_node}
self._status_to_node: Dict[str, str] = {
root_node.status_message_id: root_node.node_id
}
self._queue: asyncio.Queue[str] = asyncio.Queue()
self._lock = asyncio.Lock()
self._is_processing = False
self._current_node_id: Optional[str] = None
self._current_task: Optional[asyncio.Task] = None
logger.debug(f"Created MessageTree with root {self.root_id}")
def set_current_task(self, task: Optional[asyncio.Task]) -> None:
"""Set the current processing task. Caller must hold lock."""
self._current_task = task
@property
def is_processing(self) -> bool:
"""Check if tree is currently processing a message."""
return self._is_processing
async def add_node(
self,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
parent_id: str,
) -> MessageNode:
"""
Add a child node to the tree.
Args:
node_id: Unique ID for the new node
incoming: The incoming message
status_message_id: Bot's status message ID
parent_id: Parent node ID
Returns:
The created MessageNode
"""
async with self._lock:
if parent_id not in self._nodes:
raise ValueError(f"Parent node {parent_id} not found in tree")
node = MessageNode(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
parent_id=parent_id,
state=MessageState.PENDING,
)
self._nodes[node_id] = node
self._status_to_node[status_message_id] = node_id
self._nodes[parent_id].children_ids.append(node_id)
logger.debug(f"Added node {node_id} as child of {parent_id}")
return node
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node by ID (O(1) lookup)."""
return self._nodes.get(node_id)
def get_root(self) -> MessageNode:
"""Get the root node."""
return self._nodes[self.root_id]
def get_children(self, node_id: str) -> List[MessageNode]:
"""Get all child nodes of a given node."""
node = self._nodes.get(node_id)
if not node:
return []
return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
def get_parent(self, node_id: str) -> Optional[MessageNode]:
"""Get the parent node."""
node = self._nodes.get(node_id)
if not node or not node.parent_id:
return None
return self._nodes.get(node.parent_id)
def get_parent_session_id(self, node_id: str) -> Optional[str]:
"""
Get the parent's session ID for forking.
Returns None for root nodes.
"""
parent = self.get_parent(node_id)
return parent.session_id if parent else None
async def update_state(
self,
node_id: str,
state: MessageState,
session_id: Optional[str] = None,
error_message: Optional[str] = None,
) -> None:
"""Update a node's state."""
async with self._lock:
node = self._nodes.get(node_id)
if not node:
logger.warning(f"Node {node_id} not found for state update")
return
node.state = state
if session_id:
node.session_id = session_id
if error_message:
node.error_message = error_message
if state in (MessageState.COMPLETED, MessageState.ERROR):
node.completed_at = datetime.now(timezone.utc)
logger.debug(f"Node {node_id} state -> {state.value}")
async def enqueue(self, node_id: str) -> int:
"""
Add a node to the processing queue.
Returns:
Queue position (1-indexed)
"""
async with self._lock:
await self._queue.put(node_id)
position = self._queue.qsize()
logger.debug(f"Enqueued node {node_id}, position {position}")
return position
async def dequeue(self) -> Optional[str]:
"""
Get the next node ID from the queue.
Returns None if queue is empty.
"""
try:
return self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
async def get_queue_snapshot(self) -> List[str]:
"""
Get a snapshot of the current queue order.
Returns:
List of node IDs in FIFO order.
"""
async with self._lock:
# Read internal deque directly to avoid mutating queue state.
# Drain/put approach would inflate _unfinished_tasks without task_done().
queue_deque = cast(deque, getattr(self._queue, "_queue"))
return list(queue_deque)
def get_queue_size(self) -> int:
"""Get number of messages waiting in queue."""
return self._queue.qsize()
def remove_from_queue(self, node_id: str) -> bool:
"""
Remove node_id from the internal queue if present.
Caller must hold the tree lock (e.g. via with_lock).
Returns True if node was removed, False if not in queue.
Note: asyncio.Queue has no built-in remove; we filter via the internal
deque. O(n) in queue size; acceptable for typical tree queue sizes.
"""
queue_deque = cast(deque, getattr(self._queue, "_queue"))
if node_id not in queue_deque:
return False
object.__setattr__(
self._queue, "_queue", deque(x for x in queue_deque if x != node_id)
)
return True
@asynccontextmanager
async def with_lock(self):
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
async with self._lock:
yield
def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None:
"""Set processing state. Caller must hold lock for consistency with queue operations."""
self._is_processing = is_processing
self._current_node_id = node_id if is_processing else None
def clear_current_node(self) -> None:
"""Clear the currently processing node ID. Caller must hold lock."""
self._current_node_id = None
def is_current_node(self, node_id: str) -> bool:
"""Check if node_id is the currently processing node."""
return self._current_node_id == node_id
def put_queue_unlocked(self, node_id: str) -> None:
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
self._queue.put_nowait(node_id)
def cancel_current_task(self) -> bool:
"""Cancel the currently running task. Returns True if a task was cancelled."""
if self._current_task and not self._current_task.done():
self._current_task.cancel()
return True
return False
def drain_queue_and_mark_cancelled(
self, error_message: str = "Cancelled by user"
) -> List["MessageNode"]:
"""
Drain the queue, mark each node as ERROR, and return affected nodes.
Does not acquire lock; caller must ensure no concurrent queue access.
"""
nodes: List[MessageNode] = []
while True:
try:
node_id = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
node = self._nodes.get(node_id)
if node:
node.state = MessageState.ERROR
node.error_message = error_message
nodes.append(node)
return nodes
def reset_processing_state(self) -> None:
"""Reset processing flags after cancel/cleanup."""
self._is_processing = False
self._current_node_id = None
@property
def current_node_id(self) -> Optional[str]:
"""Get the ID of the node currently being processed."""
return self._current_node_id
def to_dict(self) -> dict:
"""Serialize tree to dictionary."""
return {
"root_id": self.root_id,
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
}
@classmethod
def from_dict(cls, data: dict) -> "MessageTree":
"""Deserialize tree from dictionary."""
root_id = data["root_id"]
nodes_data = data["nodes"]
# Create root node first
root_node = MessageNode.from_dict(nodes_data[root_id])
tree = cls(root_node)
# Add remaining nodes and build status->node index
for node_id, node_data in nodes_data.items():
if node_id != root_id:
node = MessageNode.from_dict(node_data)
tree._nodes[node_id] = node
tree._status_to_node[node.status_message_id] = node_id
return tree
def all_nodes(self) -> List[MessageNode]:
"""Get all nodes in the tree."""
return list(self._nodes.values())
def has_node(self, node_id: str) -> bool:
"""Check if a node exists in this tree."""
return node_id in self._nodes
def find_node_by_status_message(self, status_msg_id: str) -> Optional[MessageNode]:
"""Find the node that has this status message ID (O(1) lookup)."""
node_id = self._status_to_node.get(status_msg_id)
return self._nodes.get(node_id) if node_id else None
def get_descendants(self, node_id: str) -> List[str]:
"""
Get node_id and all descendant IDs (subtree).
Returns:
List of node IDs including the given node.
"""
if node_id not in self._nodes:
return []
result: List[str] = []
stack = [node_id]
while stack:
nid = stack.pop()
result.append(nid)
node = self._nodes.get(nid)
if node:
stack.extend(node.children_ids)
return result
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
"""
Remove a subtree (branch_root and all descendants) from the tree.
Updates parent's children_ids. Caller must hold lock for consistency.
Does not acquire lock internally.
Returns:
List of removed nodes.
"""
if branch_root_id not in self._nodes:
return []
parent = self.get_parent(branch_root_id)
removed = []
for nid in self.get_descendants(branch_root_id):
node = self._nodes.get(nid)
if node:
removed.append(node)
del self._nodes[nid]
del self._status_to_node[node.status_message_id]
if parent and branch_root_id in parent.children_ids:
parent.children_ids = [
c for c in parent.children_ids if c != branch_root_id
]
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
return removed
__all__ = ["MessageTree", "MessageNode", "MessageState"]

View file

@ -1,162 +1,5 @@
"""Async queue processor for message trees.
"""Backward-compatible re-export. Use messaging.trees.processor for new code."""
Handles the async processing lifecycle of tree nodes.
"""
from .trees.processor import TreeQueueProcessor
import asyncio
from typing import Callable, Awaitable, Optional
from .tree_data import MessageTree, MessageNode, MessageState
from loguru import logger
class TreeQueueProcessor:
"""
Handles async queue processing for a single tree.
Separates the async processing logic from the data management.
"""
def __init__(
self,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
):
self._queue_update_callback = queue_update_callback
self._node_started_callback = node_started_callback
def set_queue_update_callback(
self,
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
) -> None:
"""Update the callback used to refresh queue positions."""
self._queue_update_callback = queue_update_callback
def set_node_started_callback(
self,
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
) -> None:
"""Update the callback used when a queued node starts processing."""
self._node_started_callback = node_started_callback
async def _notify_queue_updated(self, tree: MessageTree) -> None:
"""Invoke queue update callback if set."""
if not self._queue_update_callback:
return
try:
await self._queue_update_callback(tree)
except Exception as e:
logger.warning(f"Queue update callback failed: {e}")
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
"""Invoke node started callback if set."""
if not self._node_started_callback:
return
try:
await self._node_started_callback(tree, node_id)
except Exception as e:
logger.warning(f"Node started callback failed: {e}")
async def process_node(
self,
tree: MessageTree,
node: MessageNode,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process a single node and then check the queue."""
# Skip if already in terminal state (e.g. from error propagation)
if node.state.value == MessageState.ERROR.value:
logger.info(
f"Skipping node {node.node_id} as it is already in state {node.state}"
)
# Still need to check for next messages
await self._process_next(tree, processor)
return
try:
await processor(node.node_id, node)
except asyncio.CancelledError:
logger.info(f"Task for node {node.node_id} was cancelled")
raise
except Exception as e:
logger.error(f"Error processing node {node.node_id}: {e}")
await tree.update_state(
node.node_id, MessageState.ERROR, error_message=str(e)
)
finally:
tree.clear_current_node()
# Check if there are more messages in the queue
await self._process_next(tree, processor)
async def _process_next(
self,
tree: MessageTree,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process the next message in queue, if any."""
next_node_id = None
node = None
async with tree.with_lock():
next_node_id = await tree.dequeue()
if not next_node_id:
tree.set_processing_state(None, False)
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
return
tree.set_processing_state(next_node_id, True)
logger.info(f"Processing next queued node {next_node_id}")
# Process next node (outside lock)
node = tree.get_node(next_node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
# Notify that this node has started processing and refresh queue positions.
if next_node_id:
await self._notify_node_started(tree, next_node_id)
await self._notify_queue_updated(tree)
async def enqueue_and_start(
self,
tree: MessageTree,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node or start processing immediately.
Args:
tree: The message tree
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
async with tree.with_lock():
if tree.is_processing:
tree.put_queue_unlocked(node_id)
queue_size = tree.get_queue_size()
logger.info(f"Queued node {node_id}, position {queue_size}")
return True
else:
tree.set_processing_state(node_id, True)
# Process outside the lock
node = tree.get_node(node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
return False
def cancel_current(self, tree: MessageTree) -> bool:
"""Cancel the currently running task in a tree."""
return tree.cancel_current_task()
__all__ = ["TreeQueueProcessor"]

View file

@ -1,461 +1,10 @@
"""Tree-Based Message Queue Manager - Refactored.
"""Backward-compatible re-export. Use messaging.trees.queue_manager for new code."""
Coordinates data access, async processing, and error handling.
Uses TreeRepository for data, TreeQueueProcessor for async logic.
"""
from .trees.queue_manager import (
TreeQueueManager,
MessageTree,
MessageNode,
MessageState,
)
import asyncio
from datetime import datetime, timezone
from typing import Callable, Awaitable, List, Optional
from .models import IncomingMessage
from .tree_data import MessageState, MessageNode, MessageTree
from .tree_repository import TreeRepository
from .tree_processor import TreeQueueProcessor
from loguru import logger
# Backward compatibility: re-export moved classes
__all__ = [
"TreeQueueManager",
"MessageState",
"MessageNode",
"MessageTree",
]
class TreeQueueManager:
"""
Manages multiple message trees. Facade that coordinates components.
Each new conversation creates a new tree.
Replies to existing messages add nodes to existing trees.
Components:
- TreeRepository: Data access layer
- TreeQueueProcessor: Async queue processing
"""
def __init__(
self,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
):
self._repository = TreeRepository()
self._processor = TreeQueueProcessor(
queue_update_callback=queue_update_callback,
node_started_callback=node_started_callback,
)
self._lock = asyncio.Lock()
logger.info("TreeQueueManager initialized")
async def create_tree(
self,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
) -> MessageTree:
"""
Create a new tree with a root node.
Args:
node_id: ID for the root node
incoming: The incoming message
status_message_id: Bot's status message ID
Returns:
The created MessageTree
"""
async with self._lock:
root_node = MessageNode(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
state=MessageState.PENDING,
)
tree = MessageTree(root_node)
self._repository.add_tree(node_id, tree)
logger.info(f"Created new tree with root {node_id}")
return tree
async def add_to_tree(
self,
parent_node_id: str,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
) -> tuple[MessageTree, MessageNode]:
"""
Add a reply as a child node to an existing tree.
Args:
parent_node_id: ID of the parent message
node_id: ID for the new node
incoming: The incoming reply message
status_message_id: Bot's status message ID
Returns:
Tuple of (tree, new_node)
"""
async with self._lock:
if not self._repository.has_node(parent_node_id):
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
tree = self._repository.get_tree_for_node(parent_node_id)
if not tree:
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
node = await tree.add_node(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
parent_id=parent_node_id,
)
async with self._lock:
self._repository.register_node(node_id, tree.root_id)
logger.info(f"Added node {node_id} to tree {tree.root_id}")
return tree, node
def get_tree(self, root_id: str) -> Optional[MessageTree]:
"""Get a tree by its root ID."""
return self._repository.get_tree(root_id)
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
"""Get the tree containing a given node."""
return self._repository.get_tree_for_node(node_id)
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node from any tree."""
return self._repository.get_node(node_id)
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
"""Resolve a message ID to the actual parent node ID."""
return self._repository.resolve_parent_node_id(msg_id)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
return self._repository.is_tree_busy(root_id)
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
return self._repository.is_node_tree_busy(node_id)
async def enqueue(
self,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node for processing.
If the tree is not busy, processing starts immediately.
If busy, the message is queued.
Args:
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
logger.error(f"No tree found for node {node_id}")
return False
return await self._processor.enqueue_and_start(tree, node_id, processor)
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
return self._repository.get_queue_size(node_id)
def get_pending_children(self, node_id: str) -> List[MessageNode]:
"""Get all pending child nodes (recursively) of a given node."""
return self._repository.get_pending_children(node_id)
async def mark_node_error(
self,
node_id: str,
error_message: str,
propagate_to_children: bool = True,
) -> List[MessageNode]:
"""
Mark a node as ERROR and optionally propagate to pending children.
Args:
node_id: The node to mark as error
error_message: Error description
propagate_to_children: If True, also mark pending children as error
Returns:
List of all nodes marked as error (including children)
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
affected = []
node = tree.get_node(node_id)
if node:
await tree.update_state(
node_id, MessageState.ERROR, error_message=error_message
)
affected.append(node)
if propagate_to_children:
pending_children = self._repository.get_pending_children(node_id)
for child in pending_children:
await tree.update_state(
child.node_id,
MessageState.ERROR,
error_message=f"Parent failed: {error_message}",
)
affected.append(child)
return affected
def cancel_tree(self, root_id: str) -> List[MessageNode]:
"""
Cancel all queued and in-progress messages in a tree.
Updates node states to ERROR and returns list of affected nodes
that were actually active or in the current processing queue.
"""
tree = self._repository.get_tree(root_id)
if not tree:
return []
cancelled_nodes = []
# 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
node = tree.get_node(current_id)
if node and node.state not in (
MessageState.COMPLETED,
MessageState.ERROR,
):
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
cancelled_nodes.append(node)
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
cleanup_count = 0
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node.node_id not in cancelled_ids
):
node.state = MessageState.ERROR
node.error_message = "Stale task cleaned up"
cleanup_count += 1
tree.reset_processing_state()
if cancelled_nodes:
logger.info(
f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
)
if cleanup_count:
logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")
return cancelled_nodes
async def cancel_node(self, node_id: str) -> List[MessageNode]:
"""
Cancel a single node (queued or in-progress) without affecting other nodes.
- If the node is currently running, cancels the current asyncio task.
- If the node is queued, removes it from the queue.
- Marks the node as ERROR with "Cancelled by user".
Returns:
List containing the cancelled node if it was cancellable, else empty list.
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
async with tree.with_lock():
node = tree.get_node(node_id)
if not node:
return []
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
return []
if tree.is_current_node(node_id):
self._processor.cancel_current(tree)
try:
tree.remove_from_queue(node_id)
except Exception:
logger.debug(
"Failed to remove node from queue; will rely on state=ERROR"
)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
return [node]
async def cancel_all(self) -> List[MessageNode]:
"""Cancel all messages in all trees (async wrapper)."""
async with self._lock:
return self.cancel_all_sync()
def cancel_all_sync(self) -> List[MessageNode]:
"""
Cancel all messages in all trees (synchronous/locked version).
NOTE: Must be called with self._lock held.
"""
all_cancelled = []
for root_id in list(self._repository.tree_ids()):
all_cancelled.extend(self.cancel_tree(root_id))
return all_cancelled
def cleanup_stale_nodes(self) -> int:
"""
Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
Used on startup to reconcile restored state.
"""
count = 0
for tree in self._repository.all_trees():
for node in tree.all_nodes():
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
node.state = MessageState.ERROR
node.error_message = "Lost during server restart"
count += 1
if count:
logger.info(f"Cleaned up {count} stale nodes during startup")
return count
def get_tree_count(self) -> int:
"""Get the number of active message trees."""
return self._repository.tree_count()
def set_queue_update_callback(
self,
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
) -> None:
"""Set callback for queue position updates."""
self._processor.set_queue_update_callback(queue_update_callback)
def set_node_started_callback(
self,
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
) -> None:
"""Set callback for when a queued node starts processing."""
self._processor.set_node_started_callback(node_started_callback)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree (for external mapping)."""
self._repository.register_node(node_id, root_id)
async def cancel_branch(self, branch_root_id: str) -> List[MessageNode]:
"""
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
return []
branch_ids = set(tree.get_descendants(branch_root_id))
cancelled: List[MessageNode] = []
async with tree.with_lock():
for nid in branch_ids:
node = tree.get_node(nid)
if not node or node.state in (
MessageState.COMPLETED,
MessageState.ERROR,
):
continue
if tree.is_current_node(nid):
self._processor.cancel_current(tree)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
cancelled.append(node)
else:
tree.remove_from_queue(nid)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
cancelled.append(node)
if cancelled:
logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
return cancelled
async def remove_branch(
self, branch_root_id: str
) -> tuple[List[MessageNode], str, bool]:
"""
Remove a branch (subtree) from the tree.
If branch_root is the tree root, removes the entire tree.
Returns:
(removed_nodes, root_id, removed_entire_tree)
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
return ([], "", False)
root_id = tree.root_id
if branch_root_id == root_id:
cancelled = self.cancel_tree(root_id)
removed_tree = self._repository.remove_tree(root_id)
if removed_tree:
return (removed_tree.all_nodes(), root_id, True)
return (cancelled, root_id, True)
async with tree.with_lock():
removed = tree.remove_branch(branch_root_id)
self._repository.unregister_nodes([n.node_id for n in removed])
return (removed, root_id, False)
def to_dict(self) -> dict:
"""Serialize all trees."""
return self._repository.to_dict()
@classmethod
def from_dict(
cls,
data: dict,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
) -> "TreeQueueManager":
"""Deserialize from dictionary."""
manager = cls(
queue_update_callback=queue_update_callback,
node_started_callback=node_started_callback,
)
manager._repository = TreeRepository.from_dict(data)
return manager
__all__ = ["TreeQueueManager", "MessageTree", "MessageNode", "MessageState"]

View file

@ -1,168 +1,5 @@
"""Repository for message tree data access.
"""Backward-compatible re-export. Use messaging.trees.repository for new code."""
Provides data access layer for managing trees and node mappings.
"""
from .trees.repository import TreeRepository
from typing import Dict, Optional, List
from loguru import logger
from .tree_data import MessageTree, MessageNode, MessageState
class TreeRepository:
"""
Repository for message tree data access.
Manages the storage and lookup of trees and node-to-tree mappings.
"""
def __init__(self):
self._trees: Dict[str, MessageTree] = {} # root_id -> tree
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
def get_tree(self, root_id: str) -> Optional[MessageTree]:
"""Get a tree by its root ID."""
return self._trees.get(root_id)
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
"""Get the tree containing a given node."""
root_id = self._node_to_tree.get(node_id)
if not root_id:
return None
return self._trees.get(root_id)
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node from any tree."""
tree = self.get_tree_for_node(node_id)
return tree.get_node(node_id) if tree else None
def add_tree(self, root_id: str, tree: MessageTree) -> None:
"""Add a new tree to the repository."""
self._trees[root_id] = tree
self._node_to_tree[root_id] = root_id
logger.debug("TREE_REPO: add_tree root_id=%s", root_id)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree."""
self._node_to_tree[node_id] = root_id
logger.debug("TREE_REPO: register_node node_id=%s root_id=%s", node_id, root_id)
def has_node(self, node_id: str) -> bool:
"""Check if a node is registered in any tree."""
return node_id in self._node_to_tree
def tree_count(self) -> int:
"""Get the number of trees in the repository."""
return len(self._trees)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
tree = self._trees.get(root_id)
return tree.is_processing if tree else False
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
tree = self.get_tree_for_node(node_id)
return tree.is_processing if tree else False
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
tree = self.get_tree_for_node(node_id)
return tree.get_queue_size() if tree else 0
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
"""
Resolve a message ID to the actual parent node ID.
Handles the case where msg_id is a status message ID
(which maps to the tree but isn't an actual node).
Returns:
The node_id to use as parent, or None if not found
"""
tree = self.get_tree_for_node(msg_id)
if not tree:
return None
# Check if msg_id is an actual node
if tree.has_node(msg_id):
return msg_id
# Otherwise, it might be a status message - find the owning node
node = tree.find_node_by_status_message(msg_id)
if node:
return node.node_id
return None
def get_pending_children(self, node_id: str) -> List[MessageNode]:
"""
Get all pending child nodes (recursively) of a given node.
Used for error propagation - when a node fails, its pending
children should also be marked as failed.
"""
tree = self.get_tree_for_node(node_id)
if not tree:
return []
pending = []
node = tree.get_node(node_id)
if not node:
return []
for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
# Recursively get children of pending children
pending.extend(self.get_pending_children(child_id))
return pending
def all_trees(self) -> List[MessageTree]:
"""Get all trees in the repository."""
return list(self._trees.values())
def tree_ids(self) -> List[str]:
"""Get all tree root IDs."""
return list(self._trees.keys())
def unregister_nodes(self, node_ids: List[str]) -> None:
"""Remove node IDs from the node-to-tree mapping."""
for nid in node_ids:
self._node_to_tree.pop(nid, None)
def remove_tree(self, root_id: str) -> Optional[MessageTree]:
"""
Remove a tree and all its node mappings from the repository.
Returns:
The removed tree, or None if not found.
"""
tree = self._trees.pop(root_id, None)
if not tree:
return None
for node in tree.all_nodes():
self._node_to_tree.pop(node.node_id, None)
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
return tree
def to_dict(self) -> dict:
"""Serialize all trees."""
return {
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
"node_to_tree": self._node_to_tree.copy(),
}
@classmethod
def from_dict(cls, data: dict) -> "TreeRepository":
"""Deserialize from dictionary."""
from .tree_data import MessageTree
repo = cls()
for root_id, tree_data in data.get("trees", {}).items():
repo._trees[root_id] = MessageTree.from_dict(tree_data)
repo._node_to_tree = data.get("node_to_tree", {})
return repo
__all__ = ["TreeRepository"]

View file

@ -0,0 +1,11 @@
"""Message tree data structures and queue management."""
from .data import MessageTree, MessageNode, MessageState
from .queue_manager import TreeQueueManager
__all__ = [
"TreeQueueManager",
"MessageTree",
"MessageNode",
"MessageState",
]

441
messaging/trees/data.py Normal file
View file

@ -0,0 +1,441 @@
"""Tree data structures for message queue.
Contains MessageState, MessageNode, and MessageTree classes.
"""
import asyncio
from collections import deque
from contextlib import asynccontextmanager
from enum import Enum
from datetime import datetime, timezone
from typing import Dict, Optional, List, Any, cast
from dataclasses import dataclass, field
from ..models import IncomingMessage
from loguru import logger
class MessageState(Enum):
"""State of a message node in the tree."""
PENDING = "pending" # Queued, waiting to be processed
IN_PROGRESS = "in_progress" # Currently being processed by Claude
COMPLETED = "completed" # Processing finished successfully
ERROR = "error" # Processing failed
@dataclass
class MessageNode:
"""
A node in the message tree.
Each node represents a single message and tracks:
- Its relationship to parent/children
- Its processing state
- Claude session information
"""
node_id: str # Unique ID (typically message_id)
incoming: IncomingMessage # The original message
status_message_id: str # Bot's status message ID
state: MessageState = MessageState.PENDING
parent_id: Optional[str] = None # Parent node ID (None for root)
session_id: Optional[str] = None # Claude session ID (forked from parent)
children_ids: List[str] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
context: Any = None # Additional context if needed
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"node_id": self.node_id,
"incoming": {
"text": self.incoming.text,
"chat_id": self.incoming.chat_id,
"user_id": self.incoming.user_id,
"message_id": self.incoming.message_id,
"platform": self.incoming.platform,
"reply_to_message_id": self.incoming.reply_to_message_id,
"username": self.incoming.username,
},
"status_message_id": self.status_message_id,
"state": self.state.value,
"parent_id": self.parent_id,
"session_id": self.session_id,
"children_ids": self.children_ids,
"created_at": self.created_at.isoformat(),
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"error_message": self.error_message,
}
@classmethod
def from_dict(cls, data: dict) -> "MessageNode":
"""Create from dictionary (JSON deserialization)."""
incoming_data = data["incoming"]
incoming = IncomingMessage(
text=incoming_data["text"],
chat_id=incoming_data["chat_id"],
user_id=incoming_data["user_id"],
message_id=incoming_data["message_id"],
platform=incoming_data["platform"],
reply_to_message_id=incoming_data.get("reply_to_message_id"),
username=incoming_data.get("username"),
)
return cls(
node_id=data["node_id"],
incoming=incoming,
status_message_id=data["status_message_id"],
state=MessageState(data["state"]),
parent_id=data.get("parent_id"),
session_id=data.get("session_id"),
children_ids=data.get("children_ids", []),
created_at=datetime.fromisoformat(data["created_at"]),
completed_at=datetime.fromisoformat(data["completed_at"])
if data.get("completed_at")
else None,
error_message=data.get("error_message"),
)
class MessageTree:
"""
A tree of message nodes with queue functionality.
Provides:
- O(1) node lookup via hashmap
- Per-tree message queue
- Thread-safe operations via asyncio.Lock
"""
def __init__(self, root_node: MessageNode):
"""
Initialize tree with a root node.
Args:
root_node: The root message node
"""
self.root_id = root_node.node_id
self._nodes: Dict[str, MessageNode] = {root_node.node_id: root_node}
self._status_to_node: Dict[str, str] = {
root_node.status_message_id: root_node.node_id
}
self._queue: asyncio.Queue[str] = asyncio.Queue()
self._lock = asyncio.Lock()
self._is_processing = False
self._current_node_id: Optional[str] = None
self._current_task: Optional[asyncio.Task] = None
logger.debug(f"Created MessageTree with root {self.root_id}")
def set_current_task(self, task: Optional[asyncio.Task]) -> None:
"""Set the current processing task. Caller must hold lock."""
self._current_task = task
@property
def is_processing(self) -> bool:
"""Check if tree is currently processing a message."""
return self._is_processing
async def add_node(
self,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
parent_id: str,
) -> MessageNode:
"""
Add a child node to the tree.
Args:
node_id: Unique ID for the new node
incoming: The incoming message
status_message_id: Bot's status message ID
parent_id: Parent node ID
Returns:
The created MessageNode
"""
async with self._lock:
if parent_id not in self._nodes:
raise ValueError(f"Parent node {parent_id} not found in tree")
node = MessageNode(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
parent_id=parent_id,
state=MessageState.PENDING,
)
self._nodes[node_id] = node
self._status_to_node[status_message_id] = node_id
self._nodes[parent_id].children_ids.append(node_id)
logger.debug(f"Added node {node_id} as child of {parent_id}")
return node
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node by ID (O(1) lookup)."""
return self._nodes.get(node_id)
def get_root(self) -> MessageNode:
"""Get the root node."""
return self._nodes[self.root_id]
def get_children(self, node_id: str) -> List[MessageNode]:
"""Get all child nodes of a given node."""
node = self._nodes.get(node_id)
if not node:
return []
return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
def get_parent(self, node_id: str) -> Optional[MessageNode]:
"""Get the parent node."""
node = self._nodes.get(node_id)
if not node or not node.parent_id:
return None
return self._nodes.get(node.parent_id)
def get_parent_session_id(self, node_id: str) -> Optional[str]:
"""
Get the parent's session ID for forking.
Returns None for root nodes.
"""
parent = self.get_parent(node_id)
return parent.session_id if parent else None
async def update_state(
self,
node_id: str,
state: MessageState,
session_id: Optional[str] = None,
error_message: Optional[str] = None,
) -> None:
"""Update a node's state."""
async with self._lock:
node = self._nodes.get(node_id)
if not node:
logger.warning(f"Node {node_id} not found for state update")
return
node.state = state
if session_id:
node.session_id = session_id
if error_message:
node.error_message = error_message
if state in (MessageState.COMPLETED, MessageState.ERROR):
node.completed_at = datetime.now(timezone.utc)
logger.debug(f"Node {node_id} state -> {state.value}")
async def enqueue(self, node_id: str) -> int:
"""
Add a node to the processing queue.
Returns:
Queue position (1-indexed)
"""
async with self._lock:
await self._queue.put(node_id)
position = self._queue.qsize()
logger.debug(f"Enqueued node {node_id}, position {position}")
return position
async def dequeue(self) -> Optional[str]:
"""
Get the next node ID from the queue.
Returns None if queue is empty.
"""
try:
return self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
async def get_queue_snapshot(self) -> List[str]:
"""
Get a snapshot of the current queue order.
Returns:
List of node IDs in FIFO order.
"""
async with self._lock:
# Read internal deque directly to avoid mutating queue state.
# Drain/put approach would inflate _unfinished_tasks without task_done().
queue_deque = cast(deque, getattr(self._queue, "_queue"))
return list(queue_deque)
def get_queue_size(self) -> int:
"""Get number of messages waiting in queue."""
return self._queue.qsize()
def remove_from_queue(self, node_id: str) -> bool:
"""
Remove node_id from the internal queue if present.
Caller must hold the tree lock (e.g. via with_lock).
Returns True if node was removed, False if not in queue.
Note: asyncio.Queue has no built-in remove; we filter via the internal
deque. O(n) in queue size; acceptable for typical tree queue sizes.
"""
queue_deque = cast(deque, getattr(self._queue, "_queue"))
if node_id not in queue_deque:
return False
object.__setattr__(
self._queue, "_queue", deque(x for x in queue_deque if x != node_id)
)
return True
@asynccontextmanager
async def with_lock(self):
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
async with self._lock:
yield
def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None:
"""Set processing state. Caller must hold lock for consistency with queue operations."""
self._is_processing = is_processing
self._current_node_id = node_id if is_processing else None
def clear_current_node(self) -> None:
"""Clear the currently processing node ID. Caller must hold lock."""
self._current_node_id = None
def is_current_node(self, node_id: str) -> bool:
"""Check if node_id is the currently processing node."""
return self._current_node_id == node_id
def put_queue_unlocked(self, node_id: str) -> None:
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
self._queue.put_nowait(node_id)
def cancel_current_task(self) -> bool:
"""Cancel the currently running task. Returns True if a task was cancelled."""
if self._current_task and not self._current_task.done():
self._current_task.cancel()
return True
return False
def drain_queue_and_mark_cancelled(
self, error_message: str = "Cancelled by user"
) -> List["MessageNode"]:
"""
Drain the queue, mark each node as ERROR, and return affected nodes.
Does not acquire lock; caller must ensure no concurrent queue access.
"""
nodes: List[MessageNode] = []
while True:
try:
node_id = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
node = self._nodes.get(node_id)
if node:
node.state = MessageState.ERROR
node.error_message = error_message
nodes.append(node)
return nodes
def reset_processing_state(self) -> None:
"""Reset processing flags after cancel/cleanup."""
self._is_processing = False
self._current_node_id = None
@property
def current_node_id(self) -> Optional[str]:
"""Get the ID of the node currently being processed."""
return self._current_node_id
def to_dict(self) -> dict:
"""Serialize tree to dictionary."""
return {
"root_id": self.root_id,
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
}
@classmethod
def from_dict(cls, data: dict) -> "MessageTree":
"""Deserialize tree from dictionary."""
root_id = data["root_id"]
nodes_data = data["nodes"]
# Create root node first
root_node = MessageNode.from_dict(nodes_data[root_id])
tree = cls(root_node)
# Add remaining nodes and build status->node index
for node_id, node_data in nodes_data.items():
if node_id != root_id:
node = MessageNode.from_dict(node_data)
tree._nodes[node_id] = node
tree._status_to_node[node.status_message_id] = node_id
return tree
def all_nodes(self) -> List[MessageNode]:
"""Get all nodes in the tree."""
return list(self._nodes.values())
def has_node(self, node_id: str) -> bool:
"""Check if a node exists in this tree."""
return node_id in self._nodes
def find_node_by_status_message(self, status_msg_id: str) -> Optional[MessageNode]:
"""Find the node that has this status message ID (O(1) lookup)."""
node_id = self._status_to_node.get(status_msg_id)
return self._nodes.get(node_id) if node_id else None
def get_descendants(self, node_id: str) -> List[str]:
"""
Get node_id and all descendant IDs (subtree).
Returns:
List of node IDs including the given node.
"""
if node_id not in self._nodes:
return []
result: List[str] = []
stack = [node_id]
while stack:
nid = stack.pop()
result.append(nid)
node = self._nodes.get(nid)
if node:
stack.extend(node.children_ids)
return result
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
"""
Remove a subtree (branch_root and all descendants) from the tree.
Updates parent's children_ids. Caller must hold lock for consistency.
Does not acquire lock internally.
Returns:
List of removed nodes.
"""
if branch_root_id not in self._nodes:
return []
parent = self.get_parent(branch_root_id)
removed = []
for nid in self.get_descendants(branch_root_id):
node = self._nodes.get(nid)
if node:
removed.append(node)
del self._nodes[nid]
del self._status_to_node[node.status_message_id]
if parent and branch_root_id in parent.children_ids:
parent.children_ids = [
c for c in parent.children_ids if c != branch_root_id
]
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
return removed

View file

@ -0,0 +1,162 @@
"""Async queue processor for message trees.
Handles the async processing lifecycle of tree nodes.
"""
import asyncio
from typing import Callable, Awaitable, Optional
from .data import MessageTree, MessageNode, MessageState
from loguru import logger
class TreeQueueProcessor:
"""
Handles async queue processing for a single tree.
Separates the async processing logic from the data management.
"""
def __init__(
self,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
):
self._queue_update_callback = queue_update_callback
self._node_started_callback = node_started_callback
def set_queue_update_callback(
self,
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
) -> None:
"""Update the callback used to refresh queue positions."""
self._queue_update_callback = queue_update_callback
def set_node_started_callback(
self,
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
) -> None:
"""Update the callback used when a queued node starts processing."""
self._node_started_callback = node_started_callback
async def _notify_queue_updated(self, tree: MessageTree) -> None:
"""Invoke queue update callback if set."""
if not self._queue_update_callback:
return
try:
await self._queue_update_callback(tree)
except Exception as e:
logger.warning(f"Queue update callback failed: {e}")
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
"""Invoke node started callback if set."""
if not self._node_started_callback:
return
try:
await self._node_started_callback(tree, node_id)
except Exception as e:
logger.warning(f"Node started callback failed: {e}")
async def process_node(
self,
tree: MessageTree,
node: MessageNode,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process a single node and then check the queue."""
# Skip if already in terminal state (e.g. from error propagation)
if node.state.value == MessageState.ERROR.value:
logger.info(
f"Skipping node {node.node_id} as it is already in state {node.state}"
)
# Still need to check for next messages
await self._process_next(tree, processor)
return
try:
await processor(node.node_id, node)
except asyncio.CancelledError:
logger.info(f"Task for node {node.node_id} was cancelled")
raise
except Exception as e:
logger.error(f"Error processing node {node.node_id}: {e}")
await tree.update_state(
node.node_id, MessageState.ERROR, error_message=str(e)
)
finally:
tree.clear_current_node()
# Check if there are more messages in the queue
await self._process_next(tree, processor)
async def _process_next(
self,
tree: MessageTree,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process the next message in queue, if any."""
next_node_id = None
node = None
async with tree.with_lock():
next_node_id = await tree.dequeue()
if not next_node_id:
tree.set_processing_state(None, False)
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
return
tree.set_processing_state(next_node_id, True)
logger.info(f"Processing next queued node {next_node_id}")
# Process next node (outside lock)
node = tree.get_node(next_node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
# Notify that this node has started processing and refresh queue positions.
if next_node_id:
await self._notify_node_started(tree, next_node_id)
await self._notify_queue_updated(tree)
async def enqueue_and_start(
self,
tree: MessageTree,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node or start processing immediately.
Args:
tree: The message tree
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
async with tree.with_lock():
if tree.is_processing:
tree.put_queue_unlocked(node_id)
queue_size = tree.get_queue_size()
logger.info(f"Queued node {node_id}, position {queue_size}")
return True
else:
tree.set_processing_state(node_id, True)
# Process outside the lock
node = tree.get_node(node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
return False
def cancel_current(self, tree: MessageTree) -> bool:
"""Cancel the currently running task in a tree."""
return tree.cancel_current_task()

View file

@ -0,0 +1,461 @@
"""Tree-Based Message Queue Manager - Refactored.
Coordinates data access, async processing, and error handling.
Uses TreeRepository for data, TreeQueueProcessor for async logic.
"""
import asyncio
from datetime import datetime, timezone
from typing import Callable, Awaitable, List, Optional
from ..models import IncomingMessage
from .data import MessageState, MessageNode, MessageTree
from .repository import TreeRepository
from .processor import TreeQueueProcessor
from loguru import logger
# Backward compatibility: re-export moved classes
__all__ = [
"TreeQueueManager",
"MessageState",
"MessageNode",
"MessageTree",
]
class TreeQueueManager:
"""
Manages multiple message trees. Facade that coordinates components.
Each new conversation creates a new tree.
Replies to existing messages add nodes to existing trees.
Components:
- TreeRepository: Data access layer
- TreeQueueProcessor: Async queue processing
"""
def __init__(
self,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
):
self._repository = TreeRepository()
self._processor = TreeQueueProcessor(
queue_update_callback=queue_update_callback,
node_started_callback=node_started_callback,
)
self._lock = asyncio.Lock()
logger.info("TreeQueueManager initialized")
async def create_tree(
self,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
) -> MessageTree:
"""
Create a new tree with a root node.
Args:
node_id: ID for the root node
incoming: The incoming message
status_message_id: Bot's status message ID
Returns:
The created MessageTree
"""
async with self._lock:
root_node = MessageNode(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
state=MessageState.PENDING,
)
tree = MessageTree(root_node)
self._repository.add_tree(node_id, tree)
logger.info(f"Created new tree with root {node_id}")
return tree
async def add_to_tree(
self,
parent_node_id: str,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
) -> tuple[MessageTree, MessageNode]:
"""
Add a reply as a child node to an existing tree.
Args:
parent_node_id: ID of the parent message
node_id: ID for the new node
incoming: The incoming reply message
status_message_id: Bot's status message ID
Returns:
Tuple of (tree, new_node)
"""
async with self._lock:
if not self._repository.has_node(parent_node_id):
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
tree = self._repository.get_tree_for_node(parent_node_id)
if not tree:
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
node = await tree.add_node(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
parent_id=parent_node_id,
)
async with self._lock:
self._repository.register_node(node_id, tree.root_id)
logger.info(f"Added node {node_id} to tree {tree.root_id}")
return tree, node
def get_tree(self, root_id: str) -> Optional[MessageTree]:
"""Get a tree by its root ID."""
return self._repository.get_tree(root_id)
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
"""Get the tree containing a given node."""
return self._repository.get_tree_for_node(node_id)
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node from any tree."""
return self._repository.get_node(node_id)
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
"""Resolve a message ID to the actual parent node ID."""
return self._repository.resolve_parent_node_id(msg_id)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
return self._repository.is_tree_busy(root_id)
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
return self._repository.is_node_tree_busy(node_id)
async def enqueue(
self,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node for processing.
If the tree is not busy, processing starts immediately.
If busy, the message is queued.
Args:
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
logger.error(f"No tree found for node {node_id}")
return False
return await self._processor.enqueue_and_start(tree, node_id, processor)
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
return self._repository.get_queue_size(node_id)
def get_pending_children(self, node_id: str) -> List[MessageNode]:
"""Get all pending child nodes (recursively) of a given node."""
return self._repository.get_pending_children(node_id)
async def mark_node_error(
self,
node_id: str,
error_message: str,
propagate_to_children: bool = True,
) -> List[MessageNode]:
"""
Mark a node as ERROR and optionally propagate to pending children.
Args:
node_id: The node to mark as error
error_message: Error description
propagate_to_children: If True, also mark pending children as error
Returns:
List of all nodes marked as error (including children)
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
affected = []
node = tree.get_node(node_id)
if node:
await tree.update_state(
node_id, MessageState.ERROR, error_message=error_message
)
affected.append(node)
if propagate_to_children:
pending_children = self._repository.get_pending_children(node_id)
for child in pending_children:
await tree.update_state(
child.node_id,
MessageState.ERROR,
error_message=f"Parent failed: {error_message}",
)
affected.append(child)
return affected
def cancel_tree(self, root_id: str) -> List[MessageNode]:
"""
Cancel all queued and in-progress messages in a tree.
Updates node states to ERROR and returns list of affected nodes
that were actually active or in the current processing queue.
"""
tree = self._repository.get_tree(root_id)
if not tree:
return []
cancelled_nodes = []
# 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
node = tree.get_node(current_id)
if node and node.state not in (
MessageState.COMPLETED,
MessageState.ERROR,
):
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
cancelled_nodes.append(node)
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
cleanup_count = 0
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node.node_id not in cancelled_ids
):
node.state = MessageState.ERROR
node.error_message = "Stale task cleaned up"
cleanup_count += 1
tree.reset_processing_state()
if cancelled_nodes:
logger.info(
f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
)
if cleanup_count:
logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")
return cancelled_nodes
async def cancel_node(self, node_id: str) -> List[MessageNode]:
"""
Cancel a single node (queued or in-progress) without affecting other nodes.
- If the node is currently running, cancels the current asyncio task.
- If the node is queued, removes it from the queue.
- Marks the node as ERROR with "Cancelled by user".
Returns:
List containing the cancelled node if it was cancellable, else empty list.
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
async with tree.with_lock():
node = tree.get_node(node_id)
if not node:
return []
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
return []
if tree.is_current_node(node_id):
self._processor.cancel_current(tree)
try:
tree.remove_from_queue(node_id)
except Exception:
logger.debug(
"Failed to remove node from queue; will rely on state=ERROR"
)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
return [node]
async def cancel_all(self) -> List[MessageNode]:
"""Cancel all messages in all trees (async wrapper)."""
async with self._lock:
return self.cancel_all_sync()
def cancel_all_sync(self) -> List[MessageNode]:
"""
Cancel all messages in all trees (synchronous/locked version).
NOTE: Must be called with self._lock held.
"""
all_cancelled = []
for root_id in list(self._repository.tree_ids()):
all_cancelled.extend(self.cancel_tree(root_id))
return all_cancelled
def cleanup_stale_nodes(self) -> int:
"""
Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
Used on startup to reconcile restored state.
"""
count = 0
for tree in self._repository.all_trees():
for node in tree.all_nodes():
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
node.state = MessageState.ERROR
node.error_message = "Lost during server restart"
count += 1
if count:
logger.info(f"Cleaned up {count} stale nodes during startup")
return count
def get_tree_count(self) -> int:
"""Get the number of active message trees."""
return self._repository.tree_count()
def set_queue_update_callback(
self,
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
) -> None:
"""Set callback for queue position updates."""
self._processor.set_queue_update_callback(queue_update_callback)
def set_node_started_callback(
self,
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
) -> None:
"""Set callback for when a queued node starts processing."""
self._processor.set_node_started_callback(node_started_callback)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree (for external mapping)."""
self._repository.register_node(node_id, root_id)
async def cancel_branch(self, branch_root_id: str) -> List[MessageNode]:
"""
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
return []
branch_ids = set(tree.get_descendants(branch_root_id))
cancelled: List[MessageNode] = []
async with tree.with_lock():
for nid in branch_ids:
node = tree.get_node(nid)
if not node or node.state in (
MessageState.COMPLETED,
MessageState.ERROR,
):
continue
if tree.is_current_node(nid):
self._processor.cancel_current(tree)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
cancelled.append(node)
else:
tree.remove_from_queue(nid)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
cancelled.append(node)
if cancelled:
logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
return cancelled
async def remove_branch(
self, branch_root_id: str
) -> tuple[List[MessageNode], str, bool]:
"""
Remove a branch (subtree) from the tree.
If branch_root is the tree root, removes the entire tree.
Returns:
(removed_nodes, root_id, removed_entire_tree)
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
return ([], "", False)
root_id = tree.root_id
if branch_root_id == root_id:
cancelled = self.cancel_tree(root_id)
removed_tree = self._repository.remove_tree(root_id)
if removed_tree:
return (removed_tree.all_nodes(), root_id, True)
return (cancelled, root_id, True)
async with tree.with_lock():
removed = tree.remove_branch(branch_root_id)
self._repository.unregister_nodes([n.node_id for n in removed])
return (removed, root_id, False)
def to_dict(self) -> dict:
"""Serialize all trees."""
return self._repository.to_dict()
@classmethod
def from_dict(
cls,
data: dict,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
) -> "TreeQueueManager":
"""Deserialize from dictionary."""
manager = cls(
queue_update_callback=queue_update_callback,
node_started_callback=node_started_callback,
)
manager._repository = TreeRepository.from_dict(data)
return manager

View file

@ -0,0 +1,168 @@
"""Repository for message tree data access.
Provides data access layer for managing trees and node mappings.
"""
from typing import Dict, Optional, List
from loguru import logger
from .data import MessageTree, MessageNode, MessageState
class TreeRepository:
"""
Repository for message tree data access.
Manages the storage and lookup of trees and node-to-tree mappings.
"""
def __init__(self):
self._trees: Dict[str, MessageTree] = {} # root_id -> tree
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
def get_tree(self, root_id: str) -> Optional[MessageTree]:
"""Get a tree by its root ID."""
return self._trees.get(root_id)
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
"""Get the tree containing a given node."""
root_id = self._node_to_tree.get(node_id)
if not root_id:
return None
return self._trees.get(root_id)
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node from any tree."""
tree = self.get_tree_for_node(node_id)
return tree.get_node(node_id) if tree else None
def add_tree(self, root_id: str, tree: MessageTree) -> None:
"""Add a new tree to the repository."""
self._trees[root_id] = tree
self._node_to_tree[root_id] = root_id
logger.debug("TREE_REPO: add_tree root_id=%s", root_id)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree."""
self._node_to_tree[node_id] = root_id
logger.debug("TREE_REPO: register_node node_id=%s root_id=%s", node_id, root_id)
def has_node(self, node_id: str) -> bool:
"""Check if a node is registered in any tree."""
return node_id in self._node_to_tree
def tree_count(self) -> int:
"""Get the number of trees in the repository."""
return len(self._trees)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
tree = self._trees.get(root_id)
return tree.is_processing if tree else False
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
tree = self.get_tree_for_node(node_id)
return tree.is_processing if tree else False
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
tree = self.get_tree_for_node(node_id)
return tree.get_queue_size() if tree else 0
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
"""
Resolve a message ID to the actual parent node ID.
Handles the case where msg_id is a status message ID
(which maps to the tree but isn't an actual node).
Returns:
The node_id to use as parent, or None if not found
"""
tree = self.get_tree_for_node(msg_id)
if not tree:
return None
# Check if msg_id is an actual node
if tree.has_node(msg_id):
return msg_id
# Otherwise, it might be a status message - find the owning node
node = tree.find_node_by_status_message(msg_id)
if node:
return node.node_id
return None
def get_pending_children(self, node_id: str) -> List[MessageNode]:
"""
Get all pending child nodes (recursively) of a given node.
Used for error propagation - when a node fails, its pending
children should also be marked as failed.
"""
tree = self.get_tree_for_node(node_id)
if not tree:
return []
pending = []
node = tree.get_node(node_id)
if not node:
return []
for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
# Recursively get children of pending children
pending.extend(self.get_pending_children(child_id))
return pending
def all_trees(self) -> List[MessageTree]:
"""Get all trees in the repository."""
return list(self._trees.values())
def tree_ids(self) -> List[str]:
"""Get all tree root IDs."""
return list(self._trees.keys())
def unregister_nodes(self, node_ids: List[str]) -> None:
"""Remove node IDs from the node-to-tree mapping."""
for nid in node_ids:
self._node_to_tree.pop(nid, None)
def remove_tree(self, root_id: str) -> Optional[MessageTree]:
"""
Remove a tree and all its node mappings from the repository.
Returns:
The removed tree, or None if not found.
"""
tree = self._trees.pop(root_id, None)
if not tree:
return None
for node in tree.all_nodes():
self._node_to_tree.pop(node.node_id, None)
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
return tree
def to_dict(self) -> dict:
"""Serialize all trees."""
return {
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
"node_to_tree": self._node_to_tree.copy(),
}
@classmethod
def from_dict(cls, data: dict) -> "TreeRepository":
"""Deserialize from dictionary."""
from .data import MessageTree
repo = cls()
for root_id, tree_data in data.get("trees", {}).items():
repo._trees[root_id] = MessageTree.from_dict(tree_data)
repo._node_to_tree = data.get("node_to_tree", {})
return repo

View file

@ -11,9 +11,10 @@ class TestCreateMessagingPlatform:
def test_telegram_with_token(self):
"""Create Telegram platform when bot_token is provided."""
mock_platform = MagicMock()
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
with patch(
"messaging.telegram.TelegramPlatform", return_value=mock_platform
"messaging.platforms.telegram.TelegramPlatform",
return_value=mock_platform,
):
result = create_messaging_platform(
"telegram",
@ -36,8 +37,11 @@ class TestCreateMessagingPlatform:
def test_discord_with_token(self):
"""Create Discord platform when discord_bot_token is provided."""
mock_platform = MagicMock()
with patch("messaging.discord.DISCORD_AVAILABLE", True):
with patch("messaging.discord.DiscordPlatform", return_value=mock_platform):
with patch("messaging.platforms.discord.DISCORD_AVAILABLE", True):
with patch(
"messaging.platforms.discord.DiscordPlatform",
return_value=mock_platform,
):
result = create_messaging_platform(
"discord",
discord_bot_token="test_token",