/ restai / docker_manager.py
docker_manager.py
  1  """Docker container manager for sandboxed command execution.
  2  
  3  Manages per-chat containers that persist across tool calls within a
  4  conversation. Idle container cleanup is handled by an external cron script
  5  (crons/docker_cleanup.py), not by an internal thread.
  6  """
  7  from __future__ import annotations
  8  
  9  import logging
 10  import time
 11  import threading
 12  from dataclasses import dataclass, field
 13  
 14  from restai import config
 15  
 16  logger = logging.getLogger(__name__)
 17  
 18  
 19  @dataclass
 20  class ContainerInfo:
 21      container_id: str
 22      chat_id: str
 23      last_activity: float = field(default_factory=time.time)
 24  
 25  
 26  class DockerManager:
 27      """Manages Docker containers keyed by chat_id for sandboxed command execution."""
 28  
 29      def __init__(self, docker_url: str, docker_image: str = "python:3.12-slim",
 30                   container_timeout: int = 900, network_mode: str = "none",
 31                   read_only: bool = True):
 32          import docker as docker_sdk
 33          self._client = docker_sdk.DockerClient(base_url=docker_url)
 34          self._image = docker_image
 35          self._timeout = container_timeout
 36          self._network_mode = network_mode
 37          self._read_only = read_only
 38          self._containers: dict[str, ContainerInfo] = {}
 39          self._lock = threading.Lock()
 40  
 41          # Verify connectivity
 42          try:
 43              self._client.ping()
 44              logger.info("Docker manager connected to %s", docker_url)
 45          except Exception as e:
 46              logger.error("Docker manager failed to connect to %s: %s", docker_url, e)
 47              raise
 48  
 49      def exec_command(self, chat_id: str, command: str) -> str:
 50          """Execute a command in the container for this chat_id.
 51          Creates a new container if one doesn't exist."""
 52          if not chat_id:
 53              chat_id = "ephemeral"
 54  
 55          container = self._get_or_create_container(chat_id)
 56  
 57          # Update last activity (label on the container for the cron script)
 58          with self._lock:
 59              info = self._containers.get(chat_id)
 60              if info:
 61                  info.last_activity = time.time()
 62  
 63          try:
 64              exec_result = container.exec_run(
 65                  ["sh", "-c", command],
 66                  demux=True,
 67                  workdir="/home/user",
 68              )
 69              stdout = (exec_result.output[0] or b"").decode("utf-8", errors="replace")
 70              stderr = (exec_result.output[1] or b"").decode("utf-8", errors="replace")
 71              output = stdout + stderr
 72  
 73              # Truncate very large outputs
 74              if len(output) > 50000:
 75                  output = output[:50000] + "\n... (output truncated)"
 76  
 77              return output if output else "(no output)"
 78          except Exception as e:
 79              logger.exception("Docker exec failed for chat_id=%s: %s", chat_id, e)
 80              # Container may have died — remove it so next call creates a fresh one
 81              self._remove_container(chat_id)
 82              return f"ERROR: Command execution failed: {e}"
 83  
 84      def run_script(self, chat_id: str, script: str, stdin_data: str = "") -> str:
 85          """Execute a Python script in the container by piping code via python3 -c.
 86  
 87          No file writes needed — avoids read-only filesystem issues.
 88          Returns stdout. Stderr is appended if non-empty.
 89          """
 90          if not chat_id:
 91              chat_id = "ephemeral"
 92  
 93          container = self._get_or_create_container(chat_id)
 94  
 95          with self._lock:
 96              info = self._containers.get(chat_id)
 97              if info:
 98                  info.last_activity = time.time()
 99  
