/ batch_runner.py
batch_runner.py
   1  #!/usr/bin/env python3
   2  """
   3  Batch Agent Runner
   4  
   5  This module provides parallel batch processing capabilities for running the agent
   6  across multiple prompts from a dataset. It includes:
   7  - Dataset loading and batching
   8  - Parallel batch processing with multiprocessing
   9  - Checkpointing for fault tolerance and resumption
  10  - Trajectory saving in the proper format (from/value pairs)
  11  - Tool usage statistics aggregation across all batches
  12  
  13  Usage:
  14      python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
  15      
  16      # Resume an interrupted run
  17      python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
  18      
  19      # Use a specific toolset distribution
  20      python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
  21  """
  22  
  23  import json
  24  import logging
  25  import os
  26  import time
  27  from pathlib import Path
  28  from typing import List, Dict, Any, Optional, Tuple
  29  from datetime import datetime
  30  from multiprocessing import Pool, Lock
  31  import traceback
  32  from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
  33  from rich.console import Console
  34  
  35  logger = logging.getLogger(__name__)
  36  import fire
  37  
  38  from run_agent import AIAgent
  39  from toolset_distributions import (
  40      list_distributions, 
  41      sample_toolsets_from_distribution,
  42      validate_distribution
  43  )
  44  from model_tools import TOOL_TO_TOOLSET_MAP
  45  
  46  
  47  # Global configuration for worker processes
  48  _WORKER_CONFIG = {}
  49  
  50  # All possible tools - auto-derived from the master mapping in model_tools.py.
  51  # This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP.
  52  # Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for
  53  # filtering corrupted entries during trajectory combination.
  54  ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys())
  55  
  56  # Default stats for tools that weren't used
  57  DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
  58  
  59  
  60  def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]:
  61      """
  62      Normalize tool_stats to include all possible tools with consistent schema.
  63      
  64      This ensures HuggingFace datasets can load the JSONL without schema mismatch errors.
  65      Tools that weren't used get zero counts.
  66      
  67      Args:
  68          tool_stats (Dict): Raw tool statistics from extraction
  69          
  70      Returns:
  71          Dict: Normalized tool statistics with all tools present
  72      """
  73      normalized = {}
  74      
  75      # Add all possible tools with defaults
  76      for tool in ALL_POSSIBLE_TOOLS:
  77          if tool in tool_stats:
  78              normalized[tool] = tool_stats[tool].copy()
  79          else:
  80              normalized[tool] = DEFAULT_TOOL_STATS.copy()
  81      
  82      # Also include any unexpected tools (in case new tools are added)
  83      for tool, stats in tool_stats.items():
  84          if tool not in normalized:
  85              normalized[tool] = stats.copy()
  86      
  87      return normalized
  88  
  89  
  90  def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]:
  91      """
  92      Normalize tool_error_counts to include all possible tools.
  93      
  94      Args:
  95          tool_error_counts (Dict): Raw error counts mapping
  96          
  97      Returns:
  98          Dict: Normalized error counts with all tools present
  99      """
 100      normalized = {}
 101      
 102      # Add all possible tools with zero defaults
 103      for tool in ALL_POSSIBLE_TOOLS:
 104          normalized[tool] = tool_error_counts.get(tool, 0)
 105      
 106      # Also include any unexpected tools
 107      for tool, count in tool_error_counts.items():
 108          if tool not in normalized:
 109              normalized[tool] = count
 110      
 111      return normalized
 112  
 113  
 114  def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
 115      """
 116      Extract tool usage statistics from message history.
 117      
 118      Args:
 119          messages (List[Dict]): Message history
 120          
 121      Returns:
 122          Dict: Tool statistics with counts and success/failure rates
 123      """
 124      tool_stats = {}
 125      
 126      # Track tool calls and their results
 127      tool_calls_map = {}  # Map tool_call_id to tool name
 128      
 129      for msg in messages:
 130          # Track tool calls from assistant messages
 131          if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
 132              for tool_call in msg["tool_calls"]:
 133                  if not tool_call or not isinstance(tool_call, dict): continue
 134                  tool_name = tool_call["function"]["name"]
 135                  tool_call_id = tool_call["id"]
 136                  
 137                  # Initialize stats for this tool if not exists
 138                  if tool_name not in tool_stats:
 139                      tool_stats[tool_name] = {
 140                          "count": 0,
 141                          "success": 0,
 142                          "failure": 0
 143                      }
 144                  
 145                  tool_stats[tool_name]["count"] += 1
 146                  tool_calls_map[tool_call_id] = tool_name
 147          
 148          # Track tool responses
 149          elif msg["role"] == "tool":
 150              tool_call_id = msg.get("tool_call_id", "")
 151              content = msg.get("content", "")
 152              
 153              # Determine if tool call was successful
 154              is_success = True
 155              try:
 156                  # Try to parse as JSON and check for actual error values
 157                  content_json = json.loads(content) if isinstance(content, str) else content
 158                  
 159                  if isinstance(content_json, dict):
 160                      # Check if error field exists AND has a non-null value
 161                      if "error" in content_json and content_json["error"] is not None:
 162                          is_success = False
 163                      
 164                      # Special handling for terminal tool responses
 165                      # Terminal wraps its response in a "content" field
 166                      if "content" in content_json and isinstance(content_json["content"], dict):
 167                          inner_content = content_json["content"]
 168                          # Check for actual error (non-null error field)
 169                          # Note: non-zero exit codes are not failures - the model can self-correct
 170                          if inner_content.get("error") is not None:
 171                              is_success = False
 172                      
 173                      # Check for "success": false pattern used by some tools
 174                      if content_json.get("success") is False:
 175                          is_success = False
 176                          
 177              except (json.JSONDecodeError, ValueError, TypeError):
 178                  # If not JSON, check if content is empty or explicitly states an error
 179                  # Note: We avoid simple substring matching to prevent false positives
 180                  if not content:
 181                      is_success = False
 182                  # Only mark as failure if it explicitly starts with "Error:" or "ERROR:"
 183                  elif content.strip().lower().startswith("error:"):
 184                      is_success = False
 185              
 186              # Update success/failure count
 187              if tool_call_id in tool_calls_map:
 188                  tool_name = tool_calls_map[tool_call_id]
 189                  if is_success:
 190                      tool_stats[tool_name]["success"] += 1
 191                  else:
 192                      tool_stats[tool_name]["failure"] += 1
 193      
 194      return tool_stats
 195  
 196  
 197  def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]:
 198      """
 199      Count how many assistant turns have reasoning vs no reasoning.
 200      
 201      Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field
 202      (native thinking tokens). Returns counts for tracking reasoning coverage.
 203      
 204      Args:
 205          messages: Message history
 206          
 207      Returns:
 208          Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning'
 209      """
 210      total = 0
 211      with_reasoning = 0
 212      
 213      for msg in messages:
 214          if msg.get("role") != "assistant":
 215              continue
 216          total += 1
 217          
 218          content = msg.get("content", "") or ""
 219          has_scratchpad = "<REASONING_SCRATCHPAD>" in content
 220          has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False
 221          
 222          if has_scratchpad or has_native_reasoning:
 223              with_reasoning += 1
 224      
 225      return {
 226          "total_assistant_turns": total,
 227          "turns_with_reasoning": with_reasoning,
 228          "turns_without_reasoning": total - with_reasoning,
 229          "has_any_reasoning": with_reasoning > 0,
 230      }
 231  
 232  
 233  def _process_single_prompt(
 234      prompt_index: int,
 235      prompt_data: Dict[str, Any],
 236      batch_num: int,
 237      config: Dict[str, Any]
 238  ) -> Dict[str, Any]:
 239      """
 240      Process a single prompt with the agent.
 241      
 242      Args:
 243          prompt_index (int): Index of prompt in dataset
 244          prompt_data (Dict): Prompt data containing 'prompt' field and optional 'image' field
 245          batch_num (int): Batch number
 246          config (Dict): Configuration dict with agent parameters
 247          
 248      Returns:
 249          Dict: Result containing trajectory, stats, and metadata
 250      """
 251      prompt = prompt_data["prompt"]
 252      task_id = f"task_{prompt_index}"
 253      
 254      # Per-prompt container image override: if the dataset row has an 'image' field,
 255      # register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona.
 256      container_image = prompt_data.get("image") or prompt_data.get("docker_image")
 257      if container_image:
 258          # Verify the image is accessible before spending tokens on the agent loop.
 259          # For Docker: check local cache, then try pulling.
 260          # For Modal: skip local check (Modal pulls server-side).
 261          env_type = os.getenv("TERMINAL_ENV", "local")
 262          if env_type == "docker":
 263              import subprocess as _sp
 264              try:
 265                  probe = _sp.run(
 266                      ["docker", "image", "inspect", container_image],
 267                      capture_output=True, timeout=10,
 268                  )
 269                  if probe.returncode != 0:
 270                      if config.get("verbose"):
 271                          print(f"   Prompt {prompt_index}: Pulling docker image {container_image}...", flush=True)
 272                      pull = _sp.run(
 273                          ["docker", "pull", container_image],
 274                          capture_output=True, text=True, timeout=600,
 275                      )
 276                      if pull.returncode != 0:
 277                          return {
 278                              "success": False,
 279                              "prompt_index": prompt_index,
 280                              "error": f"Docker image not available: {container_image}\n{pull.stderr[:500]}",
 281                              "trajectory": None,
 282                              "tool_stats": {},
 283                              "toolsets_used": [],
 284                              "metadata": {"batch_num": batch_num, "timestamp": datetime.now().isoformat()},
 285                          }
 286              except FileNotFoundError:
 287                  pass  # Docker CLI not installed β€” skip check (e.g., Modal backend)
 288              except Exception as img_err:
 289                  if config.get("verbose"):
 290                      print(f"   Prompt {prompt_index}: Docker image check failed: {img_err}", flush=True)
 291  
 292          from tools.terminal_tool import register_task_env_overrides
 293          overrides = {
 294              "docker_image": container_image,
 295              "modal_image": container_image,
 296              "singularity_image": f"docker://{container_image}",
 297              "daytona_image": container_image,
 298          }
 299          if prompt_data.get("cwd"):
 300              overrides["cwd"] = prompt_data["cwd"]
 301          register_task_env_overrides(task_id, overrides)
 302          if config.get("verbose"):
 303              print(f"   Prompt {prompt_index}: Using container image {container_image}")
 304      
 305      try:
 306          # Sample toolsets from distribution for this prompt
 307          selected_toolsets = sample_toolsets_from_distribution(config["distribution"])
 308          
 309          if config.get("verbose"):
 310              print(f"   Prompt {prompt_index}: Using toolsets {selected_toolsets}")
 311          
 312          # Initialize agent with sampled toolsets and log prefix for identification
 313          log_prefix = f"[B{batch_num}:P{prompt_index}]"
 314          agent = AIAgent(
 315              base_url=config.get("base_url"),
 316              api_key=config.get("api_key"),
 317              model=config["model"],
 318              max_iterations=config["max_iterations"],
 319              enabled_toolsets=selected_toolsets,
 320              save_trajectories=False,  # We handle saving ourselves
 321              verbose_logging=config.get("verbose", False),
 322              ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
 323              log_prefix_chars=config.get("log_prefix_chars", 100),
 324              log_prefix=log_prefix,
 325              providers_allowed=config.get("providers_allowed"),
 326              providers_ignored=config.get("providers_ignored"),
 327              providers_order=config.get("providers_order"),
 328              provider_sort=config.get("provider_sort"),
 329              max_tokens=config.get("max_tokens"),
 330              reasoning_config=config.get("reasoning_config"),
 331              prefill_messages=config.get("prefill_messages"),
 332              skip_context_files=True,  # Don't pollute trajectories with SOUL.md/AGENTS.md
 333              skip_memory=True,  # Don't use persistent memory in batch runs
 334          )
 335  
 336          # Run the agent with task_id to ensure each task gets its own isolated VM
 337          result = agent.run_conversation(prompt, task_id=task_id)
 338          
 339          # Extract tool usage statistics
 340          tool_stats = _extract_tool_stats(result["messages"])
 341          
 342          # Extract reasoning coverage stats
 343          reasoning_stats = _extract_reasoning_stats(result["messages"])
 344          
 345          # Convert to trajectory format (using existing method)
 346          trajectory = agent._convert_to_trajectory_format(
 347              result["messages"],
 348              prompt,
 349              result["completed"]
 350          )
 351          
 352          return {
 353              "success": True,
 354              "prompt_index": prompt_index,
 355              "trajectory": trajectory,
 356              "tool_stats": tool_stats,
 357              "reasoning_stats": reasoning_stats,
 358              "completed": result["completed"],
 359              "partial": result.get("partial", False),
 360              "api_calls": result["api_calls"],
 361              "toolsets_used": selected_toolsets,
 362              "metadata": {
 363                  "batch_num": batch_num,
 364                  "timestamp": datetime.now().isoformat(),
 365                  "model": config["model"]
 366              }
 367          }
 368      
 369      except Exception as e:
 370          print(f"❌ Error processing prompt {prompt_index}: {e}")
 371          if config.get("verbose"):
 372              traceback.print_exc()
 373          
 374          return {
 375              "success": False,
 376              "prompt_index": prompt_index,
 377              "error": str(e),
 378              "trajectory": None,
 379              "tool_stats": {},
 380              "toolsets_used": [],
 381              "metadata": {
 382                  "batch_num": batch_num,
 383                  "timestamp": datetime.now().isoformat()
 384              }
 385          }
 386  
 387  
 388  def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
 389      """
 390      Worker function to process a single batch of prompts.
 391      
 392      Args:
 393          args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config)
 394          
 395      Returns:
 396          Dict: Batch results with statistics
 397      """
 398      batch_num, batch_data, output_dir, completed_prompts_set, config = args
 399      
 400      output_dir = Path(output_dir)
 401      print(f"\nπŸ”„ Batch {batch_num}: Starting ({len(batch_data)} prompts)")
 402      
 403      # Output file for this batch
 404      batch_output_file = output_dir / f"batch_{batch_num}.jsonl"
 405      
 406      # Filter out already completed prompts
 407      prompts_to_process = [
 408          (idx, data) for idx, data in batch_data
 409          if idx not in completed_prompts_set
 410      ]
 411      
 412      if not prompts_to_process:
 413          print(f"βœ… Batch {batch_num}: Already completed (skipping)")
 414          return {
 415              "batch_num": batch_num,
 416              "processed": 0,
 417              "skipped": len(batch_data),
 418              "tool_stats": {},
 419              "completed_prompts": []
 420          }
 421      
 422      print(f"   Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)")
 423      
 424      # Initialize aggregated stats for this batch
 425      batch_tool_stats = {}
 426      batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
 427      completed_in_batch = []
 428      discarded_no_reasoning = 0
 429      
 430      # Process each prompt sequentially in this batch
 431      for prompt_index, prompt_data in prompts_to_process:
 432          # Process the prompt
 433          result = _process_single_prompt(
 434              prompt_index,
 435              prompt_data,
 436              batch_num,
 437              config
 438          )
 439          
 440          # Save trajectory if successful
 441          if result["success"] and result["trajectory"]:
 442              # Discard samples with zero reasoning across all turns
 443              reasoning = result.get("reasoning_stats", {})
 444              if not reasoning.get("has_any_reasoning", True):
 445                  print(f"   🚫 Prompt {prompt_index} discarded (no reasoning in any turn)")
 446                  discarded_no_reasoning += 1
 447                  completed_in_batch.append(prompt_index)
 448                  continue
 449              
 450              # Get and normalize tool stats for consistent schema across all entries
 451              raw_tool_stats = result.get("tool_stats", {})
 452              tool_stats = _normalize_tool_stats(raw_tool_stats)
 453              
 454              # Create normalized tool_error_counts mapping tool names to their failure counts
 455              raw_error_counts = {
 456                  tool_name: stats.get("failure", 0) 
 457                  for tool_name, stats in raw_tool_stats.items()
 458              }
 459              tool_error_counts = _normalize_tool_error_counts(raw_error_counts)
 460              
 461              trajectory_entry = {
 462                  "prompt_index": prompt_index,
 463                  "conversations": result["trajectory"],
 464                  "metadata": result["metadata"],
 465                  "completed": result["completed"],
 466                  "partial": result.get("partial", False),  # True if stopped due to invalid tool calls
 467                  "api_calls": result["api_calls"],
 468                  "toolsets_used": result["toolsets_used"],
 469                  "tool_stats": tool_stats,  # Full stats: {tool: {count, success, failure}} - normalized
 470                  "tool_error_counts": tool_error_counts  # Simple: {tool: failure_count} - normalized
 471              }
 472              
 473              # Append to batch output file
 474              with open(batch_output_file, 'a', encoding='utf-8') as f:
 475                  f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
 476          
 477          # Aggregate tool statistics
 478          for tool_name, stats in result.get("tool_stats", {}).items():
 479              if tool_name not in batch_tool_stats:
 480                  batch_tool_stats[tool_name] = {
 481                      "count": 0,
 482                      "success": 0,
 483                      "failure": 0
 484                  }
 485              
 486              batch_tool_stats[tool_name]["count"] += stats["count"]
 487              batch_tool_stats[tool_name]["success"] += stats["success"]
 488              batch_tool_stats[tool_name]["failure"] += stats["failure"]
 489          
 490          # Aggregate reasoning stats
 491          for key in batch_reasoning_stats:
 492              batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0)
 493          
 494          # Only mark as completed if successfully saved (failed prompts can be retried on resume)
 495          if result["success"] and result["trajectory"]:
 496              completed_in_batch.append(prompt_index)
 497              status = "⚠️  partial" if result.get("partial") else "βœ…"
 498              print(f"   {status} Prompt {prompt_index} completed")
 499          else:
 500              print(f"   ❌ Prompt {prompt_index} failed (will retry on resume)")
 501      
 502      print(f"βœ… Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
 503      
 504      return {
 505          "batch_num": batch_num,
 506          "processed": len(prompts_to_process),
 507          "skipped": len(batch_data) - len(prompts_to_process),
 508          "tool_stats": batch_tool_stats,
 509          "reasoning_stats": batch_reasoning_stats,
 510          "discarded_no_reasoning": discarded_no_reasoning,
 511          "completed_prompts": completed_in_batch
 512      }
 513  
 514  
 515  class BatchRunner:
 516      """
 517      Manages batch processing of agent prompts with checkpointing and statistics.
 518      """
 519      
 520      def __init__(
 521          self,
 522          dataset_file: str,
 523          batch_size: int,
 524          run_name: str,
 525          distribution: str = "default",
 526          max_iterations: int = 10,
 527          base_url: str = None,
 528          api_key: str = None,
 529          model: str = "claude-opus-4-20250514",
 530          num_workers: int = 4,
 531          verbose: bool = False,
 532          ephemeral_system_prompt: str = None,
 533          log_prefix_chars: int = 100,
 534          providers_allowed: List[str] = None,
 535          providers_ignored: List[str] = None,
 536          providers_order: List[str] = None,
 537          provider_sort: str = None,
 538          max_tokens: int = None,
 539          reasoning_config: Dict[str, Any] = None,
 540          prefill_messages: List[Dict[str, Any]] = None,
 541          max_samples: int = None,
 542      ):
 543          """
 544          Initialize the batch runner.
 545  
 546          Args:
 547              dataset_file (str): Path to the dataset JSONL file with 'prompt' field
 548              batch_size (int): Number of prompts per batch
 549              run_name (str): Name for this run (used for checkpointing and output)
 550              distribution (str): Toolset distribution to use (default: "default")
 551              max_iterations (int): Max iterations per agent run
 552              base_url (str): Base URL for model API
 553              api_key (str): API key for model
 554              model (str): Model name to use
 555              num_workers (int): Number of parallel workers
 556              verbose (bool): Enable verbose logging
 557              ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
 558              log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
 559              providers_allowed (List[str]): OpenRouter providers to allow (optional)
 560              providers_ignored (List[str]): OpenRouter providers to ignore (optional)
 561              providers_order (List[str]): OpenRouter providers to try in order (optional)
 562              provider_sort (str): Sort providers by price/throughput/latency (optional)
 563              max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
 564              reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking)
 565              prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming).
 566                  NOTE: Anthropic Sonnet 4.6+ and Opus 4.6+ reject a trailing assistant-role prefill
 567                  (400 error).  For those models use output_config.format or structured-output
 568                  schemas instead.  Safe here for user-role priming and for older Claude / non-Claude models.
 569              max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
 570          """
 571          self.dataset_file = Path(dataset_file)
 572          self.batch_size = batch_size
 573          self.run_name = run_name
 574          self.distribution = distribution
 575          self.max_iterations = max_iterations
 576          self.base_url = base_url
 577          self.api_key = api_key
 578          self.model = model
 579          self.num_workers = num_workers
 580          self.verbose = verbose
 581          self.ephemeral_system_prompt = ephemeral_system_prompt
 582          self.log_prefix_chars = log_prefix_chars
 583          self.providers_allowed = providers_allowed
 584          self.providers_ignored = providers_ignored
 585          self.providers_order = providers_order
 586          self.provider_sort = provider_sort
 587          self.max_tokens = max_tokens
 588          self.reasoning_config = reasoning_config
 589          self.prefill_messages = prefill_messages
 590          self.max_samples = max_samples
 591          
 592          # Validate distribution
 593          if not validate_distribution(distribution):
 594              raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}")
 595          
 596          # Setup output directory
 597          self.output_dir = Path("data") / run_name
 598          self.output_dir.mkdir(parents=True, exist_ok=True)
 599          
 600          # Checkpoint file
 601          self.checkpoint_file = self.output_dir / "checkpoint.json"
 602          
 603          # Statistics file
 604          self.stats_file = self.output_dir / "statistics.json"
 605          
 606          # Load dataset (and optionally truncate to max_samples)
 607          self.dataset = self._load_dataset()
 608          if self.max_samples and self.max_samples < len(self.dataset):
 609              full_count = len(self.dataset)
 610              self.dataset = self.dataset[:self.max_samples]
 611              print(f"βœ‚οΈ  Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)")
 612          
 613          # Create batches
 614          self.batches = self._create_batches()
 615          
 616          print("πŸ“Š Batch Runner Initialized")
 617          print(f"   Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
 618          print(f"   Batch size: {self.batch_size}")
 619          print(f"   Total batches: {len(self.batches)}")
 620          print(f"   Run name: {self.run_name}")
 621          print(f"   Distribution: {self.distribution}")
 622          print(f"   Output directory: {self.output_dir}")
 623          print(f"   Workers: {self.num_workers}")
 624          if self.ephemeral_system_prompt:
 625              prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
 626              print(f"   πŸ”’ Ephemeral system prompt: '{prompt_preview}'")
 627      
 628      def _load_dataset(self) -> List[Dict[str, Any]]:
 629          """
 630          Load dataset from JSONL file.
 631          
 632          Returns:
 633              List[Dict]: List of dataset entries
 634          """
 635          if not self.dataset_file.exists():
 636              raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}")
 637          
 638          dataset = []
 639          with open(self.dataset_file, 'r', encoding='utf-8') as f:
 640              for line_num, line in enumerate(f, 1):
 641                  line = line.strip()
 642                  if not line:
 643                      continue
 644                  
 645                  try:
 646                      entry = json.loads(line)
 647                      if 'prompt' not in entry:
 648                          print(f"⚠️  Warning: Line {line_num} missing 'prompt' field, skipping")
 649                          continue
 650                      dataset.append(entry)
 651                  except json.JSONDecodeError as e:
 652                      print(f"⚠️  Warning: Invalid JSON on line {line_num}: {e}")
 653                      continue
 654          
 655          if not dataset:
 656              raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}")
 657          
 658          return dataset
 659      
 660      def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]:
 661          """
 662          Split dataset into batches with indices.
 663          
 664          Returns:
 665              List of batches, where each batch is a list of (index, entry) tuples
 666          """
 667          batches = []
 668          for i in range(0, len(self.dataset), self.batch_size):
 669              batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)]
 670              batches.append(batch)
 671          
 672          return batches
 673      
 674      def _load_checkpoint(self) -> Dict[str, Any]:
 675          """
 676          Load checkpoint data if it exists.
 677          
 678          Returns:
 679              Dict: Checkpoint data with completed prompt indices
 680          """
 681          if not self.checkpoint_file.exists():
 682              return {
 683                  "run_name": self.run_name,
 684                  "completed_prompts": [],
 685                  "batch_stats": {},
 686                  "last_updated": None
 687              }
 688          
 689          try:
 690              with open(self.checkpoint_file, 'r', encoding='utf-8') as f:
 691                  return json.load(f)
 692          except Exception as e:
 693              print(f"⚠️  Warning: Failed to load checkpoint: {e}")
 694              return {
 695                  "run_name": self.run_name,
 696                  "completed_prompts": [],
 697                  "batch_stats": {},
 698                  "last_updated": None
 699              }
 700      
 701      def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None):
 702          """
 703          Save checkpoint data.
 704          
 705          Args:
 706              checkpoint_data (Dict): Checkpoint data to save
 707              lock (Lock): Optional lock for thread-safe access
 708          """
 709          checkpoint_data["last_updated"] = datetime.now().isoformat()
 710  
 711          from utils import atomic_json_write
 712          if lock:
 713              with lock:
 714                  atomic_json_write(self.checkpoint_file, checkpoint_data)
 715          else:
 716              atomic_json_write(self.checkpoint_file, checkpoint_data)
 717      
 718      def _scan_completed_prompts_by_content(self) -> set:
 719          """
 720          Scan all batch files and extract completed prompts by their actual content.
 721          
 722          This provides a more robust resume mechanism that matches on prompt text
 723          rather than indices, allowing recovery even if indices don't match.
 724          
 725          Returns:
 726              set: Set of prompt texts that have been successfully processed
 727          """
 728          completed_prompts = set()
 729          batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
 730          
 731          if not batch_files:
 732              return completed_prompts
 733          
 734          print(f"πŸ“‚ Scanning {len(batch_files)} batch files for completed prompts...")
 735          
 736          for batch_file in batch_files:
 737              try:
 738                  with open(batch_file, 'r', encoding='utf-8') as f:
 739                      for line in f:
 740                          try:
 741                              entry = json.loads(line.strip())
 742                              
 743                              # Skip failed entries - we want to retry these
 744                              if entry.get("failed", False):
 745                                  continue
 746                              
 747                              # Extract the human/user prompt from conversations
 748                              conversations = entry.get("conversations", [])
 749                              for msg in conversations:
 750                                  if msg.get("from") == "human":
 751                                      prompt_text = msg.get("value", "").strip()
 752                                      if prompt_text:
 753                                          completed_prompts.add(prompt_text)
 754                                      break  # Only need the first human message
 755                          except json.JSONDecodeError:
 756                              continue
 757              except Exception as e:
 758                  print(f"  ⚠️  Warning: Error reading {batch_file.name}: {e}")
 759          
 760          return completed_prompts
 761      
 762      def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]:
 763          """
 764          Filter the dataset to exclude prompts that have already been completed.
 765          
 766          Args:
 767              completed_prompts: Set of prompt texts that have been completed
 768              
 769          Returns:
 770              Tuple of (filtered_dataset, skipped_indices)
 771          """
 772          filtered_dataset = []
 773          skipped_indices = []
 774          
 775          for idx, entry in enumerate(self.dataset):
 776              # Extract prompt from the dataset entry
 777              prompt_text = entry.get("prompt", "").strip()
 778              
 779              # Also check conversations format
 780              if not prompt_text:
 781                  conversations = entry.get("conversations", [])
 782                  for msg in conversations:
 783                      role = msg.get("role") or msg.get("from")
 784                      if role in ("user", "human"):
 785                          prompt_text = (msg.get("content") or msg.get("value", "")).strip()
 786                          break
 787              
 788              if prompt_text in completed_prompts:
 789                  skipped_indices.append(idx)
 790              else:
 791                  # Keep original index for tracking
 792                  filtered_dataset.append((idx, entry))
 793          
 794          return filtered_dataset, skipped_indices
 795      
 796      def run(self, resume: bool = False):
 797          """
 798          Run the batch processing pipeline.
 799          
 800          Args:
 801              resume (bool): Whether to resume from checkpoint
 802          """
 803          print("\n" + "=" * 70)
 804          print("πŸš€ Starting Batch Processing")
 805          print("=" * 70)
 806          
 807          # Smart resume: scan batch files by content to find completed prompts
 808          completed_prompt_texts = set()
 809          if resume:
 810              completed_prompt_texts = self._scan_completed_prompts_by_content()
 811              if completed_prompt_texts:
 812                  print(f"   Found {len(completed_prompt_texts)} already-completed prompts by content matching")
 813          
 814          # Filter dataset to only include unprocessed prompts
 815          if resume and completed_prompt_texts:
 816              filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts)
 817              
 818              if not filtered_entries:
 819                  print("\nβœ… All prompts have already been processed!")
 820                  return
 821              
 822              # Recreate batches from filtered entries (keeping original indices for tracking)
 823              batches_to_process = []
 824              for i in range(0, len(filtered_entries), self.batch_size):
 825                  batch = filtered_entries[i:i + self.batch_size]
 826                  batches_to_process.append(batch)
 827              
 828              self.batches = batches_to_process
 829              
 830              # Print prominent resume summary
 831              print("\n" + "=" * 70)
 832              print("πŸ“Š RESUME SUMMARY")
 833              print("=" * 70)
 834              print(f"   Original dataset size:     {len(self.dataset):,} prompts")
 835              print(f"   Already completed:         {len(skipped_indices):,} prompts")
 836              print("   ─────────────────────────────────────────")
 837              print(f"   🎯 RESUMING WITH:          {len(filtered_entries):,} prompts")
 838              print(f"   New batches created:       {len(batches_to_process)}")
 839              print("=" * 70 + "\n")
 840          
 841          # Load existing checkpoint (so resume doesn't clobber prior progress)
 842          checkpoint_data = self._load_checkpoint()
 843          if checkpoint_data.get("run_name") != self.run_name:
 844              checkpoint_data = {
 845                  "run_name": self.run_name,
 846                  "completed_prompts": [],
 847                  "batch_stats": {},
 848                  "last_updated": None
 849              }
 850          
 851          # Prepare configuration for workers
 852          config = {
 853              "distribution": self.distribution,
 854              "model": self.model,
 855              "max_iterations": self.max_iterations,
 856              "base_url": self.base_url,
 857              "api_key": self.api_key,
 858              "verbose": self.verbose,
 859              "ephemeral_system_prompt": self.ephemeral_system_prompt,
 860              "log_prefix_chars": self.log_prefix_chars,
 861              "providers_allowed": self.providers_allowed,
 862              "providers_ignored": self.providers_ignored,
 863              "providers_order": self.providers_order,
 864              "provider_sort": self.provider_sort,
 865              "max_tokens": self.max_tokens,
 866              "reasoning_config": self.reasoning_config,
 867              "prefill_messages": self.prefill_messages,
 868          }
 869          
 870          # For backward compatibility, still track by index (but this is secondary to content matching)
 871          completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
 872          
 873          # Aggregate statistics across all batches
 874          total_tool_stats = {}
 875          
 876          start_time = time.time()
 877          
 878          print(f"\nπŸ”§ Initializing {self.num_workers} worker processes...")
 879          
 880          # Checkpoint writes happen in the parent process; keep a lock for safety.
 881          checkpoint_lock = Lock()
 882  
 883          # Process batches in parallel
 884          with Pool(processes=self.num_workers) as pool:
 885              # Create tasks for each batch
 886              tasks = [
 887                  (
 888                      batch_num,
 889                      batch_data,
 890                      str(self.output_dir),  # Convert Path to string for pickling
 891                      completed_prompts_set,
 892                      config
 893                  )
 894                  for batch_num, batch_data in enumerate(self.batches)
 895              ]
 896              
 897              print(f"βœ… Created {len(tasks)} batch tasks")
 898              print("πŸš€ Starting parallel batch processing...\n")
 899              
 900              # Use rich Progress for better visual tracking with persistent bottom bar
 901              # redirect_stdout/stderr lets rich manage all output so progress bar stays clean
 902              results = []
 903              console = Console(force_terminal=True)
 904              with Progress(
 905                  SpinnerColumn(),
 906                  TextColumn("[bold blue]πŸ“¦ Batches"),
 907                  BarColumn(bar_width=40),
 908                  MofNCompleteColumn(),
 909                  TextColumn("β€’"),
 910                  TimeRemainingColumn(),
 911                  console=console,
 912                  refresh_per_second=2,
 913                  transient=False,
 914                  redirect_stdout=False,
 915                  redirect_stderr=False,
 916              ) as progress:
 917                  task = progress.add_task("Processing", total=len(tasks))
 918                  
 919                  # Temporarily suppress DEBUG logging to avoid bar interference
 920                  root_logger = logging.getLogger()
 921                  original_level = root_logger.level
 922                  root_logger.setLevel(logging.WARNING)
 923                  
 924                  try:
 925                      for result in pool.imap_unordered(_process_batch_worker, tasks):
 926                          results.append(result)
 927                          progress.update(task, advance=1)
 928  
 929                          # Incremental checkpoint update (so resume works after crash)
 930                          try:
 931                              batch_num = result.get('batch_num')
 932                              completed = result.get('completed_prompts', []) or []
 933                              completed_prompts_set.update(completed)
 934  
 935                              if isinstance(batch_num, int):
 936                                  checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
 937                                      'processed': result.get('processed', 0),
 938                                      'skipped': result.get('skipped', 0),
 939                                      'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
 940                                  }
 941  
 942                              checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
 943                              self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
 944                          except Exception as ckpt_err:
 945                              # Don't fail the run if checkpoint write fails
 946                              print(f"⚠️  Warning: Failed to save incremental checkpoint: {ckpt_err}")
 947                  except Exception as e:
 948                      logger.error("Batch worker failed: %s", e, exc_info=True)
 949                      raise
 950                  finally:
 951                      root_logger.setLevel(original_level)
 952          
 953          # Aggregate all batch statistics and update checkpoint
 954          total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
 955  
 956          for batch_result in results:
 957              # Aggregate tool stats
 958              for tool_name, stats in batch_result.get("tool_stats", {}).items():
 959                  if tool_name not in total_tool_stats:
 960                      total_tool_stats[tool_name] = {
 961                          "count": 0,
 962                          "success": 0,
 963                          "failure": 0
 964                      }
 965                  
 966                  total_tool_stats[tool_name]["count"] += stats["count"]
 967                  total_tool_stats[tool_name]["success"] += stats["success"]
 968                  total_tool_stats[tool_name]["failure"] += stats["failure"]
 969              
 970              # Aggregate reasoning stats
 971              for key in total_reasoning_stats:
 972                  total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
 973          
 974          # Save final checkpoint (best-effort; incremental writes already happened)
 975          try:
 976              checkpoint_data["completed_prompts"] = sorted(completed_prompts_set)
 977              self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
 978          except Exception as ckpt_err:
 979              print(f"Òő ï¸  Warning: Failed to save final checkpoint: {ckpt_err}")
 980          
 981          # Calculate success rates
 982          for tool_name in total_tool_stats:
 983              stats = total_tool_stats[tool_name]
 984              total_calls = stats["success"] + stats["failure"]
 985              if total_calls > 0:
 986                  stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
 987                  stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
 988              else:
 989                  stats["success_rate"] = 0.0
 990                  stats["failure_rate"] = 0.0
 991          
 992          # Combine ALL batch files in directory into a single trajectories.jsonl file
 993          # This includes both old batches (from previous runs) and new batches (from resume)
 994          # Also filter out corrupted entries (where model generated invalid tool names)
 995          combined_file = self.output_dir / "trajectories.jsonl"
 996          print(f"\nπŸ“¦ Combining ALL batch files into {combined_file.name}...")
 997          
 998          # Valid tools auto-derived from model_tools.py β€” no manual updates needed
 999          VALID_TOOLS = ALL_POSSIBLE_TOOLS
