/ tests / gateway / test_discord_send.py
test_discord_send.py
  1  from types import SimpleNamespace
  2  from unittest.mock import AsyncMock, MagicMock
  3  import sys
  4  
  5  import pytest
  6  
  7  from gateway.config import PlatformConfig
  8  
  9  
 10  def _ensure_discord_mock():
 11      if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
 12          return
 13  
 14      discord_mod = MagicMock()
 15      discord_mod.Intents.default.return_value = MagicMock()
 16      discord_mod.Client = MagicMock
 17      discord_mod.File = MagicMock
 18      discord_mod.DMChannel = type("DMChannel", (), {})
 19      discord_mod.Thread = type("Thread", (), {})
 20      discord_mod.ForumChannel = type("ForumChannel", (), {})
 21      discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
 22      discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
 23      discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
 24      discord_mod.Interaction = object
 25      discord_mod.Embed = MagicMock
 26      discord_mod.app_commands = SimpleNamespace(
 27          describe=lambda **kwargs: (lambda fn: fn),
 28          choices=lambda **kwargs: (lambda fn: fn),
 29          Choice=lambda **kwargs: SimpleNamespace(**kwargs),
 30      )
 31  
 32      ext_mod = MagicMock()
 33      commands_mod = MagicMock()
 34      commands_mod.Bot = MagicMock
 35      ext_mod.commands = commands_mod
 36  
 37      sys.modules.setdefault("discord", discord_mod)
 38      sys.modules.setdefault("discord.ext", ext_mod)
 39      sys.modules.setdefault("discord.ext.commands", commands_mod)
 40  
 41  
 42  _ensure_discord_mock()
 43  
 44  from gateway.platforms.discord import DiscordAdapter  # noqa: E402
 45  
 46  
 47  @pytest.mark.asyncio
 48  async def test_send_retries_without_reference_when_reply_target_is_system_message():
 49      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
 50  
 51      reference_obj = object()
 52      ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj))
 53      sent_msg = SimpleNamespace(id=1234)
 54      send_calls = []
 55  
 56      async def fake_send(*, content, reference=None):
 57          send_calls.append({"content": content, "reference": reference})
 58          if len(send_calls) == 1:
 59              raise RuntimeError(
 60                  "400 Bad Request (error code: 50035): Invalid Form Body\n"
 61                  "In message_reference: Cannot reply to a system message"
 62              )
 63          return sent_msg
 64  
 65      channel = SimpleNamespace(
 66          fetch_message=AsyncMock(return_value=ref_msg),
 67          send=AsyncMock(side_effect=fake_send),
 68      )
 69      adapter._client = SimpleNamespace(
 70          get_channel=lambda _chat_id: channel,
 71          fetch_channel=AsyncMock(),
 72      )
 73  
 74      result = await adapter.send("555", "hello", reply_to="99")
 75  
 76      assert result.success is True
 77      assert result.message_id == "1234"
 78      assert channel.fetch_message.await_count == 1
 79      assert channel.send.await_count == 2
 80      ref_msg.to_reference.assert_called_once_with(fail_if_not_exists=False)
 81      assert send_calls[0]["reference"] is reference_obj
 82      assert send_calls[1]["reference"] is None
 83  
 84  
 85  @pytest.mark.asyncio
 86  async def test_send_retries_without_reference_when_reply_target_is_deleted():
 87      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
 88  
 89      reference_obj = object()
 90      ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj))
 91      sent_msgs = [SimpleNamespace(id=1001), SimpleNamespace(id=1002)]
 92      send_calls = []
 93  
 94      async def fake_send(*, content, reference=None):
 95          send_calls.append({"content": content, "reference": reference})
 96          if len(send_calls) == 1:
 97              raise RuntimeError(
 98                  "400 Bad Request (error code: 10008): Unknown Message"
 99              )
