/ tools / environments / file_sync.py
file_sync.py
  1  """Shared file sync manager for remote execution backends.
  2  
  3  Tracks local file changes via mtime+size, detects deletions, and
  4  syncs to remote environments transactionally.  Used by SSH, Modal,
  5  and Daytona.  Docker and Singularity use bind mounts (live host FS
  6  view) and don't need this.
  7  """
  8  
  9  import hashlib
 10  import logging
 11  import os
 12  import shlex
 13  import shutil
 14  import signal
 15  import tarfile
 16  import tempfile
 17  import threading
 18  import time
 19  
 20  try:
 21      import fcntl
 22  except ImportError:
 23      fcntl = None  # Windows — file locking skipped
 24  from pathlib import Path
 25  from typing import Callable
 26  
 27  from hermes_constants import get_hermes_home
 28  from tools.environments.base import _file_mtime_key
 29  
 30  logger = logging.getLogger(__name__)
 31  
 32  # Keep retry sleeps patchable without mutating the shared stdlib ``time``
 33  # module. Patching ``tools.environments.file_sync.time.sleep`` replaces
 34  # ``time.sleep`` globally because ``time`` is the module object; under xdist
 35  # that lets unrelated background threads inflate retry-test call counts.
 36  _sleep = time.sleep
 37  
 38  _SYNC_INTERVAL_SECONDS = 5.0
 39  _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC"
 40  
 41  # Transport callbacks provided by each backend
 42  UploadFn = Callable[[str, str], None]  # (host_path, remote_path) -> raises on failure
 43  BulkUploadFn = Callable[[list[tuple[str, str]]], None]  # [(host_path, remote_path), ...] -> raises on failure
 44  BulkDownloadFn = Callable[[Path], None]  # (dest_tar_path) -> writes tar archive, raises on failure
 45  DeleteFn = Callable[[list[str]], None]  # (remote_paths) -> raises on failure
 46  GetFilesFn = Callable[[], list[tuple[str, str]]]  # () -> [(host_path, remote_path), ...]
 47  
 48  
 49  def iter_sync_files(container_base: str = "/root/.hermes") -> list[tuple[str, str]]:
 50      """Enumerate all files that should be synced to a remote environment.
 51  
 52      Combines credentials, skills, and cache into a single flat list of
 53      (host_path, remote_path) pairs.  Credential paths are remapped from
 54      the hardcoded /root/.hermes to *container_base* because the remote
 55      user's home may differ (e.g. /home/daytona, /home/user).
 56      """
 57      # Late import: credential_files imports agent modules that create
 58      # circular dependencies if loaded at file_sync module level.
 59      from tools.credential_files import (
 60          get_credential_file_mounts,
 61          iter_cache_files,
 62          iter_skills_files,
 63      )
 64  
 65      files: list[tuple[str, str]] = []
 66      for entry in get_credential_file_mounts():
 67          remote = entry["container_path"].replace(
 68              "/root/.hermes", container_base, 1
 69          )
 70          files.append((entry["host_path"], remote))
 71      for entry in iter_skills_files(container_base=container_base):
 72          files.append((entry["host_path"], entry["container_path"]))
 73      for entry in iter_cache_files(container_base=container_base):
 74          files.append((entry["host_path"], entry["container_path"]))
 75      return files
 76  
 77  
 78  def quoted_rm_command(remote_paths: list[str]) -> str:
 79      """Build a shell ``rm -f`` command for a batch of remote paths."""
 80      return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths)
 81  
 82  
 83  def quoted_mkdir_command(dirs: list[str]) -> str:
 84      """Build a shell ``mkdir -p`` command for a batch of directories."""
 85      return "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
 86  
 87  
 88  def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
 89      """Extract sorted unique parent directories from (host, remote) pairs."""
 90      return sorted({str(Path(remote).parent) for _, remote in files})
 91  
 92  
 93  def _sha256_file(path: str) -> str:
 94      """Return hex SHA-256 digest of a file."""
 95      h = hashlib.sha256()
 96      with open(path, "rb") as f:
 97          for chunk in iter(lambda: f.read(65536), b""):
 98              h.update(chunk)
 99      return h.hexdigest()
