/ tests / gateway / test_retry_replacement.py
test_retry_replacement.py
 1  """Regression tests for /retry replacement semantics."""
 2  
 3  from unittest.mock import AsyncMock, MagicMock, patch
 4  
 5  import pytest
 6  
 7  from gateway.config import GatewayConfig
 8  from gateway.platforms.base import MessageEvent, MessageType
 9  from gateway.run import GatewayRunner
10  from gateway.session import SessionStore
11  
12  
13  @pytest.mark.asyncio
14  async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path):
15      config = GatewayConfig()
16      with patch("gateway.session.SessionStore._ensure_loaded"):
17          store = SessionStore(sessions_dir=tmp_path, config=config)
18      store._db = None
19      store._loaded = True
20  
21      session_id = "retry_session"
22      for msg in [
23          {"role": "session_meta", "tools": []},
24          {"role": "user", "content": "first question"},
25          {"role": "assistant", "content": "first answer"},
26          {"role": "user", "content": "retry me"},
27          {"role": "assistant", "content": "old answer"},
28      ]:
29          store.append_to_transcript(session_id, msg)
30  
31      gw = GatewayRunner.__new__(GatewayRunner)
32      gw.config = config
33      gw.session_store = store
34  
35      session_entry = MagicMock(session_id=session_id)
36      session_entry.last_prompt_tokens = 111
37      gw.session_store.get_or_create_session = MagicMock(return_value=session_entry)
38  
39      async def fake_handle_message(event):
40          assert event.text == "retry me"
41          transcript_before = store.load_transcript(session_id)
42          assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [
43              "first question"
44          ]
45          store.append_to_transcript(session_id, {"role": "user", "content": event.text})
46          store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"})
47          return "new answer"
48  
49      gw._handle_message = AsyncMock(side_effect=fake_handle_message)
50  
51      result = await gw._handle_retry_command(
52          MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
53      )
54  
55      assert result == "new answer"
56      transcript_after = store.load_transcript(session_id)
57      assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [
58          "first question",
59          "retry me",
60      ]
61      assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [
62          "first answer",
63          "new answer",
64      ]
65  
66  
67  @pytest.mark.asyncio
68  async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path):
69      config = MagicMock()
70      config.sessions_dir = tmp_path
71      config.max_context_messages = 20
72      gw = GatewayRunner.__new__(GatewayRunner)
73      gw.config = config
74      gw.session_store = MagicMock()
75  
76      session_entry = MagicMock(session_id="test-session")
77      session_entry.last_prompt_tokens = 55
78      gw.session_store.get_or_create_session.return_value = session_entry
79      gw.session_store.load_transcript.return_value = [
80          {"role": "user", "content": "real message"},
81          {"role": "assistant", "content": "answer"},
82      ]
83      gw.session_store.rewrite_transcript = MagicMock()
84  
85      captured = {}
86  
87      async def fake_handle_message(event):
88          captured["text"] = event.text
89          return "ok"
90  
91      gw._handle_message = AsyncMock(side_effect=fake_handle_message)
92  
93      await gw._handle_retry_command(
94          MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
95      )
96  
97      assert captured["text"] == "real message"