100          return sent_msgs[len(send_calls) - 2]
101  
102      channel = SimpleNamespace(
103          fetch_message=AsyncMock(return_value=ref_msg),
104          send=AsyncMock(side_effect=fake_send),
105      )
106      adapter._client = SimpleNamespace(
107          get_channel=lambda _chat_id: channel,
108          fetch_channel=AsyncMock(),
109      )
110  
111      long_text = "A" * (adapter.MAX_MESSAGE_LENGTH + 50)
112      result = await adapter.send("555", long_text, reply_to="99")
113  
114      assert result.success is True
115      assert result.message_id == "1001"
116      assert channel.fetch_message.await_count == 1
117      assert channel.send.await_count == 3
118      ref_msg.to_reference.assert_called_once_with(fail_if_not_exists=False)
119      assert send_calls[0]["reference"] is reference_obj
120      assert send_calls[1]["reference"] is None
121      assert send_calls[2]["reference"] is None
122  
123  
124  @pytest.mark.asyncio
125  async def test_send_does_not_retry_on_unrelated_errors():
126      """Regression guard: errors unrelated to the reply reference (e.g. 50013
127      Missing Permissions) must NOT trigger the no-reference retry path — they
128      should propagate out of the per-chunk loop and surface as a failed
129      SendResult so the caller sees the real problem instead of a silent retry.
130      """
131      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
132  
133      reference_obj = object()
134      ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj))
135      send_calls = []
136  
137      async def fake_send(*, content, reference=None):
138          send_calls.append({"content": content, "reference": reference})
139          raise RuntimeError(
140              "403 Forbidden (error code: 50013): Missing Permissions"
141          )
142  
143      channel = SimpleNamespace(
144          fetch_message=AsyncMock(return_value=ref_msg),
145          send=AsyncMock(side_effect=fake_send),
146      )
147      adapter._client = SimpleNamespace(
148          get_channel=lambda _chat_id: channel,
149          fetch_channel=AsyncMock(),
150      )
151  
152      result = await adapter.send("555", "hello", reply_to="99")
153  
154      # Outer except in adapter.send() wraps propagated errors as SendResult.
155      assert result.success is False
156      assert "50013" in (result.error or "")
157      # Only the first attempt happens — no reference-retry replay.
158      assert channel.send.await_count == 1
159      assert send_calls[0]["reference"] is reference_obj
160  
161  
162  # ---------------------------------------------------------------------------
163  # Forum channel tests
164  # ---------------------------------------------------------------------------
165  
166  import discord as _discord_mod  # noqa: E402 — imported after _ensure_discord_mock
167  
168  
169  class TestIsForumParent:
170      def test_none_returns_false(self):
171          adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
172          assert adapter._is_forum_parent(None) is False
173  
174      def test_forum_channel_class_instance(self):
175          adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
176          forum_cls = getattr(_discord_mod, "ForumChannel", None)
177          if forum_cls is None:
178              # Re-create a type for the mock
179              forum_cls = type("ForumChannel", (), {})
180              _discord_mod.ForumChannel = forum_cls
181          ch = forum_cls()
182          assert adapter._is_forum_parent(ch) is True
183  
184      def test_type_value_15(self):
185          adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
186          ch = SimpleNamespace(type=15)
187          assert adapter._is_forum_parent(ch) is True
188  
189      def test_regular_channel_returns_false(self):
190          adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
191          ch = SimpleNamespace(type=0)
192          assert adapter._is_forum_parent(ch) is False
193  
194      def test_thread_returns_false(self):
195          adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
196          ch = SimpleNamespace(type=11)  # public thread
197          assert adapter._is_forum_parent(ch) is False
198  
199  
200  @pytest.mark.asyncio
201  async def test_send_to_forum_creates_thread_post():
202      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
203  
204      # thread object has no 'send' so _send_to_forum uses thread.thread
205      thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600)))
206      thread = SimpleNamespace(
207          id=555,
208          message=SimpleNamespace(id=500),
209          thread=thread_ch,
210      )
211      forum_channel = _discord_mod.ForumChannel()
212      forum_channel.id = 999
213      forum_channel.name = "ideas"
214      forum_channel.create_thread = AsyncMock(return_value=thread)
215      adapter._client = SimpleNamespace(
216          get_channel=lambda _chat_id: forum_channel,
217          fetch_channel=AsyncMock(),
218      )
219  
220      result = await adapter.send("999", "Hello forum!")
221  
222      assert result.success is True
223      assert result.message_id == "500"
224      forum_channel.create_thread.assert_awaited_once()
225  
226  
227  @pytest.mark.asyncio
228  async def test_send_to_forum_sends_remaining_chunks():
229      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
230      # Force a small max message length so the message splits
231      adapter.MAX_MESSAGE_LENGTH = 20
232  
233      chunk_msg_1 = SimpleNamespace(id=500)
234      chunk_msg_2 = SimpleNamespace(id=501)
235      thread_ch = SimpleNamespace(
236          id=555,
237          send=AsyncMock(return_value=chunk_msg_2),
238      )
239      # thread object has no 'send' so _send_to_forum uses thread.thread
240      thread = SimpleNamespace(
241          id=555,
242          message=chunk_msg_1,
243          thread=thread_ch,
244      )
245      forum_channel = _discord_mod.ForumChannel()
246      forum_channel.id = 999
247      forum_channel.name = "ideas"
248      forum_channel.create_thread = AsyncMock(return_value=thread)
249      adapter._client = SimpleNamespace(
250          get_channel=lambda _chat_id: forum_channel,
251          fetch_channel=AsyncMock(),
252      )
253  
254      result = await adapter.send("999", "A" * 50)
255  
256      assert result.success is True
257      assert result.message_id == "500"
258      # Should have sent at least one follow-up chunk
259      assert thread_ch.send.await_count >= 1
260  
261  
262  @pytest.mark.asyncio
263  async def test_send_to_forum_create_thread_failure():
264      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
265  
266      forum_channel = _discord_mod.ForumChannel()
267      forum_channel.id = 999
268      forum_channel.name = "ideas"
269      forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited"))
270      adapter._client = SimpleNamespace(
271          get_channel=lambda _chat_id: forum_channel,
272          fetch_channel=AsyncMock(),
273      )
274  
275      result = await adapter.send("999", "Hello forum!")
276  
277      assert result.success is False
278      assert "rate limited" in result.error
279  
280  
281  
282  # ---------------------------------------------------------------------------
283  # Forum follow-up chunk failure reporting + media on forum paths
284  # ---------------------------------------------------------------------------
285  
286  
287  @pytest.mark.asyncio
288  async def test_send_to_forum_follow_up_chunk_failures_collected_as_warnings():
289      """Partial-send chunk failures surface in raw_response['warnings']."""
290      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
291      adapter.MAX_MESSAGE_LENGTH = 20
292  
293      chunk_msg_1 = SimpleNamespace(id=500)
294      # Every follow-up chunk fails — we should collect a warning per failure
295      thread_ch = SimpleNamespace(
296          id=555,
297          send=AsyncMock(side_effect=Exception("rate limited")),
298      )
299      thread = SimpleNamespace(id=555, message=chunk_msg_1, thread=thread_ch)
300      forum_channel = _discord_mod.ForumChannel()
301      forum_channel.id = 999
302      forum_channel.name = "ideas"
303      forum_channel.create_thread = AsyncMock(return_value=thread)
304      adapter._client = SimpleNamespace(
305          get_channel=lambda _chat_id: forum_channel,
306          fetch_channel=AsyncMock(),
307      )
308  
309      # Long enough to produce multiple chunks
310      result = await adapter.send("999", "A" * 60)
311  
312      # Starter message (first chunk) was delivered via create_thread, so send is
313      # successful overall — but follow-up chunks all failed and are reported.
314      assert result.success is True
315      assert result.message_id == "500"
316      warnings = (result.raw_response or {}).get("warnings") or []
317      assert len(warnings) >= 1
318      assert all("rate limited" in w for w in warnings)
319  
320  
321  @pytest.mark.asyncio
322  async def test_forum_post_file_creates_thread_with_attachment():
323      """_forum_post_file routes file-bearing sends to create_thread with file kwarg."""
324      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
325  
326      thread_ch = SimpleNamespace(id=777, send=AsyncMock())
327      thread = SimpleNamespace(id=777, message=SimpleNamespace(id=800), thread=thread_ch)
328      forum_channel = _discord_mod.ForumChannel()
329      forum_channel.id = 999
330      forum_channel.name = "ideas"
331      forum_channel.create_thread = AsyncMock(return_value=thread)
332  
333      # discord.File is a real class; build a MagicMock that looks like one
334      fake_file = SimpleNamespace(filename="photo.png")
335  
336      result = await adapter._forum_post_file(
337          forum_channel,
338          content="here is a photo",
339          file=fake_file,
340      )
341  
342      assert result.success is True
343      assert result.message_id == "800"
344      forum_channel.create_thread.assert_awaited_once()
345      call_kwargs = forum_channel.create_thread.await_args.kwargs
346      assert call_kwargs["file"] is fake_file
347      assert call_kwargs["content"] == "here is a photo"
348      # Thread name derived from content's first line
349      assert call_kwargs["name"] == "here is a photo"
350  
351  
352  @pytest.mark.asyncio
353  async def test_forum_post_file_uses_filename_when_no_content():
354      """Thread name falls back to file.filename when no content is provided."""
355      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
356  
357      thread = SimpleNamespace(id=1, message=SimpleNamespace(id=2), thread=SimpleNamespace(id=1, send=AsyncMock()))
358      forum_channel = _discord_mod.ForumChannel()
359      forum_channel.id = 10
360      forum_channel.name = "forum"
361      forum_channel.create_thread = AsyncMock(return_value=thread)
362  
363      fake_file = SimpleNamespace(filename="voice-message.ogg")
364      result = await adapter._forum_post_file(forum_channel, content="", file=fake_file)
365  
366      assert result.success is True
367      call_kwargs = forum_channel.create_thread.await_args.kwargs
368      # Content was empty → thread name derived from filename
369      assert call_kwargs["name"] == "voice-message.ogg"
370  
371  
372  @pytest.mark.asyncio
373  async def test_forum_post_file_creation_failure():
374      """_forum_post_file returns a failed SendResult when create_thread raises."""
375      adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
376  
377      forum_channel = _discord_mod.ForumChannel()
378      forum_channel.id = 999
379      forum_channel.create_thread = AsyncMock(side_effect=Exception("missing perms"))
380  
381      result = await adapter._forum_post_file(
382          forum_channel,
383          content="hi",
384          file=SimpleNamespace(filename="x.png"),
385      )
386  
387      assert result.success is False
388      assert "missing perms" in (result.error or "")