free-claude-code/tests/messaging/test_discord_platform.py
2026-02-27 19:50:21 -08:00

380 lines
14 KiB
Python

"""Tests for Discord platform adapter."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from messaging.platforms.discord import (
DISCORD_AVAILABLE,
DiscordPlatform,
_get_discord,
_parse_allowed_channels,
)
class TestGetDiscord:
"""Tests for _get_discord helper."""
def test_raises_when_discord_not_available(self):
import messaging.platforms.discord as discord_mod
with (
patch.object(discord_mod, "DISCORD_AVAILABLE", False),
patch.object(discord_mod, "_discord_module", None),
pytest.raises(ImportError, match=r"discord\.py is required"),
):
_get_discord()
class TestParseAllowedChannels:
"""Tests for _parse_allowed_channels helper."""
def test_empty_string_returns_empty_set(self):
assert _parse_allowed_channels("") == set()
assert _parse_allowed_channels(None) == set()
def test_whitespace_only_returns_empty_set(self):
assert _parse_allowed_channels(" ") == set()
def test_single_channel(self):
assert _parse_allowed_channels("123456789") == {"123456789"}
def test_comma_separated(self):
assert _parse_allowed_channels("111,222,333") == {"111", "222", "333"}
def test_strips_whitespace(self):
assert _parse_allowed_channels(" 111 , 222 ") == {"111", "222"}
def test_empty_parts_ignored(self):
assert _parse_allowed_channels("111,,222,") == {"111", "222"}
@pytest.mark.skipif(not DISCORD_AVAILABLE, reason="discord.py not installed")
class TestDiscordPlatform:
"""Tests for DiscordPlatform (requires discord.py)."""
def test_init_with_token(self):
platform = DiscordPlatform(
bot_token="test_token",
allowed_channel_ids="123,456",
)
assert platform.bot_token == "test_token"
assert platform.allowed_channel_ids == {"123", "456"}
def test_init_without_allowed_channels(self):
with patch.dict("os.environ", {"ALLOWED_DISCORD_CHANNELS": ""}, clear=False):
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="")
assert platform.allowed_channel_ids == set()
def test_empty_allowed_channels_rejects_all_messages(self):
"""When allowed_channel_ids is empty, no channels are allowed (secure default)."""
with patch.dict("os.environ", {"ALLOWED_DISCORD_CHANNELS": ""}, clear=False):
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="")
assert platform.allowed_channel_ids == set()
# Empty set means: not self.allowed_channel_ids is True -> reject
def test_truncate_long_message(self):
platform = DiscordPlatform(bot_token="token")
long_text = "x" * 2500
truncated = platform._truncate(long_text)
assert len(truncated) == 2000
assert truncated.endswith("...")
def test_truncate_short_message_unchanged(self):
platform = DiscordPlatform(bot_token="token")
short = "hello"
assert platform._truncate(short) == short
def test_truncate_exactly_at_limit_unchanged(self):
platform = DiscordPlatform(bot_token="token")
exact = "x" * 2000
assert platform._truncate(exact) == exact
def test_truncate_one_over_limit_truncates(self):
platform = DiscordPlatform(bot_token="token")
over = "x" * 2001
result = platform._truncate(over)
assert len(result) == 2000
assert result.endswith("...")
def test_truncate_empty_string(self):
platform = DiscordPlatform(bot_token="token")
assert platform._truncate("") == ""
@pytest.mark.asyncio
async def test_send_message_returns_message_id(self):
platform = DiscordPlatform(bot_token="token")
mock_msg = MagicMock()
mock_msg.id = 999
mock_channel = AsyncMock()
mock_channel.send = AsyncMock(return_value=mock_msg)
platform._connected = True
with patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
):
msg_id = await platform.send_message("123", "Hello")
assert msg_id == "999"
@pytest.mark.asyncio
async def test_edit_message(self):
platform = DiscordPlatform(bot_token="token")
mock_msg = AsyncMock()
mock_channel = AsyncMock()
mock_channel.fetch_message = AsyncMock(return_value=mock_msg)
platform._connected = True
with patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
):
await platform.edit_message("123", "456", "Updated text")
mock_msg.edit.assert_called_once_with(content="Updated text")
@pytest.mark.asyncio
async def test_send_message_channel_not_found_raises(self):
platform = DiscordPlatform(bot_token="token")
platform._connected = True
with (
patch.object(platform._client, "get_channel", MagicMock(return_value=None)),
pytest.raises(RuntimeError, match="Channel"),
):
await platform.send_message("123", "Hello")
@pytest.mark.asyncio
async def test_send_message_channel_no_send_raises(self):
platform = DiscordPlatform(bot_token="token")
platform._connected = True
mock_channel = MagicMock(spec=[]) # No send attr
with (
patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
),
pytest.raises(RuntimeError, match="Channel"),
):
await platform.send_message("123", "Hello")
@pytest.mark.asyncio
async def test_queue_send_message_without_limiter_calls_send_message(self):
platform = DiscordPlatform(bot_token="token")
platform._limiter = None
platform._connected = True
mock_channel = AsyncMock()
mock_msg = MagicMock()
mock_msg.id = 42
mock_channel.send = AsyncMock(return_value=mock_msg)
with patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
):
result = await platform.queue_send_message("123", "hi")
assert result == "42"
mock_channel.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_queue_edit_message_without_limiter_calls_edit_message(self):
platform = DiscordPlatform(bot_token="token")
platform._limiter = None
platform._connected = True
mock_msg = AsyncMock()
mock_channel = AsyncMock()
mock_channel.fetch_message = AsyncMock(return_value=mock_msg)
with patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
):
await platform.queue_edit_message("123", "456", "Updated")
mock_msg.edit.assert_called_once_with(content="Updated")
@pytest.mark.asyncio
async def test_on_discord_message_bot_ignored(self):
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
handler = AsyncMock()
platform.on_message(handler)
msg = MagicMock()
msg.author.bot = True
msg.content = "hello"
msg.channel.id = 123
await platform._on_discord_message(msg)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_on_discord_message_empty_content_ignored(self):
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
handler = AsyncMock()
platform.on_message(handler)
msg = MagicMock()
msg.author.bot = False
msg.content = ""
msg.channel.id = 123
await platform._on_discord_message(msg)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_on_discord_message_channel_not_allowed_ignored(self):
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
handler = AsyncMock()
platform.on_message(handler)
msg = MagicMock()
msg.author.bot = False
msg.content = "hello"
msg.channel.id = 999
await platform._on_discord_message(msg)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_on_discord_message_valid_calls_handler(self):
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
handler = AsyncMock()
platform.on_message(handler)
msg = MagicMock()
msg.author.bot = False
msg.author.id = 456
msg.author.display_name = "User"
msg.content = "hello"
msg.channel.id = 123
msg.id = 789
msg.reference = None
await platform._on_discord_message(msg)
handler.assert_awaited_once()
call = handler.call_args[0][0]
assert call.text == "hello"
assert call.chat_id == "123"
assert call.user_id == "456"
assert call.message_id == "789"
assert call.platform == "discord"
@pytest.mark.asyncio
async def test_send_message_with_reply_to(self):
platform = DiscordPlatform(bot_token="token")
mock_msg = MagicMock()
mock_msg.id = 999
mock_channel = AsyncMock()
mock_channel.send = AsyncMock(return_value=mock_msg)
platform._connected = True
with (
patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
),
patch("messaging.platforms.discord._get_discord") as mock_get,
):
mock_discord = MagicMock()
mock_get.return_value = mock_discord
msg_id = await platform.send_message("123", "Hello", reply_to="456")
assert msg_id == "999"
mock_channel.send.assert_awaited_once()
call_kw = mock_channel.send.call_args[1]
assert call_kw.get("reference") is not None
@pytest.mark.asyncio
async def test_edit_message_not_found_returns_gracefully(self):
import discord as discord_pkg
platform = DiscordPlatform(bot_token="token")
mock_channel = AsyncMock()
mock_resp = MagicMock()
mock_resp.status = 404
mock_channel.fetch_message = AsyncMock(
side_effect=discord_pkg.NotFound(mock_resp, "Not found")
)
platform._connected = True
with patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
):
await platform.edit_message("123", "456", "Updated")
# Should not raise - NotFound is caught and we return
@pytest.mark.asyncio
async def test_delete_message(self):
platform = DiscordPlatform(bot_token="token")
mock_msg = AsyncMock()
mock_channel = AsyncMock()
mock_channel.fetch_message = AsyncMock(return_value=mock_msg)
platform._connected = True
with (
patch.object(
platform._client, "get_channel", MagicMock(return_value=mock_channel)
),
patch("messaging.platforms.discord._get_discord") as mock_get,
):
mock_get.return_value = MagicMock()
await platform.delete_message("123", "456")
mock_msg.delete.assert_awaited_once()
@pytest.mark.asyncio
async def test_fire_and_forget_with_coroutine(self):
platform = DiscordPlatform(bot_token="token")
async def _task():
pass
coro = _task()
with patch("asyncio.create_task") as mock_create:
def _run(c):
return asyncio.ensure_future(c)
mock_create.side_effect = _run
platform.fire_and_forget(coro)
mock_create.assert_called_once()
await asyncio.sleep(0)
def test_on_message_registers_handler(self):
platform = DiscordPlatform(bot_token="token")
handler = AsyncMock()
platform.on_message(handler)
assert platform._message_handler is handler
@pytest.mark.asyncio
async def test_start_requires_token(self):
with patch.dict("os.environ", {"DISCORD_BOT_TOKEN": ""}, clear=False):
platform = DiscordPlatform(bot_token="")
with pytest.raises(ValueError, match="DISCORD_BOT_TOKEN"):
await platform.start()
@pytest.mark.asyncio
async def test_start_connects(self):
platform = DiscordPlatform(bot_token="token")
async def _fake_start(_token):
platform._connected = True
with (
patch.object(
platform._client,
"start",
new_callable=AsyncMock,
side_effect=_fake_start,
),
patch(
"messaging.limiter.MessagingRateLimiter.get_instance",
new_callable=AsyncMock,
),
):
await platform.start()
assert platform.is_connected is True
@pytest.mark.asyncio
async def test_stop_when_already_closed(self):
platform = DiscordPlatform(bot_token="token")
platform._connected = True
with patch.object(
platform._client, "is_closed", new_callable=MagicMock, return_value=True
):
await platform.stop()
assert platform.is_connected is False
@pytest.mark.asyncio
async def test_stop_closes_client(self):
platform = DiscordPlatform(bot_token="token")
platform._connected = True
mock_close = AsyncMock()
with (
patch.object(
platform._client,
"is_closed",
new_callable=MagicMock,
return_value=False,
),
patch.object(platform._client, "close", mock_close),
):
platform._start_task = None
await platform.stop()
mock_close.assert_awaited_once()
assert platform.is_connected is False