test_discord_reply_mode.py
1 """Tests for Discord reply_to_mode functionality. 2 3 Covers the threading behavior control for multi-chunk replies: 4 - "off": Never reply-reference to original message 5 - "first": Only first chunk uses reply reference (default) 6 - "all": All chunks reply-reference the original message 7 8 Also covers reply_to_text extraction from incoming messages. 9 """ 10 import os 11 import sys 12 from datetime import datetime, timezone 13 from types import SimpleNamespace 14 from unittest.mock import MagicMock, AsyncMock, patch 15 16 import pytest 17 18 from gateway.config import PlatformConfig, GatewayConfig, Platform, _apply_env_overrides 19 20 21 def _ensure_discord_mock(): 22 """Install a mock discord module when discord.py isn't available.""" 23 if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): 24 return 25 26 discord_mod = MagicMock() 27 discord_mod.Intents.default.return_value = MagicMock() 28 discord_mod.Client = MagicMock 29 discord_mod.File = MagicMock 30 discord_mod.DMChannel = type("DMChannel", (), {}) 31 discord_mod.Thread = type("Thread", (), {}) 32 discord_mod.ForumChannel = type("ForumChannel", (), {}) 33 discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) 34 discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3) 35 discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5) 36 discord_mod.Interaction = object 37 discord_mod.Embed = MagicMock 38 discord_mod.app_commands = SimpleNamespace( 39 describe=lambda **kwargs: (lambda fn: fn), 40 choices=lambda **kwargs: (lambda fn: fn), 41 Choice=lambda **kwargs: SimpleNamespace(**kwargs), 42 ) 43 44 ext_mod = MagicMock() 45 commands_mod = MagicMock() 46 commands_mod.Bot = MagicMock 47 ext_mod.commands = commands_mod 48 49 sys.modules.setdefault("discord", discord_mod) 50 sys.modules.setdefault("discord.ext", ext_mod) 51 sys.modules.setdefault("discord.ext.commands", commands_mod) 52 53 54 _ensure_discord_mock() 55 56 from gateway.platforms.discord import DiscordAdapter # noqa: E402 57 58 59 @pytest.fixture() 60 def adapter_factory(): 61 """Factory to create DiscordAdapter with custom reply_to_mode.""" 62 def create(reply_to_mode: str = "first"): 63 config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode) 64 return DiscordAdapter(config) 65 return create 66 67 68 class TestReplyToModeConfig: 69 """Tests for reply_to_mode configuration loading.""" 70 71 def test_default_mode_is_first(self, adapter_factory): 72 adapter = adapter_factory() 73 assert adapter._reply_to_mode == "first" 74 75 def test_off_mode(self, adapter_factory): 76 adapter = adapter_factory(reply_to_mode="off") 77 assert adapter._reply_to_mode == "off" 78 79 def test_first_mode(self, adapter_factory): 80 adapter = adapter_factory(reply_to_mode="first") 81 assert adapter._reply_to_mode == "first" 82 83 def test_all_mode(self, adapter_factory): 84 adapter = adapter_factory(reply_to_mode="all") 85 assert adapter._reply_to_mode == "all" 86 87 def test_invalid_mode_stored_as_is(self, adapter_factory): 88 """Invalid modes are stored but send() handles them gracefully.""" 89 adapter = adapter_factory(reply_to_mode="invalid") 90 assert adapter._reply_to_mode == "invalid" 91 92 def test_none_mode_defaults_to_first(self): 93 config = PlatformConfig(enabled=True, token="test-token") 94 adapter = DiscordAdapter(config) 95 assert adapter._reply_to_mode == "first" 96 97 def test_empty_string_mode_defaults_to_first(self): 98 config = PlatformConfig(enabled=True, token="test-token", reply_to_mode="") 99 adapter = DiscordAdapter(config) 100 assert adapter._reply_to_mode == "first" 101 102 103 def _make_discord_adapter(reply_to_mode: str = "first"): 104 """Create a DiscordAdapter with mocked client and channel for send() tests.""" 105 config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode) 106 adapter = DiscordAdapter(config) 107 108 # Mock the Discord client and channel. 109 # ref_message.to_reference() → a distinct sentinel: the adapter now wraps 110 # the fetched Message via to_reference(fail_if_not_exists=False) so a 111 # deleted target degrades to "send without reply chip" instead of a 400. 112 mock_channel = AsyncMock() 113 ref_message = MagicMock() 114 ref_reference = MagicMock(name="MessageReference") 115 ref_message.to_reference = MagicMock(return_value=ref_reference) 116 mock_channel.fetch_message = AsyncMock(return_value=ref_message) 117 118 sent_msg = MagicMock() 119 sent_msg.id = 42 120 mock_channel.send = AsyncMock(return_value=sent_msg) 121 122 mock_client = MagicMock() 123 mock_client.get_channel = MagicMock(return_value=mock_channel) 124 125 adapter._client = mock_client 126 # Return the reference sentinel alongside so tests can assert identity. 127 adapter._test_expected_reference = ref_reference 128 return adapter, mock_channel, ref_reference 129 130 131 class TestSendWithReplyToMode: 132 """Tests for send() method respecting reply_to_mode.""" 133 134 @pytest.mark.asyncio 135 async def test_off_mode_no_reply_reference(self): 136 adapter, channel, ref_msg = _make_discord_adapter("off") 137 adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"] 138 139 await adapter.send("12345", "test content", reply_to="999") 140 141 # Should never try to fetch the reference message 142 channel.fetch_message.assert_not_called() 143 # All chunks sent without reference 144 for call in channel.send.call_args_list: 145 assert call.kwargs.get("reference") is None 146 147 @pytest.mark.asyncio 148 async def test_first_mode_only_first_chunk_references(self): 149 adapter, channel, ref_msg = _make_discord_adapter("first") 150 adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"] 151 152 await adapter.send("12345", "test content", reply_to="999") 153 154 # Should fetch the reference message 155 channel.fetch_message.assert_called_once_with(999) 156 calls = channel.send.call_args_list 157 assert len(calls) == 3 158 assert calls[0].kwargs.get("reference") is ref_msg 159 assert calls[1].kwargs.get("reference") is None 160 assert calls[2].kwargs.get("reference") is None 161 162 @pytest.mark.asyncio 163 async def test_all_mode_all_chunks_reference(self): 164 adapter, channel, ref_msg = _make_discord_adapter("all") 165 adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"] 166 167 await adapter.send("12345", "test content", reply_to="999") 168 169 channel.fetch_message.assert_called_once_with(999) 170 calls = channel.send.call_args_list 171 assert len(calls) == 3 172 for call in calls: 173 assert call.kwargs.get("reference") is ref_msg 174 175 @pytest.mark.asyncio 176 async def test_no_reply_to_param_no_reference(self): 177 adapter, channel, ref_msg = _make_discord_adapter("all") 178 adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"] 179 180 await adapter.send("12345", "test content", reply_to=None) 181 182 channel.fetch_message.assert_not_called() 183 for call in channel.send.call_args_list: 184 assert call.kwargs.get("reference") is None 185 186 @pytest.mark.asyncio 187 async def test_single_chunk_respects_first_mode(self): 188 adapter, channel, ref_msg = _make_discord_adapter("first") 189 adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"] 190 191 await adapter.send("12345", "test", reply_to="999") 192 193 calls = channel.send.call_args_list 194 assert len(calls) == 1 195 assert calls[0].kwargs.get("reference") is ref_msg 196 197 @pytest.mark.asyncio 198 async def test_single_chunk_off_mode(self): 199 adapter, channel, ref_msg = _make_discord_adapter("off") 200 adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"] 201 202 await adapter.send("12345", "test", reply_to="999") 203 204 channel.fetch_message.assert_not_called() 205 calls = channel.send.call_args_list 206 assert len(calls) == 1 207 assert calls[0].kwargs.get("reference") is None 208 209 @pytest.mark.asyncio 210 async def test_invalid_mode_falls_back_to_first_behavior(self): 211 """Invalid mode behaves like 'first' — only first chunk gets reference.""" 212 adapter, channel, ref_msg = _make_discord_adapter("banana") 213 adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"] 214 215 await adapter.send("12345", "test", reply_to="999") 216 217 calls = channel.send.call_args_list 218 assert len(calls) == 2 219 assert calls[0].kwargs.get("reference") is ref_msg 220 assert calls[1].kwargs.get("reference") is None 221 222 223 class TestConfigSerialization: 224 """Tests for reply_to_mode serialization (shared with Telegram).""" 225 226 def test_to_dict_includes_reply_to_mode(self): 227 config = PlatformConfig(enabled=True, token="test", reply_to_mode="all") 228 result = config.to_dict() 229 assert result["reply_to_mode"] == "all" 230 231 def test_from_dict_loads_reply_to_mode(self): 232 data = {"enabled": True, "token": "***", "reply_to_mode": "off"} 233 config = PlatformConfig.from_dict(data) 234 assert config.reply_to_mode == "off" 235 236 def test_from_dict_defaults_to_first(self): 237 data = {"enabled": True, "token": "***"} 238 config = PlatformConfig.from_dict(data) 239 assert config.reply_to_mode == "first" 240 241 242 class TestEnvVarOverride: 243 """Tests for DISCORD_REPLY_TO_MODE environment variable override.""" 244 245 def _make_config(self): 246 config = GatewayConfig() 247 config.platforms[Platform.DISCORD] = PlatformConfig(enabled=True, token="test") 248 return config 249 250 def test_env_var_sets_off_mode(self): 251 config = self._make_config() 252 with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False): 253 _apply_env_overrides(config) 254 assert config.platforms[Platform.DISCORD].reply_to_mode == "off" 255 256 def test_env_var_sets_all_mode(self): 257 config = self._make_config() 258 with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "all"}, clear=False): 259 _apply_env_overrides(config) 260 assert config.platforms[Platform.DISCORD].reply_to_mode == "all" 261 262 def test_env_var_case_insensitive(self): 263 config = self._make_config() 264 with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "ALL"}, clear=False): 265 _apply_env_overrides(config) 266 assert config.platforms[Platform.DISCORD].reply_to_mode == "all" 267 268 def test_env_var_invalid_value_ignored(self): 269 config = self._make_config() 270 with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "banana"}, clear=False): 271 _apply_env_overrides(config) 272 assert config.platforms[Platform.DISCORD].reply_to_mode == "first" 273 274 def test_env_var_empty_value_ignored(self): 275 config = self._make_config() 276 with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": ""}, clear=False): 277 _apply_env_overrides(config) 278 assert config.platforms[Platform.DISCORD].reply_to_mode == "first" 279 280 def test_env_var_creates_platform_config_if_missing(self): 281 """DISCORD_REPLY_TO_MODE creates PlatformConfig even without DISCORD_BOT_TOKEN.""" 282 config = GatewayConfig() 283 assert Platform.DISCORD not in config.platforms 284 with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False): 285 _apply_env_overrides(config) 286 assert Platform.DISCORD in config.platforms 287 assert config.platforms[Platform.DISCORD].reply_to_mode == "off" 288 289 290 # ------------------------------------------------------------------ 291 # Tests for reply_to_text extraction in _handle_message 292 # ------------------------------------------------------------------ 293 294 # Build FakeDMChannel as a subclass of the real discord.DMChannel when the 295 # library is installed — this guarantees isinstance() checks pass in 296 # production code regardless of test ordering or monkeypatch state. 297 try: 298 import discord as _discord_lib 299 _DMChannelBase = _discord_lib.DMChannel 300 except (ImportError, AttributeError): 301 _DMChannelBase = object 302 303 304 class FakeDMChannel(_DMChannelBase): 305 """Minimal DM channel stub (skips mention / channel-allow checks).""" 306 def __init__(self, channel_id: int = 100, name: str = "dm"): 307 # Do NOT call super().__init__() — real DMChannel requires State 308 self.id = channel_id 309 self.name = name 310 311 312 def _make_message(*, content: str = "hi", reference=None): 313 """Build a mock Discord message for _handle_message tests.""" 314 author = SimpleNamespace(id=42, display_name="TestUser", name="TestUser") 315 return SimpleNamespace( 316 id=999, 317 content=content, 318 mentions=[], 319 attachments=[], 320 reference=reference, 321 created_at=datetime.now(timezone.utc), 322 channel=FakeDMChannel(), 323 author=author, 324 ) 325 326 327 @pytest.fixture 328 def reply_text_adapter(monkeypatch): 329 """DiscordAdapter wired for _handle_message → handle_message capture.""" 330 config = PlatformConfig(enabled=True, token="fake-token") 331 adapter = DiscordAdapter(config) 332 adapter._client = SimpleNamespace(user=SimpleNamespace(id=999)) 333 adapter._text_batch_delay_seconds = 0 334 adapter.handle_message = AsyncMock() 335 return adapter 336 337 338 class TestReplyToText: 339 """Tests for reply_to_text populated by _handle_message.""" 340 341 @pytest.mark.asyncio 342 async def test_no_reference_both_none(self, reply_text_adapter): 343 message = _make_message(reference=None) 344 345 await reply_text_adapter._handle_message(message) 346 347 event = reply_text_adapter.handle_message.await_args.args[0] 348 assert event.reply_to_message_id is None 349 assert event.reply_to_text is None 350 351 @pytest.mark.asyncio 352 async def test_reference_without_resolved(self, reply_text_adapter): 353 ref = SimpleNamespace(message_id=555, resolved=None) 354 message = _make_message(reference=ref) 355 356 await reply_text_adapter._handle_message(message) 357 358 event = reply_text_adapter.handle_message.await_args.args[0] 359 assert event.reply_to_message_id == "555" 360 assert event.reply_to_text is None 361 362 @pytest.mark.asyncio 363 async def test_reference_with_resolved_content(self, reply_text_adapter): 364 resolved_msg = SimpleNamespace(content="original message text") 365 ref = SimpleNamespace(message_id=555, resolved=resolved_msg) 366 message = _make_message(reference=ref) 367 368 await reply_text_adapter._handle_message(message) 369 370 event = reply_text_adapter.handle_message.await_args.args[0] 371 assert event.reply_to_message_id == "555" 372 assert event.reply_to_text == "original message text" 373 374 @pytest.mark.asyncio 375 async def test_reference_with_empty_resolved_content(self, reply_text_adapter): 376 """Empty string content should become None, not leak as empty string.""" 377 resolved_msg = SimpleNamespace(content="") 378 ref = SimpleNamespace(message_id=555, resolved=resolved_msg) 379 message = _make_message(reference=ref) 380 381 await reply_text_adapter._handle_message(message) 382 383 event = reply_text_adapter.handle_message.await_args.args[0] 384 assert event.reply_to_message_id == "555" 385 assert event.reply_to_text is None 386 387 @pytest.mark.asyncio 388 async def test_reference_with_deleted_message(self, reply_text_adapter): 389 """Deleted messages lack .content — getattr guard should return None.""" 390 resolved_deleted = SimpleNamespace(id=555) 391 ref = SimpleNamespace(message_id=555, resolved=resolved_deleted) 392 message = _make_message(reference=ref) 393 394 await reply_text_adapter._handle_message(message) 395 396 event = reply_text_adapter.handle_message.await_args.args[0] 397 assert event.reply_to_message_id == "555" 398 assert event.reply_to_text is None