/ grader.py
grader.py
  1  from task_loader import TESTS
  2  import math
  3  import ast
  4  from sandbox import sandbox
  5  
  6  
  7  def grader(code: str, task_id: str) -> dict:
  8      grader_score = 0.02
  9      task_tests = TESTS[task_id]
 10      baseline_time = task_tests["baseline_time"]
 11      baseline_mem = task_tests["baseline_mem"]
 12      total_tests = len(task_tests["cases"])
 13  
 14      # syntax check
 15      syntax_ok, syntax_error = check_syntax(code)
 16      if not syntax_ok:
 17          return {
 18              "status": "syntax_error",
 19              "grader_score": 0.02,
 20              "tests_passed": 0,
 21              "total_tests": total_tests,
 22              "efficiency": {
 23                  "runtime_ms": None,
 24                  "baseline_ms": baseline_time,
 25                  "memory_mib": None,
 26                  "baseline_mib": baseline_mem,
 27                  "speed_ratio": None,   # agent_time / baseline
 28                  "memory_ratio": None,  # agent_mem / baseline
 29              },
 30              "error_message": syntax_error
 31          }
 32  
 33      # sandbox execution and  logic check
 34      exec_ok, output, time_ms, mem_mib, tests_passed = sandbox(code, task_id)
 35  
 36      if not exec_ok:
 37          status = "runtime_error"
 38          error_message = output
 39      elif tests_passed == total_tests:
 40          status = "success"
 41          error_message = None
 42      elif tests_passed > 0:
 43          status = "partial"        
 44          error_message = output
 45      else:
 46          status = "logic_error"
 47          error_message = output
 48  
 49      # Speed and memory ratios — key for reward tradeoff
 50      speed_ratio = (baseline_time / time_ms) if time_ms and time_ms > 0 else None
 51      memory_ratio = (baseline_mem / mem_mib) if mem_mib and mem_mib > 0 else None
 52      # ratio > 1.0 = better than baseline
 53      # ratio < 1.0 = worse than baseline
 54  
 55      if status == "syntax_error" or status == "runtime_error":
 56          grader_score = 0.02
 57      elif status == "logic_error":
 58          grader_score = 0.02
 59      elif status == "partial":
 60          raw = 0.6 * (tests_passed / total_tests)
 61          grader_score = round(max(0.01, min(raw, 0.99)), 3)
 62      elif status == "success":
 63          base = 0.6
 64          if speed_ratio and speed_ratio >= 1.0:
 65              base += min(0.25, 0.25 * (speed_ratio - 1.0))
 66          if memory_ratio and memory_ratio >= 1.0:
 67              base += min(0.15, 0.15 * (memory_ratio - 1.0))
 68          grader_score = round(min(base, 0.98), 3)
 69  
 70      # At the very end, clamp everything:
 71      grader_score = max(0.02, min(grader_score, 0.98))
 72      grader_score = float(grader_score)
 73      if grader_score >= 1.0:
 74          grader_score = 0.99
 75      if grader_score <= 0.0:
 76          grader_score = 0.02
 77  
 78      grader_score = _safe_score(grader_score)
 79  
 80      return {
 81          "status": status,
 82          "grader_score":grader_score,
 83          "tests_passed": tests_passed,
 84          "total_tests": total_tests,
 85          "efficiency": {
 86              "runtime_ms": time_ms,
 87              "baseline_ms": baseline_time,
 88              "memory_mib": mem_mib,
 89              "baseline_mib": baseline_mem,
 90              "speed_ratio": speed_ratio,
 91              "memory_ratio": memory_ratio,
 92          },
 93          "error_message": error_message
 94      }
 95  
 96  def compare_results(actual, expected):
 97      import math
 98      
 99      # Float comparison
100      if isinstance(expected, float) and isinstance(actual, float):
101          return math.isclose(actual, expected, rel_tol=1e-3)
102  
103      # Int
104      if isinstance(expected, int) and isinstance(actual, int):
105          return actual == expected
106  
107      # String
108      if isinstance(expected, str) and isinstance(actual, str):
109          return actual == expected
110  
111      # List
112      if isinstance(expected, list) and isinstance(actual, list):
113          if len(expected) != len(actual):
114              return False
115          for a, e in zip(actual, expected):
116              if isinstance(e, float):
117                  if not math.isclose(a, e, rel_tol=1e-3):
118                      return False
119              else:
120                  if a != e:
121                      return False
122          return True
123  
124      # Dict
125      if isinstance(expected, dict) and isinstance(actual, dict):
126          return expected == actual
127  
128      # Numpy arrays
129      try:
130          import numpy as np
131          if isinstance(actual, np.ndarray) or isinstance(expected, np.ndarray):
132              return np.allclose(actual, expected, rtol=1e-3)
133      except:
134          pass
135  
136      # Pandas
137      try:
138          import pandas as pd
139          if isinstance(actual, pd.Series) or isinstance(expected, pd.Series):
140              return actual.equals(expected)
141          if isinstance(actual, pd.DataFrame) or isinstance(expected, pd.DataFrame):
142              return actual.equals(expected)
143      except:
144          pass
145  
146      # Torch tensors
147      try:
148          import torch
149          if isinstance(actual, torch.Tensor) or isinstance(expected, torch.Tensor):
150              return torch.allclose(actual, expected, rtol=1e-3)
151      except:
152          pass
153  
154      # Fallback
155      return actual == expected
156  
157  
158  def check_syntax(code: str) -> tuple[bool, str]:
159      try:
160          ast.parse(code)
161          return True, ""
162      except SyntaxError as e:
163          return False, f"SyntaxError at line {e.lineno}: {e.msg}"
164      except Exception as e:
165          return False, str(e)
166      
167  
168  def _safe_score(score) -> float:
169      """Ensure score is strictly between 0 and 1, pure Python float."""
170      s = float(score)
171      if s <= 0.0 or s >= 1.0:
172          s = max(0.02, min(s, 0.99))
173      if s == 0.0:
174          s = 0.02
175      if s == 1.0:
176          s = 0.99
177      return s