/ examples / rl-training / main.py
main.py
  1  # Copyright 2025 Alibaba Group Holding Ltd.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  
 15  import asyncio
 16  import os
 17  import textwrap
 18  from datetime import timedelta
 19  from pathlib import Path
 20  
 21  from opensandbox import Sandbox
 22  from opensandbox.config import ConnectionConfig
 23  
 24  
 25  def _load_requirements() -> str:
 26      requirements_path = Path(__file__).with_name("requirements.txt")
 27      return requirements_path.read_text(encoding="utf-8")
 28  
 29  
 30  def _training_script() -> str:
 31      return textwrap.dedent(
 32          """
 33          import json
 34          import os
 35  
 36          import gymnasium as gym
 37          from stable_baselines3 import DQN
 38          from stable_baselines3.common.evaluation import evaluate_policy
 39  
 40          timesteps = int(os.getenv("RL_TIMESTEPS", "5000"))
 41          tensorboard_log = os.getenv("RL_TENSORBOARD_LOG", "runs")
 42  
 43          env = gym.make("CartPole-v1")
 44          model = DQN(
 45              "MlpPolicy",
 46              env,
 47              verbose=1,
 48              tensorboard_log=tensorboard_log,
 49              learning_rate=1e-3,
 50              buffer_size=10000,
 51              learning_starts=1000,
 52              batch_size=32,
 53              train_freq=4,
 54              gradient_steps=1,
 55          )
 56  
 57          model.learn(total_timesteps=timesteps)
 58  
 59          os.makedirs("checkpoints", exist_ok=True)
 60          checkpoint_path = "checkpoints/cartpole_dqn"
 61          model.save(checkpoint_path)
 62  
 63          mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=5)
 64          summary = {
 65              "timesteps": timesteps,
 66              "mean_reward": float(mean_reward),
 67              "std_reward": float(std_reward),
 68              "checkpoint_path": f"{checkpoint_path}.zip",
 69          }
 70          with open("training_summary.json", "w", encoding="utf-8") as handle:
 71              json.dump(summary, handle, indent=2)
 72  
 73          print("Training summary:", summary)
 74          env.close()
 75          """
 76      ).lstrip()
 77  
 78  
 79  async def _print_execution_logs(execution) -> None:
 80      for msg in execution.logs.stdout:
 81          print(f"[stdout] {msg.text}")
 82      for msg in execution.logs.stderr:
 83          print(f"[stderr] {msg.text}")
 84      if execution.error:
 85          print(f"[error] {execution.error.name}: {execution.error.value}")
 86  
 87  
 88  def _execution_failed(execution) -> bool:
 89      return execution.error is not None
 90  
 91  
 92  async def _run_command(sandbox: Sandbox, command: str) -> bool:
 93      execution = await sandbox.commands.run(command)
 94      await _print_execution_logs(execution)
 95      return not _execution_failed(execution)
 96  
 97  
 98  def _with_python_env(command: str) -> str:
 99      return (
100          "bash -lc '"
101          "source /opt/opensandbox/code-interpreter-env.sh "
102          "python ${PYTHON_VERSION:-3.14} >/dev/null "
103          "&& "
104          f"{command}"
105          "'"
106      )
107  
108  
109  async def _ensure_pip(sandbox: Sandbox) -> bool:
110      bootstrap_commands = [
111          _with_python_env("python3 -m pip --version"),
112          _with_python_env("python3 -m ensurepip --upgrade"),
113          "apt-get update && apt-get install -y python3-pip",
114          "apk add --no-cache py3-pip",
115      ]
116      for command in bootstrap_commands:
117          if await _run_command(sandbox, command):
118              return True
119      return False
120  
121  
122  async def _install_requirements(sandbox: Sandbox) -> bool:
123      install_commands = [
124          _with_python_env(
125              "python3 -m pip install --no-cache-dir --break-system-packages -r requirements.txt"
126          ),
127          "pip3 install --no-cache-dir -r requirements.txt",
128          "pip install --no-cache-dir -r requirements.txt",
129      ]
130      for command in install_commands:
131          if await _run_command(sandbox, command):
132              return True
133      return False
134  
135  
136  async def main() -> None:
137      domain = os.getenv("SANDBOX_DOMAIN", "localhost:8080")
138      api_key = os.getenv("SANDBOX_API_KEY")
139      image = os.getenv("SANDBOX_IMAGE", "sandbox-registry.cn-zhangjiakou.cr.aliyuncs.com/opensandbox/code-interpreter:v1.0.2")
140      timesteps = os.getenv("RL_TIMESTEPS", "5000")
141  
142      config = ConnectionConfig(
143          domain=domain,
144          api_key=api_key,
145          request_timeout=timedelta(minutes=10),
146      )
147  
148      sandbox = await Sandbox.create(
149          image,
150          connection_config=config,
151          env={"RL_TIMESTEPS": timesteps},
152      )
153  
154      async with sandbox:
155          try:
156              await sandbox.files.write_file("requirements.txt", _load_requirements())
157              if not await _ensure_pip(sandbox):
158                  print("Failed to bootstrap pip inside the sandbox.")
159                  return
160  
161              if not await _install_requirements(sandbox):
162                  print("Failed to install RL dependencies inside the sandbox.")
163                  return
164  
165              await sandbox.files.write_file("train.py", _training_script())
166              train_exec = await sandbox.commands.run(_with_python_env("python3 train.py"))
167              await _print_execution_logs(train_exec)
168              if _execution_failed(train_exec):
169                  print("Training failed inside the sandbox.")
170                  return
171  
172              try:
173                  summary = await sandbox.files.read_file("training_summary.json")
174              except Exception as exc:
175                  print(f"\nFailed to read training summary: {exc}")
176              else:
177                  print("\n=== Training summary ===")
178                  print(summary)
179          finally:
180              await sandbox.kill()
181  
182  
183  if __name__ == "__main__":
184      asyncio.run(main())