test_discord_channel_prompts.py
1 """Tests for Discord channel_prompts resolution and injection.""" 2 3 import sys 4 import threading 5 import types 6 from types import SimpleNamespace 7 from unittest.mock import AsyncMock, MagicMock 8 9 import pytest 10 11 12 def _ensure_discord_mock(): 13 if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): 14 return 15 discord_mod = types.ModuleType("discord") 16 discord_mod.Intents = MagicMock() 17 discord_mod.Intents.default.return_value = MagicMock() 18 discord_mod.DMChannel = type("DMChannel", (), {}) 19 discord_mod.Thread = type("Thread", (), {}) 20 discord_mod.ForumChannel = type("ForumChannel", (), {}) 21 discord_mod.Interaction = object 22 ext_mod = MagicMock() 23 commands_mod = MagicMock() 24 commands_mod.Bot = MagicMock 25 ext_mod.commands = commands_mod 26 sys.modules.setdefault("discord", discord_mod) 27 sys.modules.setdefault("discord.ext", ext_mod) 28 sys.modules.setdefault("discord.ext.commands", commands_mod) 29 30 31 import gateway.run as gateway_run 32 from gateway.config import Platform 33 from gateway.platforms.base import MessageEvent 34 from gateway.session import SessionSource 35 36 37 class _CapturingAgent: 38 last_init = None 39 40 def __init__(self, *args, **kwargs): 41 type(self).last_init = dict(kwargs) 42 self.tools = [] 43 44 def run_conversation(self, user_message, conversation_history=None, task_id=None, persist_user_message=None): 45 return { 46 "final_response": "ok", 47 "messages": [], 48 "api_calls": 1, 49 "completed": True, 50 } 51 52 53 def _install_fake_agent(monkeypatch): 54 fake_run_agent = types.ModuleType("run_agent") 55 fake_run_agent.AIAgent = _CapturingAgent 56 monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) 57 58 59 def _make_adapter(): 60 _ensure_discord_mock() 61 from gateway.platforms.discord import DiscordAdapter 62 63 adapter = object.__new__(DiscordAdapter) 64 adapter.config = MagicMock() 65 adapter.config.extra = {} 66 return adapter 67 68 69 def _make_runner(): 70 runner = object.__new__(gateway_run.GatewayRunner) 71 runner.adapters = {} 72 runner._ephemeral_system_prompt = "Global prompt" 73 runner._prefill_messages = [] 74 runner._reasoning_config = None 75 runner._service_tier = None 76 runner._provider_routing = {} 77 runner._fallback_model = None 78 runner._running_agents = {} 79 runner._pending_model_notes = {} 80 runner._session_db = None 81 runner._agent_cache = {} 82 runner._agent_cache_lock = threading.Lock() 83 runner._session_model_overrides = {} 84 runner.hooks = SimpleNamespace(loaded_hooks=False) 85 runner.config = SimpleNamespace(streaming=None) 86 runner.session_store = SimpleNamespace( 87 get_or_create_session=lambda source: SimpleNamespace(session_id="session-1"), 88 load_transcript=lambda session_id: [], 89 ) 90 runner._get_or_create_gateway_honcho = lambda session_key: (None, None) 91 runner._enrich_message_with_vision = AsyncMock(return_value="ENRICHED") 92 return runner 93 94 95 def _make_source() -> SessionSource: 96 return SessionSource( 97 platform=Platform.DISCORD, 98 chat_id="12345", 99 chat_type="thread", 100 user_id="user-1", 101 ) 102 103 104 class TestResolveChannelPrompts: 105 def test_no_prompt_returns_none(self): 106 adapter = _make_adapter() 107 assert adapter._resolve_channel_prompt("123") is None 108 109 def test_match_by_channel_id(self): 110 adapter = _make_adapter() 111 adapter.config.extra = {"channel_prompts": {"100": "Research mode"}} 112 assert adapter._resolve_channel_prompt("100") == "Research mode" 113 114 def test_numeric_yaml_keys_normalized_at_config_load(self): 115 """Numeric YAML keys are normalized to strings by config bridging. 116 117 The resolver itself expects string keys (config.py handles normalization), 118 so raw numeric keys will not match — this is intentional. 119 """ 120 adapter = _make_adapter() 121 # Simulates post-bridging state: keys are already strings 122 adapter.config.extra = {"channel_prompts": {"100": "Research mode"}} 123 assert adapter._resolve_channel_prompt("100") == "Research mode" 124 # Pre-bridging numeric key would not match (bridging is responsible) 125 adapter.config.extra = {"channel_prompts": {100: "Research mode"}} 126 assert adapter._resolve_channel_prompt("100") is None 127 128 def test_match_by_parent_id(self): 129 adapter = _make_adapter() 130 adapter.config.extra = {"channel_prompts": {"200": "Forum prompt"}} 131 assert adapter._resolve_channel_prompt("999", parent_id="200") == "Forum prompt" 132 133 def test_exact_channel_overrides_parent(self): 134 adapter = _make_adapter() 135 adapter.config.extra = { 136 "channel_prompts": { 137 "999": "Thread override", 138 "200": "Forum prompt", 139 } 140 } 141 assert adapter._resolve_channel_prompt("999", parent_id="200") == "Thread override" 142 143 def test_build_message_event_sets_channel_prompt(self): 144 adapter = _make_adapter() 145 adapter.config.extra = {"channel_prompts": {"321": "Command prompt"}} 146 adapter.build_source = MagicMock(return_value=SimpleNamespace()) 147 148 interaction = SimpleNamespace( 149 channel_id=321, 150 channel=SimpleNamespace(name="general", guild=None, parent_id=None), 151 user=SimpleNamespace(id=1, display_name="Brenner"), 152 ) 153 adapter._get_effective_topic = MagicMock(return_value=None) 154 155 event = adapter._build_slash_event(interaction, "/retry") 156 157 assert event.channel_prompt == "Command prompt" 158 159 @pytest.mark.asyncio 160 async def test_dispatch_thread_session_inherits_parent_channel_prompt(self): 161 adapter = _make_adapter() 162 adapter.config.extra = {"channel_prompts": {"200": "Parent prompt"}} 163 adapter.build_source = MagicMock(return_value=SimpleNamespace()) 164 adapter._get_effective_topic = MagicMock(return_value=None) 165 adapter.handle_message = AsyncMock() 166 167 interaction = SimpleNamespace( 168 guild=SimpleNamespace(name="Wetlands"), 169 channel=SimpleNamespace(id=200, parent=None), 170 user=SimpleNamespace(id=1, display_name="Brenner"), 171 ) 172 173 await adapter._dispatch_thread_session(interaction, "999", "new-thread", "hello") 174 175 dispatched_event = adapter.handle_message.await_args.args[0] 176 assert dispatched_event.channel_prompt == "Parent prompt" 177 178 def test_blank_prompts_are_ignored(self): 179 adapter = _make_adapter() 180 adapter.config.extra = {"channel_prompts": {"100": " "}} 181 assert adapter._resolve_channel_prompt("100") is None 182 183 184 @pytest.mark.asyncio 185 async def test_retry_preserves_channel_prompt(monkeypatch): 186 runner = _make_runner() 187 runner.session_store = SimpleNamespace( 188 get_or_create_session=lambda source: SimpleNamespace(session_id="session-1", last_prompt_tokens=10), 189 load_transcript=lambda session_id: [ 190 {"role": "user", "content": "original message"}, 191 {"role": "assistant", "content": "old reply"}, 192 ], 193 rewrite_transcript=MagicMock(), 194 ) 195 runner._handle_message = AsyncMock(return_value="ok") 196 197 event = MessageEvent( 198 text="/retry", 199 message_type=gateway_run.MessageType.COMMAND, 200 source=_make_source(), 201 raw_message=SimpleNamespace(), 202 channel_prompt="Channel prompt", 203 ) 204 205 result = await runner._handle_retry_command(event) 206 207 assert result == "ok" 208 retried_event = runner._handle_message.await_args.args[0] 209 assert retried_event.channel_prompt == "Channel prompt" 210 211 212 @pytest.mark.asyncio 213 async def test_run_agent_appends_channel_prompt_to_ephemeral_system_prompt(monkeypatch, tmp_path): 214 _install_fake_agent(monkeypatch) 215 runner = _make_runner() 216 217 (tmp_path / "config.yaml").write_text("agent:\n system_prompt: Global prompt\n", encoding="utf-8") 218 monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) 219 monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env") 220 monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None) 221 monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {}) 222 monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4") 223 monkeypatch.setattr( 224 gateway_run, 225 "_resolve_runtime_agent_kwargs", 226 lambda: { 227 "provider": "openrouter", 228 "api_mode": "chat_completions", 229 "base_url": "https://openrouter.ai/api/v1", 230 "api_key": "***", 231 }, 232 ) 233 234 import hermes_cli.tools_config as tools_config 235 236 monkeypatch.setattr(tools_config, "_get_platform_tools", lambda user_config, platform_key: {"core"}) 237 238 _CapturingAgent.last_init = None 239 event = MessageEvent( 240 text="hi", 241 source=_make_source(), 242 message_id="m1", 243 channel_prompt="Channel prompt", 244 ) 245 result = await runner._run_agent( 246 message="hi", 247 context_prompt="Context prompt", 248 history=[], 249 source=_make_source(), 250 session_id="session-1", 251 session_key="agent:main:discord:thread:12345", 252 channel_prompt=event.channel_prompt, 253 ) 254 255 assert result["final_response"] == "ok" 256 assert _CapturingAgent.last_init["ephemeral_system_prompt"] == ( 257 "Context prompt\n\nChannel prompt\n\nGlobal prompt" 258 )