/ tests / gateway / test_discord_reply_mode.py
test_discord_reply_mode.py
  1  """Tests for Discord reply_to_mode functionality.
  2  
  3  Covers the threading behavior control for multi-chunk replies:
  4  - "off": Never reply-reference to original message
  5  - "first": Only first chunk uses reply reference (default)
  6  - "all": All chunks reply-reference the original message
  7  
  8  Also covers reply_to_text extraction from incoming messages.
  9  """
 10  import os
 11  import sys
 12  from datetime import datetime, timezone
 13  from types import SimpleNamespace
 14  from unittest.mock import MagicMock, AsyncMock, patch
 15  
 16  import pytest
 17  
 18  from gateway.config import PlatformConfig, GatewayConfig, Platform, _apply_env_overrides
 19  
 20  
 21  def _ensure_discord_mock():
 22      """Install a mock discord module when discord.py isn't available."""
 23      if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
 24          return
 25  
 26      discord_mod = MagicMock()
 27      discord_mod.Intents.default.return_value = MagicMock()
 28      discord_mod.Client = MagicMock
 29      discord_mod.File = MagicMock
 30      discord_mod.DMChannel = type("DMChannel", (), {})
 31      discord_mod.Thread = type("Thread", (), {})
 32      discord_mod.ForumChannel = type("ForumChannel", (), {})
 33      discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
 34      discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
 35      discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
 36      discord_mod.Interaction = object
 37      discord_mod.Embed = MagicMock
 38      discord_mod.app_commands = SimpleNamespace(
 39          describe=lambda **kwargs: (lambda fn: fn),
 40          choices=lambda **kwargs: (lambda fn: fn),
 41          Choice=lambda **kwargs: SimpleNamespace(**kwargs),
 42      )
 43  
 44      ext_mod = MagicMock()
 45      commands_mod = MagicMock()
 46      commands_mod.Bot = MagicMock
 47      ext_mod.commands = commands_mod
 48  
 49      sys.modules.setdefault("discord", discord_mod)
 50      sys.modules.setdefault("discord.ext", ext_mod)
 51      sys.modules.setdefault("discord.ext.commands", commands_mod)
 52  
 53  
 54  _ensure_discord_mock()
 55  
 56  from gateway.platforms.discord import DiscordAdapter  # noqa: E402
 57  
 58  
 59  @pytest.fixture()
 60  def adapter_factory():
 61      """Factory to create DiscordAdapter with custom reply_to_mode."""
 62      def create(reply_to_mode: str = "first"):
 63          config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
 64          return DiscordAdapter(config)
 65      return create
 66  
 67  
 68  class TestReplyToModeConfig:
 69      """Tests for reply_to_mode configuration loading."""
 70  
 71      def test_default_mode_is_first(self, adapter_factory):
 72          adapter = adapter_factory()
 73          assert adapter._reply_to_mode == "first"
 74  
 75      def test_off_mode(self, adapter_factory):
 76          adapter = adapter_factory(reply_to_mode="off")
 77          assert adapter._reply_to_mode == "off"
 78  
 79      def test_first_mode(self, adapter_factory):
 80          adapter = adapter_factory(reply_to_mode="first")
 81          assert adapter._reply_to_mode == "first"
 82  
 83      def test_all_mode(self, adapter_factory):
 84          adapter = adapter_factory(reply_to_mode="all")
 85          assert adapter._reply_to_mode == "all"
 86  
 87      def test_invalid_mode_stored_as_is(self, adapter_factory):
 88          """Invalid modes are stored but send() handles them gracefully."""
 89          adapter = adapter_factory(reply_to_mode="invalid")
 90          assert adapter._reply_to_mode == "invalid"
 91  
 92      def test_none_mode_defaults_to_first(self):
 93          config = PlatformConfig(enabled=True, token="test-token")
 94          adapter = DiscordAdapter(config)
 95          assert adapter._reply_to_mode == "first"
 96  
 97      def test_empty_string_mode_defaults_to_first(self):
 98          config = PlatformConfig(enabled=True, token="test-token", reply_to_mode="")
 99          adapter = DiscordAdapter(config)