100          try:
101              import base64
102              # Encode script as base64 to avoid shell quoting issues
103              b64_script = base64.b64encode(script.encode("utf-8")).decode("ascii")
104              b64_stdin = base64.b64encode(stdin_data.encode("utf-8")).decode("ascii") if stdin_data else ""
105  
106              if b64_stdin:
107                  cmd = f'echo "{b64_stdin}" | base64 -d | python3 -c "$(echo {b64_script} | base64 -d)"'
108              else:
109                  cmd = f'python3 -c "$(echo {b64_script} | base64 -d)"'
110  
111              exec_result = container.exec_run(
112                  ["sh", "-c", cmd],
113                  demux=True,
114                  workdir="/home/user",
115              )
116              stdout = (exec_result.output[0] or b"").decode("utf-8", errors="replace")
117              stderr = (exec_result.output[1] or b"").decode("utf-8", errors="replace")
118  
119              if stderr.strip():
120                  return stdout + "\nSTDERR: " + stderr if stdout else "ERROR: " + stderr
121              return stdout.strip() if stdout.strip() else "(no output)"
122          except Exception as e:
123              logger.exception("Docker run_script failed for chat_id=%s: %s", chat_id, e)
124              self._remove_container(chat_id)
125              return f"ERROR: Script execution failed: {e}"
126  
127      def put_files(self, chat_id: str, files: list[tuple[str, bytes]],
128                    extract_to: str = "/home/user", subdir: str = "uploads") -> list[dict]:
129          """Copy a batch of files into the container for this chat_id.
130  
131          ``files`` is a list of ``(filename, raw_bytes)`` tuples. We build a
132          single tarball that contains a ``{subdir}/`` directory with every
133          file inside it, and extract it to ``{extract_to}``. ``{extract_to}``
134          must exist (it's a tmpfs mount, ``/home/user``, which is always
135          present). ``{subdir}`` is created by tar extraction — no shell call
136          needed, so this works on read-only root filesystems.
137  
138          Returns a manifest suitable for embedding into the LLM prompt:
139          ``[{name, path, size}, ...]``.
140          """
141          if not files:
142              return []
143          if not chat_id:
144              chat_id = "ephemeral"
145  
146          container = self._get_or_create_container(chat_id)
147  
148          import io
149          import os
150          import tarfile
151          import time as _time
152  
153          # Always use the chunked-exec path: Docker's put_archive API is
154          # unreliable for tmpfs-mounted targets (silent extraction into the
155          # underlying rootfs layer that the tmpfs mount shadows, or 404s on
156          # runtime-created subdirs). Writing via `sh -c` inside the container
157          # always sees the live mount namespace, so files end up visible to
158          # every subsequent `exec_run`.
159          target_dir = f"{extract_to}/{subdir}"
160          buf = io.BytesIO()
161          manifest: list[dict] = []
162          now = int(_time.time())
163  
164          tar = tarfile.open(fileobj=buf, mode="w", format=tarfile.USTAR_FORMAT)
165          try:
166              seen: set[str] = set()
167              for name, data in files:
168                  safe = os.path.basename(name).replace("\x00", "") or "file"
169                  base = safe
170                  counter = 1
171                  while safe in seen:
172                      stem, _, ext = base.rpartition(".")
173                      safe = f"{stem}_{counter}.{ext}" if stem else f"{base}_{counter}"
174                      counter += 1
175                  seen.add(safe)
176  
177                  info = tarfile.TarInfo(name=safe)
178                  info.size = len(data)
179                  info.mtime = now
180                  info.mode = 0o644
181                  tar.addfile(info, io.BytesIO(data))
182                  manifest.append({
183                      "name": safe,
184                      "path": f"{target_dir}/{safe}",
185                      "size": len(data),
186                  })
187          finally:
188              tar.close()
189  
190          tar_bytes = buf.getvalue()
191  
192          # Stream the tar into the container in small base64 chunks appended
193          # to a staging file, then extract. Chunk size has to stay under the
194          # per-argv limit (Linux MAX_ARG_STRLEN = 128 KB) because the base64
195          # blob rides as a single `sh -c` argument. 64 KB raw → ~87 KB base64
196          # → safely under the cap.
197          import base64 as _b64
198          CHUNK = 64 * 1024
199          tmp_path = f"{extract_to}/_restai_upload.tar"
200  
201          try:
202              res = container.exec_run(["sh", "-c", f"mkdir -p {target_dir} && : > {tmp_path}"])
203              if res.exit_code != 0:
204                  raise RuntimeError(f"tar staging failed (exit {res.exit_code})")
205  
206              for offset in range(0, len(tar_bytes), CHUNK):
207                  chunk = tar_bytes[offset:offset + CHUNK]
208                  chunk_b64 = _b64.b64encode(chunk).decode("ascii")
209                  cmd = f"printf '%s' {chunk_b64} | base64 -d >> {tmp_path}"
210                  res = container.exec_run(["sh", "-c", cmd])
211                  if res.exit_code != 0:
212                      err_out = (res.output or b"").decode("utf-8", errors="replace")
213                      raise RuntimeError(
214                          f"tar chunk write failed (exit {res.exit_code}): {err_out.strip()}"
215                      )
216  
217              cmd = f"tar xf {tmp_path} -C {target_dir} && rm -f {tmp_path}"
218              res = container.exec_run(["sh", "-c", cmd])
219              if res.exit_code != 0:
220                  err_out = (res.output or b"").decode("utf-8", errors="replace")
221                  raise RuntimeError(f"tar extract failed (exit {res.exit_code}): {err_out.strip()}")
222          except Exception as e:
223              logger.exception("Failed to put files into container for chat_id=%s: %s", chat_id, e)
224              raise RuntimeError(f"Failed to upload files to sandbox: {e}")
225  
226          # Sanity check: stat each file we claim to have uploaded. Cheap, and
227          # catches any surprise where the tar silently extracted to nowhere.
228          expected_paths = " ".join(f"'{entry['path']}'" for entry in manifest)
229          check = container.exec_run(
230              ["sh", "-c", f"for p in {expected_paths}; do [ -f \"$p\" ] || {{ echo MISSING:$p; exit 1; }}; done"]
231          )
232          if check.exit_code != 0:
233              missing = (check.output or b"").decode("utf-8", errors="replace").strip()
234              logger.error(
235                  "Upload verification failed for chat_id=%s: %s (tar=%d bytes, target=%s)",
236                  chat_id, missing, len(tar_bytes), target_dir,
237              )
238              raise RuntimeError(f"Files not present after upload: {missing}")
239  
240          with self._lock:
241              info = self._containers.get(chat_id)
242              if info:
243                  info.last_activity = time.time()
244  
245          logger.info("Uploaded %d file(s) to chat_id=%s at %s",
246                      len(manifest), chat_id, target_dir)
247          return manifest
248  
249      def _get_or_create_container(self, chat_id: str):
250          """Return existing container or create a new one."""
251          import docker as docker_sdk
252  
253          with self._lock:
254              info = self._containers.get(chat_id)
255              if info:
256                  try:
257                      container = self._client.containers.get(info.container_id)
258                      if container.status == "running":
259                          return container
260                  except docker_sdk.errors.NotFound:
261                      pass
262                  del self._containers[chat_id]
263  
264          # Also check for orphaned containers from a previous process
265          try:
266              existing = self._client.containers.list(
267                  filters={"label": [f"restai.chat_id={chat_id}"]},
268                  limit=1,
269              )
270              if existing and existing[0].status == "running":
271                  container = existing[0]
272                  with self._lock:
273                      self._containers[chat_id] = ContainerInfo(
274                          container_id=container.id,
275                          chat_id=chat_id,
276                      )
277                  return container
278          except Exception:
279              pass
280  
281          # Create new container
282          try:
283              container = self._client.containers.run(
284                  self._image,
285                  command="tail -f /dev/null",
286                  detach=True,
287                  labels={
288                      "restai.managed": "true",
289                      "restai.chat_id": chat_id,
290                      "restai.created_at": str(int(time.time())),
291                  },
292                  mem_limit="512m",
293                  cpu_period=100000,
294                  cpu_quota=50000,
295                  network_mode=self._network_mode,
296                  # Roomy tmpfs so the LLM can `pip install` modest packages
297                  # (pandas wheel ~60MB + build/temp space) and drop result
298                  # files without hitting ENOSPC.
299                  tmpfs={"/tmp": "size=1G", "/home/user": "size=1G"},
300                  # Rootfs read-only by default for sandbox hardening. Toggled
301                  # via the `docker_read_only` admin setting — admins can flip
302                  # to false when they need `pip install` inside the sandbox.
303                  read_only=self._read_only,
304                  remove=True,
305              )
306              with self._lock:
307                  self._containers[chat_id] = ContainerInfo(
308                      container_id=container.id,
309                      chat_id=chat_id,
310                  )
311              logger.info("Created container %s for chat_id=%s", container.short_id, chat_id)
312              return container
313          except Exception as e:
314              logger.exception("Failed to create container for chat_id=%s: %s", chat_id, e)
315              raise RuntimeError(f"Failed to create sandbox container: {e}")
316  
317      def _remove_container(self, chat_id: str):
318          """Stop and remove a container by chat_id."""
319          with self._lock:
320              info = self._containers.pop(chat_id, None)
321          if not info:
322              return
323          try:
324              container = self._client.containers.get(info.container_id)
325              container.stop(timeout=5)
326              logger.info("Removed container %s for chat_id=%s", info.container_id[:12], chat_id)
327          except Exception:
328              pass
329  
330      def shutdown(self):
331          """Remove all managed containers. Called on app shutdown."""
332          with self._lock:
333              chat_ids = list(self._containers.keys())
334  
335          for chat_id in chat_ids:
336              self._remove_container(chat_id)
337  
338          logger.info("Docker manager shut down, removed %d containers", len(chat_ids))
339  
340      @property
341      def active_container_count(self) -> int:
342          with self._lock:
343              return len(self._containers)