test_discord_send.py
1 from types import SimpleNamespace 2 from unittest.mock import AsyncMock, MagicMock 3 import sys 4 5 import pytest 6 7 from gateway.config import PlatformConfig 8 9 10 def _ensure_discord_mock(): 11 if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): 12 return 13 14 discord_mod = MagicMock() 15 discord_mod.Intents.default.return_value = MagicMock() 16 discord_mod.Client = MagicMock 17 discord_mod.File = MagicMock 18 discord_mod.DMChannel = type("DMChannel", (), {}) 19 discord_mod.Thread = type("Thread", (), {}) 20 discord_mod.ForumChannel = type("ForumChannel", (), {}) 21 discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) 22 discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3) 23 discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5) 24 discord_mod.Interaction = object 25 discord_mod.Embed = MagicMock 26 discord_mod.app_commands = SimpleNamespace( 27 describe=lambda **kwargs: (lambda fn: fn), 28 choices=lambda **kwargs: (lambda fn: fn), 29 Choice=lambda **kwargs: SimpleNamespace(**kwargs), 30 ) 31 32 ext_mod = MagicMock() 33 commands_mod = MagicMock() 34 commands_mod.Bot = MagicMock 35 ext_mod.commands = commands_mod 36 37 sys.modules.setdefault("discord", discord_mod) 38 sys.modules.setdefault("discord.ext", ext_mod) 39 sys.modules.setdefault("discord.ext.commands", commands_mod) 40 41 42 _ensure_discord_mock() 43 44 from gateway.platforms.discord import DiscordAdapter # noqa: E402 45 46 47 @pytest.mark.asyncio 48 async def test_send_retries_without_reference_when_reply_target_is_system_message(): 49 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 50 51 reference_obj = object() 52 ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj)) 53 sent_msg = SimpleNamespace(id=1234) 54 send_calls = [] 55 56 async def fake_send(*, content, reference=None): 57 send_calls.append({"content": content, "reference": reference}) 58 if len(send_calls) == 1: 59 raise RuntimeError( 60 "400 Bad Request (error code: 50035): Invalid Form Body\n" 61 "In message_reference: Cannot reply to a system message" 62 ) 63 return sent_msg 64 65 channel = SimpleNamespace( 66 fetch_message=AsyncMock(return_value=ref_msg), 67 send=AsyncMock(side_effect=fake_send), 68 ) 69 adapter._client = SimpleNamespace( 70 get_channel=lambda _chat_id: channel, 71 fetch_channel=AsyncMock(), 72 ) 73 74 result = await adapter.send("555", "hello", reply_to="99") 75 76 assert result.success is True 77 assert result.message_id == "1234" 78 assert channel.fetch_message.await_count == 1 79 assert channel.send.await_count == 2 80 ref_msg.to_reference.assert_called_once_with(fail_if_not_exists=False) 81 assert send_calls[0]["reference"] is reference_obj 82 assert send_calls[1]["reference"] is None 83 84 85 @pytest.mark.asyncio 86 async def test_send_retries_without_reference_when_reply_target_is_deleted(): 87 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 88 89 reference_obj = object() 90 ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj)) 91 sent_msgs = [SimpleNamespace(id=1001), SimpleNamespace(id=1002)] 92 send_calls = [] 93 94 async def fake_send(*, content, reference=None): 95 send_calls.append({"content": content, "reference": reference}) 96 if len(send_calls) == 1: 97 raise RuntimeError( 98 "400 Bad Request (error code: 10008): Unknown Message" 99 ) 100 return sent_msgs[len(send_calls) - 2] 101 102 channel = SimpleNamespace( 103 fetch_message=AsyncMock(return_value=ref_msg), 104 send=AsyncMock(side_effect=fake_send), 105 ) 106 adapter._client = SimpleNamespace( 107 get_channel=lambda _chat_id: channel, 108 fetch_channel=AsyncMock(), 109 ) 110 111 long_text = "A" * (adapter.MAX_MESSAGE_LENGTH + 50) 112 result = await adapter.send("555", long_text, reply_to="99") 113 114 assert result.success is True 115 assert result.message_id == "1001" 116 assert channel.fetch_message.await_count == 1 117 assert channel.send.await_count == 3 118 ref_msg.to_reference.assert_called_once_with(fail_if_not_exists=False) 119 assert send_calls[0]["reference"] is reference_obj 120 assert send_calls[1]["reference"] is None 121 assert send_calls[2]["reference"] is None 122 123 124 @pytest.mark.asyncio 125 async def test_send_does_not_retry_on_unrelated_errors(): 126 """Regression guard: errors unrelated to the reply reference (e.g. 50013 127 Missing Permissions) must NOT trigger the no-reference retry path — they 128 should propagate out of the per-chunk loop and surface as a failed 129 SendResult so the caller sees the real problem instead of a silent retry. 130 """ 131 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 132 133 reference_obj = object() 134 ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj)) 135 send_calls = [] 136 137 async def fake_send(*, content, reference=None): 138 send_calls.append({"content": content, "reference": reference}) 139 raise RuntimeError( 140 "403 Forbidden (error code: 50013): Missing Permissions" 141 ) 142 143 channel = SimpleNamespace( 144 fetch_message=AsyncMock(return_value=ref_msg), 145 send=AsyncMock(side_effect=fake_send), 146 ) 147 adapter._client = SimpleNamespace( 148 get_channel=lambda _chat_id: channel, 149 fetch_channel=AsyncMock(), 150 ) 151 152 result = await adapter.send("555", "hello", reply_to="99") 153 154 # Outer except in adapter.send() wraps propagated errors as SendResult. 155 assert result.success is False 156 assert "50013" in (result.error or "") 157 # Only the first attempt happens — no reference-retry replay. 158 assert channel.send.await_count == 1 159 assert send_calls[0]["reference"] is reference_obj 160 161 162 # --------------------------------------------------------------------------- 163 # Forum channel tests 164 # --------------------------------------------------------------------------- 165 166 import discord as _discord_mod # noqa: E402 — imported after _ensure_discord_mock 167 168 169 class TestIsForumParent: 170 def test_none_returns_false(self): 171 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 172 assert adapter._is_forum_parent(None) is False 173 174 def test_forum_channel_class_instance(self): 175 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 176 forum_cls = getattr(_discord_mod, "ForumChannel", None) 177 if forum_cls is None: 178 # Re-create a type for the mock 179 forum_cls = type("ForumChannel", (), {}) 180 _discord_mod.ForumChannel = forum_cls 181 ch = forum_cls() 182 assert adapter._is_forum_parent(ch) is True 183 184 def test_type_value_15(self): 185 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 186 ch = SimpleNamespace(type=15) 187 assert adapter._is_forum_parent(ch) is True 188 189 def test_regular_channel_returns_false(self): 190 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 191 ch = SimpleNamespace(type=0) 192 assert adapter._is_forum_parent(ch) is False 193 194 def test_thread_returns_false(self): 195 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 196 ch = SimpleNamespace(type=11) # public thread 197 assert adapter._is_forum_parent(ch) is False 198 199 200 @pytest.mark.asyncio 201 async def test_send_to_forum_creates_thread_post(): 202 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 203 204 # thread object has no 'send' so _send_to_forum uses thread.thread 205 thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600))) 206 thread = SimpleNamespace( 207 id=555, 208 message=SimpleNamespace(id=500), 209 thread=thread_ch, 210 ) 211 forum_channel = _discord_mod.ForumChannel() 212 forum_channel.id = 999 213 forum_channel.name = "ideas" 214 forum_channel.create_thread = AsyncMock(return_value=thread) 215 adapter._client = SimpleNamespace( 216 get_channel=lambda _chat_id: forum_channel, 217 fetch_channel=AsyncMock(), 218 ) 219 220 result = await adapter.send("999", "Hello forum!") 221 222 assert result.success is True 223 assert result.message_id == "500" 224 forum_channel.create_thread.assert_awaited_once() 225 226 227 @pytest.mark.asyncio 228 async def test_send_to_forum_sends_remaining_chunks(): 229 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 230 # Force a small max message length so the message splits 231 adapter.MAX_MESSAGE_LENGTH = 20 232 233 chunk_msg_1 = SimpleNamespace(id=500) 234 chunk_msg_2 = SimpleNamespace(id=501) 235 thread_ch = SimpleNamespace( 236 id=555, 237 send=AsyncMock(return_value=chunk_msg_2), 238 ) 239 # thread object has no 'send' so _send_to_forum uses thread.thread 240 thread = SimpleNamespace( 241 id=555, 242 message=chunk_msg_1, 243 thread=thread_ch, 244 ) 245 forum_channel = _discord_mod.ForumChannel() 246 forum_channel.id = 999 247 forum_channel.name = "ideas" 248 forum_channel.create_thread = AsyncMock(return_value=thread) 249 adapter._client = SimpleNamespace( 250 get_channel=lambda _chat_id: forum_channel, 251 fetch_channel=AsyncMock(), 252 ) 253 254 result = await adapter.send("999", "A" * 50) 255 256 assert result.success is True 257 assert result.message_id == "500" 258 # Should have sent at least one follow-up chunk 259 assert thread_ch.send.await_count >= 1 260 261 262 @pytest.mark.asyncio 263 async def test_send_to_forum_create_thread_failure(): 264 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 265 266 forum_channel = _discord_mod.ForumChannel() 267 forum_channel.id = 999 268 forum_channel.name = "ideas" 269 forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited")) 270 adapter._client = SimpleNamespace( 271 get_channel=lambda _chat_id: forum_channel, 272 fetch_channel=AsyncMock(), 273 ) 274 275 result = await adapter.send("999", "Hello forum!") 276 277 assert result.success is False 278 assert "rate limited" in result.error 279 280 281 282 # --------------------------------------------------------------------------- 283 # Forum follow-up chunk failure reporting + media on forum paths 284 # --------------------------------------------------------------------------- 285 286 287 @pytest.mark.asyncio 288 async def test_send_to_forum_follow_up_chunk_failures_collected_as_warnings(): 289 """Partial-send chunk failures surface in raw_response['warnings'].""" 290 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 291 adapter.MAX_MESSAGE_LENGTH = 20 292 293 chunk_msg_1 = SimpleNamespace(id=500) 294 # Every follow-up chunk fails — we should collect a warning per failure 295 thread_ch = SimpleNamespace( 296 id=555, 297 send=AsyncMock(side_effect=Exception("rate limited")), 298 ) 299 thread = SimpleNamespace(id=555, message=chunk_msg_1, thread=thread_ch) 300 forum_channel = _discord_mod.ForumChannel() 301 forum_channel.id = 999 302 forum_channel.name = "ideas" 303 forum_channel.create_thread = AsyncMock(return_value=thread) 304 adapter._client = SimpleNamespace( 305 get_channel=lambda _chat_id: forum_channel, 306 fetch_channel=AsyncMock(), 307 ) 308 309 # Long enough to produce multiple chunks 310 result = await adapter.send("999", "A" * 60) 311 312 # Starter message (first chunk) was delivered via create_thread, so send is 313 # successful overall — but follow-up chunks all failed and are reported. 314 assert result.success is True 315 assert result.message_id == "500" 316 warnings = (result.raw_response or {}).get("warnings") or [] 317 assert len(warnings) >= 1 318 assert all("rate limited" in w for w in warnings) 319 320 321 @pytest.mark.asyncio 322 async def test_forum_post_file_creates_thread_with_attachment(): 323 """_forum_post_file routes file-bearing sends to create_thread with file kwarg.""" 324 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 325 326 thread_ch = SimpleNamespace(id=777, send=AsyncMock()) 327 thread = SimpleNamespace(id=777, message=SimpleNamespace(id=800), thread=thread_ch) 328 forum_channel = _discord_mod.ForumChannel() 329 forum_channel.id = 999 330 forum_channel.name = "ideas" 331 forum_channel.create_thread = AsyncMock(return_value=thread) 332 333 # discord.File is a real class; build a MagicMock that looks like one 334 fake_file = SimpleNamespace(filename="photo.png") 335 336 result = await adapter._forum_post_file( 337 forum_channel, 338 content="here is a photo", 339 file=fake_file, 340 ) 341 342 assert result.success is True 343 assert result.message_id == "800" 344 forum_channel.create_thread.assert_awaited_once() 345 call_kwargs = forum_channel.create_thread.await_args.kwargs 346 assert call_kwargs["file"] is fake_file 347 assert call_kwargs["content"] == "here is a photo" 348 # Thread name derived from content's first line 349 assert call_kwargs["name"] == "here is a photo" 350 351 352 @pytest.mark.asyncio 353 async def test_forum_post_file_uses_filename_when_no_content(): 354 """Thread name falls back to file.filename when no content is provided.""" 355 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 356 357 thread = SimpleNamespace(id=1, message=SimpleNamespace(id=2), thread=SimpleNamespace(id=1, send=AsyncMock())) 358 forum_channel = _discord_mod.ForumChannel() 359 forum_channel.id = 10 360 forum_channel.name = "forum" 361 forum_channel.create_thread = AsyncMock(return_value=thread) 362 363 fake_file = SimpleNamespace(filename="voice-message.ogg") 364 result = await adapter._forum_post_file(forum_channel, content="", file=fake_file) 365 366 assert result.success is True 367 call_kwargs = forum_channel.create_thread.await_args.kwargs 368 # Content was empty → thread name derived from filename 369 assert call_kwargs["name"] == "voice-message.ogg" 370 371 372 @pytest.mark.asyncio 373 async def test_forum_post_file_creation_failure(): 374 """_forum_post_file returns a failed SendResult when create_thread raises.""" 375 adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) 376 377 forum_channel = _discord_mod.ForumChannel() 378 forum_channel.id = 999 379 forum_channel.create_thread = AsyncMock(side_effect=Exception("missing perms")) 380 381 result = await adapter._forum_post_file( 382 forum_channel, 383 content="hi", 384 file=SimpleNamespace(filename="x.png"), 385 ) 386 387 assert result.success is False 388 assert "missing perms" in (result.error or "")