/ tests / gateway / test_discord_channel_controls.py
test_discord_channel_controls.py
  1  """Tests for Discord ignored_channels and no_thread_channels config."""
  2  
  3  from types import SimpleNamespace
  4  from datetime import datetime, timezone
  5  from unittest.mock import AsyncMock, MagicMock
  6  import sys
  7  
  8  import pytest
  9  
 10  from gateway.config import PlatformConfig
 11  
 12  
 13  def _ensure_discord_mock():
 14      """Install a mock discord module when discord.py isn't available."""
 15      if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
 16          return
 17  
 18      discord_mod = MagicMock()
 19      discord_mod.Intents.default.return_value = MagicMock()
 20      discord_mod.Client = MagicMock
 21      discord_mod.File = MagicMock
 22      discord_mod.DMChannel = type("DMChannel", (), {})
 23      discord_mod.Thread = type("Thread", (), {})
 24      discord_mod.ForumChannel = type("ForumChannel", (), {})
 25      discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
 26      discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
 27      discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
 28      discord_mod.Interaction = object
 29      discord_mod.Embed = MagicMock
 30      discord_mod.app_commands = SimpleNamespace(
 31          describe=lambda **kwargs: (lambda fn: fn),
 32          choices=lambda **kwargs: (lambda fn: fn),
 33          Choice=lambda **kwargs: SimpleNamespace(**kwargs),
 34      )
 35  
 36      ext_mod = MagicMock()
 37      commands_mod = MagicMock()
 38      commands_mod.Bot = MagicMock
 39      ext_mod.commands = commands_mod
 40  
 41      sys.modules.setdefault("discord", discord_mod)
 42      sys.modules.setdefault("discord.ext", ext_mod)
 43      sys.modules.setdefault("discord.ext.commands", commands_mod)
 44  
 45  
 46  _ensure_discord_mock()
 47  
 48  import gateway.platforms.discord as discord_platform  # noqa: E402
 49  from gateway.platforms.discord import DiscordAdapter  # noqa: E402
 50  
 51  
 52  class FakeDMChannel:
 53      def __init__(self, channel_id: int = 1, name: str = "dm"):
 54          self.id = channel_id
 55          self.name = name
 56  
 57  
 58  class FakeTextChannel:
 59      def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
 60          self.id = channel_id
 61          self.name = name
 62          self.guild = SimpleNamespace(name=guild_name)
 63          self.topic = None
 64  
 65  
 66  class FakeThread:
 67      def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
 68          self.id = channel_id
 69          self.name = name
 70          self.parent = parent
 71          self.parent_id = getattr(parent, "id", None)
 72          self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
 73          self.topic = None
 74  
 75  
 76  @pytest.fixture
 77  def adapter(monkeypatch):
 78      monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
 79      monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
 80  
 81      config = PlatformConfig(enabled=True, token="fake-token")
 82      adapter = DiscordAdapter(config)
 83      adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
 84      adapter._text_batch_delay_seconds = 0  # disable batching for tests
 85      adapter.handle_message = AsyncMock()
 86      return adapter
 87  
 88  
 89  def make_message(*, channel, content: str, mentions=None):
 90      author = SimpleNamespace(id=42, display_name="TestUser", name="TestUser")
 91      return SimpleNamespace(
 92          id=123,
 93          content=content,
 94          mentions=list(mentions or []),
 95          attachments=[],
 96          reference=None,
 97          created_at=datetime.now(timezone.utc),
 98          channel=channel,
 99          author=author,
100      )
101  
102  
103  # ── ignored_channels ─────────────────────────────────────────────────
104  
105  
106  @pytest.mark.asyncio
107  async def test_ignored_channel_blocks_message(adapter, monkeypatch):
108      """Messages in ignored channels are silently dropped."""
109      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
110      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
111      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
112  
113      message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
114      await adapter._handle_message(message)
115  
116      adapter.handle_message.assert_not_awaited()
117  
118  
119  @pytest.mark.asyncio
120  async def test_ignored_channel_blocks_even_with_mention(adapter, monkeypatch):
121      """Ignored channels take priority — even @mentions are dropped."""
122      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
123      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
124  
125      bot_user = adapter._client.user
126      message = make_message(
127          channel=FakeTextChannel(channel_id=500),
128          content=f"<@{bot_user.id}> hello",
129          mentions=[bot_user],
130      )
131      await adapter._handle_message(message)
132  
133      adapter.handle_message.assert_not_awaited()
134  
135  
136  @pytest.mark.asyncio
137  async def test_non_ignored_channel_processes_normally(adapter, monkeypatch):
138      """Channels not in the ignored list process normally."""
139      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
140      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500,600")
141      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
142  
143      message = make_message(channel=FakeTextChannel(channel_id=700), content="hello")
144      await adapter._handle_message(message)
145  
146      adapter.handle_message.assert_awaited_once()
147  
148  
149  @pytest.mark.asyncio
150  async def test_ignored_channels_csv_parsing(adapter, monkeypatch):
151      """Multiple channel IDs are parsed correctly from CSV."""
152      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
153      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500, 600 , 700")
154      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
155  
156      for ch_id in (500, 600, 700):
157          adapter.handle_message.reset_mock()
158          message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
159          await adapter._handle_message(message)
160          adapter.handle_message.assert_not_awaited()
161  
162  
163  @pytest.mark.asyncio
164  async def test_ignored_channels_empty_string_ignores_nothing(adapter, monkeypatch):
165      """Empty DISCORD_IGNORED_CHANNELS means nothing is ignored."""
166      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
167      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
168      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
169  
170      message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
171      await adapter._handle_message(message)
172  
173      adapter.handle_message.assert_awaited_once()
174  
175  
176  @pytest.mark.asyncio
177  async def test_ignored_channel_thread_parent_match(adapter, monkeypatch):
178      """Thread whose parent channel is ignored should also be ignored."""
179      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
180      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
181      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
182  
183      parent = FakeTextChannel(channel_id=500, name="ignored-channel")
184      thread = FakeThread(channel_id=501, name="thread-in-ignored", parent=parent)
185      message = make_message(channel=thread, content="hello from thread")
186      await adapter._handle_message(message)
187  
188      adapter.handle_message.assert_not_awaited()
189  
190  
191  @pytest.mark.asyncio
192  async def test_dms_unaffected_by_ignored_channels(adapter, monkeypatch):
193      """DMs should never be affected by ignored_channels."""
194      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
195      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
196  
197      message = make_message(channel=FakeDMChannel(channel_id=500), content="dm hello")
198      await adapter._handle_message(message)
199  
200      adapter.handle_message.assert_awaited_once()
201  
202  
203  # ── no_thread_channels ───────────────────────────────────────────────
204  
205  
206  @pytest.mark.asyncio
207  async def test_no_thread_channel_skips_auto_thread(adapter, monkeypatch):
208      """Channels in no_thread_channels should not auto-create threads."""
209      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
210      monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
211      monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
212      monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
213      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
214  
215      adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
216  
217      message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
218      await adapter._handle_message(message)
219  
220      adapter._auto_create_thread.assert_not_awaited()
221      adapter.handle_message.assert_awaited_once()
222      event = adapter.handle_message.await_args.args[0]
223      assert event.source.chat_type == "group"
224  
225  
226  @pytest.mark.asyncio
227  async def test_normal_channel_still_auto_threads(adapter, monkeypatch):
228      """Channels NOT in no_thread_channels still get auto-threading."""
229      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
230      monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
231      monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
232      monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
233      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
234  
235      fake_thread = FakeThread(channel_id=999, name="auto-thread")
236      adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
237  
238      message = make_message(channel=FakeTextChannel(channel_id=900), content="hello")
239      await adapter._handle_message(message)
240  
241      adapter._auto_create_thread.assert_awaited_once()
242      adapter.handle_message.assert_awaited_once()
243      event = adapter.handle_message.await_args.args[0]
244      assert event.source.chat_type == "thread"
245  
246  
247  @pytest.mark.asyncio
248  async def test_no_thread_channels_csv_parsing(adapter, monkeypatch):
249      """Multiple no_thread channel IDs parsed from CSV."""
250      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
251      monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800, 900")
252      monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
253      monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
254      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
255  
256      adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
257  
258      for ch_id in (800, 900):
259          adapter._auto_create_thread.reset_mock()
260          adapter.handle_message.reset_mock()
261          message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
262          await adapter._handle_message(message)
263          adapter._auto_create_thread.assert_not_awaited()
264  
265  
266  @pytest.mark.asyncio
267  async def test_no_thread_with_auto_thread_disabled_is_noop(adapter, monkeypatch):
268      """no_thread_channels is a no-op when auto_thread is globally disabled."""
269      monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
270      monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
271      monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
272      monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
273      monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
274  
275      adapter._auto_create_thread = AsyncMock()
276  
277      message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
278      await adapter._handle_message(message)
279  
280      adapter._auto_create_thread.assert_not_awaited()
281      adapter.handle_message.assert_awaited_once()
282  
283  
284  # ── config.py bridging ───────────────────────────────────────────────
285  
286  
287  def test_config_bridges_ignored_channels(monkeypatch, tmp_path):
288      """gateway/config.py bridges discord.ignored_channels to env var."""
289      import yaml
290      config_file = tmp_path / "config.yaml"
291      config_file.write_text(yaml.dump({
292          "discord": {
293              "ignored_channels": ["111", "222"],
294          },
295      }))
296      monkeypatch.setenv("HERMES_HOME", str(tmp_path))
297      # Use setenv (not delenv) so monkeypatch registers cleanup even when
298      # the var doesn't exist yet — load_gateway_config will overwrite it.
299      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
300  
301      from gateway.config import load_gateway_config
302      load_gateway_config()
303  
304      import os
305      assert os.getenv("DISCORD_IGNORED_CHANNELS") == "111,222"
306  
307  
308  def test_config_bridges_no_thread_channels(monkeypatch, tmp_path):
309      """gateway/config.py bridges discord.no_thread_channels to env var."""
310      import yaml
311      config_file = tmp_path / "config.yaml"
312      config_file.write_text(yaml.dump({
313          "discord": {
314              "no_thread_channels": ["333"],
315          },
316      }))
317      monkeypatch.setenv("HERMES_HOME", str(tmp_path))
318      monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "")
319  
320      from gateway.config import load_gateway_config
321      load_gateway_config()
322  
323      import os
324      assert os.getenv("DISCORD_NO_THREAD_CHANNELS") == "333"
325  
326  
327  def test_config_env_var_takes_precedence(monkeypatch, tmp_path):
328      """Env vars should take precedence over config.yaml values."""
329      import yaml
330      config_file = tmp_path / "config.yaml"
331      config_file.write_text(yaml.dump({
332          "discord": {
333              "ignored_channels": ["111"],
334          },
335      }))
336      monkeypatch.setenv("HERMES_HOME", str(tmp_path))
337      monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "999")
338  
339      from gateway.config import load_gateway_config
340      load_gateway_config()
341  
342      import os
343      # Env var should NOT be overwritten
344      assert os.getenv("DISCORD_IGNORED_CHANNELS") == "999"