test_update_streaming.py
1 """Tests for /update live streaming, prompt forwarding, and gateway IPC. 2 3 Tests the new --gateway mode for hermes update, including: 4 - _gateway_prompt() file-based IPC 5 - _watch_update_progress() output streaming and prompt detection 6 - Message interception for update prompt responses 7 - _restore_stashed_changes() with input_fn parameter 8 """ 9 10 import json 11 import os 12 import time 13 import asyncio 14 from pathlib import Path 15 from unittest.mock import patch, MagicMock, AsyncMock 16 17 import pytest 18 19 from gateway.config import Platform 20 from gateway.platforms.base import MessageEvent 21 from gateway.session import SessionSource 22 23 24 def _make_event(text="/update", platform=Platform.TELEGRAM, 25 user_id="12345", chat_id="67890"): 26 """Build a MessageEvent for testing.""" 27 source = SessionSource( 28 platform=platform, 29 user_id=user_id, 30 chat_id=chat_id, 31 user_name="testuser", 32 ) 33 return MessageEvent(text=text, source=source) 34 35 36 def _make_runner(hermes_home=None): 37 """Create a bare GatewayRunner without calling __init__.""" 38 from gateway.run import GatewayRunner 39 runner = object.__new__(GatewayRunner) 40 runner.adapters = {} 41 runner._voice_mode = {} 42 runner._update_prompt_pending = {} 43 runner._running_agents = {} 44 runner._running_agents_ts = {} 45 runner._pending_messages = {} 46 runner._pending_approvals = {} 47 runner._failed_platforms = {} 48 return runner 49 50 51 # --------------------------------------------------------------------------- 52 # _gateway_prompt (file-based IPC in main.py) 53 # --------------------------------------------------------------------------- 54 55 56 class TestGatewayPrompt: 57 """Tests for _gateway_prompt() function.""" 58 59 def test_writes_prompt_file_and_reads_response(self, tmp_path): 60 """Writes .update_prompt.json, reads .update_response, returns answer.""" 61 import threading 62 hermes_home = tmp_path / ".hermes" 63 hermes_home.mkdir() 64 65 # Simulate the response arriving after a short delay 66 def write_response(): 67 time.sleep(0.3) 68 (hermes_home / ".update_response").write_text("y") 69 70 thread = threading.Thread(target=write_response) 71 thread.start() 72 73 with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): 74 from hermes_cli.main import _gateway_prompt 75 result = _gateway_prompt("Restore? [Y/n]", "y", timeout=5.0) 76 77 thread.join() 78 assert result == "y" 79 # Both files should be cleaned up 80 assert not (hermes_home / ".update_prompt.json").exists() 81 assert not (hermes_home / ".update_response").exists() 82 83 def test_prompt_file_content(self, tmp_path): 84 """Verifies the prompt JSON structure.""" 85 import threading 86 hermes_home = tmp_path / ".hermes" 87 hermes_home.mkdir() 88 89 prompt_data = None 90 91 def capture_and_respond(): 92 nonlocal prompt_data 93 prompt_path = hermes_home / ".update_prompt.json" 94 for _ in range(20): 95 if prompt_path.exists(): 96 prompt_data = json.loads(prompt_path.read_text()) 97 (hermes_home / ".update_response").write_text("n") 98 return 99 time.sleep(0.1) 100 101 thread = threading.Thread(target=capture_and_respond) 102 thread.start() 103 104 with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): 105 from hermes_cli.main import _gateway_prompt 106 _gateway_prompt("Configure now? [Y/n]", "n", timeout=5.0) 107 108 thread.join() 109 assert prompt_data is not None 110 assert prompt_data["prompt"] == "Configure now? [Y/n]" 111 assert prompt_data["default"] == "n" 112 assert "id" in prompt_data 113 114 def test_timeout_returns_default(self, tmp_path): 115 """Returns default when no response within timeout.""" 116 hermes_home = tmp_path / ".hermes" 117 hermes_home.mkdir() 118 119 with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): 120 from hermes_cli.main import _gateway_prompt 121 result = _gateway_prompt("test?", "default_val", timeout=0.5) 122 123 assert result == "default_val" 124 125 def test_empty_response_returns_default(self, tmp_path): 126 """Empty response file returns default.""" 127 hermes_home = tmp_path / ".hermes" 128 hermes_home.mkdir() 129 (hermes_home / ".update_response").write_text("") 130 131 # Write prompt file so the function starts polling 132 with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): 133 from hermes_cli.main import _gateway_prompt 134 # Pre-create the response 135 result = _gateway_prompt("test?", "default_val", timeout=2.0) 136 137 assert result == "default_val" 138 139 140 # --------------------------------------------------------------------------- 141 # _restore_stashed_changes with input_fn 142 # --------------------------------------------------------------------------- 143 144 145 class TestRestoreStashWithInputFn: 146 """Tests for _restore_stashed_changes with the input_fn parameter.""" 147 148 def test_uses_input_fn_when_provided(self, tmp_path): 149 """When input_fn is provided, it's called instead of input().""" 150 from hermes_cli.main import _restore_stashed_changes 151 152 captured_args = [] 153 154 def fake_input_fn(prompt, default=""): 155 captured_args.append((prompt, default)) 156 return "n" 157 158 with patch("subprocess.run") as mock_run: 159 mock_run.return_value = MagicMock( 160 returncode=0, stdout="", stderr="" 161 ) 162 result = _restore_stashed_changes( 163 ["git"], tmp_path, "abc123", 164 prompt_user=True, 165 input_fn=fake_input_fn, 166 ) 167 168 assert len(captured_args) == 1 169 assert "Restore" in captured_args[0][0] 170 assert result is False # user declined 171 172 def test_input_fn_yes_proceeds_with_restore(self, tmp_path): 173 """When input_fn returns 'y', stash apply is attempted.""" 174 from hermes_cli.main import _restore_stashed_changes 175 176 call_count = [0] 177 178 def fake_run(*args, **kwargs): 179 call_count[0] += 1 180 mock = MagicMock() 181 mock.returncode = 0 182 mock.stdout = "" 183 mock.stderr = "" 184 return mock 185 186 with patch("subprocess.run", side_effect=fake_run): 187 _restore_stashed_changes( 188 ["git"], tmp_path, "abc123", 189 prompt_user=True, 190 input_fn=lambda p, d="": "y", 191 ) 192 193 # Should have called git stash apply + git diff --name-only 194 assert call_count[0] >= 2 195 196 197 # --------------------------------------------------------------------------- 198 # Update command spawns --gateway flag 199 # --------------------------------------------------------------------------- 200 201 202 class TestUpdateCommandGatewayFlag: 203 """Verify the gateway spawns hermes update --gateway.""" 204 205 @pytest.mark.asyncio 206 async def test_spawns_with_gateway_flag(self, tmp_path): 207 """The spawned update command includes --gateway and PYTHONUNBUFFERED.""" 208 runner = _make_runner() 209 event = _make_event() 210 211 fake_root = tmp_path / "project" 212 fake_root.mkdir() 213 (fake_root / ".git").mkdir() 214 (fake_root / "gateway").mkdir() 215 (fake_root / "gateway" / "run.py").touch() 216 fake_file = str(fake_root / "gateway" / "run.py") 217 hermes_home = tmp_path / "hermes" 218 hermes_home.mkdir() 219 220 mock_popen = MagicMock() 221 with patch("gateway.run._hermes_home", hermes_home), \ 222 patch("gateway.run.__file__", fake_file), \ 223 patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"), \ 224 patch("subprocess.Popen", mock_popen): 225 result = await runner._handle_update_command(event) 226 227 # Check the bash command string contains --gateway and PYTHONUNBUFFERED 228 call_args = mock_popen.call_args[0][0] 229 cmd_string = call_args[-1] if isinstance(call_args, list) else str(call_args) 230 assert "--gateway" in cmd_string 231 assert "PYTHONUNBUFFERED" in cmd_string 232 assert "stream progress" in result 233 234 235 # --------------------------------------------------------------------------- 236 # _watch_update_progress ā output streaming 237 # --------------------------------------------------------------------------- 238 239 240 class TestWatchUpdateProgress: 241 """Tests for _watch_update_progress() streaming output.""" 242 243 @pytest.mark.asyncio 244 async def test_streams_output_to_adapter(self, tmp_path): 245 """New output is sent to the adapter periodically.""" 246 runner = _make_runner() 247 hermes_home = tmp_path / "hermes" 248 hermes_home.mkdir() 249 250 pending = {"platform": "telegram", "chat_id": "111", "user_id": "222", 251 "session_key": "agent:main:telegram:dm:111"} 252 (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) 253 # Write output 254 (hermes_home / ".update_output.txt").write_text("ā Fetching updates...\n", encoding="utf-8") 255 256 mock_adapter = AsyncMock() 257 runner.adapters = {Platform.TELEGRAM: mock_adapter} 258 259 # Write exit code after a brief delay 260 async def write_exit_code(): 261 await asyncio.sleep(0.3) 262 (hermes_home / ".update_output.txt").write_text( 263 "ā Fetching updates...\nā Code updated!\n" 264 , encoding="utf-8") 265 (hermes_home / ".update_exit_code").write_text("0") 266 267 with patch("gateway.run._hermes_home", hermes_home): 268 task = asyncio.create_task(write_exit_code()) 269 await runner._watch_update_progress( 270 poll_interval=0.1, 271 stream_interval=0.2, 272 timeout=5.0, 273 ) 274 await task 275 276 # Should have sent at least the output and a success message 277 assert mock_adapter.send.call_count >= 1 278 all_sent = " ".join(str(c) for c in mock_adapter.send.call_args_list) 279 assert "update finished" in all_sent.lower() 280 281 @pytest.mark.asyncio 282 async def test_detects_and_forwards_prompt(self, tmp_path): 283 """Detects .update_prompt.json and sends it to the user.""" 284 runner = _make_runner() 285 hermes_home = tmp_path / "hermes" 286 hermes_home.mkdir() 287 288 pending = {"platform": "telegram", "chat_id": "111", "user_id": "222", 289 "session_key": "agent:main:telegram:dm:111"} 290 (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) 291 (hermes_home / ".update_output.txt").write_text("output\n") 292 293 mock_adapter = AsyncMock() 294 runner.adapters = {Platform.TELEGRAM: mock_adapter} 295 296 # Write a prompt, then respond and finish 297 async def simulate_prompt_cycle(): 298 await asyncio.sleep(0.3) 299 prompt = {"prompt": "Restore local changes? [Y/n]", "default": "y", "id": "test1"} 300 (hermes_home / ".update_prompt.json").write_text(json.dumps(prompt)) 301 # Simulate user responding 302 await asyncio.sleep(0.5) 303 (hermes_home / ".update_response").write_text("y") 304 (hermes_home / ".update_prompt.json").unlink(missing_ok=True) 305 await asyncio.sleep(0.3) 306 (hermes_home / ".update_exit_code").write_text("0") 307 308 with patch("gateway.run._hermes_home", hermes_home): 309 task = asyncio.create_task(simulate_prompt_cycle()) 310 await runner._watch_update_progress( 311 poll_interval=0.1, 312 stream_interval=0.2, 313 timeout=10.0, 314 ) 315 await task 316 317 # Check that the prompt was forwarded 318 all_sent = [str(c) for c in mock_adapter.send.call_args_list] 319 prompt_found = any("Restore local changes" in s for s in all_sent) 320 assert prompt_found, f"Prompt not forwarded. Sent: {all_sent}" 321 # Check session was marked as having pending prompt 322 # (may be cleared by the time we check since update finished) 323 324 @pytest.mark.asyncio 325 async def test_prompt_forwarding_preserves_thread_metadata(self, tmp_path): 326 """Forwarded update prompts keep the originating thread/topic metadata.""" 327 runner = _make_runner() 328 hermes_home = tmp_path / "hermes" 329 hermes_home.mkdir() 330 331 pending = { 332 "platform": "telegram", 333 "chat_id": "111", 334 "thread_id": "777", 335 "user_id": "222", 336 "session_key": "agent:main:telegram:group:111:777", 337 } 338 (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) 339 (hermes_home / ".update_output.txt").write_text("") 340 (hermes_home / ".update_prompt.json").write_text(json.dumps({ 341 "prompt": "Restore local changes? [Y/n]", 342 "default": "y", 343 "id": "threaded-prompt", 344 })) 345 346 class _PromptCapableAdapter: 347 def __init__(self): 348 self.send = AsyncMock() 349 self.prompt_calls = AsyncMock() 350 351 async def send_update_prompt(self, **kwargs): 352 return await self.prompt_calls(**kwargs) 353 354 mock_adapter = _PromptCapableAdapter() 355 runner.adapters = {Platform.TELEGRAM: mock_adapter} 356 357 async def finish_after_prompt(): 358 await asyncio.sleep(0.3) 359 (hermes_home / ".update_response").write_text("y") 360 await asyncio.sleep(0.2) 361 (hermes_home / ".update_exit_code").write_text("0") 362 363 with patch("gateway.run._hermes_home", hermes_home): 364 task = asyncio.create_task(finish_after_prompt()) 365 await runner._watch_update_progress( 366 poll_interval=0.1, 367 stream_interval=0.2, 368 timeout=5.0, 369 ) 370 await task 371 372 assert mock_adapter.prompt_calls.call_args.kwargs["metadata"] == { 373 "thread_id": "777" 374 } 375 376 @pytest.mark.asyncio 377 async def test_cleans_up_on_completion(self, tmp_path): 378 """All marker files are cleaned up when update finishes.""" 379 runner = _make_runner() 380 hermes_home = tmp_path / "hermes" 381 hermes_home.mkdir() 382 383 pending = {"platform": "telegram", "chat_id": "111", "user_id": "222", 384 "session_key": "agent:main:telegram:dm:111"} 385 pending_path = hermes_home / ".update_pending.json" 386 output_path = hermes_home / ".update_output.txt" 387 exit_code_path = hermes_home / ".update_exit_code" 388 pending_path.write_text(json.dumps(pending)) 389 output_path.write_text("done\n") 390 exit_code_path.write_text("0") 391 392 mock_adapter = AsyncMock() 393 runner.adapters = {Platform.TELEGRAM: mock_adapter} 394 395 with patch("gateway.run._hermes_home", hermes_home): 396 await runner._watch_update_progress( 397 poll_interval=0.1, 398 stream_interval=0.2, 399 timeout=5.0, 400 ) 401 402 assert not pending_path.exists() 403 assert not output_path.exists() 404 assert not exit_code_path.exists() 405 406 @pytest.mark.asyncio 407 async def test_failure_exit_code(self, tmp_path): 408 """Non-zero exit code sends failure message.""" 409 runner = _make_runner() 410 hermes_home = tmp_path / "hermes" 411 hermes_home.mkdir() 412 413 pending = {"platform": "telegram", "chat_id": "111", "user_id": "222", 414 "session_key": "agent:main:telegram:dm:111"} 415 (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) 416 (hermes_home / ".update_output.txt").write_text("error occurred\n") 417 (hermes_home / ".update_exit_code").write_text("1") 418 419 mock_adapter = AsyncMock() 420 runner.adapters = {Platform.TELEGRAM: mock_adapter} 421 422 with patch("gateway.run._hermes_home", hermes_home): 423 await runner._watch_update_progress( 424 poll_interval=0.1, 425 stream_interval=0.2, 426 timeout=5.0, 427 ) 428 429 all_sent = " ".join(str(c) for c in mock_adapter.send.call_args_list) 430 assert "failed" in all_sent.lower() 431 432 @pytest.mark.asyncio 433 async def test_falls_back_when_adapter_unavailable(self, tmp_path): 434 """Falls back to legacy notification when adapter can't be resolved.""" 435 runner = _make_runner() 436 hermes_home = tmp_path / "hermes" 437 hermes_home.mkdir() 438 439 # Platform doesn't match any adapter 440 pending = {"platform": "discord", "chat_id": "111", "user_id": "222"} 441 (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) 442 (hermes_home / ".update_output.txt").write_text("done\n") 443 (hermes_home / ".update_exit_code").write_text("0") 444 445 # Only telegram adapter available 446 mock_adapter = AsyncMock() 447 runner.adapters = {Platform.TELEGRAM: mock_adapter} 448 449 with patch("gateway.run._hermes_home", hermes_home): 450 await runner._watch_update_progress( 451 poll_interval=0.1, 452 stream_interval=0.2, 453 timeout=5.0, 454 ) 455 456 # Should not crash; legacy notification handles this case 457 458 @pytest.mark.asyncio 459 async def test_prompt_forwarded_only_once(self, tmp_path): 460 """Regression: prompt must not be re-sent on every poll cycle. 461 462 Before the fix, the watcher never deleted .update_prompt.json after 463 forwarding, causing the same prompt to be sent every poll_interval. 464 """ 465 runner = _make_runner() 466 hermes_home = tmp_path / "hermes" 467 hermes_home.mkdir() 468 469 pending = {"platform": "telegram", "chat_id": "111", "user_id": "222", 470 "session_key": "agent:main:telegram:dm:111"} 471 (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) 472 (hermes_home / ".update_output.txt").write_text("") 473 474 mock_adapter = AsyncMock() 475 runner.adapters = {Platform.TELEGRAM: mock_adapter} 476 477 # Write the prompt file up front (before the watcher starts). 478 # The watcher should forward it exactly once, then delete it. 479 prompt = {"prompt": "Would you like to configure new options now? Y/n", 480 "default": "n", "id": "dup-test"} 481 (hermes_home / ".update_prompt.json").write_text(json.dumps(prompt)) 482 483 async def finish_after_polls(): 484 # Wait long enough for multiple poll cycles to occur, then 485 # simulate a response + completion. 486 await asyncio.sleep(1.0) 487 (hermes_home / ".update_response").write_text("n") 488 await asyncio.sleep(0.3) 489 (hermes_home / ".update_exit_code").write_text("0") 490 491 with patch("gateway.run._hermes_home", hermes_home): 492 task = asyncio.create_task(finish_after_polls()) 493 await runner._watch_update_progress( 494 poll_interval=0.1, 495 stream_interval=0.2, 496 timeout=10.0, 497 ) 498 await task 499 500 # Count how many times the prompt text was sent 501 all_sent = [str(c) for c in mock_adapter.send.call_args_list] 502 prompt_sends = [s for s in all_sent if "configure new options" in s] 503 assert len(prompt_sends) == 1, ( 504 f"Prompt was sent {len(prompt_sends)} times (expected 1). " 505 f"All sends: {all_sent}" 506 ) 507 508 509 # --------------------------------------------------------------------------- 510 # Message interception for update prompts 511 # --------------------------------------------------------------------------- 512 513 514 class TestUpdatePromptInterception: 515 """Tests for update prompt response interception in _handle_message.""" 516 517 @pytest.mark.asyncio 518 async def test_intercepts_response_when_prompt_pending(self, tmp_path): 519 """When _update_prompt_pending is set, the next message writes .update_response.""" 520 runner = _make_runner() 521 hermes_home = tmp_path / "hermes" 522 hermes_home.mkdir() 523 524 event = _make_event(text="y", chat_id="67890") 525 # The session key uses the full format from build_session_key 526 session_key = "agent:main:telegram:dm:67890" 527 runner._update_prompt_pending[session_key] = True 528 529 # Mock authorization and _session_key_for_source 530 runner._is_user_authorized = MagicMock(return_value=True) 531 runner._session_key_for_source = MagicMock(return_value=session_key) 532 533 with patch("gateway.run._hermes_home", hermes_home): 534 result = await runner._handle_message(event) 535 536 assert result is not None 537 assert "Sent" in result 538 response_path = hermes_home / ".update_response" 539 assert response_path.exists() 540 assert response_path.read_text() == "y" 541 # Should clear the pending flag 542 assert session_key not in runner._update_prompt_pending 543 544 @pytest.mark.asyncio 545 async def test_recognized_slash_command_bypasses_pending_update_prompt(self, tmp_path): 546 """Known slash commands must dispatch normally instead of being consumed. 547 548 The update subprocess is still blocked on stdin waiting for 549 ``.update_response``, so the gateway writes a blank response to 550 unblock it (``_gateway_prompt`` returns the prompt's default on 551 empty) before falling through to normal command dispatch. 552 """ 553 runner = _make_runner() 554 hermes_home = tmp_path / "hermes" 555 hermes_home.mkdir() 556 557 event = _make_event(text="/new", chat_id="67890") 558 session_key = "agent:main:telegram:dm:67890" 559 runner._update_prompt_pending[session_key] = True 560 runner._is_user_authorized = MagicMock(return_value=True) 561 runner._session_key_for_source = MagicMock(return_value=session_key) 562 runner._handle_reset_command = AsyncMock(return_value="reset ok") 563 564 with patch("gateway.run._hermes_home", hermes_home): 565 result = await runner._handle_message(event) 566 567 assert result == "reset ok" 568 runner._handle_reset_command.assert_awaited_once_with(event) 569 # .update_response was written (empty) to unblock the update 570 # subprocess; _gateway_prompt will read "", strip to "", and 571 # return the prompt's default. 572 response_path = hermes_home / ".update_response" 573 assert response_path.exists() 574 assert response_path.read_text() == "" 575 # Pending flag is cleared so stray future input won't be 576 # re-intercepted for a prompt that is no longer outstanding. 577 assert session_key not in runner._update_prompt_pending 578 579 @pytest.mark.asyncio 580 async def test_unrecognized_slash_command_still_consumed_as_response(self, tmp_path): 581 """Unknown /foo is written verbatim to .update_response (legacy behavior).""" 582 runner = _make_runner() 583 hermes_home = tmp_path / "hermes" 584 hermes_home.mkdir() 585 586 event = _make_event(text="/foobarbaz", chat_id="67890") 587 session_key = "agent:main:telegram:dm:67890" 588 runner._update_prompt_pending[session_key] = True 589 runner._is_user_authorized = MagicMock(return_value=True) 590 runner._session_key_for_source = MagicMock(return_value=session_key) 591 592 with patch("gateway.run._hermes_home", hermes_home): 593 result = await runner._handle_message(event) 594 595 response_path = hermes_home / ".update_response" 596 assert response_path.exists() 597 assert response_path.read_text() == "/foobarbaz" 598 assert "Sent" in (result or "") 599 assert session_key not in runner._update_prompt_pending 600 601 @pytest.mark.asyncio 602 async def test_normal_message_when_no_prompt_pending(self, tmp_path): 603 """Messages pass through normally when no prompt is pending.""" 604 runner = _make_runner() 605 hermes_home = tmp_path / "hermes" 606 hermes_home.mkdir() 607 608 event = _make_event(text="hello", chat_id="67890") 609 610 # No pending prompt 611 runner._is_user_authorized = MagicMock(return_value=True) 612 613 # The message should flow through to normal processing; 614 # we just verify it doesn't get intercepted 615 session_key = "agent:main:telegram:dm:67890" 616 assert session_key not in runner._update_prompt_pending 617 618 619 # --------------------------------------------------------------------------- 620 # cmd_update --gateway flag 621 # --------------------------------------------------------------------------- 622 623 624 class TestCmdUpdateGatewayMode: 625 """Tests for cmd_update with --gateway flag.""" 626 627 def test_gateway_flag_enables_gateway_prompt_for_stash(self, tmp_path): 628 """With --gateway, stash restore uses _gateway_prompt instead of input().""" 629 from hermes_cli.main import _restore_stashed_changes 630 631 # Use input_fn to verify the gateway path is taken 632 calls = [] 633 634 def fake_input(prompt, default=""): 635 calls.append(prompt) 636 return "n" 637 638 with patch("subprocess.run") as mock_run: 639 mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") 640 _restore_stashed_changes( 641 ["git"], tmp_path, "abc123", 642 prompt_user=True, 643 input_fn=fake_input, 644 ) 645 646 assert len(calls) == 1 647 assert "Restore" in calls[0] 648 649 def test_gateway_flag_parsed(self): 650 """The --gateway flag is accepted by the update subparser.""" 651 # Verify the argparse parser accepts --gateway by checking cmd_update 652 # receives gateway=True when the flag is set 653 from types import SimpleNamespace 654 args = SimpleNamespace(gateway=True) 655 assert args.gateway is True