test_ssh_bulk_upload.py
1 """Tests for SSH bulk upload via tar pipe.""" 2 3 import os 4 import subprocess 5 from pathlib import Path 6 from unittest.mock import MagicMock, patch 7 8 import pytest 9 10 from tools.environments import ssh as ssh_env 11 from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs 12 from tools.environments.ssh import SSHEnvironment 13 14 15 def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""), 16 stderr_read=b""): 17 """Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes.""" 18 m = MagicMock() 19 m.stdout = MagicMock() 20 m.returncode = returncode 21 m.poll.return_value = poll_return 22 m.communicate.return_value = communicate_return 23 m.stderr = MagicMock() 24 m.stderr.read.return_value = stderr_read 25 return m 26 27 28 @pytest.fixture 29 def mock_env(monkeypatch): 30 """Create an SSHEnvironment with mocked connection/sync.""" 31 monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") 32 monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) 33 monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser") 34 monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) 35 monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) 36 monkeypatch.setattr( 37 ssh_env, "FileSyncManager", 38 lambda **kw: type("M", (), {"sync": lambda self, **k: None})(), 39 ) 40 return SSHEnvironment(host="example.com", user="testuser") 41 42 43 class TestSSHBulkUpload: 44 """Unit tests for _ssh_bulk_upload — tar pipe mechanics.""" 45 46 def test_empty_files_is_noop(self, mock_env): 47 """Empty file list should not spawn any subprocesses.""" 48 with patch.object(subprocess, "run") as mock_run, \ 49 patch.object(subprocess, "Popen") as mock_popen: 50 mock_env._ssh_bulk_upload([]) 51 mock_run.assert_not_called() 52 mock_popen.assert_not_called() 53 54 def test_mkdir_batched_into_single_call(self, mock_env, tmp_path): 55 """All parent directories should be created in one SSH call.""" 56 # Create test files 57 f1 = tmp_path / "a.txt" 58 f1.write_text("aaa") 59 f2 = tmp_path / "b.txt" 60 f2.write_text("bbb") 61 62 files = [ 63 (str(f1), "/home/testuser/.hermes/skills/a.txt"), 64 (str(f2), "/home/testuser/.hermes/credentials/b.txt"), 65 ] 66 67 # Mock subprocess.run for mkdir and Popen for tar pipe 68 mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0)) 69 70 def make_proc(cmd, **kwargs): 71 m = MagicMock() 72 m.stdout = MagicMock() 73 m.returncode = 0 74 m.poll.return_value = 0 75 m.communicate.return_value = (b"", b"") 76 m.stderr = MagicMock() 77 m.stderr.read.return_value = b"" 78 return m 79 80 with patch.object(subprocess, "run", mock_run), \ 81 patch.object(subprocess, "Popen", side_effect=make_proc): 82 mock_env._ssh_bulk_upload(files) 83 84 # Exactly one subprocess.run call for mkdir 85 assert mock_run.call_count == 1 86 mkdir_cmd = mock_run.call_args[0][0] 87 # Should contain mkdir -p with both parent dirs 88 mkdir_str = " ".join(mkdir_cmd) 89 assert "mkdir -p" in mkdir_str 90 assert "/home/testuser/.hermes/skills" in mkdir_str 91 assert "/home/testuser/.hermes/credentials" in mkdir_str 92 93 def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path): 94 """Symlinks in staging dir should mirror the remote path structure.""" 95 f1 = tmp_path / "local_a.txt" 96 f1.write_text("content a") 97 98 files = [ 99 (str(f1), "/home/testuser/.hermes/skills/my_skill.md"), 100 ] 101 102 staging_paths = [] 103 104 def capture_tar_cmd(cmd, **kwargs): 105 if cmd[0] == "tar": 106 # Capture the staging dir from -C argument 107 c_idx = cmd.index("-C") 108 staging_dir = cmd[c_idx + 1] 109 # Check the symlink exists 110 expected = os.path.join( 111 staging_dir, "home/testuser/.hermes/skills/my_skill.md" 112 ) 113 staging_paths.append(expected) 114 assert os.path.islink(expected), f"Expected symlink at {expected}" 115 assert os.readlink(expected) == os.path.abspath(str(f1)) 116 117 mock = MagicMock() 118 mock.stdout = MagicMock() 119 mock.returncode = 0 120 mock.poll.return_value = 0 121 mock.communicate.return_value = (b"", b"") 122 mock.stderr = MagicMock() 123 mock.stderr.read.return_value = b"" 124 return mock 125 126 with patch.object(subprocess, "run", 127 return_value=subprocess.CompletedProcess([], 0)), \ 128 patch.object(subprocess, "Popen", side_effect=capture_tar_cmd): 129 mock_env._ssh_bulk_upload(files) 130 131 assert len(staging_paths) == 1, "tar command should have been called" 132 133 def test_tar_pipe_commands(self, mock_env, tmp_path): 134 """Verify tar and SSH commands are wired correctly.""" 135 f1 = tmp_path / "x.txt" 136 f1.write_text("x") 137 138 files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")] 139 140 popen_cmds = [] 141 142 def capture_popen(cmd, **kwargs): 143 popen_cmds.append(cmd) 144 mock = MagicMock() 145 mock.stdout = MagicMock() 146 mock.returncode = 0 147 mock.poll.return_value = 0 148 mock.communicate.return_value = (b"", b"") 149 mock.stderr = MagicMock() 150 mock.stderr.read.return_value = b"" 151 return mock 152 153 with patch.object(subprocess, "run", 154 return_value=subprocess.CompletedProcess([], 0)), \ 155 patch.object(subprocess, "Popen", side_effect=capture_popen): 156 mock_env._ssh_bulk_upload(files) 157 158 assert len(popen_cmds) == 2, "Should spawn tar + ssh processes" 159 160 tar_cmd = popen_cmds[0] 161 ssh_cmd = popen_cmds[1] 162 163 # tar: create, dereference symlinks, to stdout 164 assert tar_cmd[0] == "tar" 165 assert "-chf" in tar_cmd 166 assert "-" in tar_cmd # stdout 167 assert "-C" in tar_cmd 168 169 # ssh: extract from stdin at /, preserving existing dir modes (#17767) 170 ssh_str = " ".join(ssh_cmd) 171 assert "ssh" in ssh_str 172 assert "tar xf -" in ssh_str 173 assert "--no-overwrite-dir" in ssh_str 174 assert "-C /" in ssh_str 175 assert "testuser@example.com" in ssh_str 176 177 def test_mkdir_failure_raises(self, mock_env, tmp_path): 178 """mkdir failure should raise RuntimeError before tar pipe.""" 179 f1 = tmp_path / "y.txt" 180 f1.write_text("y") 181 files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")] 182 183 failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied") 184 with patch.object(subprocess, "run", return_value=failed_run): 185 with pytest.raises(RuntimeError, match="remote mkdir failed"): 186 mock_env._ssh_bulk_upload(files) 187 188 def test_tar_create_failure_raises(self, mock_env, tmp_path): 189 """tar create failure should raise RuntimeError.""" 190 f1 = tmp_path / "z.txt" 191 f1.write_text("z") 192 files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")] 193 194 mock_tar = MagicMock() 195 mock_tar.stdout = MagicMock() 196 mock_tar.returncode = 1 197 mock_tar.poll.return_value = 1 198 mock_tar.communicate.return_value = (b"tar: error", b"") 199 mock_tar.stderr = MagicMock() 200 mock_tar.stderr.read.return_value = b"tar: error" 201 202 mock_ssh = MagicMock() 203 mock_ssh.communicate.return_value = (b"", b"") 204 mock_ssh.returncode = 0 205 206 def popen_side_effect(cmd, **kwargs): 207 if cmd[0] == "tar": 208 return mock_tar 209 return mock_ssh 210 211 with patch.object(subprocess, "run", 212 return_value=subprocess.CompletedProcess([], 0)), \ 213 patch.object(subprocess, "Popen", side_effect=popen_side_effect): 214 with pytest.raises(RuntimeError, match="tar create failed"): 215 mock_env._ssh_bulk_upload(files) 216 217 def test_ssh_extract_failure_raises(self, mock_env, tmp_path): 218 """SSH tar extract failure should raise RuntimeError.""" 219 f1 = tmp_path / "w.txt" 220 f1.write_text("w") 221 files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")] 222 223 mock_tar = MagicMock() 224 mock_tar.stdout = MagicMock() 225 mock_tar.returncode = 0 226 mock_tar.poll.return_value = 0 227 mock_tar.communicate.return_value = (b"", b"") 228 mock_tar.stderr = MagicMock() 229 mock_tar.stderr.read.return_value = b"" 230 231 mock_ssh = MagicMock() 232 mock_ssh.communicate.return_value = (b"", b"Permission denied") 233 mock_ssh.returncode = 1 234 235 def popen_side_effect(cmd, **kwargs): 236 if cmd[0] == "tar": 237 return mock_tar 238 return mock_ssh 239 240 with patch.object(subprocess, "run", 241 return_value=subprocess.CompletedProcess([], 0)), \ 242 patch.object(subprocess, "Popen", side_effect=popen_side_effect): 243 with pytest.raises(RuntimeError, match="tar extract over SSH failed"): 244 mock_env._ssh_bulk_upload(files) 245 246 def test_ssh_command_uses_control_socket(self, mock_env, tmp_path): 247 """SSH command for tar extract should reuse ControlMaster socket.""" 248 f1 = tmp_path / "c.txt" 249 f1.write_text("c") 250 files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")] 251 252 popen_cmds = [] 253 254 def capture_popen(cmd, **kwargs): 255 popen_cmds.append(cmd) 256 mock = MagicMock() 257 mock.stdout = MagicMock() 258 mock.returncode = 0 259 mock.poll.return_value = 0 260 mock.communicate.return_value = (b"", b"") 261 mock.stderr = MagicMock() 262 mock.stderr.read.return_value = b"" 263 return mock 264 265 with patch.object(subprocess, "run", 266 return_value=subprocess.CompletedProcess([], 0)), \ 267 patch.object(subprocess, "Popen", side_effect=capture_popen): 268 mock_env._ssh_bulk_upload(files) 269 270 # The SSH command (second Popen call) should include ControlPath 271 ssh_cmd = popen_cmds[1] 272 assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd) 273 274 def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path): 275 """Bulk upload SSH command should include custom port and key.""" 276 monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") 277 monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) 278 monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") 279 monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) 280 monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) 281 monkeypatch.setattr( 282 ssh_env, "FileSyncManager", 283 lambda **kw: type("M", (), {"sync": lambda self, **k: None})(), 284 ) 285 env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key") 286 287 f1 = tmp_path / "d.txt" 288 f1.write_text("d") 289 files = [(str(f1), "/home/u/.hermes/skills/d.txt")] 290 291 run_cmds = [] 292 popen_cmds = [] 293 294 def capture_run(cmd, **kwargs): 295 run_cmds.append(cmd) 296 return subprocess.CompletedProcess([], 0) 297 298 def capture_popen(cmd, **kwargs): 299 popen_cmds.append(cmd) 300 mock = MagicMock() 301 mock.stdout = MagicMock() 302 mock.returncode = 0 303 mock.poll.return_value = 0 304 mock.communicate.return_value = (b"", b"") 305 mock.stderr = MagicMock() 306 mock.stderr.read.return_value = b"" 307 return mock 308 309 with patch.object(subprocess, "run", side_effect=capture_run), \ 310 patch.object(subprocess, "Popen", side_effect=capture_popen): 311 env._ssh_bulk_upload(files) 312 313 # Check mkdir SSH call includes port and key 314 assert len(run_cmds) == 1 315 mkdir_cmd = run_cmds[0] 316 assert "-p" in mkdir_cmd and "2222" in mkdir_cmd 317 assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd 318 319 # Check tar extract SSH call includes port and key 320 ssh_cmd = popen_cmds[1] 321 assert "-p" in ssh_cmd and "2222" in ssh_cmd 322 assert "-i" in ssh_cmd and "/my/key" in ssh_cmd 323 324 def test_parent_dirs_deduplicated(self, mock_env, tmp_path): 325 """Multiple files in the same dir should produce one mkdir entry.""" 326 f1 = tmp_path / "a.txt" 327 f1.write_text("a") 328 f2 = tmp_path / "b.txt" 329 f2.write_text("b") 330 f3 = tmp_path / "c.txt" 331 f3.write_text("c") 332 333 files = [ 334 (str(f1), "/home/testuser/.hermes/skills/a.txt"), 335 (str(f2), "/home/testuser/.hermes/skills/b.txt"), 336 (str(f3), "/home/testuser/.hermes/credentials/c.txt"), 337 ] 338 339 run_cmds = [] 340 341 def capture_run(cmd, **kwargs): 342 run_cmds.append(cmd) 343 return subprocess.CompletedProcess([], 0) 344 345 def make_mock_proc(cmd, **kwargs): 346 mock = MagicMock() 347 mock.stdout = MagicMock() 348 mock.returncode = 0 349 mock.poll.return_value = 0 350 mock.communicate.return_value = (b"", b"") 351 mock.stderr = MagicMock() 352 mock.stderr.read.return_value = b"" 353 return mock 354 355 with patch.object(subprocess, "run", side_effect=capture_run), \ 356 patch.object(subprocess, "Popen", side_effect=make_mock_proc): 357 mock_env._ssh_bulk_upload(files) 358 359 # Only one mkdir call 360 assert len(run_cmds) == 1 361 mkdir_str = " ".join(run_cmds[0]) 362 # skills dir should appear exactly once despite two files 363 assert mkdir_str.count("/home/testuser/.hermes/skills") == 1 364 assert "/home/testuser/.hermes/credentials" in mkdir_str 365 366 def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path): 367 """tar_proc.stdout must be closed so SIGPIPE propagates correctly.""" 368 f1 = tmp_path / "s.txt" 369 f1.write_text("s") 370 files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")] 371 372 mock_tar_stdout = MagicMock() 373 374 def make_proc(cmd, **kwargs): 375 mock = MagicMock() 376 if cmd[0] == "tar": 377 mock.stdout = mock_tar_stdout 378 else: 379 mock.stdout = MagicMock() 380 mock.returncode = 0 381 mock.poll.return_value = 0 382 mock.communicate.return_value = (b"", b"") 383 mock.stderr = MagicMock() 384 mock.stderr.read.return_value = b"" 385 return mock 386 387 with patch.object(subprocess, "run", 388 return_value=subprocess.CompletedProcess([], 0)), \ 389 patch.object(subprocess, "Popen", side_effect=make_proc): 390 mock_env._ssh_bulk_upload(files) 391 392 mock_tar_stdout.close.assert_called_once() 393 394 def test_timeout_kills_both_processes(self, mock_env, tmp_path): 395 """TimeoutExpired during communicate should kill both processes.""" 396 f1 = tmp_path / "t.txt" 397 f1.write_text("t") 398 files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")] 399 400 mock_tar = MagicMock() 401 mock_tar.stdout = MagicMock() 402 mock_tar.returncode = None 403 mock_tar.poll.return_value = None 404 405 mock_ssh = MagicMock() 406 mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120) 407 mock_ssh.returncode = None 408 409 def make_proc(cmd, **kwargs): 410 if cmd[0] == "tar": 411 return mock_tar 412 return mock_ssh 413 414 with patch.object(subprocess, "run", 415 return_value=subprocess.CompletedProcess([], 0)), \ 416 patch.object(subprocess, "Popen", side_effect=make_proc): 417 with pytest.raises(RuntimeError, match="SSH bulk upload timed out"): 418 mock_env._ssh_bulk_upload(files) 419 420 mock_tar.kill.assert_called_once() 421 mock_ssh.kill.assert_called_once() 422 423 424 class TestSSHBulkUploadWiring: 425 """Verify bulk_upload_fn is wired into FileSyncManager.""" 426 427 def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch): 428 """SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager.""" 429 monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") 430 monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) 431 monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root") 432 monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) 433 monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) 434 435 captured_kwargs = {} 436 437 class FakeSyncManager: 438 def __init__(self, **kwargs): 439 captured_kwargs.update(kwargs) 440 441 def sync(self, **kw): 442 pass 443 444 monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager) 445 446 env = SSHEnvironment(host="h", user="u") 447 448 assert "bulk_upload_fn" in captured_kwargs 449 assert captured_kwargs["bulk_upload_fn"] is not None 450 # Should be the bound method 451 assert callable(captured_kwargs["bulk_upload_fn"]) 452 453 454 class TestSharedHelpers: 455 """Direct unit tests for file_sync.py helpers.""" 456 457 def test_quoted_mkdir_command_basic(self): 458 result = quoted_mkdir_command(["/a", "/b/c"]) 459 assert result == "mkdir -p /a /b/c" 460 461 def test_quoted_mkdir_command_quotes_special_chars(self): 462 result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"]) 463 assert "mkdir -p" in result 464 # shlex.quote wraps in single quotes 465 assert "'/path/with spaces'" in result 466 467 def test_quoted_mkdir_command_empty(self): 468 result = quoted_mkdir_command([]) 469 assert result == "mkdir -p " 470 471 def test_unique_parent_dirs_deduplicates(self): 472 files = [ 473 ("/local/a.txt", "/remote/dir/a.txt"), 474 ("/local/b.txt", "/remote/dir/b.txt"), 475 ("/local/c.txt", "/remote/other/c.txt"), 476 ] 477 result = unique_parent_dirs(files) 478 assert result == ["/remote/dir", "/remote/other"] 479 480 def test_unique_parent_dirs_sorted(self): 481 files = [ 482 ("/local/z.txt", "/z/file.txt"), 483 ("/local/a.txt", "/a/file.txt"), 484 ] 485 result = unique_parent_dirs(files) 486 assert result == ["/a", "/z"] 487 488 def test_unique_parent_dirs_empty(self): 489 assert unique_parent_dirs([]) == [] 490 491 492 class TestSSHBulkUploadEdgeCases: 493 """Edge cases for _ssh_bulk_upload.""" 494 495 def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path): 496 """If SSH Popen raises, tar process must be killed and cleaned up.""" 497 f1 = tmp_path / "e.txt" 498 f1.write_text("e") 499 files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")] 500 501 mock_tar = _mock_proc() 502 503 call_count = 0 504 505 def failing_ssh_popen(cmd, **kwargs): 506 nonlocal call_count 507 call_count += 1 508 if call_count == 1: 509 return mock_tar # tar Popen succeeds 510 raise OSError("SSH binary not found") 511 512 with patch.object(subprocess, "run", 513 return_value=subprocess.CompletedProcess([], 0)), \ 514 patch.object(subprocess, "Popen", side_effect=failing_ssh_popen): 515 with pytest.raises(OSError, match="SSH binary not found"): 516 mock_env._ssh_bulk_upload(files) 517 518 mock_tar.kill.assert_called_once() 519 mock_tar.wait.assert_called_once()