main.py
1 # Copyright 2025 Alibaba Group Holding Ltd. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 import asyncio 16 import os 17 import textwrap 18 from datetime import timedelta 19 from pathlib import Path 20 21 from opensandbox import Sandbox 22 from opensandbox.config import ConnectionConfig 23 24 25 def _load_requirements() -> str: 26 requirements_path = Path(__file__).with_name("requirements.txt") 27 return requirements_path.read_text(encoding="utf-8") 28 29 30 def _training_script() -> str: 31 return textwrap.dedent( 32 """ 33 import json 34 import os 35 36 import gymnasium as gym 37 from stable_baselines3 import DQN 38 from stable_baselines3.common.evaluation import evaluate_policy 39 40 timesteps = int(os.getenv("RL_TIMESTEPS", "5000")) 41 tensorboard_log = os.getenv("RL_TENSORBOARD_LOG", "runs") 42 43 env = gym.make("CartPole-v1") 44 model = DQN( 45 "MlpPolicy", 46 env, 47 verbose=1, 48 tensorboard_log=tensorboard_log, 49 learning_rate=1e-3, 50 buffer_size=10000, 51 learning_starts=1000, 52 batch_size=32, 53 train_freq=4, 54 gradient_steps=1, 55 ) 56 57 model.learn(total_timesteps=timesteps) 58 59 os.makedirs("checkpoints", exist_ok=True) 60 checkpoint_path = "checkpoints/cartpole_dqn" 61 model.save(checkpoint_path) 62 63 mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=5) 64 summary = { 65 "timesteps": timesteps, 66 "mean_reward": float(mean_reward), 67 "std_reward": float(std_reward), 68 "checkpoint_path": f"{checkpoint_path}.zip", 69 } 70 with open("training_summary.json", "w", encoding="utf-8") as handle: 71 json.dump(summary, handle, indent=2) 72 73 print("Training summary:", summary) 74 env.close() 75 """ 76 ).lstrip() 77 78 79 async def _print_execution_logs(execution) -> None: 80 for msg in execution.logs.stdout: 81 print(f"[stdout] {msg.text}") 82 for msg in execution.logs.stderr: 83 print(f"[stderr] {msg.text}") 84 if execution.error: 85 print(f"[error] {execution.error.name}: {execution.error.value}") 86 87 88 def _execution_failed(execution) -> bool: 89 return execution.error is not None 90 91 92 async def _run_command(sandbox: Sandbox, command: str) -> bool: 93 execution = await sandbox.commands.run(command) 94 await _print_execution_logs(execution) 95 return not _execution_failed(execution) 96 97 98 def _with_python_env(command: str) -> str: 99 return ( 100 "bash -lc '" 101 "source /opt/opensandbox/code-interpreter-env.sh " 102 "python ${PYTHON_VERSION:-3.14} >/dev/null " 103 "&& " 104 f"{command}" 105 "'" 106 ) 107 108 109 async def _ensure_pip(sandbox: Sandbox) -> bool: 110 bootstrap_commands = [ 111 _with_python_env("python3 -m pip --version"), 112 _with_python_env("python3 -m ensurepip --upgrade"), 113 "apt-get update && apt-get install -y python3-pip", 114 "apk add --no-cache py3-pip", 115 ] 116 for command in bootstrap_commands: 117 if await _run_command(sandbox, command): 118 return True 119 return False 120 121 122 async def _install_requirements(sandbox: Sandbox) -> bool: 123 install_commands = [ 124 _with_python_env( 125 "python3 -m pip install --no-cache-dir --break-system-packages -r requirements.txt" 126 ), 127 "pip3 install --no-cache-dir -r requirements.txt", 128 "pip install --no-cache-dir -r requirements.txt", 129 ] 130 for command in install_commands: 131 if await _run_command(sandbox, command): 132 return True 133 return False 134 135 136 async def main() -> None: 137 domain = os.getenv("SANDBOX_DOMAIN", "localhost:8080") 138 api_key = os.getenv("SANDBOX_API_KEY") 139 image = os.getenv("SANDBOX_IMAGE", "sandbox-registry.cn-zhangjiakou.cr.aliyuncs.com/opensandbox/code-interpreter:v1.0.2") 140 timesteps = os.getenv("RL_TIMESTEPS", "5000") 141 142 config = ConnectionConfig( 143 domain=domain, 144 api_key=api_key, 145 request_timeout=timedelta(minutes=10), 146 ) 147 148 sandbox = await Sandbox.create( 149 image, 150 connection_config=config, 151 env={"RL_TIMESTEPS": timesteps}, 152 ) 153 154 async with sandbox: 155 try: 156 await sandbox.files.write_file("requirements.txt", _load_requirements()) 157 if not await _ensure_pip(sandbox): 158 print("Failed to bootstrap pip inside the sandbox.") 159 return 160 161 if not await _install_requirements(sandbox): 162 print("Failed to install RL dependencies inside the sandbox.") 163 return 164 165 await sandbox.files.write_file("train.py", _training_script()) 166 train_exec = await sandbox.commands.run(_with_python_env("python3 train.py")) 167 await _print_execution_logs(train_exec) 168 if _execution_failed(train_exec): 169 print("Training failed inside the sandbox.") 170 return 171 172 try: 173 summary = await sandbox.files.read_file("training_summary.json") 174 except Exception as exc: 175 print(f"\nFailed to read training summary: {exc}") 176 else: 177 print("\n=== Training summary ===") 178 print(summary) 179 finally: 180 await sandbox.kill() 181 182 183 if __name__ == "__main__": 184 asyncio.run(main())