/ environment.py
environment.py
1 2 from openenv.core.env_server import Environment 3 from models import EngAction, EngObservation, EngState 4 import random 5 from grader import grader 6 from reward import calculate_reward 7 import uuid 8 from task_loader import TASKS, TESTS 9 10 class EngEnv(Environment): 11 def __init__(self): 12 self._state = EngState() 13 self.last_grader_result = None 14 self.attempts_remaining = 10 15 self.task_id = None 16 self.task_name = None 17 self.task_time_base = None 18 self.task_memory_base = None 19 self.difficulty = "EASY" 20 self.task_description = "" 21 self.task_domain = "" 22 self.current_task = None 23 24 def reset(self, task_id: str = None, seed: int = None, episode_id: str = None, **kwargs) -> EngObservation: 25 levels = ['EASY', 'MEDIUM', 'HARD'] 26 27 if task_id: 28 task = None 29 level = None 30 for lvl in levels: 31 for t in TASKS[lvl]: 32 if t['id'] == task_id: 33 task = t 34 level = lvl 35 break 36 if task: 37 break 38 if task is None: 39 raise ValueError(f"Task '{task_id}' not found") 40 else: 41 level = random.choice(levels) 42 task = random.choice(TASKS[level]) 43 44 self.task_id = task['id'] 45 self.task_name = task['name'] 46 self.task_time_base = task['baseline_time_ms'] 47 self.task_memory_base = task['baseline_mem_mib'] 48 self.difficulty = level 49 self.task_description = task["description"] 50 self.task_domain = task["domain"] 51 self.attempts_remaining = 10 52 self.last_grader_result = None 53 54 self._state = EngState( 55 episode_id=str(uuid.uuid4()), 56 step_count=0, 57 total_steps=self.attempts_remaining, 58 difficulty=level, 59 task_domain=task["domain"], 60 ) 61 62 return EngObservation( 63 domain=task["domain"], 64 difficulty=level, 65 task=self.task_description, 66 code=task["broken_code"], 67 done=False, 68 reward=0.5, 69 output=None, 70 tests_passed=None, 71 num_tests=len(TESTS[task["id"]]["cases"]), 72 time_taken=None, 73 mem_taken=None, 74 message=f"Fix the {level} {task['domain']} task. {self.attempts_remaining} attempts remaining.", 75 num_steps_remain=self.attempts_remaining 76 ) 77 78 79 80 def step(self, action: EngAction) -> EngObservation: 81 self._state.step_count += 1 82 self.attempts_remaining -= 1 83 84 code = action.sol.strip() 85 g_result = grader(code, self.task_id) 86 self.last_grader_result = g_result 87 reward = calculate_reward(g_result, self.difficulty) 88 89 # Done logic 90 out_of_attempts = self.attempts_remaining <= 0 91 logic_passed = ( 92 g_result['status'] == 'success' or 93 g_result['tests_passed'] == g_result['total_tests'] 94 ) 95 if self.difficulty == 'HARD': 96 speed_ok = (g_result['efficiency'].get('speed_ratio') or 0) >= 1.0 97 memory_ok = (g_result['efficiency'].get('memory_ratio') or 0) >= 1.0 98 efficiency_met = speed_ok or memory_ok 99 else: 100 efficiency_met = True 101 102 done = out_of_attempts or (logic_passed and efficiency_met) 103 104 return EngObservation( 105 domain=self.task_domain, 106 difficulty=self.difficulty, 107 task=self.task_description, 108 code=code, 109 done=done, 110 reward=reward, 111 output=g_result['error_message'], 112 tests_passed=g_result['tests_passed'], 113 num_tests=g_result['total_tests'], 114 time_taken=g_result['efficiency'].get('runtime_ms'), 115 mem_taken=g_result['efficiency'].get('memory_mib'), 116 message=self._build_message(g_result, reward, done), 117 num_steps_remain=self.attempts_remaining 118 ) 119 120 @property 121 def state(self) -> EngState: 122 return self._state 123 124 def list_tasks(self) -> list: 125 all_tasks = [] 126 for level in ['EASY', 'MEDIUM', 'HARD']: 127 for task in TASKS[level]: 128 all_tasks.append({ 129 "id": task["id"], 130 "domain": task["domain"], 131 "difficulty": level, 132 "name": task["name"], 133 "description": task["description"] 134 }) 135 return all_tasks 136 137 def get_grader_score(self) -> dict: 138 if not self.last_grader_result: 139 return {"grader_score": 0.01, "status": "no_attempt"} 140 return { 141 "task_id": self.task_id, 142 "grader_score": self.last_grader_result.get("grader_score", 0.01), 143 "status": self.last_grader_result.get("status"), 144 "tests_passed": self.last_grader_result.get("tests_passed", 0), 145 "total_tests": self.last_grader_result.get("total_tests", 0), 146 } 147 148 def _build_message(self, g_result: dict, reward: float, done: bool) -> str: 149 status = g_result['status'] 150 passed = g_result['tests_passed'] 151 total = g_result['total_tests'] 152 eff = g_result['efficiency'] 153 154 if status == 'syntax_error': 155 return f"Syntax error: {g_result['error_message']}" 156 if status == 'runtime_error': 157 return f"Runtime error: {g_result['error_message']}" 158 if status == 'logic_error': 159 return f"Code runs but fails all tests (0/{total}). Check your logic.Error: {g_result['error_message']}. {self.attempts_remaining} attempts left." 160 if status == 'partial': 161 error = g_result.get('error_message', '') 162 return f"Partial: {passed}/{total} tests passed. Reward: {reward}. Error: {error}. {self.attempts_remaining} attempts left." 163 if status == 'success' and not done: 164 sr = eff.get('speed_ratio') 165 mr = eff.get('memory_ratio') 166 if self.difficulty == 'HARD': 167 return f"All tests pass! Optimize further. Speed: {sr:.2f}x baseline. Memory: {mr:.2f}x baseline." 168 return f"All tests passed! Reward: {reward}" 169 if done and status == 'success': 170 sr = eff.get('speed_ratio') or 0.01 171 mr = eff.get('memory_ratio') or 0.01 172 return f"Episode complete! Reward: {reward}. Speed: {sr:.2f}x. Memory: {mr:.2f}x." 173 if done and status != 'success': 174 return f"Out of attempts. Best: {passed}/{total} tests." 175 return f"Reward: {reward}. {self.attempts_remaining} attempts remaining."