/ grader.py
grader.py
1 from task_loader import TESTS 2 import math 3 import ast 4 from sandbox import sandbox 5 6 7 def grader(code: str, task_id: str) -> dict: 8 grader_score = 0.02 9 task_tests = TESTS[task_id] 10 baseline_time = task_tests["baseline_time"] 11 baseline_mem = task_tests["baseline_mem"] 12 total_tests = len(task_tests["cases"]) 13 14 # syntax check 15 syntax_ok, syntax_error = check_syntax(code) 16 if not syntax_ok: 17 return { 18 "status": "syntax_error", 19 "grader_score": 0.02, 20 "tests_passed": 0, 21 "total_tests": total_tests, 22 "efficiency": { 23 "runtime_ms": None, 24 "baseline_ms": baseline_time, 25 "memory_mib": None, 26 "baseline_mib": baseline_mem, 27 "speed_ratio": None, # agent_time / baseline 28 "memory_ratio": None, # agent_mem / baseline 29 }, 30 "error_message": syntax_error 31 } 32 33 # sandbox execution and logic check 34 exec_ok, output, time_ms, mem_mib, tests_passed = sandbox(code, task_id) 35 36 if not exec_ok: 37 status = "runtime_error" 38 error_message = output 39 elif tests_passed == total_tests: 40 status = "success" 41 error_message = None 42 elif tests_passed > 0: 43 status = "partial" 44 error_message = output 45 else: 46 status = "logic_error" 47 error_message = output 48 49 # Speed and memory ratios — key for reward tradeoff 50 speed_ratio = (baseline_time / time_ms) if time_ms and time_ms > 0 else None 51 memory_ratio = (baseline_mem / mem_mib) if mem_mib and mem_mib > 0 else None 52 # ratio > 1.0 = better than baseline 53 # ratio < 1.0 = worse than baseline 54 55 if status == "syntax_error" or status == "runtime_error": 56 grader_score = 0.02 57 elif status == "logic_error": 58 grader_score = 0.02 59 elif status == "partial": 60 raw = 0.6 * (tests_passed / total_tests) 61 grader_score = round(max(0.01, min(raw, 0.99)), 3) 62 elif status == "success": 63 base = 0.6 64 if speed_ratio and speed_ratio >= 1.0: 65 base += min(0.25, 0.25 * (speed_ratio - 1.0)) 66 if memory_ratio and memory_ratio >= 1.0: 67 base += min(0.15, 0.15 * (memory_ratio - 1.0)) 68 grader_score = round(min(base, 0.98), 3) 69 70 # At the very end, clamp everything: 71 grader_score = max(0.02, min(grader_score, 0.98)) 72 grader_score = float(grader_score) 73 if grader_score >= 1.0: 74 grader_score = 0.99 75 if grader_score <= 0.0: 76 grader_score = 0.02 77 78 grader_score = _safe_score(grader_score) 79 80 return { 81 "status": status, 82 "grader_score":grader_score, 83 "tests_passed": tests_passed, 84 "total_tests": total_tests, 85 "efficiency": { 86 "runtime_ms": time_ms, 87 "baseline_ms": baseline_time, 88 "memory_mib": mem_mib, 89 "baseline_mib": baseline_mem, 90 "speed_ratio": speed_ratio, 91 "memory_ratio": memory_ratio, 92 }, 93 "error_message": error_message 94 } 95 96 def compare_results(actual, expected): 97 import math 98 99 # Float comparison 100 if isinstance(expected, float) and isinstance(actual, float): 101 return math.isclose(actual, expected, rel_tol=1e-3) 102 103 # Int 104 if isinstance(expected, int) and isinstance(actual, int): 105 return actual == expected 106 107 # String 108 if isinstance(expected, str) and isinstance(actual, str): 109 return actual == expected 110 111 # List 112 if isinstance(expected, list) and isinstance(actual, list): 113 if len(expected) != len(actual): 114 return False 115 for a, e in zip(actual, expected): 116 if isinstance(e, float): 117 if not math.isclose(a, e, rel_tol=1e-3): 118 return False 119 else: 120 if a != e: 121 return False 122 return True 123 124 # Dict 125 if isinstance(expected, dict) and isinstance(actual, dict): 126 return expected == actual 127 128 # Numpy arrays 129 try: 130 import numpy as np 131 if isinstance(actual, np.ndarray) or isinstance(expected, np.ndarray): 132 return np.allclose(actual, expected, rtol=1e-3) 133 except: 134 pass 135 136 # Pandas 137 try: 138 import pandas as pd 139 if isinstance(actual, pd.Series) or isinstance(expected, pd.Series): 140 return actual.equals(expected) 141 if isinstance(actual, pd.DataFrame) or isinstance(expected, pd.DataFrame): 142 return actual.equals(expected) 143 except: 144 pass 145 146 # Torch tensors 147 try: 148 import torch 149 if isinstance(actual, torch.Tensor) or isinstance(expected, torch.Tensor): 150 return torch.allclose(actual, expected, rtol=1e-3) 151 except: 152 pass 153 154 # Fallback 155 return actual == expected 156 157 158 def check_syntax(code: str) -> tuple[bool, str]: 159 try: 160 ast.parse(code) 161 return True, "" 162 except SyntaxError as e: 163 return False, f"SyntaxError at line {e.lineno}: {e.msg}" 164 except Exception as e: 165 return False, str(e) 166 167 168 def _safe_score(score) -> float: 169 """Ensure score is strictly between 0 and 1, pure Python float.""" 170 s = float(score) 171 if s <= 0.0 or s >= 1.0: 172 s = max(0.02, min(s, 0.99)) 173 if s == 0.0: 174 s = 0.02 175 if s == 1.0: 176 s = 0.99 177 return s