/ tests / tools / test_ssh_bulk_upload.py
test_ssh_bulk_upload.py
  1  """Tests for SSH bulk upload via tar pipe."""
  2  
  3  import os
  4  import subprocess
  5  from pathlib import Path
  6  from unittest.mock import MagicMock, patch
  7  
  8  import pytest
  9  
 10  from tools.environments import ssh as ssh_env
 11  from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
 12  from tools.environments.ssh import SSHEnvironment
 13  
 14  
 15  def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
 16                 stderr_read=b""):
 17      """Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
 18      m = MagicMock()
 19      m.stdout = MagicMock()
 20      m.returncode = returncode
 21      m.poll.return_value = poll_return
 22      m.communicate.return_value = communicate_return
 23      m.stderr = MagicMock()
 24      m.stderr.read.return_value = stderr_read
 25      return m
 26  
 27  
 28  @pytest.fixture
 29  def mock_env(monkeypatch):
 30      """Create an SSHEnvironment with mocked connection/sync."""
 31      monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
 32      monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
 33      monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
 34      monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
 35      monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
 36      monkeypatch.setattr(
 37          ssh_env, "FileSyncManager",
 38          lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
 39      )
 40      return SSHEnvironment(host="example.com", user="testuser")
 41  
 42  
 43  class TestSSHBulkUpload:
 44      """Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
 45  
 46      def test_empty_files_is_noop(self, mock_env):
 47          """Empty file list should not spawn any subprocesses."""
 48          with patch.object(subprocess, "run") as mock_run, \
 49               patch.object(subprocess, "Popen") as mock_popen:
 50              mock_env._ssh_bulk_upload([])
 51              mock_run.assert_not_called()
 52              mock_popen.assert_not_called()
 53  
 54      def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
 55          """All parent directories should be created in one SSH call."""
 56          # Create test files
 57          f1 = tmp_path / "a.txt"
 58          f1.write_text("aaa")
 59          f2 = tmp_path / "b.txt"
 60          f2.write_text("bbb")
 61  
 62          files = [
 63              (str(f1), "/home/testuser/.hermes/skills/a.txt"),
 64              (str(f2), "/home/testuser/.hermes/credentials/b.txt"),
 65          ]
 66  
 67          # Mock subprocess.run for mkdir and Popen for tar pipe
 68          mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
 69  
 70          def make_proc(cmd, **kwargs):
 71              m = MagicMock()
 72              m.stdout = MagicMock()
 73              m.returncode = 0
 74              m.poll.return_value = 0
 75              m.communicate.return_value = (b"", b"")
 76              m.stderr = MagicMock()
 77              m.stderr.read.return_value = b""
 78              return m
 79  
 80          with patch.object(subprocess, "run", mock_run), \
 81               patch.object(subprocess, "Popen", side_effect=make_proc):
 82              mock_env._ssh_bulk_upload(files)
 83  
 84          # Exactly one subprocess.run call for mkdir
 85          assert mock_run.call_count == 1
 86          mkdir_cmd = mock_run.call_args[0][0]
 87          # Should contain mkdir -p with both parent dirs
 88          mkdir_str = " ".join(mkdir_cmd)
 89          assert "mkdir -p" in mkdir_str
 90          assert "/home/testuser/.hermes/skills" in mkdir_str
 91          assert "/home/testuser/.hermes/credentials" in mkdir_str
 92  
 93      def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
 94          """Symlinks in staging dir should mirror the remote path structure."""
 95          f1 = tmp_path / "local_a.txt"
 96          f1.write_text("content a")
 97  
 98          files = [
 99              (str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
100          ]
101  
102          staging_paths = []
103  
104          def capture_tar_cmd(cmd, **kwargs):
105              if cmd[0] == "tar":
106                  # Capture the staging dir from -C argument
107                  c_idx = cmd.index("-C")
108                  staging_dir = cmd[c_idx + 1]
109                  # Check the symlink exists
110                  expected = os.path.join(
111                      staging_dir, "home/testuser/.hermes/skills/my_skill.md"
112                  )
113                  staging_paths.append(expected)
114                  assert os.path.islink(expected), f"Expected symlink at {expected}"
115                  assert os.readlink(expected) == os.path.abspath(str(f1))
116  
117              mock = MagicMock()
118              mock.stdout = MagicMock()
119              mock.returncode = 0
120              mock.poll.return_value = 0
121              mock.communicate.return_value = (b"", b"")
122              mock.stderr = MagicMock()
123              mock.stderr.read.return_value = b""
124              return mock
125  
126          with patch.object(subprocess, "run",
127                            return_value=subprocess.CompletedProcess([], 0)), \
128               patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
129              mock_env._ssh_bulk_upload(files)
130  
131          assert len(staging_paths) == 1, "tar command should have been called"
132  
133      def test_tar_pipe_commands(self, mock_env, tmp_path):
134          """Verify tar and SSH commands are wired correctly."""
135          f1 = tmp_path / "x.txt"
136          f1.write_text("x")
137  
138          files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
139  
140          popen_cmds = []
141  
142          def capture_popen(cmd, **kwargs):
143              popen_cmds.append(cmd)
144              mock = MagicMock()
145              mock.stdout = MagicMock()
146              mock.returncode = 0
147              mock.poll.return_value = 0
148              mock.communicate.return_value = (b"", b"")
149              mock.stderr = MagicMock()
150              mock.stderr.read.return_value = b""
151              return mock
152  
153          with patch.object(subprocess, "run",
154                            return_value=subprocess.CompletedProcess([], 0)), \
155               patch.object(subprocess, "Popen", side_effect=capture_popen):
156              mock_env._ssh_bulk_upload(files)
157  
158          assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
159  
160          tar_cmd = popen_cmds[0]
161          ssh_cmd = popen_cmds[1]
162  
163          # tar: create, dereference symlinks, to stdout
164          assert tar_cmd[0] == "tar"
165          assert "-chf" in tar_cmd
166          assert "-" in tar_cmd  # stdout
167          assert "-C" in tar_cmd
168  
169          # ssh: extract from stdin at /, preserving existing dir modes (#17767)
170          ssh_str = " ".join(ssh_cmd)
171          assert "ssh" in ssh_str
172          assert "tar xf -" in ssh_str
173          assert "--no-overwrite-dir" in ssh_str
174          assert "-C /" in ssh_str
175          assert "testuser@example.com" in ssh_str
176  
177      def test_mkdir_failure_raises(self, mock_env, tmp_path):
178          """mkdir failure should raise RuntimeError before tar pipe."""
179          f1 = tmp_path / "y.txt"
180          f1.write_text("y")
181          files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
182  
183          failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
184          with patch.object(subprocess, "run", return_value=failed_run):
185              with pytest.raises(RuntimeError, match="remote mkdir failed"):
186                  mock_env._ssh_bulk_upload(files)
187  
188      def test_tar_create_failure_raises(self, mock_env, tmp_path):
189          """tar create failure should raise RuntimeError."""
190          f1 = tmp_path / "z.txt"
191          f1.write_text("z")
192          files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
193  
194          mock_tar = MagicMock()
195          mock_tar.stdout = MagicMock()
196          mock_tar.returncode = 1
197          mock_tar.poll.return_value = 1
198          mock_tar.communicate.return_value = (b"tar: error", b"")
199          mock_tar.stderr = MagicMock()
200          mock_tar.stderr.read.return_value = b"tar: error"
201  
202          mock_ssh = MagicMock()
203          mock_ssh.communicate.return_value = (b"", b"")
204          mock_ssh.returncode = 0
205  
206          def popen_side_effect(cmd, **kwargs):
207              if cmd[0] == "tar":
208                  return mock_tar
209              return mock_ssh
210  
211          with patch.object(subprocess, "run",
212                            return_value=subprocess.CompletedProcess([], 0)), \
213               patch.object(subprocess, "Popen", side_effect=popen_side_effect):
214              with pytest.raises(RuntimeError, match="tar create failed"):
215                  mock_env._ssh_bulk_upload(files)
216  
217      def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
218          """SSH tar extract failure should raise RuntimeError."""
219          f1 = tmp_path / "w.txt"
220          f1.write_text("w")
221          files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
222  
223          mock_tar = MagicMock()
224          mock_tar.stdout = MagicMock()
225          mock_tar.returncode = 0
226          mock_tar.poll.return_value = 0
227          mock_tar.communicate.return_value = (b"", b"")
228          mock_tar.stderr = MagicMock()
229          mock_tar.stderr.read.return_value = b""
230  
231          mock_ssh = MagicMock()
232          mock_ssh.communicate.return_value = (b"", b"Permission denied")
233          mock_ssh.returncode = 1
234  
235          def popen_side_effect(cmd, **kwargs):
236              if cmd[0] == "tar":
237                  return mock_tar
238              return mock_ssh
239  
240          with patch.object(subprocess, "run",
241                            return_value=subprocess.CompletedProcess([], 0)), \
242               patch.object(subprocess, "Popen", side_effect=popen_side_effect):
243              with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
244                  mock_env._ssh_bulk_upload(files)
245  
246      def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
247          """SSH command for tar extract should reuse ControlMaster socket."""
248          f1 = tmp_path / "c.txt"
249          f1.write_text("c")
250          files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
251  
252          popen_cmds = []
253  
254          def capture_popen(cmd, **kwargs):
255              popen_cmds.append(cmd)
256              mock = MagicMock()
257              mock.stdout = MagicMock()
258              mock.returncode = 0
259              mock.poll.return_value = 0
260              mock.communicate.return_value = (b"", b"")
261              mock.stderr = MagicMock()
262              mock.stderr.read.return_value = b""
263              return mock
264  
265          with patch.object(subprocess, "run",
266                            return_value=subprocess.CompletedProcess([], 0)), \
267               patch.object(subprocess, "Popen", side_effect=capture_popen):
268              mock_env._ssh_bulk_upload(files)
269  
270          # The SSH command (second Popen call) should include ControlPath
271          ssh_cmd = popen_cmds[1]
272          assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
273  
274      def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
275          """Bulk upload SSH command should include custom port and key."""
276          monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
277          monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
278          monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
279          monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
280          monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
281          monkeypatch.setattr(
282              ssh_env, "FileSyncManager",
283              lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
284          )
285          env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
286  
287          f1 = tmp_path / "d.txt"
288          f1.write_text("d")
289          files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
290  
291          run_cmds = []
292          popen_cmds = []
293  
294          def capture_run(cmd, **kwargs):
295              run_cmds.append(cmd)
296              return subprocess.CompletedProcess([], 0)
297  
298          def capture_popen(cmd, **kwargs):
299              popen_cmds.append(cmd)
300              mock = MagicMock()
301              mock.stdout = MagicMock()
302              mock.returncode = 0
303              mock.poll.return_value = 0
304              mock.communicate.return_value = (b"", b"")
305              mock.stderr = MagicMock()
306              mock.stderr.read.return_value = b""
307              return mock
308  
309          with patch.object(subprocess, "run", side_effect=capture_run), \
310               patch.object(subprocess, "Popen", side_effect=capture_popen):
311              env._ssh_bulk_upload(files)
312  
313          # Check mkdir SSH call includes port and key
314          assert len(run_cmds) == 1
315          mkdir_cmd = run_cmds[0]
316          assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
317          assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
318  
319          # Check tar extract SSH call includes port and key
320          ssh_cmd = popen_cmds[1]
321          assert "-p" in ssh_cmd and "2222" in ssh_cmd
322          assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
323  
324      def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
325          """Multiple files in the same dir should produce one mkdir entry."""
326          f1 = tmp_path / "a.txt"
327          f1.write_text("a")
328          f2 = tmp_path / "b.txt"
329          f2.write_text("b")
330          f3 = tmp_path / "c.txt"
331          f3.write_text("c")
332  
333          files = [
334              (str(f1), "/home/testuser/.hermes/skills/a.txt"),
335              (str(f2), "/home/testuser/.hermes/skills/b.txt"),
336              (str(f3), "/home/testuser/.hermes/credentials/c.txt"),
337          ]
338  
339          run_cmds = []
340  
341          def capture_run(cmd, **kwargs):
342              run_cmds.append(cmd)
343              return subprocess.CompletedProcess([], 0)
344  
345          def make_mock_proc(cmd, **kwargs):
346              mock = MagicMock()
347              mock.stdout = MagicMock()
348              mock.returncode = 0
349              mock.poll.return_value = 0
350              mock.communicate.return_value = (b"", b"")
351              mock.stderr = MagicMock()
352              mock.stderr.read.return_value = b""
353              return mock
354  
355          with patch.object(subprocess, "run", side_effect=capture_run), \
356               patch.object(subprocess, "Popen", side_effect=make_mock_proc):
357              mock_env._ssh_bulk_upload(files)
358  
359          # Only one mkdir call
360          assert len(run_cmds) == 1
361          mkdir_str = " ".join(run_cmds[0])
362          # skills dir should appear exactly once despite two files
363          assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
364          assert "/home/testuser/.hermes/credentials" in mkdir_str
365  
366      def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
367          """tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
368          f1 = tmp_path / "s.txt"
369          f1.write_text("s")
370          files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
371  
372          mock_tar_stdout = MagicMock()
373  
374          def make_proc(cmd, **kwargs):
375              mock = MagicMock()
376              if cmd[0] == "tar":
377                  mock.stdout = mock_tar_stdout
378              else:
379                  mock.stdout = MagicMock()
380              mock.returncode = 0
381              mock.poll.return_value = 0
382              mock.communicate.return_value = (b"", b"")
383              mock.stderr = MagicMock()
384              mock.stderr.read.return_value = b""
385              return mock
386  
387          with patch.object(subprocess, "run",
388                            return_value=subprocess.CompletedProcess([], 0)), \
389               patch.object(subprocess, "Popen", side_effect=make_proc):
390              mock_env._ssh_bulk_upload(files)
391  
392          mock_tar_stdout.close.assert_called_once()
393  
394      def test_timeout_kills_both_processes(self, mock_env, tmp_path):
395          """TimeoutExpired during communicate should kill both processes."""
396          f1 = tmp_path / "t.txt"
397          f1.write_text("t")
398          files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
399  
400          mock_tar = MagicMock()
401          mock_tar.stdout = MagicMock()
402          mock_tar.returncode = None
403          mock_tar.poll.return_value = None
404  
405          mock_ssh = MagicMock()
406          mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
407          mock_ssh.returncode = None
408  
409          def make_proc(cmd, **kwargs):
410              if cmd[0] == "tar":
411                  return mock_tar
412              return mock_ssh
413  
414          with patch.object(subprocess, "run",
415                            return_value=subprocess.CompletedProcess([], 0)), \
416               patch.object(subprocess, "Popen", side_effect=make_proc):
417              with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
418                  mock_env._ssh_bulk_upload(files)
419  
420          mock_tar.kill.assert_called_once()
421          mock_ssh.kill.assert_called_once()
422  
423  
424  class TestSSHBulkUploadWiring:
425      """Verify bulk_upload_fn is wired into FileSyncManager."""
426  
427      def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
428          """SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
429          monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
430          monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
431          monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
432          monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
433          monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
434  
435          captured_kwargs = {}
436  
437          class FakeSyncManager:
438              def __init__(self, **kwargs):
439                  captured_kwargs.update(kwargs)
440  
441              def sync(self, **kw):
442                  pass
443  
444          monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
445  
446          env = SSHEnvironment(host="h", user="u")
447  
448          assert "bulk_upload_fn" in captured_kwargs
449          assert captured_kwargs["bulk_upload_fn"] is not None
450          # Should be the bound method
451          assert callable(captured_kwargs["bulk_upload_fn"])
452  
453  
454  class TestSharedHelpers:
455      """Direct unit tests for file_sync.py helpers."""
456  
457      def test_quoted_mkdir_command_basic(self):
458          result = quoted_mkdir_command(["/a", "/b/c"])
459          assert result == "mkdir -p /a /b/c"
460  
461      def test_quoted_mkdir_command_quotes_special_chars(self):
462          result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
463          assert "mkdir -p" in result
464          # shlex.quote wraps in single quotes
465          assert "'/path/with spaces'" in result
466  
467      def test_quoted_mkdir_command_empty(self):
468          result = quoted_mkdir_command([])
469          assert result == "mkdir -p "
470  
471      def test_unique_parent_dirs_deduplicates(self):
472          files = [
473              ("/local/a.txt", "/remote/dir/a.txt"),
474              ("/local/b.txt", "/remote/dir/b.txt"),
475              ("/local/c.txt", "/remote/other/c.txt"),
476          ]
477          result = unique_parent_dirs(files)
478          assert result == ["/remote/dir", "/remote/other"]
479  
480      def test_unique_parent_dirs_sorted(self):
481          files = [
482              ("/local/z.txt", "/z/file.txt"),
483              ("/local/a.txt", "/a/file.txt"),
484          ]
485          result = unique_parent_dirs(files)
486          assert result == ["/a", "/z"]
487  
488      def test_unique_parent_dirs_empty(self):
489          assert unique_parent_dirs([]) == []
490  
491  
492  class TestSSHBulkUploadEdgeCases:
493      """Edge cases for _ssh_bulk_upload."""
494  
495      def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
496          """If SSH Popen raises, tar process must be killed and cleaned up."""
497          f1 = tmp_path / "e.txt"
498          f1.write_text("e")
499          files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
500  
501          mock_tar = _mock_proc()
502  
503          call_count = 0
504  
505          def failing_ssh_popen(cmd, **kwargs):
506              nonlocal call_count
507              call_count += 1
508              if call_count == 1:
509                  return mock_tar  # tar Popen succeeds
510              raise OSError("SSH binary not found")
511  
512          with patch.object(subprocess, "run",
513                            return_value=subprocess.CompletedProcess([], 0)), \
514               patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
515              with pytest.raises(OSError, match="SSH binary not found"):
516                  mock_env._ssh_bulk_upload(files)
517  
518          mock_tar.kill.assert_called_once()
519          mock_tar.wait.assert_called_once()