/ tests / gateway / test_discord_channel_prompts.py
test_discord_channel_prompts.py
  1  """Tests for Discord channel_prompts resolution and injection."""
  2  
  3  import sys
  4  import threading
  5  import types
  6  from types import SimpleNamespace
  7  from unittest.mock import AsyncMock, MagicMock
  8  
  9  import pytest
 10  
 11  
 12  def _ensure_discord_mock():
 13      if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
 14          return
 15      discord_mod = types.ModuleType("discord")
 16      discord_mod.Intents = MagicMock()
 17      discord_mod.Intents.default.return_value = MagicMock()
 18      discord_mod.DMChannel = type("DMChannel", (), {})
 19      discord_mod.Thread = type("Thread", (), {})
 20      discord_mod.ForumChannel = type("ForumChannel", (), {})
 21      discord_mod.Interaction = object
 22      ext_mod = MagicMock()
 23      commands_mod = MagicMock()
 24      commands_mod.Bot = MagicMock
 25      ext_mod.commands = commands_mod
 26      sys.modules.setdefault("discord", discord_mod)
 27      sys.modules.setdefault("discord.ext", ext_mod)
 28      sys.modules.setdefault("discord.ext.commands", commands_mod)
 29  
 30  
 31  import gateway.run as gateway_run
 32  from gateway.config import Platform
 33  from gateway.platforms.base import MessageEvent
 34  from gateway.session import SessionSource
 35  
 36  
 37  class _CapturingAgent:
 38      last_init = None
 39  
 40      def __init__(self, *args, **kwargs):
 41          type(self).last_init = dict(kwargs)
 42          self.tools = []
 43  
 44      def run_conversation(self, user_message, conversation_history=None, task_id=None, persist_user_message=None):
 45          return {
 46              "final_response": "ok",
 47              "messages": [],
 48              "api_calls": 1,
 49              "completed": True,
 50          }
 51  
 52  
 53  def _install_fake_agent(monkeypatch):
 54      fake_run_agent = types.ModuleType("run_agent")
 55      fake_run_agent.AIAgent = _CapturingAgent
 56      monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
 57  
 58  
 59  def _make_adapter():
 60      _ensure_discord_mock()
 61      from gateway.platforms.discord import DiscordAdapter
 62  
 63      adapter = object.__new__(DiscordAdapter)
 64      adapter.config = MagicMock()
 65      adapter.config.extra = {}
 66      return adapter
 67  
 68  
 69  def _make_runner():
 70      runner = object.__new__(gateway_run.GatewayRunner)
 71      runner.adapters = {}
 72      runner._ephemeral_system_prompt = "Global prompt"
 73      runner._prefill_messages = []
 74      runner._reasoning_config = None
 75      runner._service_tier = None
 76      runner._provider_routing = {}
 77      runner._fallback_model = None
 78      runner._running_agents = {}
 79      runner._pending_model_notes = {}
 80      runner._session_db = None
 81      runner._agent_cache = {}
 82      runner._agent_cache_lock = threading.Lock()
 83      runner._session_model_overrides = {}
 84      runner.hooks = SimpleNamespace(loaded_hooks=False)
 85      runner.config = SimpleNamespace(streaming=None)
 86      runner.session_store = SimpleNamespace(
 87          get_or_create_session=lambda source: SimpleNamespace(session_id="session-1"),
 88          load_transcript=lambda session_id: [],
 89      )
 90      runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
 91      runner._enrich_message_with_vision = AsyncMock(return_value="ENRICHED")
 92      return runner
 93  
 94  
 95  def _make_source() -> SessionSource:
 96      return SessionSource(
 97          platform=Platform.DISCORD,
 98          chat_id="12345",
 99          chat_type="thread",
100          user_id="user-1",
101      )
102  
103  
104  class TestResolveChannelPrompts:
105      def test_no_prompt_returns_none(self):
106          adapter = _make_adapter()
107          assert adapter._resolve_channel_prompt("123") is None
108  
109      def test_match_by_channel_id(self):
110          adapter = _make_adapter()
111          adapter.config.extra = {"channel_prompts": {"100": "Research mode"}}
112          assert adapter._resolve_channel_prompt("100") == "Research mode"
113  
114      def test_numeric_yaml_keys_normalized_at_config_load(self):
115          """Numeric YAML keys are normalized to strings by config bridging.
116  
117          The resolver itself expects string keys (config.py handles normalization),
118          so raw numeric keys will not match — this is intentional.
119          """
120          adapter = _make_adapter()
121          # Simulates post-bridging state: keys are already strings
122          adapter.config.extra = {"channel_prompts": {"100": "Research mode"}}
123          assert adapter._resolve_channel_prompt("100") == "Research mode"
124          # Pre-bridging numeric key would not match (bridging is responsible)
125          adapter.config.extra = {"channel_prompts": {100: "Research mode"}}
126          assert adapter._resolve_channel_prompt("100") is None
127  
128      def test_match_by_parent_id(self):
129          adapter = _make_adapter()
130          adapter.config.extra = {"channel_prompts": {"200": "Forum prompt"}}
131          assert adapter._resolve_channel_prompt("999", parent_id="200") == "Forum prompt"
132  
133      def test_exact_channel_overrides_parent(self):
134          adapter = _make_adapter()
135          adapter.config.extra = {
136              "channel_prompts": {
137                  "999": "Thread override",
138                  "200": "Forum prompt",
139              }
140          }
141          assert adapter._resolve_channel_prompt("999", parent_id="200") == "Thread override"
142  
143      def test_build_message_event_sets_channel_prompt(self):
144          adapter = _make_adapter()
145          adapter.config.extra = {"channel_prompts": {"321": "Command prompt"}}
146          adapter.build_source = MagicMock(return_value=SimpleNamespace())
147  
148          interaction = SimpleNamespace(
149              channel_id=321,
150              channel=SimpleNamespace(name="general", guild=None, parent_id=None),
151              user=SimpleNamespace(id=1, display_name="Brenner"),
152          )
153          adapter._get_effective_topic = MagicMock(return_value=None)
154  
155          event = adapter._build_slash_event(interaction, "/retry")
156  
157          assert event.channel_prompt == "Command prompt"
158  
159      @pytest.mark.asyncio
160      async def test_dispatch_thread_session_inherits_parent_channel_prompt(self):
161          adapter = _make_adapter()
162          adapter.config.extra = {"channel_prompts": {"200": "Parent prompt"}}
163          adapter.build_source = MagicMock(return_value=SimpleNamespace())
164          adapter._get_effective_topic = MagicMock(return_value=None)
165          adapter.handle_message = AsyncMock()
166  
167          interaction = SimpleNamespace(
168              guild=SimpleNamespace(name="Wetlands"),
169              channel=SimpleNamespace(id=200, parent=None),
170              user=SimpleNamespace(id=1, display_name="Brenner"),
171          )
172  
173          await adapter._dispatch_thread_session(interaction, "999", "new-thread", "hello")
174  
175          dispatched_event = adapter.handle_message.await_args.args[0]
176          assert dispatched_event.channel_prompt == "Parent prompt"
177  
178      def test_blank_prompts_are_ignored(self):
179          adapter = _make_adapter()
180          adapter.config.extra = {"channel_prompts": {"100": "   "}}
181          assert adapter._resolve_channel_prompt("100") is None
182  
183  
184  @pytest.mark.asyncio
185  async def test_retry_preserves_channel_prompt(monkeypatch):
186      runner = _make_runner()
187      runner.session_store = SimpleNamespace(
188          get_or_create_session=lambda source: SimpleNamespace(session_id="session-1", last_prompt_tokens=10),
189          load_transcript=lambda session_id: [
190              {"role": "user", "content": "original message"},
191              {"role": "assistant", "content": "old reply"},
192          ],
193          rewrite_transcript=MagicMock(),
194      )
195      runner._handle_message = AsyncMock(return_value="ok")
196  
197      event = MessageEvent(
198          text="/retry",
199          message_type=gateway_run.MessageType.COMMAND,
200          source=_make_source(),
201          raw_message=SimpleNamespace(),
202          channel_prompt="Channel prompt",
203      )
204  
205      result = await runner._handle_retry_command(event)
206  
207      assert result == "ok"
208      retried_event = runner._handle_message.await_args.args[0]
209      assert retried_event.channel_prompt == "Channel prompt"
210  
211  
212  @pytest.mark.asyncio
213  async def test_run_agent_appends_channel_prompt_to_ephemeral_system_prompt(monkeypatch, tmp_path):
214      _install_fake_agent(monkeypatch)
215      runner = _make_runner()
216  
217      (tmp_path / "config.yaml").write_text("agent:\n  system_prompt: Global prompt\n", encoding="utf-8")
218      monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
219      monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env")
220      monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
221      monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
222      monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
223      monkeypatch.setattr(
224          gateway_run,
225          "_resolve_runtime_agent_kwargs",
226          lambda: {
227              "provider": "openrouter",
228              "api_mode": "chat_completions",
229              "base_url": "https://openrouter.ai/api/v1",
230              "api_key": "***",
231          },
232      )
233  
234      import hermes_cli.tools_config as tools_config
235  
236      monkeypatch.setattr(tools_config, "_get_platform_tools", lambda user_config, platform_key: {"core"})
237  
238      _CapturingAgent.last_init = None
239      event = MessageEvent(
240          text="hi",
241          source=_make_source(),
242          message_id="m1",
243          channel_prompt="Channel prompt",
244      )
245      result = await runner._run_agent(
246          message="hi",
247          context_prompt="Context prompt",
248          history=[],
249          source=_make_source(),
250          session_id="session-1",
251          session_key="agent:main:discord:thread:12345",
252          channel_prompt=event.channel_prompt,
253      )
254  
255      assert result["final_response"] == "ok"
256      assert _CapturingAgent.last_init["ephemeral_system_prompt"] == (
257          "Context prompt\n\nChannel prompt\n\nGlobal prompt"
258      )