/ environment.py
environment.py
  1  
  2  from openenv.core.env_server import Environment
  3  from models import EngAction, EngObservation, EngState
  4  import random
  5  from grader import grader
  6  from reward import calculate_reward
  7  import uuid
  8  from task_loader import TASKS, TESTS
  9  
 10  class EngEnv(Environment):
 11      def __init__(self):
 12          self._state = EngState()
 13          self.last_grader_result = None
 14          self.attempts_remaining = 10
 15          self.task_id = None
 16          self.task_name = None
 17          self.task_time_base = None
 18          self.task_memory_base = None
 19          self.difficulty = "EASY"
 20          self.task_description = ""
 21          self.task_domain = ""
 22          self.current_task = None
 23  
 24      def reset(self, task_id: str = None, seed: int = None, episode_id: str = None, **kwargs) -> EngObservation:
 25          levels = ['EASY', 'MEDIUM', 'HARD']
 26  
 27          if task_id:
 28              task = None
 29              level = None
 30              for lvl in levels:
 31                  for t in TASKS[lvl]:
 32                      if t['id'] == task_id:
 33                          task = t
 34                          level = lvl
 35                          break
 36                  if task:
 37                      break
 38              if task is None:
 39                  raise ValueError(f"Task '{task_id}' not found")
 40          else:
 41              level = random.choice(levels)
 42              task = random.choice(TASKS[level])
 43  
 44          self.task_id = task['id']
 45          self.task_name = task['name']
 46          self.task_time_base = task['baseline_time_ms']
 47          self.task_memory_base = task['baseline_mem_mib']
 48          self.difficulty = level
 49          self.task_description = task["description"]
 50          self.task_domain = task["domain"]
 51          self.attempts_remaining = 10
 52          self.last_grader_result = None
 53  
 54          self._state = EngState(
 55              episode_id=str(uuid.uuid4()),
 56              step_count=0,
 57              total_steps=self.attempts_remaining,
 58              difficulty=level,
 59              task_domain=task["domain"],
 60          )
 61  
 62          return EngObservation(
 63              domain=task["domain"],
 64              difficulty=level,
 65              task=self.task_description,
 66              code=task["broken_code"],
 67              done=False,
 68              reward=0.5,
 69              output=None,
 70              tests_passed=None,
 71              num_tests=len(TESTS[task["id"]]["cases"]),
 72              time_taken=None,
 73              mem_taken=None,
 74              message=f"Fix the {level} {task['domain']} task. {self.attempts_remaining} attempts remaining.",
 75              num_steps_remain=self.attempts_remaining
 76          )
 77      
 78      
 79  
 80      def step(self, action: EngAction) -> EngObservation:
 81          self._state.step_count += 1
 82          self.attempts_remaining -= 1
 83  
 84          code = action.sol.strip()
 85          g_result = grader(code, self.task_id)
 86          self.last_grader_result = g_result
 87          reward = calculate_reward(g_result, self.difficulty)
 88  
 89          # Done logic
 90          out_of_attempts = self.attempts_remaining <= 0
 91          logic_passed = (
 92              g_result['status'] == 'success' or
 93              g_result['tests_passed'] == g_result['total_tests']
 94          )
 95          if self.difficulty == 'HARD':
 96              speed_ok = (g_result['efficiency'].get('speed_ratio') or 0) >= 1.0
 97              memory_ok = (g_result['efficiency'].get('memory_ratio') or 0) >= 1.0
 98              efficiency_met = speed_ok or memory_ok
 99          else:
100              efficiency_met = True
101  
102          done = out_of_attempts or (logic_passed and efficiency_met)
103  
104          return EngObservation(
105              domain=self.task_domain,
106              difficulty=self.difficulty,
107              task=self.task_description,
108              code=code,
109              done=done,
110              reward=reward,
111              output=g_result['error_message'],
112              tests_passed=g_result['tests_passed'],
113              num_tests=g_result['total_tests'],
114              time_taken=g_result['efficiency'].get('runtime_ms'),
115              mem_taken=g_result['efficiency'].get('memory_mib'),
116              message=self._build_message(g_result, reward, done),
117              num_steps_remain=self.attempts_remaining
118          )
119  
120      @property
121      def state(self) -> EngState:
122          return self._state
123  
124      def list_tasks(self) -> list:
125          all_tasks = []
126          for level in ['EASY', 'MEDIUM', 'HARD']:
127              for task in TASKS[level]:
128                  all_tasks.append({
129                      "id": task["id"],
130                      "domain": task["domain"],
131                      "difficulty": level,
132                      "name": task["name"],
133                      "description": task["description"]
134                  })
135          return all_tasks
136  
137      def get_grader_score(self) -> dict:
138          if not self.last_grader_result:
139              return {"grader_score": 0.01, "status": "no_attempt"}
140          return {
141              "task_id": self.task_id,
142              "grader_score": self.last_grader_result.get("grader_score", 0.01),
143              "status": self.last_grader_result.get("status"),
144              "tests_passed": self.last_grader_result.get("tests_passed", 0),
145              "total_tests": self.last_grader_result.get("total_tests", 0),
146          }
147  
148      def _build_message(self, g_result: dict, reward: float, done: bool) -> str:
149          status = g_result['status']
150          passed = g_result['tests_passed']
151          total = g_result['total_tests']
152          eff = g_result['efficiency']
153  
154          if status == 'syntax_error':
155              return f"Syntax error: {g_result['error_message']}"
156          if status == 'runtime_error':
157              return f"Runtime error: {g_result['error_message']}"
158          if status == 'logic_error':
159              return f"Code runs but fails all tests (0/{total}). Check your logic.Error: {g_result['error_message']}. {self.attempts_remaining} attempts left."
160          if status == 'partial':
161              error = g_result.get('error_message', '')
162              return f"Partial: {passed}/{total} tests passed. Reward: {reward}. Error: {error}. {self.attempts_remaining} attempts left."
163          if status == 'success' and not done:
164              sr = eff.get('speed_ratio')
165              mr = eff.get('memory_ratio')
166              if self.difficulty == 'HARD':
167                  return f"All tests pass! Optimize further. Speed: {sr:.2f}x baseline. Memory: {mr:.2f}x baseline."
168              return f"All tests passed! Reward: {reward}"
169          if done and status == 'success':
170              sr = eff.get('speed_ratio') or 0.01
171              mr = eff.get('memory_ratio') or 0.01
172              return f"Episode complete! Reward: {reward}. Speed: {sr:.2f}x. Memory: {mr:.2f}x."
173          if done and status != 'success':
174              return f"Out of attempts. Best: {passed}/{total} tests."
175          return f"Reward: {reward}. {self.attempts_remaining} attempts remaining."