test_retry_replacement.py
1 """Regression tests for /retry replacement semantics.""" 2 3 from unittest.mock import AsyncMock, MagicMock, patch 4 5 import pytest 6 7 from gateway.config import GatewayConfig 8 from gateway.platforms.base import MessageEvent, MessageType 9 from gateway.run import GatewayRunner 10 from gateway.session import SessionStore 11 12 13 @pytest.mark.asyncio 14 async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path): 15 config = GatewayConfig() 16 with patch("gateway.session.SessionStore._ensure_loaded"): 17 store = SessionStore(sessions_dir=tmp_path, config=config) 18 store._db = None 19 store._loaded = True 20 21 session_id = "retry_session" 22 for msg in [ 23 {"role": "session_meta", "tools": []}, 24 {"role": "user", "content": "first question"}, 25 {"role": "assistant", "content": "first answer"}, 26 {"role": "user", "content": "retry me"}, 27 {"role": "assistant", "content": "old answer"}, 28 ]: 29 store.append_to_transcript(session_id, msg) 30 31 gw = GatewayRunner.__new__(GatewayRunner) 32 gw.config = config 33 gw.session_store = store 34 35 session_entry = MagicMock(session_id=session_id) 36 session_entry.last_prompt_tokens = 111 37 gw.session_store.get_or_create_session = MagicMock(return_value=session_entry) 38 39 async def fake_handle_message(event): 40 assert event.text == "retry me" 41 transcript_before = store.load_transcript(session_id) 42 assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [ 43 "first question" 44 ] 45 store.append_to_transcript(session_id, {"role": "user", "content": event.text}) 46 store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"}) 47 return "new answer" 48 49 gw._handle_message = AsyncMock(side_effect=fake_handle_message) 50 51 result = await gw._handle_retry_command( 52 MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock()) 53 ) 54 55 assert result == "new answer" 56 transcript_after = store.load_transcript(session_id) 57 assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [ 58 "first question", 59 "retry me", 60 ] 61 assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [ 62 "first answer", 63 "new answer", 64 ] 65 66 67 @pytest.mark.asyncio 68 async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path): 69 config = MagicMock() 70 config.sessions_dir = tmp_path 71 config.max_context_messages = 20 72 gw = GatewayRunner.__new__(GatewayRunner) 73 gw.config = config 74 gw.session_store = MagicMock() 75 76 session_entry = MagicMock(session_id="test-session") 77 session_entry.last_prompt_tokens = 55 78 gw.session_store.get_or_create_session.return_value = session_entry 79 gw.session_store.load_transcript.return_value = [ 80 {"role": "user", "content": "real message"}, 81 {"role": "assistant", "content": "answer"}, 82 ] 83 gw.session_store.rewrite_transcript = MagicMock() 84 85 captured = {} 86 87 async def fake_handle_message(event): 88 captured["text"] = event.text 89 return "ok" 90 91 gw._handle_message = AsyncMock(side_effect=fake_handle_message) 92 93 await gw._handle_retry_command( 94 MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock()) 95 ) 96 97 assert captured["text"] == "real message"