/ batch_runner.py
batch_runner.py
1 #!/usr/bin/env python3 2 """ 3 Batch Agent Runner 4 5 This module provides parallel batch processing capabilities for running the agent 6 across multiple prompts from a dataset. It includes: 7 - Dataset loading and batching 8 - Parallel batch processing with multiprocessing 9 - Checkpointing for fault tolerance and resumption 10 - Trajectory saving in the proper format (from/value pairs) 11 - Tool usage statistics aggregation across all batches 12 13 Usage: 14 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run 15 16 # Resume an interrupted run 17 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume 18 19 # Use a specific toolset distribution 20 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen 21 """ 22 23 import json 24 import logging 25 import os 26 import time 27 from pathlib import Path 28 from typing import List, Dict, Any, Optional, Tuple 29 from datetime import datetime 30 from multiprocessing import Pool, Lock 31 import traceback 32 from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn 33 from rich.console import Console 34 35 logger = logging.getLogger(__name__) 36 import fire 37 38 from run_agent import AIAgent 39 from toolset_distributions import ( 40 list_distributions, 41 sample_toolsets_from_distribution, 42 validate_distribution 43 ) 44 from model_tools import TOOL_TO_TOOLSET_MAP 45 46 47 # Global configuration for worker processes 48 _WORKER_CONFIG = {} 49 50 # All possible tools - auto-derived from the master mapping in model_tools.py. 51 # This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP. 52 # Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for 53 # filtering corrupted entries during trajectory combination. 54 ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys()) 55 56 # Default stats for tools that weren't used 57 DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0} 58 59 60 def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: 61 """ 62 Normalize tool_stats to include all possible tools with consistent schema. 63 64 This ensures HuggingFace datasets can load the JSONL without schema mismatch errors. 65 Tools that weren't used get zero counts. 66 67 Args: 68 tool_stats (Dict): Raw tool statistics from extraction 69 70 Returns: 71 Dict: Normalized tool statistics with all tools present 72 """ 73 normalized = {} 74 75 # Add all possible tools with defaults 76 for tool in ALL_POSSIBLE_TOOLS: 77 if tool in tool_stats: 78 normalized[tool] = tool_stats[tool].copy() 79 else: 80 normalized[tool] = DEFAULT_TOOL_STATS.copy() 81 82 # Also include any unexpected tools (in case new tools are added) 83 for tool, stats in tool_stats.items(): 84 if tool not in normalized: 85 normalized[tool] = stats.copy() 86 87 return normalized 88 89 90 def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]: 91 """ 92 Normalize tool_error_counts to include all possible tools. 93 94 Args: 95 tool_error_counts (Dict): Raw error counts mapping 96 97 Returns: 98 Dict: Normalized error counts with all tools present 99 """ 100 normalized = {} 101 102 # Add all possible tools with zero defaults 103 for tool in ALL_POSSIBLE_TOOLS: 104 normalized[tool] = tool_error_counts.get(tool, 0) 105 106 # Also include any unexpected tools 107 for tool, count in tool_error_counts.items(): 108 if tool not in normalized: 109 normalized[tool] = count 110 111 return normalized 112 113 114 def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: 115 """ 116 Extract tool usage statistics from message history. 117 118 Args: 119 messages (List[Dict]): Message history 120 121 Returns: 122 Dict: Tool statistics with counts and success/failure rates 123 """ 124 tool_stats = {} 125 126 # Track tool calls and their results 127 tool_calls_map = {} # Map tool_call_id to tool name 128 129 for msg in messages: 130 # Track tool calls from assistant messages 131 if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: 132 for tool_call in msg["tool_calls"]: 133 if not tool_call or not isinstance(tool_call, dict): continue 134 tool_name = tool_call["function"]["name"] 135 tool_call_id = tool_call["id"] 136 137 # Initialize stats for this tool if not exists 138 if tool_name not in tool_stats: 139 tool_stats[tool_name] = { 140 "count": 0, 141 "success": 0, 142 "failure": 0 143 } 144 145 tool_stats[tool_name]["count"] += 1 146 tool_calls_map[tool_call_id] = tool_name 147 148 # Track tool responses 149 elif msg["role"] == "tool": 150 tool_call_id = msg.get("tool_call_id", "") 151 content = msg.get("content", "") 152 153 # Determine if tool call was successful 154 is_success = True 155 try: 156 # Try to parse as JSON and check for actual error values 157 content_json = json.loads(content) if isinstance(content, str) else content 158 159 if isinstance(content_json, dict): 160 # Check if error field exists AND has a non-null value 161 if "error" in content_json and content_json["error"] is not None: 162 is_success = False 163 164 # Special handling for terminal tool responses 165 # Terminal wraps its response in a "content" field 166 if "content" in content_json and isinstance(content_json["content"], dict): 167 inner_content = content_json["content"] 168 # Check for actual error (non-null error field) 169 # Note: non-zero exit codes are not failures - the model can self-correct 170 if inner_content.get("error") is not None: 171 is_success = False 172 173 # Check for "success": false pattern used by some tools 174 if content_json.get("success") is False: 175 is_success = False 176 177 except (json.JSONDecodeError, ValueError, TypeError): 178 # If not JSON, check if content is empty or explicitly states an error 179 # Note: We avoid simple substring matching to prevent false positives 180 if not content: 181 is_success = False 182 # Only mark as failure if it explicitly starts with "Error:" or "ERROR:" 183 elif content.strip().lower().startswith("error:"): 184 is_success = False 185 186 # Update success/failure count 187 if tool_call_id in tool_calls_map: 188 tool_name = tool_calls_map[tool_call_id] 189 if is_success: 190 tool_stats[tool_name]["success"] += 1 191 else: 192 tool_stats[tool_name]["failure"] += 1 193 194 return tool_stats 195 196 197 def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]: 198 """ 199 Count how many assistant turns have reasoning vs no reasoning. 200 201 Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field 202 (native thinking tokens). Returns counts for tracking reasoning coverage. 203 204 Args: 205 messages: Message history 206 207 Returns: 208 Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning' 209 """ 210 total = 0 211 with_reasoning = 0 212 213 for msg in messages: 214 if msg.get("role") != "assistant": 215 continue 216 total += 1 217 218 content = msg.get("content", "") or "" 219 has_scratchpad = "<REASONING_SCRATCHPAD>" in content 220 has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False 221 222 if has_scratchpad or has_native_reasoning: 223 with_reasoning += 1 224 225 return { 226 "total_assistant_turns": total, 227 "turns_with_reasoning": with_reasoning, 228 "turns_without_reasoning": total - with_reasoning, 229 "has_any_reasoning": with_reasoning > 0, 230 } 231 232 233 def _process_single_prompt( 234 prompt_index: int, 235 prompt_data: Dict[str, Any], 236 batch_num: int, 237 config: Dict[str, Any] 238 ) -> Dict[str, Any]: 239 """ 240 Process a single prompt with the agent. 241 242 Args: 243 prompt_index (int): Index of prompt in dataset 244 prompt_data (Dict): Prompt data containing 'prompt' field and optional 'image' field 245 batch_num (int): Batch number 246 config (Dict): Configuration dict with agent parameters 247 248 Returns: 249 Dict: Result containing trajectory, stats, and metadata 250 """ 251 prompt = prompt_data["prompt"] 252 task_id = f"task_{prompt_index}" 253 254 # Per-prompt container image override: if the dataset row has an 'image' field, 255 # register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona. 256 container_image = prompt_data.get("image") or prompt_data.get("docker_image") 257 if container_image: 258 # Verify the image is accessible before spending tokens on the agent loop. 259 # For Docker: check local cache, then try pulling. 260 # For Modal: skip local check (Modal pulls server-side). 261 env_type = os.getenv("TERMINAL_ENV", "local") 262 if env_type == "docker": 263 import subprocess as _sp 264 try: 265 probe = _sp.run( 266 ["docker", "image", "inspect", container_image], 267 capture_output=True, timeout=10, 268 ) 269 if probe.returncode != 0: 270 if config.get("verbose"): 271 print(f" Prompt {prompt_index}: Pulling docker image {container_image}...", flush=True) 272 pull = _sp.run( 273 ["docker", "pull", container_image], 274 capture_output=True, text=True, timeout=600, 275 ) 276 if pull.returncode != 0: 277 return { 278 "success": False, 279 "prompt_index": prompt_index, 280 "error": f"Docker image not available: {container_image}\n{pull.stderr[:500]}", 281 "trajectory": None, 282 "tool_stats": {}, 283 "toolsets_used": [], 284 "metadata": {"batch_num": batch_num, "timestamp": datetime.now().isoformat()}, 285 } 286 except FileNotFoundError: 287 pass # Docker CLI not installed β skip check (e.g., Modal backend) 288 except Exception as img_err: 289 if config.get("verbose"): 290 print(f" Prompt {prompt_index}: Docker image check failed: {img_err}", flush=True) 291 292 from tools.terminal_tool import register_task_env_overrides 293 overrides = { 294 "docker_image": container_image, 295 "modal_image": container_image, 296 "singularity_image": f"docker://{container_image}", 297 "daytona_image": container_image, 298 } 299 if prompt_data.get("cwd"): 300 overrides["cwd"] = prompt_data["cwd"] 301 register_task_env_overrides(task_id, overrides) 302 if config.get("verbose"): 303 print(f" Prompt {prompt_index}: Using container image {container_image}") 304 305 try: 306 # Sample toolsets from distribution for this prompt 307 selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) 308 309 if config.get("verbose"): 310 print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") 311 312 # Initialize agent with sampled toolsets and log prefix for identification 313 log_prefix = f"[B{batch_num}:P{prompt_index}]" 314 agent = AIAgent( 315 base_url=config.get("base_url"), 316 api_key=config.get("api_key"), 317 model=config["model"], 318 max_iterations=config["max_iterations"], 319 enabled_toolsets=selected_toolsets, 320 save_trajectories=False, # We handle saving ourselves 321 verbose_logging=config.get("verbose", False), 322 ephemeral_system_prompt=config.get("ephemeral_system_prompt"), 323 log_prefix_chars=config.get("log_prefix_chars", 100), 324 log_prefix=log_prefix, 325 providers_allowed=config.get("providers_allowed"), 326 providers_ignored=config.get("providers_ignored"), 327 providers_order=config.get("providers_order"), 328 provider_sort=config.get("provider_sort"), 329 max_tokens=config.get("max_tokens"), 330 reasoning_config=config.get("reasoning_config"), 331 prefill_messages=config.get("prefill_messages"), 332 skip_context_files=True, # Don't pollute trajectories with SOUL.md/AGENTS.md 333 skip_memory=True, # Don't use persistent memory in batch runs 334 ) 335 336 # Run the agent with task_id to ensure each task gets its own isolated VM 337 result = agent.run_conversation(prompt, task_id=task_id) 338 339 # Extract tool usage statistics 340 tool_stats = _extract_tool_stats(result["messages"]) 341 342 # Extract reasoning coverage stats 343 reasoning_stats = _extract_reasoning_stats(result["messages"]) 344 345 # Convert to trajectory format (using existing method) 346 trajectory = agent._convert_to_trajectory_format( 347 result["messages"], 348 prompt, 349 result["completed"] 350 ) 351 352 return { 353 "success": True, 354 "prompt_index": prompt_index, 355 "trajectory": trajectory, 356 "tool_stats": tool_stats, 357 "reasoning_stats": reasoning_stats, 358 "completed": result["completed"], 359 "partial": result.get("partial", False), 360 "api_calls": result["api_calls"], 361 "toolsets_used": selected_toolsets, 362 "metadata": { 363 "batch_num": batch_num, 364 "timestamp": datetime.now().isoformat(), 365 "model": config["model"] 366 } 367 } 368 369 except Exception as e: 370 print(f"β Error processing prompt {prompt_index}: {e}") 371 if config.get("verbose"): 372 traceback.print_exc() 373 374 return { 375 "success": False, 376 "prompt_index": prompt_index, 377 "error": str(e), 378 "trajectory": None, 379 "tool_stats": {}, 380 "toolsets_used": [], 381 "metadata": { 382 "batch_num": batch_num, 383 "timestamp": datetime.now().isoformat() 384 } 385 } 386 387 388 def _process_batch_worker(args: Tuple) -> Dict[str, Any]: 389 """ 390 Worker function to process a single batch of prompts. 391 392 Args: 393 args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) 394 395 Returns: 396 Dict: Batch results with statistics 397 """ 398 batch_num, batch_data, output_dir, completed_prompts_set, config = args 399 400 output_dir = Path(output_dir) 401 print(f"\nπ Batch {batch_num}: Starting ({len(batch_data)} prompts)") 402 403 # Output file for this batch 404 batch_output_file = output_dir / f"batch_{batch_num}.jsonl" 405 406 # Filter out already completed prompts 407 prompts_to_process = [ 408 (idx, data) for idx, data in batch_data 409 if idx not in completed_prompts_set 410 ] 411 412 if not prompts_to_process: 413 print(f"β Batch {batch_num}: Already completed (skipping)") 414 return { 415 "batch_num": batch_num, 416 "processed": 0, 417 "skipped": len(batch_data), 418 "tool_stats": {}, 419 "completed_prompts": [] 420 } 421 422 print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") 423 424 # Initialize aggregated stats for this batch 425 batch_tool_stats = {} 426 batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} 427 completed_in_batch = [] 428 discarded_no_reasoning = 0 429 430 # Process each prompt sequentially in this batch 431 for prompt_index, prompt_data in prompts_to_process: 432 # Process the prompt 433 result = _process_single_prompt( 434 prompt_index, 435 prompt_data, 436 batch_num, 437 config 438 ) 439 440 # Save trajectory if successful 441 if result["success"] and result["trajectory"]: 442 # Discard samples with zero reasoning across all turns 443 reasoning = result.get("reasoning_stats", {}) 444 if not reasoning.get("has_any_reasoning", True): 445 print(f" π« Prompt {prompt_index} discarded (no reasoning in any turn)") 446 discarded_no_reasoning += 1 447 completed_in_batch.append(prompt_index) 448 continue 449 450 # Get and normalize tool stats for consistent schema across all entries 451 raw_tool_stats = result.get("tool_stats", {}) 452 tool_stats = _normalize_tool_stats(raw_tool_stats) 453 454 # Create normalized tool_error_counts mapping tool names to their failure counts 455 raw_error_counts = { 456 tool_name: stats.get("failure", 0) 457 for tool_name, stats in raw_tool_stats.items() 458 } 459 tool_error_counts = _normalize_tool_error_counts(raw_error_counts) 460 461 trajectory_entry = { 462 "prompt_index": prompt_index, 463 "conversations": result["trajectory"], 464 "metadata": result["metadata"], 465 "completed": result["completed"], 466 "partial": result.get("partial", False), # True if stopped due to invalid tool calls 467 "api_calls": result["api_calls"], 468 "toolsets_used": result["toolsets_used"], 469 "tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized 470 "tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized 471 } 472 473 # Append to batch output file 474 with open(batch_output_file, 'a', encoding='utf-8') as f: 475 f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") 476 477 # Aggregate tool statistics 478 for tool_name, stats in result.get("tool_stats", {}).items(): 479 if tool_name not in batch_tool_stats: 480 batch_tool_stats[tool_name] = { 481 "count": 0, 482 "success": 0, 483 "failure": 0 484 } 485 486 batch_tool_stats[tool_name]["count"] += stats["count"] 487 batch_tool_stats[tool_name]["success"] += stats["success"] 488 batch_tool_stats[tool_name]["failure"] += stats["failure"] 489 490 # Aggregate reasoning stats 491 for key in batch_reasoning_stats: 492 batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0) 493 494 # Only mark as completed if successfully saved (failed prompts can be retried on resume) 495 if result["success"] and result["trajectory"]: 496 completed_in_batch.append(prompt_index) 497 status = "β οΈ partial" if result.get("partial") else "β " 498 print(f" {status} Prompt {prompt_index} completed") 499 else: 500 print(f" β Prompt {prompt_index} failed (will retry on resume)") 501 502 print(f"β Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") 503 504 return { 505 "batch_num": batch_num, 506 "processed": len(prompts_to_process), 507 "skipped": len(batch_data) - len(prompts_to_process), 508 "tool_stats": batch_tool_stats, 509 "reasoning_stats": batch_reasoning_stats, 510 "discarded_no_reasoning": discarded_no_reasoning, 511 "completed_prompts": completed_in_batch 512 } 513 514 515 class BatchRunner: 516 """ 517 Manages batch processing of agent prompts with checkpointing and statistics. 518 """ 519 520 def __init__( 521 self, 522 dataset_file: str, 523 batch_size: int, 524 run_name: str, 525 distribution: str = "default", 526 max_iterations: int = 10, 527 base_url: str = None, 528 api_key: str = None, 529 model: str = "claude-opus-4-20250514", 530 num_workers: int = 4, 531 verbose: bool = False, 532 ephemeral_system_prompt: str = None, 533 log_prefix_chars: int = 100, 534 providers_allowed: List[str] = None, 535 providers_ignored: List[str] = None, 536 providers_order: List[str] = None, 537 provider_sort: str = None, 538 max_tokens: int = None, 539 reasoning_config: Dict[str, Any] = None, 540 prefill_messages: List[Dict[str, Any]] = None, 541 max_samples: int = None, 542 ): 543 """ 544 Initialize the batch runner. 545 546 Args: 547 dataset_file (str): Path to the dataset JSONL file with 'prompt' field 548 batch_size (int): Number of prompts per batch 549 run_name (str): Name for this run (used for checkpointing and output) 550 distribution (str): Toolset distribution to use (default: "default") 551 max_iterations (int): Max iterations per agent run 552 base_url (str): Base URL for model API 553 api_key (str): API key for model 554 model (str): Model name to use 555 num_workers (int): Number of parallel workers 556 verbose (bool): Enable verbose logging 557 ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) 558 log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) 559 providers_allowed (List[str]): OpenRouter providers to allow (optional) 560 providers_ignored (List[str]): OpenRouter providers to ignore (optional) 561 providers_order (List[str]): OpenRouter providers to try in order (optional) 562 provider_sort (str): Sort providers by price/throughput/latency (optional) 563 max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) 564 reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking) 565 prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming). 566 NOTE: Anthropic Sonnet 4.6+ and Opus 4.6+ reject a trailing assistant-role prefill 567 (400 error). For those models use output_config.format or structured-output 568 schemas instead. Safe here for user-role priming and for older Claude / non-Claude models. 569 max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) 570 """ 571 self.dataset_file = Path(dataset_file) 572 self.batch_size = batch_size 573 self.run_name = run_name 574 self.distribution = distribution 575 self.max_iterations = max_iterations 576 self.base_url = base_url 577 self.api_key = api_key 578 self.model = model 579 self.num_workers = num_workers 580 self.verbose = verbose 581 self.ephemeral_system_prompt = ephemeral_system_prompt 582 self.log_prefix_chars = log_prefix_chars 583 self.providers_allowed = providers_allowed 584 self.providers_ignored = providers_ignored 585 self.providers_order = providers_order 586 self.provider_sort = provider_sort 587 self.max_tokens = max_tokens 588 self.reasoning_config = reasoning_config 589 self.prefill_messages = prefill_messages 590 self.max_samples = max_samples 591 592 # Validate distribution 593 if not validate_distribution(distribution): 594 raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") 595 596 # Setup output directory 597 self.output_dir = Path("data") / run_name 598 self.output_dir.mkdir(parents=True, exist_ok=True) 599 600 # Checkpoint file 601 self.checkpoint_file = self.output_dir / "checkpoint.json" 602 603 # Statistics file 604 self.stats_file = self.output_dir / "statistics.json" 605 606 # Load dataset (and optionally truncate to max_samples) 607 self.dataset = self._load_dataset() 608 if self.max_samples and self.max_samples < len(self.dataset): 609 full_count = len(self.dataset) 610 self.dataset = self.dataset[:self.max_samples] 611 print(f"βοΈ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)") 612 613 # Create batches 614 self.batches = self._create_batches() 615 616 print("π Batch Runner Initialized") 617 print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") 618 print(f" Batch size: {self.batch_size}") 619 print(f" Total batches: {len(self.batches)}") 620 print(f" Run name: {self.run_name}") 621 print(f" Distribution: {self.distribution}") 622 print(f" Output directory: {self.output_dir}") 623 print(f" Workers: {self.num_workers}") 624 if self.ephemeral_system_prompt: 625 prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt 626 print(f" π Ephemeral system prompt: '{prompt_preview}'") 627 628 def _load_dataset(self) -> List[Dict[str, Any]]: 629 """ 630 Load dataset from JSONL file. 631 632 Returns: 633 List[Dict]: List of dataset entries 634 """ 635 if not self.dataset_file.exists(): 636 raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") 637 638 dataset = [] 639 with open(self.dataset_file, 'r', encoding='utf-8') as f: 640 for line_num, line in enumerate(f, 1): 641 line = line.strip() 642 if not line: 643 continue 644 645 try: 646 entry = json.loads(line) 647 if 'prompt' not in entry: 648 print(f"β οΈ Warning: Line {line_num} missing 'prompt' field, skipping") 649 continue 650 dataset.append(entry) 651 except json.JSONDecodeError as e: 652 print(f"β οΈ Warning: Invalid JSON on line {line_num}: {e}") 653 continue 654 655 if not dataset: 656 raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") 657 658 return dataset 659 660 def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: 661 """ 662 Split dataset into batches with indices. 663 664 Returns: 665 List of batches, where each batch is a list of (index, entry) tuples 666 """ 667 batches = [] 668 for i in range(0, len(self.dataset), self.batch_size): 669 batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] 670 batches.append(batch) 671 672 return batches 673 674 def _load_checkpoint(self) -> Dict[str, Any]: 675 """ 676 Load checkpoint data if it exists. 677 678 Returns: 679 Dict: Checkpoint data with completed prompt indices 680 """ 681 if not self.checkpoint_file.exists(): 682 return { 683 "run_name": self.run_name, 684 "completed_prompts": [], 685 "batch_stats": {}, 686 "last_updated": None 687 } 688 689 try: 690 with open(self.checkpoint_file, 'r', encoding='utf-8') as f: 691 return json.load(f) 692 except Exception as e: 693 print(f"β οΈ Warning: Failed to load checkpoint: {e}") 694 return { 695 "run_name": self.run_name, 696 "completed_prompts": [], 697 "batch_stats": {}, 698 "last_updated": None 699 } 700 701 def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): 702 """ 703 Save checkpoint data. 704 705 Args: 706 checkpoint_data (Dict): Checkpoint data to save 707 lock (Lock): Optional lock for thread-safe access 708 """ 709 checkpoint_data["last_updated"] = datetime.now().isoformat() 710 711 from utils import atomic_json_write 712 if lock: 713 with lock: 714 atomic_json_write(self.checkpoint_file, checkpoint_data) 715 else: 716 atomic_json_write(self.checkpoint_file, checkpoint_data) 717 718 def _scan_completed_prompts_by_content(self) -> set: 719 """ 720 Scan all batch files and extract completed prompts by their actual content. 721 722 This provides a more robust resume mechanism that matches on prompt text 723 rather than indices, allowing recovery even if indices don't match. 724 725 Returns: 726 set: Set of prompt texts that have been successfully processed 727 """ 728 completed_prompts = set() 729 batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) 730 731 if not batch_files: 732 return completed_prompts 733 734 print(f"π Scanning {len(batch_files)} batch files for completed prompts...") 735 736 for batch_file in batch_files: 737 try: 738 with open(batch_file, 'r', encoding='utf-8') as f: 739 for line in f: 740 try: 741 entry = json.loads(line.strip()) 742 743 # Skip failed entries - we want to retry these 744 if entry.get("failed", False): 745 continue 746 747 # Extract the human/user prompt from conversations 748 conversations = entry.get("conversations", []) 749 for msg in conversations: 750 if msg.get("from") == "human": 751 prompt_text = msg.get("value", "").strip() 752 if prompt_text: 753 completed_prompts.add(prompt_text) 754 break # Only need the first human message 755 except json.JSONDecodeError: 756 continue 757 except Exception as e: 758 print(f" β οΈ Warning: Error reading {batch_file.name}: {e}") 759 760 return completed_prompts 761 762 def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]: 763 """ 764 Filter the dataset to exclude prompts that have already been completed. 765 766 Args: 767 completed_prompts: Set of prompt texts that have been completed 768 769 Returns: 770 Tuple of (filtered_dataset, skipped_indices) 771 """ 772 filtered_dataset = [] 773 skipped_indices = [] 774 775 for idx, entry in enumerate(self.dataset): 776 # Extract prompt from the dataset entry 777 prompt_text = entry.get("prompt", "").strip() 778 779 # Also check conversations format 780 if not prompt_text: 781 conversations = entry.get("conversations", []) 782 for msg in conversations: 783 role = msg.get("role") or msg.get("from") 784 if role in ("user", "human"): 785 prompt_text = (msg.get("content") or msg.get("value", "")).strip() 786 break 787 788 if prompt_text in completed_prompts: 789 skipped_indices.append(idx) 790 else: 791 # Keep original index for tracking 792 filtered_dataset.append((idx, entry)) 793 794 return filtered_dataset, skipped_indices 795 796 def run(self, resume: bool = False): 797 """ 798 Run the batch processing pipeline. 799 800 Args: 801 resume (bool): Whether to resume from checkpoint 802 """ 803 print("\n" + "=" * 70) 804 print("π Starting Batch Processing") 805 print("=" * 70) 806 807 # Smart resume: scan batch files by content to find completed prompts 808 completed_prompt_texts = set() 809 if resume: 810 completed_prompt_texts = self._scan_completed_prompts_by_content() 811 if completed_prompt_texts: 812 print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching") 813 814 # Filter dataset to only include unprocessed prompts 815 if resume and completed_prompt_texts: 816 filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts) 817 818 if not filtered_entries: 819 print("\nβ All prompts have already been processed!") 820 return 821 822 # Recreate batches from filtered entries (keeping original indices for tracking) 823 batches_to_process = [] 824 for i in range(0, len(filtered_entries), self.batch_size): 825 batch = filtered_entries[i:i + self.batch_size] 826 batches_to_process.append(batch) 827 828 self.batches = batches_to_process 829 830 # Print prominent resume summary 831 print("\n" + "=" * 70) 832 print("π RESUME SUMMARY") 833 print("=" * 70) 834 print(f" Original dataset size: {len(self.dataset):,} prompts") 835 print(f" Already completed: {len(skipped_indices):,} prompts") 836 print(" βββββββββββββββββββββββββββββββββββββββββ") 837 print(f" π― RESUMING WITH: {len(filtered_entries):,} prompts") 838 print(f" New batches created: {len(batches_to_process)}") 839 print("=" * 70 + "\n") 840 841 # Load existing checkpoint (so resume doesn't clobber prior progress) 842 checkpoint_data = self._load_checkpoint() 843 if checkpoint_data.get("run_name") != self.run_name: 844 checkpoint_data = { 845 "run_name": self.run_name, 846 "completed_prompts": [], 847 "batch_stats": {}, 848 "last_updated": None 849 } 850 851 # Prepare configuration for workers 852 config = { 853 "distribution": self.distribution, 854 "model": self.model, 855 "max_iterations": self.max_iterations, 856 "base_url": self.base_url, 857 "api_key": self.api_key, 858 "verbose": self.verbose, 859 "ephemeral_system_prompt": self.ephemeral_system_prompt, 860 "log_prefix_chars": self.log_prefix_chars, 861 "providers_allowed": self.providers_allowed, 862 "providers_ignored": self.providers_ignored, 863 "providers_order": self.providers_order, 864 "provider_sort": self.provider_sort, 865 "max_tokens": self.max_tokens, 866 "reasoning_config": self.reasoning_config, 867 "prefill_messages": self.prefill_messages, 868 } 869 870 # For backward compatibility, still track by index (but this is secondary to content matching) 871 completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) 872 873 # Aggregate statistics across all batches 874 total_tool_stats = {} 875 876 start_time = time.time() 877 878 print(f"\nπ§ Initializing {self.num_workers} worker processes...") 879 880 # Checkpoint writes happen in the parent process; keep a lock for safety. 881 checkpoint_lock = Lock() 882 883 # Process batches in parallel 884 with Pool(processes=self.num_workers) as pool: 885 # Create tasks for each batch 886 tasks = [ 887 ( 888 batch_num, 889 batch_data, 890 str(self.output_dir), # Convert Path to string for pickling 891 completed_prompts_set, 892 config 893 ) 894 for batch_num, batch_data in enumerate(self.batches) 895 ] 896 897 print(f"β Created {len(tasks)} batch tasks") 898 print("π Starting parallel batch processing...\n") 899 900 # Use rich Progress for better visual tracking with persistent bottom bar 901 # redirect_stdout/stderr lets rich manage all output so progress bar stays clean 902 results = [] 903 console = Console(force_terminal=True) 904 with Progress( 905 SpinnerColumn(), 906 TextColumn("[bold blue]π¦ Batches"), 907 BarColumn(bar_width=40), 908 MofNCompleteColumn(), 909 TextColumn("β’"), 910 TimeRemainingColumn(), 911 console=console, 912 refresh_per_second=2, 913 transient=False, 914 redirect_stdout=False, 915 redirect_stderr=False, 916 ) as progress: 917 task = progress.add_task("Processing", total=len(tasks)) 918 919 # Temporarily suppress DEBUG logging to avoid bar interference 920 root_logger = logging.getLogger() 921 original_level = root_logger.level 922 root_logger.setLevel(logging.WARNING) 923 924 try: 925 for result in pool.imap_unordered(_process_batch_worker, tasks): 926 results.append(result) 927 progress.update(task, advance=1) 928 929 # Incremental checkpoint update (so resume works after crash) 930 try: 931 batch_num = result.get('batch_num') 932 completed = result.get('completed_prompts', []) or [] 933 completed_prompts_set.update(completed) 934 935 if isinstance(batch_num, int): 936 checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = { 937 'processed': result.get('processed', 0), 938 'skipped': result.get('skipped', 0), 939 'discarded_no_reasoning': result.get('discarded_no_reasoning', 0), 940 } 941 942 checkpoint_data['completed_prompts'] = sorted(completed_prompts_set) 943 self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) 944 except Exception as ckpt_err: 945 # Don't fail the run if checkpoint write fails 946 print(f"β οΈ Warning: Failed to save incremental checkpoint: {ckpt_err}") 947 except Exception as e: 948 logger.error("Batch worker failed: %s", e, exc_info=True) 949 raise 950 finally: 951 root_logger.setLevel(original_level) 952 953 # Aggregate all batch statistics and update checkpoint 954 total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} 955 956 for batch_result in results: 957 # Aggregate tool stats 958 for tool_name, stats in batch_result.get("tool_stats", {}).items(): 959 if tool_name not in total_tool_stats: 960 total_tool_stats[tool_name] = { 961 "count": 0, 962 "success": 0, 963 "failure": 0 964 } 965 966 total_tool_stats[tool_name]["count"] += stats["count"] 967 total_tool_stats[tool_name]["success"] += stats["success"] 968 total_tool_stats[tool_name]["failure"] += stats["failure"] 969 970 # Aggregate reasoning stats 971 for key in total_reasoning_stats: 972 total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0) 973 974 # Save final checkpoint (best-effort; incremental writes already happened) 975 try: 976 checkpoint_data["completed_prompts"] = sorted(completed_prompts_set) 977 self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) 978 except Exception as ckpt_err: 979 print(f"Γ’Ε‘Β Γ―ΒΈΒ Warning: Failed to save final checkpoint: {ckpt_err}") 980 981 # Calculate success rates 982 for tool_name in total_tool_stats: 983 stats = total_tool_stats[tool_name] 984 total_calls = stats["success"] + stats["failure"] 985 if total_calls > 0: 986 stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) 987 stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) 988 else: 989 stats["success_rate"] = 0.0 990 stats["failure_rate"] = 0.0 991 992 # Combine ALL batch files in directory into a single trajectories.jsonl file 993 # This includes both old batches (from previous runs) and new batches (from resume) 994 # Also filter out corrupted entries (where model generated invalid tool names) 995 combined_file = self.output_dir / "trajectories.jsonl" 996 print(f"\nπ¦ Combining ALL batch files into {combined_file.name}...") 997 998 # Valid tools auto-derived from model_tools.py β no manual updates needed 999 VALID_TOOLS = ALL_POSSIBLE_TOOLS 1000 1001 total_entries = 0 1002 filtered_entries = 0 1003 batch_files_found = 0 1004 1005 # Find ALL batch files in the output directory (handles resume merging old + new) 1006 all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) 1007 1008 with open(combined_file, 'w', encoding='utf-8') as outfile: 1009 for batch_file in all_batch_files: 1010 batch_files_found += 1 1011 batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging 1012 1013 with open(batch_file, 'r', encoding='utf-8') as infile: 1014 for line in infile: 1015 total_entries += 1 1016 try: 1017 data = json.loads(line) 1018 tool_stats = data.get('tool_stats', {}) 1019 1020 # Check for invalid tool names (model hallucinations) 1021 invalid_tools = [k for k in tool_stats if k not in VALID_TOOLS] 1022 1023 if invalid_tools: 1024 filtered_entries += 1 1025 invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0] 1026 print(f" β οΈ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'") 1027 continue 1028 1029 outfile.write(line) 1030 except json.JSONDecodeError: 1031 filtered_entries += 1 1032 print(f" β οΈ Filtering invalid JSON entry (batch {batch_num})") 1033 1034 if filtered_entries > 0: 1035 print(f"β οΈ Filtered {filtered_entries} corrupted entries out of {total_entries} total") 1036 print(f"β Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)") 1037 1038 # Save final statistics 1039 final_stats = { 1040 "run_name": self.run_name, 1041 "distribution": self.distribution, 1042 "total_prompts": len(self.dataset), 1043 "total_batches": len(self.batches), 1044 "batch_size": self.batch_size, 1045 "model": self.model, 1046 "completed_at": datetime.now().isoformat(), 1047 "duration_seconds": round(time.time() - start_time, 2), 1048 "tool_statistics": total_tool_stats, 1049 "reasoning_statistics": total_reasoning_stats, 1050 } 1051 1052 with open(self.stats_file, 'w', encoding='utf-8') as f: 1053 json.dump(final_stats, f, indent=2, ensure_ascii=False) 1054 1055 # Print summary 1056 print("\n" + "=" * 70) 1057 print("π BATCH PROCESSING COMPLETE") 1058 print("=" * 70) 1059 print(f"β Prompts processed this run: {sum(r.get('processed', 0) for r in results)}") 1060 print(f"β Total trajectories in merged file: {total_entries - filtered_entries}") 1061 print(f"β Total batch files merged: {batch_files_found}") 1062 print(f"β±οΈ Total duration: {round(time.time() - start_time, 2)}s") 1063 print("\nπ Tool Usage Statistics:") 1064 print("-" * 70) 1065 1066 if total_tool_stats: 1067 # Sort by count descending 1068 sorted_tools = sorted( 1069 total_tool_stats.items(), 1070 key=lambda x: x[1]["count"], 1071 reverse=True 1072 ) 1073 1074 print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") 1075 print("-" * 70) 1076 for tool_name, stats in sorted_tools: 1077 print( 1078 f"{tool_name:<25} " 1079 f"{stats['count']:<10} " 1080 f"{stats['success']:<10} " 1081 f"{stats['failure']:<10} " 1082 f"{stats['success_rate']:.1f}%" 1083 ) 1084 else: 1085 print("No tool calls were made during this run.") 1086 1087 # Print reasoning coverage stats 1088 total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results) 1089 1090 print("\nπ§ Reasoning Coverage:") 1091 print("-" * 70) 1092 total_turns = total_reasoning_stats["total_assistant_turns"] 1093 with_reasoning = total_reasoning_stats["turns_with_reasoning"] 1094 without_reasoning = total_reasoning_stats["turns_without_reasoning"] 1095 if total_turns > 0: 1096 pct_with = round(with_reasoning / total_turns * 100, 1) 1097 pct_without = round(without_reasoning / total_turns * 100, 1) 1098 print(f" Total assistant turns: {total_turns:,}") 1099 print(f" With reasoning: {with_reasoning:,} ({pct_with}%)") 1100 print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)") 1101 else: 1102 print(" No assistant turns recorded.") 1103 if total_discarded > 0: 1104 print(f" π« Samples discarded (zero reasoning): {total_discarded:,}") 1105 1106 print(f"\nπΎ Results saved to: {self.output_dir}") 1107 print(" - Trajectories: trajectories.jsonl (combined)") 1108 print(" - Individual batches: batch_*.jsonl (for debugging)") 1109 print(f" - Statistics: {self.stats_file.name}") 1110 print(f" - Checkpoint: {self.checkpoint_file.name}") 1111 1112 1113 def main( 1114 dataset_file: str = None, 1115 batch_size: int = None, 1116 run_name: str = None, 1117 distribution: str = "default", 1118 model: str = "anthropic/claude-sonnet-4.6", 1119 api_key: str = None, 1120 base_url: str = "https://openrouter.ai/api/v1", 1121 max_turns: int = 10, 1122 num_workers: int = 4, 1123 resume: bool = False, 1124 verbose: bool = False, 1125 list_distributions: bool = False, 1126 ephemeral_system_prompt: str = None, 1127 log_prefix_chars: int = 100, 1128 providers_allowed: str = None, 1129 providers_ignored: str = None, 1130 providers_order: str = None, 1131 provider_sort: str = None, 1132 max_tokens: int = None, 1133 reasoning_effort: str = None, 1134 reasoning_disabled: bool = False, 1135 prefill_messages_file: str = None, 1136 max_samples: int = None, 1137 ): 1138 """ 1139 Run batch processing of agent prompts from a dataset. 1140 1141 Args: 1142 dataset_file (str): Path to JSONL file with 'prompt' field in each entry 1143 batch_size (int): Number of prompts per batch 1144 run_name (str): Name for this run (used for output and checkpointing) 1145 distribution (str): Toolset distribution to use (default: "default") 1146 model (str): Model name to use (default: "claude-opus-4-20250514") 1147 api_key (str): API key for model authentication 1148 base_url (str): Base URL for model API 1149 max_turns (int): Maximum number of tool calling iterations per prompt (default: 10) 1150 num_workers (int): Number of parallel worker processes (default: 4) 1151 resume (bool): Resume from checkpoint if run was interrupted (default: False) 1152 verbose (bool): Enable verbose logging (default: False) 1153 list_distributions (bool): List available toolset distributions and exit 1154 ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) 1155 log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) 1156 providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai") 1157 providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra") 1158 providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google") 1159 provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only) 1160 max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) 1161 reasoning_effort (str): OpenRouter reasoning effort level: "none", "minimal", "low", "medium", "high", "xhigh" (default: "medium") 1162 reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False) 1163 prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts) 1164 max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) 1165 1166 Examples: 1167 # Basic usage 1168 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run 1169 1170 # Resume interrupted run 1171 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume 1172 1173 # Use specific distribution 1174 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen 1175 1176 # With disabled reasoning and max tokens 1177 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ 1178 --reasoning_disabled --max_tokens=128000 1179 1180 # With prefill messages from file 1181 python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ 1182 --prefill_messages_file=configs/prefill_opus.json 1183 1184 # List available distributions 1185 python batch_runner.py --list_distributions 1186 """ 1187 # Handle list distributions 1188 if list_distributions: 1189 from toolset_distributions import print_distribution_info 1190 1191 print("π Available Toolset Distributions") 1192 print("=" * 70) 1193 1194 all_dists = list_distributions() 1195 for dist_name in sorted(all_dists.keys()): 1196 print_distribution_info(dist_name) 1197 1198 print("\nπ‘ Usage:") 1199 print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") 1200 print(" --run_name=my_run --distribution=<name>") 1201 return 1202 1203 # Validate required arguments 1204 if not dataset_file: 1205 print("β Error: --dataset_file is required") 1206 return 1207 1208 if not batch_size or batch_size < 1: 1209 print("β Error: --batch_size must be a positive integer") 1210 return 1211 1212 if not run_name: 1213 print("β Error: --run_name is required") 1214 return 1215 1216 # Parse provider preferences (comma-separated strings to lists) 1217 providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None 1218 providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None 1219 providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None 1220 1221 # Build reasoning_config from CLI flags 1222 # --reasoning_disabled takes priority, then --reasoning_effort, then default (medium) 1223 reasoning_config = None 1224 if reasoning_disabled: 1225 # Completely disable reasoning/thinking tokens 1226 reasoning_config = {"effort": "none"} 1227 print("π§ Reasoning: DISABLED (effort=none)") 1228 elif reasoning_effort: 1229 # Use specified effort level 1230 valid_efforts = ["none", "minimal", "low", "medium", "high", "xhigh"] 1231 if reasoning_effort not in valid_efforts: 1232 print(f"β Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}") 1233 return 1234 reasoning_config = {"enabled": True, "effort": reasoning_effort} 1235 print(f"π§ Reasoning effort: {reasoning_effort}") 1236 1237 # Load prefill messages from JSON file if provided 1238 prefill_messages = None 1239 if prefill_messages_file: 1240 try: 1241 with open(prefill_messages_file, 'r', encoding='utf-8') as f: 1242 prefill_messages = json.load(f) 1243 if not isinstance(prefill_messages, list): 1244 print("β Error: prefill_messages_file must contain a JSON array of messages") 1245 return 1246 print(f"π¬ Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}") 1247 except Exception as e: 1248 print(f"β Error loading prefill messages: {e}") 1249 return 1250 1251 # Initialize and run batch runner 1252 try: 1253 runner = BatchRunner( 1254 dataset_file=dataset_file, 1255 batch_size=batch_size, 1256 run_name=run_name, 1257 distribution=distribution, 1258 max_iterations=max_turns, 1259 base_url=base_url, 1260 api_key=api_key, 1261 model=model, 1262 num_workers=num_workers, 1263 verbose=verbose, 1264 ephemeral_system_prompt=ephemeral_system_prompt, 1265 log_prefix_chars=log_prefix_chars, 1266 providers_allowed=providers_allowed_list, 1267 providers_ignored=providers_ignored_list, 1268 providers_order=providers_order_list, 1269 provider_sort=provider_sort, 1270 max_tokens=max_tokens, 1271 reasoning_config=reasoning_config, 1272 prefill_messages=prefill_messages, 1273 max_samples=max_samples, 1274 ) 1275 1276 runner.run(resume=resume) 1277 1278 except Exception as e: 1279 print(f"\nβ Fatal error: {e}") 1280 if verbose: 1281 traceback.print_exc() 1282 return 1 1283 1284 1285 if __name__ == "__main__": 1286 fire.Fire(main) 1287