1000          
1001          total_entries = 0
1002          filtered_entries = 0
1003          batch_files_found = 0
1004          
1005          # Find ALL batch files in the output directory (handles resume merging old + new)
1006          all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
1007          
1008          with open(combined_file, 'w', encoding='utf-8') as outfile:
1009              for batch_file in all_batch_files:
1010                  batch_files_found += 1
1011                  batch_num = batch_file.stem.split("_")[1]  # Extract batch number for logging
1012                  
1013                  with open(batch_file, 'r', encoding='utf-8') as infile:
1014                      for line in infile:
1015                          total_entries += 1
1016                          try:
1017                              data = json.loads(line)
1018                              tool_stats = data.get('tool_stats', {})
1019                              
1020                              # Check for invalid tool names (model hallucinations)
1021                              invalid_tools = [k for k in tool_stats if k not in VALID_TOOLS]
1022                              
1023                              if invalid_tools:
1024                                  filtered_entries += 1
1025                                  invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
1026                                  print(f"   ⚠️  Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
1027                                  continue
1028                              
1029                              outfile.write(line)
1030                          except json.JSONDecodeError:
1031                              filtered_entries += 1
1032                              print(f"   ⚠️  Filtering invalid JSON entry (batch {batch_num})")
1033          
1034          if filtered_entries > 0:
1035              print(f"⚠️  Filtered {filtered_entries} corrupted entries out of {total_entries} total")
1036          print(f"βœ… Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
1037          
1038          # Save final statistics
1039          final_stats = {
1040              "run_name": self.run_name,
1041              "distribution": self.distribution,
1042              "total_prompts": len(self.dataset),
1043              "total_batches": len(self.batches),
1044              "batch_size": self.batch_size,
1045              "model": self.model,
1046              "completed_at": datetime.now().isoformat(),
1047              "duration_seconds": round(time.time() - start_time, 2),
1048              "tool_statistics": total_tool_stats,
1049              "reasoning_statistics": total_reasoning_stats,
1050          }
1051          
1052          with open(self.stats_file, 'w', encoding='utf-8') as f:
1053              json.dump(final_stats, f, indent=2, ensure_ascii=False)
1054          
1055          # Print summary
1056          print("\n" + "=" * 70)
1057          print("πŸ“Š BATCH PROCESSING COMPLETE")
1058          print("=" * 70)
1059          print(f"βœ… Prompts processed this run: {sum(r.get('processed', 0) for r in results)}")
1060          print(f"βœ… Total trajectories in merged file: {total_entries - filtered_entries}")
1061          print(f"βœ… Total batch files merged: {batch_files_found}")
1062          print(f"⏱️  Total duration: {round(time.time() - start_time, 2)}s")
1063          print("\nπŸ“ˆ Tool Usage Statistics:")
1064          print("-" * 70)
1065          
1066          if total_tool_stats:
1067              # Sort by count descending
1068              sorted_tools = sorted(
1069                  total_tool_stats.items(),
1070                  key=lambda x: x[1]["count"],
1071                  reverse=True
1072              )
1073              
1074              print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
1075              print("-" * 70)
1076              for tool_name, stats in sorted_tools:
1077                  print(
1078                      f"{tool_name:<25} "
1079                      f"{stats['count']:<10} "
1080                      f"{stats['success']:<10} "
1081                      f"{stats['failure']:<10} "
1082                      f"{stats['success_rate']:.1f}%"
1083                  )
1084          else:
1085              print("No tool calls were made during this run.")
1086          
1087          # Print reasoning coverage stats
1088          total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results)
1089          
1090          print("\n🧠 Reasoning Coverage:")
1091          print("-" * 70)
1092          total_turns = total_reasoning_stats["total_assistant_turns"]
1093          with_reasoning = total_reasoning_stats["turns_with_reasoning"]
1094          without_reasoning = total_reasoning_stats["turns_without_reasoning"]
1095          if total_turns > 0:
1096              pct_with = round(with_reasoning / total_turns * 100, 1)
1097              pct_without = round(without_reasoning / total_turns * 100, 1)
1098              print(f"   Total assistant turns:    {total_turns:,}")
1099              print(f"   With reasoning:           {with_reasoning:,} ({pct_with}%)")
1100              print(f"   Without reasoning:        {without_reasoning:,} ({pct_without}%)")
1101          else:
1102              print("   No assistant turns recorded.")
1103          if total_discarded > 0:
1104              print(f"   🚫 Samples discarded (zero reasoning): {total_discarded:,}")
1105          
1106          print(f"\nπŸ’Ύ Results saved to: {self.output_dir}")
1107          print("   - Trajectories: trajectories.jsonl (combined)")
1108          print("   - Individual batches: batch_*.jsonl (for debugging)")
1109          print(f"   - Statistics: {self.stats_file.name}")
1110          print(f"   - Checkpoint: {self.checkpoint_file.name}")
1111  
1112  
1113  def main(
1114      dataset_file: str = None,
1115      batch_size: int = None,
1116      run_name: str = None,
1117      distribution: str = "default",
1118      model: str = "anthropic/claude-sonnet-4.6",
1119      api_key: str = None,
1120      base_url: str = "https://openrouter.ai/api/v1",
1121      max_turns: int = 10,
1122      num_workers: int = 4,
1123      resume: bool = False,
1124      verbose: bool = False,
1125      list_distributions: bool = False,
1126      ephemeral_system_prompt: str = None,
1127      log_prefix_chars: int = 100,
1128      providers_allowed: str = None,
1129      providers_ignored: str = None,
1130      providers_order: str = None,
1131      provider_sort: str = None,
1132      max_tokens: int = None,
1133      reasoning_effort: str = None,
1134      reasoning_disabled: bool = False,
1135      prefill_messages_file: str = None,
1136      max_samples: int = None,
1137  ):
1138      """
1139      Run batch processing of agent prompts from a dataset.
1140  
1141      Args:
1142          dataset_file (str): Path to JSONL file with 'prompt' field in each entry
1143          batch_size (int): Number of prompts per batch
1144          run_name (str): Name for this run (used for output and checkpointing)
1145          distribution (str): Toolset distribution to use (default: "default")
1146          model (str): Model name to use (default: "claude-opus-4-20250514")
1147          api_key (str): API key for model authentication
1148          base_url (str): Base URL for model API
1149          max_turns (int): Maximum number of tool calling iterations per prompt (default: 10)
1150          num_workers (int): Number of parallel worker processes (default: 4)
1151          resume (bool): Resume from checkpoint if run was interrupted (default: False)
1152          verbose (bool): Enable verbose logging (default: False)
1153          list_distributions (bool): List available toolset distributions and exit
1154          ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
1155          log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
1156          providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai")
1157          providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
1158          providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
1159          provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
1160          max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
1161          reasoning_effort (str): OpenRouter reasoning effort level: "none", "minimal", "low", "medium", "high", "xhigh" (default: "medium")
1162          reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
1163          prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
1164          max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
1165          
1166      Examples:
1167          # Basic usage
1168          python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
1169          
1170          # Resume interrupted run
1171          python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
1172          
1173          # Use specific distribution
1174          python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
1175          
1176          # With disabled reasoning and max tokens
1177          python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
1178                                 --reasoning_disabled --max_tokens=128000
1179          
1180          # With prefill messages from file
1181          python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
1182                                 --prefill_messages_file=configs/prefill_opus.json
1183          
1184          # List available distributions
1185          python batch_runner.py --list_distributions
1186      """
1187      # Handle list distributions
1188      if list_distributions:
1189          from toolset_distributions import print_distribution_info
1190  
1191          print("πŸ“Š Available Toolset Distributions")
1192          print("=" * 70)
1193  
1194          all_dists = list_distributions()
1195          for dist_name in sorted(all_dists.keys()):
1196              print_distribution_info(dist_name)
1197          
1198          print("\nπŸ’‘ Usage:")
1199          print("  python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\")
1200          print("                         --run_name=my_run --distribution=<name>")
1201          return
1202      
1203      # Validate required arguments
1204      if not dataset_file:
1205          print("❌ Error: --dataset_file is required")
1206          return
1207      
1208      if not batch_size or batch_size < 1:
1209          print("❌ Error: --batch_size must be a positive integer")
1210          return
1211      
1212      if not run_name:
1213          print("❌ Error: --run_name is required")
1214          return
1215      
1216      # Parse provider preferences (comma-separated strings to lists)
1217      providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None
1218      providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
1219      providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
1220      
1221      # Build reasoning_config from CLI flags
1222      # --reasoning_disabled takes priority, then --reasoning_effort, then default (medium)
1223      reasoning_config = None
1224      if reasoning_disabled:
1225          # Completely disable reasoning/thinking tokens
1226          reasoning_config = {"effort": "none"}
1227          print("🧠 Reasoning: DISABLED (effort=none)")
1228      elif reasoning_effort:
1229          # Use specified effort level
1230          valid_efforts = ["none", "minimal", "low", "medium", "high", "xhigh"]
1231          if reasoning_effort not in valid_efforts:
1232              print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}")
1233              return
1234          reasoning_config = {"enabled": True, "effort": reasoning_effort}
1235          print(f"🧠 Reasoning effort: {reasoning_effort}")
1236      
1237      # Load prefill messages from JSON file if provided
1238      prefill_messages = None
1239      if prefill_messages_file:
1240          try:
1241              with open(prefill_messages_file, 'r', encoding='utf-8') as f:
1242                  prefill_messages = json.load(f)
1243              if not isinstance(prefill_messages, list):
1244                  print("❌ Error: prefill_messages_file must contain a JSON array of messages")
1245                  return
1246              print(f"πŸ’¬ Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}")
1247          except Exception as e:
1248              print(f"❌ Error loading prefill messages: {e}")
1249              return
1250      
1251      # Initialize and run batch runner
1252      try:
1253          runner = BatchRunner(
1254              dataset_file=dataset_file,
1255              batch_size=batch_size,
1256              run_name=run_name,
1257              distribution=distribution,
1258              max_iterations=max_turns,
1259              base_url=base_url,
1260              api_key=api_key,
1261              model=model,
1262              num_workers=num_workers,
1263              verbose=verbose,
1264              ephemeral_system_prompt=ephemeral_system_prompt,
1265              log_prefix_chars=log_prefix_chars,
1266              providers_allowed=providers_allowed_list,
1267              providers_ignored=providers_ignored_list,
1268              providers_order=providers_order_list,
1269              provider_sort=provider_sort,
1270              max_tokens=max_tokens,
1271              reasoning_config=reasoning_config,
1272              prefill_messages=prefill_messages,
1273              max_samples=max_samples,
1274          )
1275  
1276          runner.run(resume=resume)
1277      
1278      except Exception as e:
1279          print(f"\n❌ Fatal error: {e}")
1280          if verbose:
1281              traceback.print_exc()
1282          return 1
1283  
1284  
1285  if __name__ == "__main__":
1286      fire.Fire(main)
1287