/ tests / gateway / test_slack_approval_buttons.py
test_slack_approval_buttons.py
  1  """Tests for Slack Block Kit approval buttons and thread context fetching."""
  2  
  3  import asyncio
  4  import os
  5  import sys
  6  from pathlib import Path
  7  from unittest.mock import AsyncMock, MagicMock, patch
  8  
  9  import pytest
 10  
 11  # ---------------------------------------------------------------------------
 12  # Ensure the repo root is importable
 13  # ---------------------------------------------------------------------------
 14  _repo = str(Path(__file__).resolve().parents[2])
 15  if _repo not in sys.path:
 16      sys.path.insert(0, _repo)
 17  
 18  
 19  # ---------------------------------------------------------------------------
 20  # Minimal Slack SDK mock so SlackAdapter can be imported
 21  # ---------------------------------------------------------------------------
 22  def _ensure_slack_mock():
 23      """Wire up the minimal mocks required to import SlackAdapter."""
 24      if "slack_bolt" in sys.modules:
 25          return
 26      slack_bolt = MagicMock()
 27      slack_bolt.async_app.AsyncApp = MagicMock
 28      sys.modules["slack_bolt"] = slack_bolt
 29      sys.modules["slack_bolt.async_app"] = slack_bolt.async_app
 30      handler_mod = MagicMock()
 31      handler_mod.AsyncSocketModeHandler = MagicMock
 32      sys.modules["slack_bolt.adapter"] = MagicMock()
 33      sys.modules["slack_bolt.adapter.socket_mode"] = MagicMock()
 34      sys.modules["slack_bolt.adapter.socket_mode.async_handler"] = handler_mod
 35      sdk_mod = MagicMock()
 36      sdk_mod.web = MagicMock()
 37      sdk_mod.web.async_client = MagicMock()
 38      sdk_mod.web.async_client.AsyncWebClient = MagicMock
 39      sys.modules["slack_sdk"] = sdk_mod
 40      sys.modules["slack_sdk.web"] = sdk_mod.web
 41      sys.modules["slack_sdk.web.async_client"] = sdk_mod.web.async_client
 42  
 43  
 44  _ensure_slack_mock()
 45  
 46  from gateway.platforms.slack import SlackAdapter
 47  from gateway.config import Platform, PlatformConfig
 48  
 49  
 50  def _make_adapter():
 51      """Create a SlackAdapter instance with mocked internals."""
 52      config = PlatformConfig(enabled=True, token="xoxb-test-token")
 53      adapter = SlackAdapter(config)
 54      adapter._app = MagicMock()
 55      adapter._bot_user_id = "U_BOT"
 56      adapter._team_clients = {"T1": AsyncMock()}
 57      adapter._team_bot_user_ids = {"T1": "U_BOT"}
 58      adapter._channel_team = {"C1": "T1"}
 59      return adapter
 60  
 61  
 62  # ===========================================================================
 63  # send_exec_approval — Block Kit buttons
 64  # ===========================================================================
 65  
 66  class TestSlackExecApproval:
 67      """Test the send_exec_approval method sends Block Kit buttons."""
 68  
 69      @pytest.mark.asyncio
 70      async def test_sends_blocks_with_buttons(self):
 71          adapter = _make_adapter()
 72          mock_client = adapter._team_clients["T1"]
 73          mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"})
 74  
 75          result = await adapter.send_exec_approval(
 76              chat_id="C1",
 77              command="rm -rf /important",
 78              session_key="agent:main:slack:group:C1:1111",
 79              description="dangerous deletion",
 80          )
 81  
 82          assert result.success is True
 83          assert result.message_id == "1234.5678"
 84  
 85          # Verify chat_postMessage was called with blocks
 86          mock_client.chat_postMessage.assert_called_once()
 87          kwargs = mock_client.chat_postMessage.call_args[1]
 88          assert "blocks" in kwargs
 89          blocks = kwargs["blocks"]
 90          assert len(blocks) == 2
 91          assert blocks[0]["type"] == "section"
 92          assert "rm -rf /important" in blocks[0]["text"]["text"]
 93          assert "dangerous deletion" in blocks[0]["text"]["text"]
 94          assert blocks[1]["type"] == "actions"
 95          elements = blocks[1]["elements"]
 96          assert len(elements) == 4
 97          action_ids = [e["action_id"] for e in elements]
 98          assert "hermes_approve_once" in action_ids
 99          assert "hermes_approve_session" in action_ids
