/ inference.py
inference.py
1 """ 2 inference.py — AE² Baseline Inference Script 3 ============================================ 4 Required env vars: 5 API_BASE_URL LLM API endpoint 6 MODEL_NAME Model identifier 7 HF_TOKEN HuggingFace / API key 8 AE2_URL Your HF Space URL (e.g. https://username-ae2-env.hf.space) 9 """ 10 11 import sys 12 import os 13 import warnings 14 import logging 15 import re 16 import json 17 18 # Silence all warnings to stderr only 19 warnings.filterwarnings("ignore") 20 logging.getLogger("websockets").setLevel(logging.CRITICAL) 21 logging.getLogger("websockets.legacy").setLevel(logging.CRITICAL) 22 logging.basicConfig(stream=sys.stderr, level=logging.CRITICAL) 23 24 from openai import OpenAI 25 from models import EngAction 26 from client import AE2Env 27 import requests 28 import time as _time 29 30 31 # ── Config ────────────────────────────────────────────────────────────────── 32 API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") 33 API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") 34 MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") 35 AE2_URL = os.getenv("AE2_URL", "http://localhost:7860") 36 ENV_NAME = "ae2-applied-ai-engineering" 37 38 MAX_STEPS = 10 39 TEMPERATURE = 0.1 40 MAX_TOKENS = 1024 41 42 43 44 def log_start(task: str, model: str) -> None: 45 # Ensure ONLY the [START] line goes to stdout 46 sys.stdout.write(f"[START] task={task} env={ENV_NAME} model={model}\n") 47 sys.stdout.flush() 48 49 def log_step(step: int, action: str, reward: float, done: bool, error=None) -> None: 50 # FORCE the reward to be valid, no matter what the environment says 51 safe_reward = max(0.05, min(float(reward), 0.95)) 52 53 d_str = "true" if done else "false" 54 r_str = f"{safe_reward:.2f}" 55 a_clean = str(action).replace('\n', ' ').strip()[:50] 56 e_clean = "null" if not error else str(error).replace('\n', ' ').strip() 57 58 sys.stdout.write(f"[STEP] step={step} action={a_clean} reward={r_str} done={d_str} error={e_clean}\n") 59 sys.stdout.flush() 60 61 def log_end(success: bool, steps: int, rewards: list) -> None: 62 success_str = "true" if success else "false" 63 rewards_formatted = ",".join([f"{float(r):.2f}" for r in rewards]) 64 score = sum(rewards) / len(rewards) if rewards else 0.5 65 score = max(0.01, min(score, 0.99)) 66 sys.stdout.write(f"[END] success={success_str} steps={steps} score={score:.2f} rewards={rewards_formatted}\n") 67 sys.stdout.flush() 68 69 70 71 # ── System prompt ──────────────────────────────────────────────────────────── 72 SYSTEM_PROMPT = """ 73 You are an expert AI/ML Software Engineer. 74 You will be given a broken Python function and a task description. 75 Your job is to fix or optimize the code. 76 77 RULES: 78 1. Always return ONLY the complete fixed Python function — no explanation, no markdown,only python function. 79 2. The function MUST be named 'solution'. 80 3. Do not add any imports outside the function body unless they were already present. 81 4. Do not use os, subprocess, sys, socket, or any file I/O. 82 5. Keep the same function signature as the broken code. 83 6. keep helper functions 84 85 Example response format: 86 def solution(x, y): 87 return x + y 88 """ 89 90 # ── Helpers ────────────────────────────────────────────────────────────────── 91 def extract_code(response_text: str) -> str: 92 """Extract Python code from LLM response.""" 93 # Try to extract from markdown code block first 94 match = re.search(r"```python\n(.*?)```", response_text, re.DOTALL) 95 if match: 96 return match.group(1).strip() 97 98 # Try plain code block 99 match = re.search(r"```\n(.*?)```", response_text, re.DOTALL) 100 if match: 101 return match.group(1).strip() 102 103 # If no code block, assume entire response is code 104 return response_text.strip() 105 106 def print_summary(results, scores_by_difficulty): 107 print(f"\n{'='*60}") 108 print("BASELINE RESULTS SUMMARY") 109 print(f"{'='*60}") 110 print(f"\n{'Domain':<20} {'Diff':<8} {'Score':<8} {'Tests':<10} {'Steps'}") 111 print("-" * 60) 112 for r in results: 113 if "error" not in r: 114 print( 115 f"{r.get('domain','?'):<20} " 116 f"{r.get('difficulty','?'):<8} " 117 f"{r.get('grader_score',0.02):.3f} " 118 f"{r.get('tests_passed',0)}/{r.get('total_tests',0):<6} " 119 f"{r.get('steps_taken',0)}" 120 ) 121 print(f"\n{'Difficulty':<12} {'Avg Score':<12} {'Tasks'}") 122 print("-" * 35) 123 for diff, scores in scores_by_difficulty.items(): 124 if scores: 125 avg = sum(scores) / len(scores) 126 print(f"{diff:<12} {avg:.3f} {len(scores)}") 127 overall = [r["grader_score"] for r in results if "error" not in r] 128 if overall: 129 print(f"\nOVERALL BASELINE SCORE: {sum(overall)/len(overall):.3f}") 130 with open("baseline_results.json", "w") as f: 131 json.dump({ 132 "model": MODEL_NAME, 133 "results": results, 134 "summary": {d: sum(s)/len(s) if s else 0.01 for d, s in scores_by_difficulty.items()}, 135 "overall": sum(overall)/len(overall) if overall else 0.02 136 }, f, indent=2) 137 print("\nSaved to baseline_results.json") 138 139 140 def build_user_prompt(observation, step: int, history: list) -> str: 141 history_text = "\n".join(history[-3:]) if history else "None" 142 143 return f""" 144 TASK: {observation.task} 145 146 CURRENT CODE (fix or optimize this): 147 ```python 148 {observation.code} 149 ``` 150 151 FEEDBACK FROM LAST ATTEMPT: 152 {observation.message or 'No previous attempt.'} 153 154 OUTPUT/ERROR FROM LAST RUN: 155 {observation.output or 'No output yet.'} 156 157 TESTS PASSED: {observation.tests_passed}/{observation.num_tests if observation.num_tests else '?'} 158 EXECUTION TIME: {f"{observation.time_taken:.2f}ms" if observation.time_taken else "N/A"} 159 MEMORY USED: {f"{observation.mem_taken:.2f}MiB" if observation.mem_taken else "N/A"} 160 ATTEMPTS REMAINING: {observation.num_steps_remain} 161 162 STEP HISTORY: 163 {history_text} 164 165 Return ONLY the fixed Python function named 'solution'. No explanation. No markdown. 166 """.strip() 167 168 169 def run_episode(client, env, task_id: str = None) -> dict: 170 result = env.reset(task_id=task_id) 171 observation = result.observation 172 173 # print(f"\n{'='*60}") 174 # print(f"Task: {observation.task[:80]}...") 175 # print(f"Domain: {observation.domain} | Difficulty: {observation.difficulty}") 176 # print(f"{'='*60}") 177 178 # Mandatory [START] log 179 log_start(task=task_id or observation.domain, model=MODEL_NAME) 180 181 history = [] 182 final_reward = 0.01 183 steps_taken = 0 184 rewards = [] 185 186 for step in range(1, MAX_STEPS + 1): 187 if result.done: 188 # print(f"Episode done at step {step-1}.") 189 break 190 191 steps_taken = step 192 user_prompt = build_user_prompt(observation, step, history) 193 messages = [ 194 {"role": "system", "content": SYSTEM_PROMPT}, 195 {"role": "user", "content": user_prompt} 196 ] 197 198 try: 199 completion = client.chat.completions.create( 200 model=MODEL_NAME, 201 messages=messages, 202 temperature=TEMPERATURE, 203 max_tokens=MAX_TOKENS, 204 stream=False, 205 ) 206 response_text = completion.choices[0].message.content or "" 207 except Exception as exc: 208 if "429" in str(exc): 209 print(f" Rate limited, waiting 20s...") 210 _time.sleep(20) 211 try: 212 completion = client.chat.completions.create( 213 model=MODEL_NAME, 214 messages=messages, 215 temperature=TEMPERATURE, 216 max_tokens=MAX_TOKENS, 217 ) 218 response_text = completion.choices[0].message.content or "" 219 except Exception: 220 response_text = observation.code 221 else: 222 print(f" LLM call failed: {exc}") 223 response_text = observation.code 224 225 code = extract_code(response_text) 226 # print(f"\nStep {step}: {code[:60].replace(chr(10), ' ')}...") 227 228 action = EngAction(sol=code) 229 result = env.step(action) 230 observation = result.observation 231 232 reward = result.reward or 0.01 233 final_reward = reward 234 rewards.append(reward) 235 236 error_msg = observation.output if not result.done and observation.tests_passed == 0 else None 237 238 # Mandatory [STEP] log 239 log_step(step=step, action=code[:50], reward=reward, done=result.done, error=error_msg) 240 241 if observation.time_taken: 242 history_line = f"Step {step}: tests={observation.tests_passed}/{observation.num_tests} reward={reward:+.3f} time={observation.time_taken:.1f}ms" 243 else: 244 history_line = f"Step {step}: reward={reward:+.3f}" 245 history.append(history_line) 246 247 # print(f" Reward: {reward:+.3f} | Tests: {observation.tests_passed}/{observation.num_tests} | Done: {result.done}") 248 249 if result.done: 250 # print(f" {observation.message}") 251 break 252 253 success = result.done and (observation.tests_passed == observation.num_tests) 254 255 # Mandatory [END] log 256 log_end(success=success, steps=steps_taken, rewards=rewards) 257 258 # Get grader score via HTTP 259 # Inside run_episode() 260 try: 261 # Use requests.post and 'json=' parameter 262 resp = requests.post( 263 f"{AE2_URL}/grader", 264 json={"task_id": task_id, "code": observation.code} 265 ) 266 267 if resp.status_code == 200: 268 data = resp.json() 269 final_grader_score = data.get("grader_score", 0.02) 270 else: 271 # This is where your 405 error is currently being caught 272 print(f" Grader error: {resp.status_code}") 273 final_grader_score = 0.02 274 except Exception as e: 275 print(f" Grader call failed: {e}") 276 final_grader_score = 0.02 277 278 return { 279 "task": observation.task, 280 "domain": observation.domain, 281 "difficulty": observation.difficulty, 282 "steps_taken": steps_taken, 283 "final_reward": final_reward, 284 "grader_score": final_grader_score, 285 "tests_passed": observation.tests_passed, 286 "total_tests": observation.num_tests, 287 "success": success 288 } 289 290 291 def main(): 292 client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) 293 warnings.filterwarnings("ignore") 294 os.environ["PYTHONWARNINGS"] = "ignore" 295 296 # print("AE² — Applied AI Engineering Environment") 297 # print(f"Model: {MODEL_NAME}") 298 # print(f"Environment: {AE2_URL}") 299 300 results = [] 301 scores_by_difficulty = {"EASY": [], "MEDIUM": [], "HARD": []} 302 303 # Get task list via HTTP (this is fine, stateless is ok for /tasks) 304 tasks_resp = requests.get(f"{AE2_URL}/tasks") 305 tasks = tasks_resp.json().get("tasks", []) 306 307 # Each task gets its OWN WebSocket connection = its own stateful session 308 for task_info in tasks: 309 task_id = task_info["id"] 310 311 try: 312 # New WebSocket connection per task = clean stateful episode 313 with AE2Env(base_url=AE2_URL).sync() as env: 314 episode_result = run_episode(client, env, task_id=task_id) 315 results.append(episode_result) 316 difficulty = episode_result["difficulty"] 317 scores_by_difficulty[difficulty].append(episode_result["grader_score"]) 318 319 except Exception as e: 320 print(f"Episode failed for {task_id}: {e}") 321 _time.sleep(3) # give server time to release connection 322 323 results.append({ 324 "task_id": task_id, 325 "grader_score": 0.02, 326 "difficulty": task_info.get("difficulty", "EASY"), 327 "domain": task_info.get("domain", ""), 328 "error": str(e) 329 }) 330 331 # Print and save results (same as before) 332 # print_summary(results, scores_by_difficulty) 333 334 335 if __name__ == "__main__": 336 main()