100  
101  
102  _SYNC_BACK_MAX_RETRIES = 3
103  _SYNC_BACK_BACKOFF = (2, 4, 8)  # seconds between retries
104  _SYNC_BACK_MAX_BYTES = 2 * 1024 * 1024 * 1024  # 2 GiB — refuse to extract larger tars
105  
106  
107  class FileSyncManager:
108      """Tracks local file changes and syncs to a remote environment.
109  
110      Backends instantiate this with transport callbacks (upload, delete)
111      and a file-source callable.  The manager handles mtime-based change
112      detection, deletion tracking, rate limiting, and transactional state.
113  
114      Not used by bind-mount backends (Docker, Singularity) — those get
115      live host FS views and don't need file sync.
116      """
117  
118      def __init__(
119          self,
120          get_files_fn: GetFilesFn,
121          upload_fn: UploadFn,
122          delete_fn: DeleteFn,
123          sync_interval: float = _SYNC_INTERVAL_SECONDS,
124          bulk_upload_fn: BulkUploadFn | None = None,
125          bulk_download_fn: BulkDownloadFn | None = None,
126      ):
127          self._get_files_fn = get_files_fn
128          self._upload_fn = upload_fn
129          self._bulk_upload_fn = bulk_upload_fn
130          self._bulk_download_fn = bulk_download_fn
131          self._delete_fn = delete_fn
132          self._synced_files: dict[str, tuple[float, int]] = {}  # remote_path -> (mtime, size)
133          self._pushed_hashes: dict[str, str] = {}  # remote_path -> sha256 hex digest
134          self._last_sync_time: float = 0.0  # monotonic; 0 ensures first sync runs
135          self._sync_interval = sync_interval
136  
137      def sync(self, *, force: bool = False) -> None:
138          """Run a sync cycle: upload changed files, delete removed files.
139  
140          Rate-limited to once per ``sync_interval`` unless *force* is True
141          or ``HERMES_FORCE_FILE_SYNC=1`` is set.
142  
143          Transactional: state only committed if ALL operations succeed.
144          On failure, state rolls back so the next cycle retries everything.
145          """
146          if not force and not os.environ.get(_FORCE_SYNC_ENV):
147              now = time.monotonic()
148              if now - self._last_sync_time < self._sync_interval:
149                  return
150  
151          current_files = self._get_files_fn()
152          current_remote_paths = {remote for _, remote in current_files}
153  
154          # --- Uploads: new or changed files ---
155          to_upload: list[tuple[str, str]] = []
156          new_files = dict(self._synced_files)
157          for host_path, remote_path in current_files:
158              file_key = _file_mtime_key(host_path)
159              if file_key is None:
160                  continue
161              if self._synced_files.get(remote_path) == file_key:
162                  continue
163              to_upload.append((host_path, remote_path))
164              new_files[remote_path] = file_key
165  
166          # --- Deletes: synced paths no longer in current set ---
167          to_delete = [p for p in self._synced_files if p not in current_remote_paths]
168  
169          if not to_upload and not to_delete:
170              self._last_sync_time = time.monotonic()
171              return
172  
173          # Snapshot for rollback (only when there's work to do)
174          prev_files = dict(self._synced_files)
175          prev_hashes = dict(self._pushed_hashes)
176  
177          if to_upload:
178              logger.debug("file_sync: uploading %d file(s)", len(to_upload))
179          if to_delete:
180              logger.debug("file_sync: deleting %d stale remote file(s)", len(to_delete))
181  
182          try:
183              if to_upload and self._bulk_upload_fn is not None:
184                  self._bulk_upload_fn(to_upload)
185                  logger.debug("file_sync: bulk-uploaded %d file(s)", len(to_upload))
186              else:
187                  for host_path, remote_path in to_upload:
188                      self._upload_fn(host_path, remote_path)
189                      logger.debug("file_sync: uploaded %s -> %s", host_path, remote_path)
190  
191              if to_delete:
192                  self._delete_fn(to_delete)
193                  logger.debug("file_sync: deleted %s", to_delete)
194  
195              # --- Commit (all succeeded) ---
196              for host_path, remote_path in to_upload:
197                  self._pushed_hashes[remote_path] = _sha256_file(host_path)
198  
199              for p in to_delete:
200                  new_files.pop(p, None)
201                  self._pushed_hashes.pop(p, None)
202  
203              self._synced_files = new_files
204              self._last_sync_time = time.monotonic()
205  
206          except Exception as exc:
207              self._synced_files = prev_files
208              self._pushed_hashes = prev_hashes
209              self._last_sync_time = time.monotonic()
210              logger.warning("file_sync: sync failed, rolled back state: %s", exc)
211  
212      # ------------------------------------------------------------------
213      # Sync-back: pull remote changes to host on teardown
214      # ------------------------------------------------------------------
215  
216      def sync_back(self, hermes_home: Path | None = None) -> None:
217          """Pull remote changes back to the host filesystem.
218  
219          Downloads the remote ``.hermes/`` directory as a tar archive,
220          unpacks it, and applies only files that differ from what was
221          originally pushed (based on SHA-256 content hashes).
222  
223          Protected against SIGINT (defers the signal until complete) and
224          serialized across concurrent gateway sandboxes via file lock.
225          """
226          if self._bulk_download_fn is None:
227              return
228  
229          # Nothing was ever committed through this manager — the initial
230          # push failed or never ran. Skip sync_back to avoid retry storms
231          # against an uninitialized remote .hermes/ directory.
232          if not self._pushed_hashes and not self._synced_files:
233              logger.debug("sync_back: no prior push state — skipping")
234              return
235  
236          lock_path = (hermes_home or get_hermes_home()) / ".sync.lock"
237          lock_path.parent.mkdir(parents=True, exist_ok=True)
238  
239          last_exc: Exception | None = None
240          for attempt in range(_SYNC_BACK_MAX_RETRIES):
241              try:
242                  self._sync_back_once(lock_path)
243                  return
244              except Exception as exc:
245                  last_exc = exc
246                  if attempt < _SYNC_BACK_MAX_RETRIES - 1:
247                      delay = _SYNC_BACK_BACKOFF[attempt]
248                      logger.warning(
249                          "sync_back: attempt %d failed (%s), retrying in %ds",
250                          attempt + 1, exc, delay,
251                      )
252                      _sleep(delay)
253  
254          logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc)
255  
256      def _sync_back_once(self, lock_path: Path) -> None:
257          """Single sync-back attempt with SIGINT protection and file lock."""
258          # signal.signal() only works from the main thread. In gateway
259          # contexts cleanup() may run from a worker thread — skip SIGINT
260          # deferral there rather than crashing.
261          on_main_thread = threading.current_thread() is threading.main_thread()
262  
263          deferred_sigint: list[object] = []
264          original_handler = None
265          if on_main_thread:
266              original_handler = signal.getsignal(signal.SIGINT)
267  
268              def _defer_sigint(signum, frame):
269                  deferred_sigint.append((signum, frame))
270                  logger.debug("sync_back: SIGINT deferred until sync completes")
271  
272              signal.signal(signal.SIGINT, _defer_sigint)
273          try:
274              self._sync_back_locked(lock_path)
275          finally:
276              if on_main_thread and original_handler is not None:
277                  signal.signal(signal.SIGINT, original_handler)
278                  if deferred_sigint:
279                      os.kill(os.getpid(), signal.SIGINT)
280  
281      def _sync_back_locked(self, lock_path: Path) -> None:
282          """Sync-back under file lock (serializes concurrent gateways)."""
283          if fcntl is None:
284              # Windows: no flock — run without serialization
285              self._sync_back_impl()
286              return
287          lock_fd = open(lock_path, "w")
288          try:
289              fcntl.flock(lock_fd, fcntl.LOCK_EX)
290              self._sync_back_impl()
291          finally:
292              fcntl.flock(lock_fd, fcntl.LOCK_UN)
293              lock_fd.close()
294  
295      def _sync_back_impl(self) -> None:
296          """Download, diff, and apply remote changes to host."""
297          if self._bulk_download_fn is None:
298              raise RuntimeError("_sync_back_impl called without bulk_download_fn")
299  
300          # Cache file mapping once to avoid O(n*m) from repeated iteration
301          try:
302              file_mapping = list(self._get_files_fn())
303          except Exception:
304              file_mapping = []
305  
306          with tempfile.NamedTemporaryFile(suffix=".tar") as tf:
307              self._bulk_download_fn(Path(tf.name))
308  
309              # Defensive size cap: a misbehaving sandbox could produce an
310              # arbitrarily large tar. Refuse to extract if it exceeds the cap.
311              try:
312                  tar_size = os.path.getsize(tf.name)
313              except OSError:
314                  tar_size = 0
315              if tar_size > _SYNC_BACK_MAX_BYTES:
316                  logger.warning(
317                      "sync_back: remote tar is %d bytes (cap %d) — skipping extraction",
318                      tar_size, _SYNC_BACK_MAX_BYTES,
319                  )
320                  return
321  
322              with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging:
323                  with tarfile.open(tf.name) as tar:
324                      tar.extractall(staging, filter="data")
325  
326                  applied = 0
327                  for dirpath, _dirnames, filenames in os.walk(staging):
328                      for fname in filenames:
329                          staged_file = os.path.join(dirpath, fname)
330                          rel = os.path.relpath(staged_file, staging)
331                          remote_path = "/" + rel
332  
333                          pushed_hash = self._pushed_hashes.get(remote_path)
334  
335                          # Skip hashing for files unchanged from push
336                          if pushed_hash is not None:
337                              remote_hash = _sha256_file(staged_file)
338                              if remote_hash == pushed_hash:
339                                  continue
340                          else:
341                              remote_hash = None  # new remote file
342  
343                          # Resolve host path from cached mapping
344                          host_path = self._resolve_host_path(remote_path, file_mapping)
345                          if host_path is None:
346                              host_path = self._infer_host_path(remote_path, file_mapping)
347                              if host_path is None:
348                                  logger.debug(
349                                      "sync_back: skipping %s (no host mapping)",
350                                      remote_path,
351                                  )
352                                  continue
353  
354                          if os.path.exists(host_path) and pushed_hash is not None:
355                              host_hash = _sha256_file(host_path)
356                              if host_hash != pushed_hash:
357                                  logger.warning(
358                                      "sync_back: conflict on %s — host modified "
359                                      "since push, remote also changed. Applying "
360                                      "remote version (last-write-wins).",
361                                      remote_path,
362                                  )
363  
364                          os.makedirs(os.path.dirname(host_path), exist_ok=True)
365                          shutil.copy2(staged_file, host_path)
366                          applied += 1
367  
368                  if applied:
369                      logger.info("sync_back: applied %d changed file(s)", applied)
370                  else:
371                      logger.debug("sync_back: no remote changes detected")
372  
373      def _resolve_host_path(self, remote_path: str,
374                             file_mapping: list[tuple[str, str]] | None = None) -> str | None:
375          """Find the host path for a known remote path from the file mapping."""
376          mapping = file_mapping if file_mapping is not None else []
377          for host, remote in mapping:
378              if remote == remote_path:
379                  return host
380          return None
381  
382      def _infer_host_path(self, remote_path: str,
383                           file_mapping: list[tuple[str, str]] | None = None) -> str | None:
384          """Infer a host path for a new remote file by matching path prefixes.
385  
386          Uses the existing file mapping to find a remote->host directory
387          pair, then applies the same prefix substitution to the new file.
388          For example, if the mapping has ``/root/.hermes/skills/a.md`` →
389          ``~/.hermes/skills/a.md``, a new remote file at
390          ``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``.
391          """
392          mapping = file_mapping if file_mapping is not None else []
393          for host, remote in mapping:
394              remote_dir = str(Path(remote).parent)
395              if remote_path.startswith(remote_dir + "/"):
396                  host_dir = str(Path(host).parent)
397                  suffix = remote_path[len(remote_dir):]
398                  return host_dir + suffix
399          return None