100          assert adapter._reply_to_mode == "first"
101  
102  
103  def _make_discord_adapter(reply_to_mode: str = "first"):
104      """Create a DiscordAdapter with mocked client and channel for send() tests."""
105      config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
106      adapter = DiscordAdapter(config)
107  
108      # Mock the Discord client and channel.
109      # ref_message.to_reference() → a distinct sentinel: the adapter now wraps
110      # the fetched Message via to_reference(fail_if_not_exists=False) so a
111      # deleted target degrades to "send without reply chip" instead of a 400.
112      mock_channel = AsyncMock()
113      ref_message = MagicMock()
114      ref_reference = MagicMock(name="MessageReference")
115      ref_message.to_reference = MagicMock(return_value=ref_reference)
116      mock_channel.fetch_message = AsyncMock(return_value=ref_message)
117  
118      sent_msg = MagicMock()
119      sent_msg.id = 42
120      mock_channel.send = AsyncMock(return_value=sent_msg)
121  
122      mock_client = MagicMock()
123      mock_client.get_channel = MagicMock(return_value=mock_channel)
124  
125      adapter._client = mock_client
126      # Return the reference sentinel alongside so tests can assert identity.
127      adapter._test_expected_reference = ref_reference
128      return adapter, mock_channel, ref_reference
129  
130  
131  class TestSendWithReplyToMode:
132      """Tests for send() method respecting reply_to_mode."""
133  
134      @pytest.mark.asyncio
135      async def test_off_mode_no_reply_reference(self):
136          adapter, channel, ref_msg = _make_discord_adapter("off")
137          adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
138  
139          await adapter.send("12345", "test content", reply_to="999")
140  
141          # Should never try to fetch the reference message
142          channel.fetch_message.assert_not_called()
143          # All chunks sent without reference
144          for call in channel.send.call_args_list:
145              assert call.kwargs.get("reference") is None
146  
147      @pytest.mark.asyncio
148      async def test_first_mode_only_first_chunk_references(self):
149          adapter, channel, ref_msg = _make_discord_adapter("first")
150          adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
151  
152          await adapter.send("12345", "test content", reply_to="999")
153  
154          # Should fetch the reference message
155          channel.fetch_message.assert_called_once_with(999)
156          calls = channel.send.call_args_list
157          assert len(calls) == 3
158          assert calls[0].kwargs.get("reference") is ref_msg
159          assert calls[1].kwargs.get("reference") is None
160          assert calls[2].kwargs.get("reference") is None
161  
162      @pytest.mark.asyncio
163      async def test_all_mode_all_chunks_reference(self):
164          adapter, channel, ref_msg = _make_discord_adapter("all")
165          adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2", "chunk3"]
166  
167          await adapter.send("12345", "test content", reply_to="999")
168  
169          channel.fetch_message.assert_called_once_with(999)
170          calls = channel.send.call_args_list
171          assert len(calls) == 3
172          for call in calls:
173              assert call.kwargs.get("reference") is ref_msg
174  
175      @pytest.mark.asyncio
176      async def test_no_reply_to_param_no_reference(self):
177          adapter, channel, ref_msg = _make_discord_adapter("all")
178          adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
179  
180          await adapter.send("12345", "test content", reply_to=None)
181  
182          channel.fetch_message.assert_not_called()
183          for call in channel.send.call_args_list:
184              assert call.kwargs.get("reference") is None
185  
186      @pytest.mark.asyncio
187      async def test_single_chunk_respects_first_mode(self):
188          adapter, channel, ref_msg = _make_discord_adapter("first")
189          adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
190  
191          await adapter.send("12345", "test", reply_to="999")
192  
193          calls = channel.send.call_args_list
194          assert len(calls) == 1
195          assert calls[0].kwargs.get("reference") is ref_msg
196  
197      @pytest.mark.asyncio
198      async def test_single_chunk_off_mode(self):
199          adapter, channel, ref_msg = _make_discord_adapter("off")
200          adapter.truncate_message = lambda content, max_len, **kw: ["single chunk"]
201  
202          await adapter.send("12345", "test", reply_to="999")
203  
204          channel.fetch_message.assert_not_called()
205          calls = channel.send.call_args_list
206          assert len(calls) == 1
207          assert calls[0].kwargs.get("reference") is None
208  
209      @pytest.mark.asyncio
210      async def test_invalid_mode_falls_back_to_first_behavior(self):
211          """Invalid mode behaves like 'first' — only first chunk gets reference."""
212          adapter, channel, ref_msg = _make_discord_adapter("banana")
213          adapter.truncate_message = lambda content, max_len, **kw: ["chunk1", "chunk2"]
214  
215          await adapter.send("12345", "test", reply_to="999")
216  
217          calls = channel.send.call_args_list
218          assert len(calls) == 2
219          assert calls[0].kwargs.get("reference") is ref_msg
220          assert calls[1].kwargs.get("reference") is None
221  
222  
223  class TestConfigSerialization:
224      """Tests for reply_to_mode serialization (shared with Telegram)."""
225  
226      def test_to_dict_includes_reply_to_mode(self):
227          config = PlatformConfig(enabled=True, token="test", reply_to_mode="all")
228          result = config.to_dict()
229          assert result["reply_to_mode"] == "all"
230  
231      def test_from_dict_loads_reply_to_mode(self):
232          data = {"enabled": True, "token": "***", "reply_to_mode": "off"}
233          config = PlatformConfig.from_dict(data)
234          assert config.reply_to_mode == "off"
235  
236      def test_from_dict_defaults_to_first(self):
237          data = {"enabled": True, "token": "***"}
238          config = PlatformConfig.from_dict(data)
239          assert config.reply_to_mode == "first"
240  
241  
242  class TestEnvVarOverride:
243      """Tests for DISCORD_REPLY_TO_MODE environment variable override."""
244  
245      def _make_config(self):
246          config = GatewayConfig()
247          config.platforms[Platform.DISCORD] = PlatformConfig(enabled=True, token="test")
248          return config
249  
250      def test_env_var_sets_off_mode(self):
251          config = self._make_config()
252          with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
253              _apply_env_overrides(config)
254          assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
255  
256      def test_env_var_sets_all_mode(self):
257          config = self._make_config()
258          with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "all"}, clear=False):
259              _apply_env_overrides(config)
260          assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
261  
262      def test_env_var_case_insensitive(self):
263          config = self._make_config()
264          with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "ALL"}, clear=False):
265              _apply_env_overrides(config)
266          assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
267  
268      def test_env_var_invalid_value_ignored(self):
269          config = self._make_config()
270          with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "banana"}, clear=False):
271              _apply_env_overrides(config)
272          assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
273  
274      def test_env_var_empty_value_ignored(self):
275          config = self._make_config()
276          with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": ""}, clear=False):
277              _apply_env_overrides(config)
278          assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
279  
280      def test_env_var_creates_platform_config_if_missing(self):
281          """DISCORD_REPLY_TO_MODE creates PlatformConfig even without DISCORD_BOT_TOKEN."""
282          config = GatewayConfig()
283          assert Platform.DISCORD not in config.platforms
284          with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
285              _apply_env_overrides(config)
286          assert Platform.DISCORD in config.platforms
287          assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
288  
289  
290  # ------------------------------------------------------------------
291  # Tests for reply_to_text extraction in _handle_message
292  # ------------------------------------------------------------------
293  
294  # Build FakeDMChannel as a subclass of the real discord.DMChannel when the
295  # library is installed — this guarantees isinstance() checks pass in
296  # production code regardless of test ordering or monkeypatch state.
297  try:
298      import discord as _discord_lib
299      _DMChannelBase = _discord_lib.DMChannel
300  except (ImportError, AttributeError):
301      _DMChannelBase = object
302  
303  
304  class FakeDMChannel(_DMChannelBase):
305      """Minimal DM channel stub (skips mention / channel-allow checks)."""
306      def __init__(self, channel_id: int = 100, name: str = "dm"):
307          # Do NOT call super().__init__() — real DMChannel requires State
308          self.id = channel_id
309          self.name = name
310  
311  
312  def _make_message(*, content: str = "hi", reference=None):
313      """Build a mock Discord message for _handle_message tests."""
314      author = SimpleNamespace(id=42, display_name="TestUser", name="TestUser")
315      return SimpleNamespace(
316          id=999,
317          content=content,
318          mentions=[],
319          attachments=[],
320          reference=reference,
321          created_at=datetime.now(timezone.utc),
322          channel=FakeDMChannel(),
323          author=author,
324      )
325  
326  
327  @pytest.fixture
328  def reply_text_adapter(monkeypatch):
329      """DiscordAdapter wired for _handle_message → handle_message capture."""
330      config = PlatformConfig(enabled=True, token="fake-token")
331      adapter = DiscordAdapter(config)
332      adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
333      adapter._text_batch_delay_seconds = 0
334      adapter.handle_message = AsyncMock()
335      return adapter
336  
337  
338  class TestReplyToText:
339      """Tests for reply_to_text populated by _handle_message."""
340  
341      @pytest.mark.asyncio
342      async def test_no_reference_both_none(self, reply_text_adapter):
343          message = _make_message(reference=None)
344  
345          await reply_text_adapter._handle_message(message)
346  
347          event = reply_text_adapter.handle_message.await_args.args[0]
348          assert event.reply_to_message_id is None
349          assert event.reply_to_text is None
350  
351      @pytest.mark.asyncio
352      async def test_reference_without_resolved(self, reply_text_adapter):
353          ref = SimpleNamespace(message_id=555, resolved=None)
354          message = _make_message(reference=ref)
355  
356          await reply_text_adapter._handle_message(message)
357  
358          event = reply_text_adapter.handle_message.await_args.args[0]
359          assert event.reply_to_message_id == "555"
360          assert event.reply_to_text is None
361  
362      @pytest.mark.asyncio
363      async def test_reference_with_resolved_content(self, reply_text_adapter):
364          resolved_msg = SimpleNamespace(content="original message text")
365          ref = SimpleNamespace(message_id=555, resolved=resolved_msg)
366          message = _make_message(reference=ref)
367  
368          await reply_text_adapter._handle_message(message)
369  
370          event = reply_text_adapter.handle_message.await_args.args[0]
371          assert event.reply_to_message_id == "555"
372          assert event.reply_to_text == "original message text"
373  
374      @pytest.mark.asyncio
375      async def test_reference_with_empty_resolved_content(self, reply_text_adapter):
376          """Empty string content should become None, not leak as empty string."""
377          resolved_msg = SimpleNamespace(content="")
378          ref = SimpleNamespace(message_id=555, resolved=resolved_msg)
379          message = _make_message(reference=ref)
380  
381          await reply_text_adapter._handle_message(message)
382  
383          event = reply_text_adapter.handle_message.await_args.args[0]
384          assert event.reply_to_message_id == "555"
385          assert event.reply_to_text is None
386  
387      @pytest.mark.asyncio
388      async def test_reference_with_deleted_message(self, reply_text_adapter):
389          """Deleted messages lack .content — getattr guard should return None."""
390          resolved_deleted = SimpleNamespace(id=555)
391          ref = SimpleNamespace(message_id=555, resolved=resolved_deleted)
392          message = _make_message(reference=ref)
393  
394          await reply_text_adapter._handle_message(message)
395  
396          event = reply_text_adapter.handle_message.await_args.args[0]
397          assert event.reply_to_message_id == "555"
398          assert event.reply_to_text is None