100          assert "hermes_approve_always" in action_ids
101          assert "hermes_deny" in action_ids
102          # Each button carries the session key as value
103          for e in elements:
104              assert e["value"] == "agent:main:slack:group:C1:1111"
105  
106      @pytest.mark.asyncio
107      async def test_sends_in_thread(self):
108          adapter = _make_adapter()
109          mock_client = adapter._team_clients["T1"]
110          mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"})
111  
112          await adapter.send_exec_approval(
113              chat_id="C1",
114              command="echo test",
115              session_key="test-session",
116              metadata={"thread_id": "9999.0000"},
117          )
118  
119          kwargs = mock_client.chat_postMessage.call_args[1]
120          assert kwargs.get("thread_ts") == "9999.0000"
121  
122      @pytest.mark.asyncio
123      async def test_not_connected(self):
124          adapter = _make_adapter()
125          adapter._app = None
126          result = await adapter.send_exec_approval(
127              chat_id="C1", command="ls", session_key="s"
128          )
129          assert result.success is False
130  
131      @pytest.mark.asyncio
132      async def test_truncates_long_command(self):
133          adapter = _make_adapter()
134          mock_client = adapter._team_clients["T1"]
135          mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1.2"})
136  
137          long_cmd = "x" * 5000
138          await adapter.send_exec_approval(
139              chat_id="C1", command=long_cmd, session_key="s"
140          )
141  
142          kwargs = mock_client.chat_postMessage.call_args[1]
143          section_text = kwargs["blocks"][0]["text"]["text"]
144          assert "..." in section_text
145          assert len(section_text) < 5000
146  
147  
148  # ===========================================================================
149  # _handle_approval_action — button click handler
150  # ===========================================================================
151  
152  class TestSlackApprovalAction:
153      """Test the approval button click handler."""
154  
155      @pytest.mark.asyncio
156      async def test_resolves_approval(self):
157          adapter = _make_adapter()
158          adapter._approval_resolved["1234.5678"] = False
159  
160          ack = AsyncMock()
161          body = {
162              "message": {
163                  "ts": "1234.5678",
164                  "blocks": [
165                      {"type": "section", "text": {"type": "mrkdwn", "text": "original text"}},
166                      {"type": "actions", "elements": []},
167                  ],
168              },
169              "channel": {"id": "C1"},
170              "user": {"name": "norbert"},
171          }
172          action = {
173              "action_id": "hermes_approve_once",
174              "value": "agent:main:slack:group:C1:1111",
175          }
176  
177          mock_client = adapter._team_clients["T1"]
178          mock_client.chat_update = AsyncMock()
179  
180          with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
181              await adapter._handle_approval_action(ack, body, action)
182  
183          ack.assert_called_once()
184          mock_resolve.assert_called_once_with("agent:main:slack:group:C1:1111", "once")
185  
186          # Message should be updated with decision
187          mock_client.chat_update.assert_called_once()
188          update_kwargs = mock_client.chat_update.call_args[1]
189          assert "Approved once by norbert" in update_kwargs["text"]
190  
191      @pytest.mark.asyncio
192      async def test_prevents_double_click(self):
193          adapter = _make_adapter()
194          adapter._approval_resolved["1234.5678"] = True  # Already resolved
195  
196          ack = AsyncMock()
197          body = {
198              "message": {"ts": "1234.5678", "blocks": []},
199              "channel": {"id": "C1"},
200              "user": {"name": "norbert"},
201          }
202          action = {
203              "action_id": "hermes_approve_once",
204              "value": "some-session",
205          }
206  
207          with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
208              await adapter._handle_approval_action(ack, body, action)
209  
210          # Should have acked but NOT resolved
211          ack.assert_called_once()
212          mock_resolve.assert_not_called()
213  
214      @pytest.mark.asyncio
215      async def test_deny_action(self):
216          adapter = _make_adapter()
217          adapter._approval_resolved["1.2"] = False
218  
219          ack = AsyncMock()
220          body = {
221              "message": {"ts": "1.2", "blocks": [
222                  {"type": "section", "text": {"type": "mrkdwn", "text": "cmd"}},
223              ]},
224              "channel": {"id": "C1"},
225              "user": {"name": "alice"},
226          }
227          action = {"action_id": "hermes_deny", "value": "session-key"}
228  
229          mock_client = adapter._team_clients["T1"]
230          mock_client.chat_update = AsyncMock()
231  
232          with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
233              await adapter._handle_approval_action(ack, body, action)
234  
235          mock_resolve.assert_called_once_with("session-key", "deny")
236          update_kwargs = mock_client.chat_update.call_args[1]
237          assert "Denied by alice" in update_kwargs["text"]
238  
239  
240  # ===========================================================================
241  # _fetch_thread_context
242  # ===========================================================================
243  
244  class TestSlackThreadContext:
245      """Test thread context fetching."""
246  
247      @pytest.mark.asyncio
248      async def test_fetches_and_formats_context(self):
249          adapter = _make_adapter()
250          mock_client = adapter._team_clients["T1"]
251          mock_client.conversations_replies = AsyncMock(return_value={
252              "messages": [
253                  {"ts": "1000.0", "user": "U1", "text": "This is the parent message"},
254                  {"ts": "1000.1", "user": "U2", "text": "I think we should refactor"},
255                  {"ts": "1000.2", "user": "U1", "text": "Good idea, <@U_BOT> what do you think?"},
256              ]
257          })
258  
259          # Mock user name resolution
260          adapter._user_name_cache = {"U1": "Alice", "U2": "Bob"}
261  
262          context = await adapter._fetch_thread_context(
263              channel_id="C1",
264              thread_ts="1000.0",
265              current_ts="1000.2",  # The message that triggered the fetch
266              team_id="T1",
267          )
268  
269          assert "[Thread context" in context
270          assert "[thread parent] Alice: This is the parent message" in context
271          assert "Bob: I think we should refactor" in context
272          # Current message should be excluded
273          assert "what do you think" not in context
274          # Bot mention should be stripped from context
275          assert "<@U_BOT>" not in context
276  
277      @pytest.mark.asyncio
278      async def test_skips_bot_messages(self):
279          """Self-bot child replies are skipped to avoid circular context,
280          but non-self bots (e.g. cron posts, third-party integrations) are kept.
281  
282          Regression guard for the fix in _fetch_thread_context: previously ALL
283          bot messages were dropped, which lost context when the bot was replying
284          to a cron-posted thread parent."""
285          adapter = _make_adapter()
286          mock_client = adapter._team_clients["T1"]
287          mock_client.conversations_replies = AsyncMock(return_value={
288              "messages": [
289                  {"ts": "1000.0", "user": "U1", "text": "Parent"},
290                  # Self-bot reply -> must be skipped (circular)
291                  {
292                      "ts": "1000.1",
293                      "bot_id": "B_SELF",
294                      "user": "U_BOT",
295                      "text": "Previous bot self-reply (should be skipped)",
296                  },
297                  # Third-party bot child -> kept (useful context)
298                  {
299                      "ts": "1000.15",
300                      "bot_id": "B_OTHER",
301                      "user": "U_OTHER_BOT",
302                      "text": "Deploy succeeded",
303                  },
304                  {"ts": "1000.2", "user": "U1", "text": "Current"},
305              ]
306          })
307          adapter._user_name_cache = {"U1": "Alice", "U_OTHER_BOT": "DeployBot"}
308  
309          context = await adapter._fetch_thread_context(
310              channel_id="C1", thread_ts="1000.0", current_ts="1000.2", team_id="T1"
311          )
312  
313          assert "Previous bot self-reply" not in context
314          assert "Alice: Parent" in context
315          # Third-party bot message must now be included
316          assert "Deploy succeeded" in context
317  
318      @pytest.mark.asyncio
319      async def test_empty_thread(self):
320          adapter = _make_adapter()
321          mock_client = adapter._team_clients["T1"]
322          mock_client.conversations_replies = AsyncMock(return_value={"messages": []})
323  
324          context = await adapter._fetch_thread_context(
325              channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
326          )
327          assert context == ""
328  
329      @pytest.mark.asyncio
330      async def test_api_failure_returns_empty(self):
331          adapter = _make_adapter()
332          mock_client = adapter._team_clients["T1"]
333          mock_client.conversations_replies = AsyncMock(side_effect=Exception("API error"))
334  
335          context = await adapter._fetch_thread_context(
336              channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
337          )
338          assert context == ""
339  
340      @pytest.mark.asyncio
341      async def test_fetch_thread_context_includes_bot_parent(self):
342          """The thread parent posted by a bot (e.g. a cron summary) must be
343          included in the context, prefixed with ``[thread parent]``."""
344          adapter = _make_adapter()
345          mock_client = adapter._team_clients["T1"]
346          mock_client.conversations_replies = AsyncMock(return_value={
347              "messages": [
348                  # Bot-posted parent (cron job)
349                  {
350                      "ts": "1000.0",
351                      "bot_id": "B123",
352                      "subtype": "bot_message",
353                      "username": "cron",
354                      "text": "メール要約: 本日の新着3件",
355                  },
356                  # User reply that triggered the fetch
357                  {"ts": "1000.1", "user": "U1", "text": "詳細を教えて"},
358              ]
359          })
360          adapter._user_name_cache = {"U1": "Alice"}
361  
362          context = await adapter._fetch_thread_context(
363              channel_id="C1",
364              thread_ts="1000.0",
365              current_ts="1000.1",  # exclude the trigger message itself
366              team_id="T1",
367          )
368  
369          assert "[thread parent]" in context
370          assert "メール要約: 本日の新着3件" in context
371  
372      @pytest.mark.asyncio
373      async def test_fetch_thread_context_excludes_self_bot_replies(self):
374          """Parent (non-self bot) is kept, self-bot child replies are dropped,
375          user replies are kept."""
376          adapter = _make_adapter()
377          mock_client = adapter._team_clients["T1"]
378          mock_client.conversations_replies = AsyncMock(return_value={
379              "messages": [
380                  {"ts": "1000.0", "bot_id": "B_CRON", "text": "Cron summary"},
381                  # Self-bot child reply -> excluded
382                  {
383                      "ts": "1000.1",
384                      "bot_id": "B_SELF",
385                      "user": "U_BOT",  # matches adapter._bot_user_id
386                      "text": "Previous self reply",
387                  },
388                  # User reply -> kept
389                  {"ts": "1000.2", "user": "U1", "text": "Follow-up question"},
390                  # Current trigger (excluded by current_ts match)
391                  {"ts": "1000.3", "user": "U1", "text": "Current"},
392              ]
393          })
394          adapter._user_name_cache = {"U1": "Alice"}
395  
396          context = await adapter._fetch_thread_context(
397              channel_id="C1", thread_ts="1000.0", current_ts="1000.3", team_id="T1"
398          )
399  
400          assert "Cron summary" in context
401          assert "[thread parent]" in context
402          assert "Previous self reply" not in context
403          assert "Follow-up question" in context
404          assert "Current" not in context
405  
406      @pytest.mark.asyncio
407      async def test_fetch_thread_context_multi_workspace(self):
408          """Self-bot filtering must use the per-workspace bot user id so a
409          self-bot id that belongs to a different workspace does not accidentally
410          filter out a legitimate message in the current workspace."""
411          adapter = _make_adapter()
412          # Add a second workspace with a different bot user id
413          adapter._team_clients["T2"] = AsyncMock()
414          adapter._team_bot_user_ids = {"T1": "U_BOT_T1", "T2": "U_BOT_T2"}
415          adapter._bot_user_id = "U_BOT_T1"
416          adapter._channel_team["C2"] = "T2"
417  
418          mock_client = adapter._team_clients["T2"]
419          mock_client.conversations_replies = AsyncMock(return_value={
420              "messages": [
421                  {"ts": "2000.0", "user": "U2", "text": "Parent T2"},
422                  # This has the *T1* bot's user id — from T2's perspective this
423                  # is a third-party bot, so it must be kept.
424                  {
425                      "ts": "2000.1",
426                      "bot_id": "B_FOREIGN",
427                      "user": "U_BOT_T1",
428                      "team": "T2",
429                      "text": "Cross-workspace bot reply",
430                  },
431                  # Self-bot for T2 — must be skipped
432                  {
433                      "ts": "2000.2",
434                      "bot_id": "B_SELF_T2",
435                      "user": "U_BOT_T2",
436                      "team": "T2",
437                      "text": "Own T2 bot reply",
438                  },
439                  {"ts": "2000.3", "user": "U2", "text": "Current"},
440              ]
441          })
442          adapter._user_name_cache = {"U2": "Bob"}
443  
444          context = await adapter._fetch_thread_context(
445              channel_id="C2", thread_ts="2000.0", current_ts="2000.3", team_id="T2"
446          )
447  
448          assert "Parent T2" in context
449          assert "Cross-workspace bot reply" in context
450          assert "Own T2 bot reply" not in context
451  
452      @pytest.mark.asyncio
453      async def test_fetch_thread_context_current_ts_excluded(self):
454          """Regression guard: the message whose ts == current_ts must never
455          appear in the context output (it will be delivered as the user
456          message itself)."""
457          adapter = _make_adapter()
458          mock_client = adapter._team_clients["T1"]
459          mock_client.conversations_replies = AsyncMock(return_value={
460              "messages": [
461                  {"ts": "1000.0", "user": "U1", "text": "Parent"},
462                  {"ts": "1000.1", "user": "U1", "text": "DO NOT INCLUDE THIS"},
463              ]
464          })
465          adapter._user_name_cache = {"U1": "Alice"}
466  
467          context = await adapter._fetch_thread_context(
468              channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
469          )
470  
471          assert "Parent" in context
472          assert "DO NOT INCLUDE THIS" not in context
473  
474      @pytest.mark.asyncio
475      async def test_fetch_thread_parent_text_from_cache(self):
476          """_fetch_thread_parent_text should reuse the thread-context cache
477          when it is warm, avoiding an extra conversations.replies call."""
478          adapter = _make_adapter()
479          mock_client = adapter._team_clients["T1"]
480          mock_client.conversations_replies = AsyncMock(return_value={
481              "messages": [
482                  {"ts": "1000.0", "bot_id": "B123", "text": "Parent summary"},
483                  {"ts": "1000.1", "user": "U1", "text": "reply"},
484              ]
485          })
486  
487          # Warm the cache via _fetch_thread_context
488          await adapter._fetch_thread_context(
489              channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
490          )
491          assert mock_client.conversations_replies.await_count == 1
492  
493          parent = await adapter._fetch_thread_parent_text(
494              channel_id="C1", thread_ts="1000.0", team_id="T1"
495          )
496          assert parent == "Parent summary"
497          # No additional API call
498          assert mock_client.conversations_replies.await_count == 1
499  
500  
501  # ===========================================================================
502  # _has_active_session_for_thread — session key fix (#5833)
503  # ===========================================================================
504  
505  class TestSessionKeyFix:
506      """Test that _has_active_session_for_thread uses build_session_key."""
507  
508      def test_uses_build_session_key(self):
509          """Verify the fix uses build_session_key instead of manual key construction."""
510          adapter = _make_adapter()
511  
512          # Mock session store with a known entry
513          mock_store = MagicMock()
514          mock_store._entries = {
515              "agent:main:slack:group:C1:1000.0": MagicMock()
516          }
517          mock_store._ensure_loaded = MagicMock()
518          mock_store.config = MagicMock()
519          mock_store.config.group_sessions_per_user = False  # threads don't include user_id
520          mock_store.config.thread_sessions_per_user = False
521          adapter._session_store = mock_store
522  
523          # With the fix, build_session_key should be called which respects
524          # group_sessions_per_user=False (no user_id appended)
525          result = adapter._has_active_session_for_thread(
526              channel_id="C1", thread_ts="1000.0", user_id="U123"
527          )
528  
529          # Should find the session because build_session_key with
530          # group_sessions_per_user=False doesn't append user_id
531          assert result is True
532  
533      def test_no_session_returns_false(self):
534          adapter = _make_adapter()
535          mock_store = MagicMock()
536          mock_store._entries = {}
537          mock_store._ensure_loaded = MagicMock()
538          mock_store.config = MagicMock()
539          mock_store.config.group_sessions_per_user = True
540          mock_store.config.thread_sessions_per_user = False
541          adapter._session_store = mock_store
542  
543          result = adapter._has_active_session_for_thread(
544              channel_id="C1", thread_ts="1000.0", user_id="U123"
545          )
546          assert result is False
547  
548      def test_no_session_store(self):
549          adapter = _make_adapter()
550          # No _session_store attribute
551          result = adapter._has_active_session_for_thread(
552              channel_id="C1", thread_ts="1000.0", user_id="U123"
553          )
554          assert result is False
555  
556  
557  # ===========================================================================
558  # Thread engagement — bot-started threads & mentioned threads
559  # ===========================================================================
560  
561  class TestThreadEngagement:
562      """Test _bot_message_ts and _mentioned_threads tracking."""
563  
564      @pytest.mark.asyncio
565      async def test_send_tracks_bot_message_ts(self):
566          """Bot's sent messages are tracked so thread replies work without @mention."""
567          adapter = _make_adapter()
568          mock_client = adapter._team_clients["T1"]
569          mock_client.chat_postMessage = AsyncMock(return_value={"ts": "9000.1"})
570  
571          await adapter.send(chat_id="C1", content="Hello!", metadata={"thread_id": "8000.0"})
572  
573          assert "9000.1" in adapter._bot_message_ts
574          # Thread root should also be tracked
575          assert "8000.0" in adapter._bot_message_ts
576  
577      @pytest.mark.asyncio
578      async def test_bot_message_ts_cap(self):
579          """Verify memory is bounded when many messages are sent."""
580          adapter = _make_adapter()
581          adapter._BOT_TS_MAX = 10  # low cap for testing
582          mock_client = adapter._team_clients["T1"]
583  
584          for i in range(20):
585              mock_client.chat_postMessage = AsyncMock(return_value={"ts": f"{i}.0"})
586              await adapter.send(chat_id="C1", content=f"msg {i}")
587  
588          assert len(adapter._bot_message_ts) <= 10
589  
590      def test_mentioned_threads_populated_on_mention(self):
591          """When bot is @mentioned in a thread, that thread is tracked."""
592          adapter = _make_adapter()
593          # Simulate what _handle_slack_message does on mention
594          adapter._mentioned_threads.add("1000.0")
595          assert "1000.0" in adapter._mentioned_threads
596  
597      def test_mentioned_threads_cap(self):
598          """Verify _mentioned_threads is bounded."""
599          adapter = _make_adapter()
600          adapter._MENTIONED_THREADS_MAX = 10
601          for i in range(15):
602              adapter._mentioned_threads.add(f"{i}.0")
603              if len(adapter._mentioned_threads) > adapter._MENTIONED_THREADS_MAX:
604                  to_remove = list(adapter._mentioned_threads)[:adapter._MENTIONED_THREADS_MAX // 2]
605                  for t in to_remove:
606                      adapter._mentioned_threads.discard(t)
607          assert len(adapter._mentioned_threads) <= 10