/ trajectory_compressor.py
trajectory_compressor.py
   1  #!/usr/bin/env python3
   2  """
   3  Trajectory Compressor
   4  
   5  Post-processes completed agent trajectories to compress them within a target
   6  token budget while preserving training signal quality.
   7  
   8  Compression Strategy:
   9  1. Protect first turns (system, human, first gpt, first tool)
  10  2. Protect last N turns (final actions and conclusions)
  11  3. Compress MIDDLE turns only, starting from 2nd tool response
  12  4. Compress only as much as needed to fit under target
  13  5. Replace compressed region with a single human summary message
  14  6. Keep remaining tool calls intact (model continues working after summary)
  15  
  16  Usage:
  17      # Compress a directory of JSONL files
  18      python trajectory_compressor.py --input=data/my_run
  19      
  20      # Compress a single JSONL file
  21      python trajectory_compressor.py --input=data/trajectories.jsonl
  22      
  23      # Compress 15% sample of a file
  24      python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
  25      
  26      # Compress with custom output and token target
  27      python trajectory_compressor.py --input=data/trajectories.jsonl --output=compressed.jsonl --target_max_tokens=16000
  28      
  29      # Compress 10% sample from a directory
  30      python trajectory_compressor.py --input=data/my_run --sample_percent=10
  31  """
  32  
  33  import json
  34  import os
  35  import time
  36  import yaml
  37  import logging
  38  import asyncio
  39  from pathlib import Path
  40  from typing import List, Dict, Any, Optional, Tuple
  41  from dataclasses import dataclass, field
  42  from datetime import datetime
  43  
  44  from utils import base_url_host_matches, base_url_hostname
  45  import fire
  46  from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn
  47  from rich.console import Console
  48  from hermes_constants import OPENROUTER_BASE_URL, get_hermes_home
  49  from agent.retry_utils import jittered_backoff
  50  
  51  # Load .env from HERMES_HOME first, then project root as a dev fallback.
  52  from hermes_cli.env_loader import load_hermes_dotenv
  53  
  54  _hermes_home = get_hermes_home()
  55  _project_env = Path(__file__).parent / ".env"
  56  load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
  57  
  58  
  59  def _effective_temperature_for_model(
  60      model: str,
  61      requested_temperature: float,
  62      base_url: Optional[str] = None,
  63  ) -> Optional[float]:
  64      """Apply fixed model temperature contracts to direct client calls.
  65  
  66      Returns ``None`` when the model manages temperature server-side (Kimi);
  67      callers must omit the ``temperature`` kwarg entirely in that case.
  68      """
  69      try:
  70          from agent.auxiliary_client import _fixed_temperature_for_model, OMIT_TEMPERATURE
  71      except Exception:
  72          return requested_temperature
  73  
  74      fixed_temperature = _fixed_temperature_for_model(model, base_url)
  75      if fixed_temperature is OMIT_TEMPERATURE:
  76          return None  # caller must omit temperature
  77      if fixed_temperature is not None:
  78          return fixed_temperature
  79      return requested_temperature
  80  
  81  
  82  @dataclass
  83  class CompressionConfig:
  84      """Configuration for trajectory compression."""
  85      # Tokenizer
  86      tokenizer_name: str = "moonshotai/Kimi-K2-Thinking"
  87      trust_remote_code: bool = True
  88      
  89      # Compression targets
  90      target_max_tokens: int = 15250
  91      summary_target_tokens: int = 750
  92      
  93      # Protected turns
  94      protect_first_system: bool = True
  95      protect_first_human: bool = True
  96      protect_first_gpt: bool = True
  97      protect_first_tool: bool = True
  98      protect_last_n_turns: int = 4
  99      
 100      # Summarization (OpenRouter)
 101      summarization_model: str = "google/gemini-3-flash-preview"
 102      base_url: str = OPENROUTER_BASE_URL
 103      api_key_env: str = "OPENROUTER_API_KEY"
 104      temperature: float = 0.3
 105      max_retries: int = 3
 106      retry_delay: int = 2
 107      
 108      # Output
 109      add_summary_notice: bool = True
 110      summary_notice_text: str = "\n\nSome of your previous tool responses may be summarized to preserve context."
 111      output_suffix: str = "_compressed"
 112      
 113      # Processing
 114      num_workers: int = 4
 115      max_concurrent_requests: int = 50  # Max concurrent API calls for summarization
 116      skip_under_target: bool = True
 117      save_over_limit: bool = True
 118      per_trajectory_timeout: int = 300  # Timeout per trajectory in seconds (default: 5 min)
 119      
 120      # Metrics
 121      metrics_enabled: bool = True
 122      metrics_per_trajectory: bool = True
 123      metrics_output_file: str = "compression_metrics.json"
 124      
 125      @classmethod
 126      def from_yaml(cls, yaml_path: str) -> "CompressionConfig":
 127          """Load configuration from YAML file."""
 128          with open(yaml_path, 'r') as f:
 129              data = yaml.safe_load(f)
 130          
 131          config = cls()
 132          
 133          # Tokenizer
 134          if 'tokenizer' in data:
 135              config.tokenizer_name = data['tokenizer'].get('name', config.tokenizer_name)
 136              config.trust_remote_code = data['tokenizer'].get('trust_remote_code', config.trust_remote_code)
 137          
 138          # Compression
 139          if 'compression' in data:
 140              config.target_max_tokens = data['compression'].get('target_max_tokens', config.target_max_tokens)
 141              config.summary_target_tokens = data['compression'].get('summary_target_tokens', config.summary_target_tokens)
 142          
 143          # Protected turns
 144          if 'protected_turns' in data:
 145              config.protect_first_system = data['protected_turns'].get('first_system', config.protect_first_system)
 146              config.protect_first_human = data['protected_turns'].get('first_human', config.protect_first_human)
 147              config.protect_first_gpt = data['protected_turns'].get('first_gpt', config.protect_first_gpt)
 148              config.protect_first_tool = data['protected_turns'].get('first_tool', config.protect_first_tool)
 149              config.protect_last_n_turns = data['protected_turns'].get('last_n_turns', config.protect_last_n_turns)
 150          
 151          # Summarization
 152          if 'summarization' in data:
 153              config.summarization_model = data['summarization'].get('model', config.summarization_model)
 154              config.base_url = data['summarization'].get('base_url') or config.base_url
 155              config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env)
 156              config.temperature = data['summarization'].get('temperature', config.temperature)
 157              config.max_retries = data['summarization'].get('max_retries', config.max_retries)
 158              config.retry_delay = data['summarization'].get('retry_delay', config.retry_delay)
 159          
 160          # Output
 161          if 'output' in data:
 162              config.add_summary_notice = data['output'].get('add_summary_notice', config.add_summary_notice)
 163              config.summary_notice_text = data['output'].get('summary_notice_text', config.summary_notice_text)
 164              config.output_suffix = data['output'].get('output_suffix', config.output_suffix)
 165          
 166          # Processing
 167          if 'processing' in data:
 168              config.num_workers = data['processing'].get('num_workers', config.num_workers)
 169              config.max_concurrent_requests = data['processing'].get('max_concurrent_requests', config.max_concurrent_requests)
 170              config.skip_under_target = data['processing'].get('skip_under_target', config.skip_under_target)
 171              config.save_over_limit = data['processing'].get('save_over_limit', config.save_over_limit)
 172          
 173          # Metrics
 174          if 'metrics' in data:
 175              config.metrics_enabled = data['metrics'].get('enabled', config.metrics_enabled)
 176              config.metrics_per_trajectory = data['metrics'].get('per_trajectory', config.metrics_per_trajectory)
 177              config.metrics_output_file = data['metrics'].get('output_file', config.metrics_output_file)
 178          
 179          return config
 180  
 181  
 182  @dataclass
 183  class TrajectoryMetrics:
 184      """Metrics for a single trajectory compression."""
 185      original_tokens: int = 0
 186      compressed_tokens: int = 0
 187      tokens_saved: int = 0
 188      compression_ratio: float = 1.0
 189      
 190      original_turns: int = 0
 191      compressed_turns: int = 0
 192      turns_removed: int = 0
 193      
 194      turns_compressed_start_idx: int = -1
 195      turns_compressed_end_idx: int = -1
 196      turns_in_compressed_region: int = 0
 197      
 198      was_compressed: bool = False
 199      still_over_limit: bool = False
 200      skipped_under_target: bool = False
 201      
 202      summarization_api_calls: int = 0
 203      summarization_errors: int = 0
 204      
 205      def to_dict(self) -> Dict[str, Any]:
 206          return {
 207              "original_tokens": self.original_tokens,
 208              "compressed_tokens": self.compressed_tokens,
 209              "tokens_saved": self.tokens_saved,
 210              "compression_ratio": round(self.compression_ratio, 4),
 211              "original_turns": self.original_turns,
 212              "compressed_turns": self.compressed_turns,
 213              "turns_removed": self.turns_removed,
 214              "compression_region": {
 215                  "start_idx": self.turns_compressed_start_idx,
 216                  "end_idx": self.turns_compressed_end_idx,
 217                  "turns_count": self.turns_in_compressed_region,
 218              },
 219              "was_compressed": self.was_compressed,
 220              "still_over_limit": self.still_over_limit,
 221              "skipped_under_target": self.skipped_under_target,
 222              "summarization_api_calls": self.summarization_api_calls,
 223              "summarization_errors": self.summarization_errors,
 224          }
 225  
 226  
 227  @dataclass 
 228  class AggregateMetrics:
 229      """Aggregate metrics across all trajectories."""
 230      total_trajectories: int = 0
 231      trajectories_compressed: int = 0
 232      trajectories_skipped_under_target: int = 0
 233      trajectories_still_over_limit: int = 0
 234      trajectories_failed: int = 0
 235      
 236      total_tokens_before: int = 0
 237      total_tokens_after: int = 0
 238      total_tokens_saved: int = 0
 239      
 240      total_turns_before: int = 0
 241      total_turns_after: int = 0
 242      total_turns_removed: int = 0
 243      
 244      total_summarization_calls: int = 0
 245      total_summarization_errors: int = 0
 246      
 247      # Distribution stats
 248      compression_ratios: List[float] = field(default_factory=list)
 249      tokens_saved_list: List[int] = field(default_factory=list)
 250      turns_removed_list: List[int] = field(default_factory=list)
 251      
 252      processing_start_time: str = ""
 253      processing_end_time: str = ""
 254      processing_duration_seconds: float = 0.0
 255      
 256      def add_trajectory_metrics(self, metrics: TrajectoryMetrics):
 257          """Add a trajectory's metrics to the aggregate."""
 258          self.total_trajectories += 1
 259          self.total_tokens_before += metrics.original_tokens
 260          self.total_tokens_after += metrics.compressed_tokens
 261          self.total_tokens_saved += metrics.tokens_saved
 262          self.total_turns_before += metrics.original_turns
 263          self.total_turns_after += metrics.compressed_turns
 264          self.total_turns_removed += metrics.turns_removed
 265          self.total_summarization_calls += metrics.summarization_api_calls
 266          self.total_summarization_errors += metrics.summarization_errors
 267          
 268          if metrics.was_compressed:
 269              self.trajectories_compressed += 1
 270              self.compression_ratios.append(metrics.compression_ratio)
 271              self.tokens_saved_list.append(metrics.tokens_saved)
 272              self.turns_removed_list.append(metrics.turns_removed)
 273          
 274          if metrics.skipped_under_target:
 275              self.trajectories_skipped_under_target += 1
 276          
 277          if metrics.still_over_limit:
 278              self.trajectories_still_over_limit += 1
 279      
 280      def to_dict(self) -> Dict[str, Any]:
 281          avg_compression_ratio = (
 282              sum(self.compression_ratios) / len(self.compression_ratios) 
 283              if self.compression_ratios else 1.0
 284          )
 285          avg_tokens_saved = (
 286              sum(self.tokens_saved_list) / len(self.tokens_saved_list)
 287              if self.tokens_saved_list else 0
 288          )
 289          avg_turns_removed = (
 290              sum(self.turns_removed_list) / len(self.turns_removed_list)
 291              if self.turns_removed_list else 0
 292          )
 293          
 294          return {
 295              "summary": {
 296                  "total_trajectories": self.total_trajectories,
 297                  "trajectories_compressed": self.trajectories_compressed,
 298                  "trajectories_skipped_under_target": self.trajectories_skipped_under_target,
 299                  "trajectories_still_over_limit": self.trajectories_still_over_limit,
 300                  "trajectories_failed": self.trajectories_failed,
 301                  "compression_rate": round(self.trajectories_compressed / max(self.total_trajectories, 1), 4),
 302              },
 303              "tokens": {
 304                  "total_before": self.total_tokens_before,
 305                  "total_after": self.total_tokens_after,
 306                  "total_saved": self.total_tokens_saved,
 307                  "overall_compression_ratio": round(self.total_tokens_after / max(self.total_tokens_before, 1), 4),
 308              },
 309              "turns": {
 310                  "total_before": self.total_turns_before,
 311                  "total_after": self.total_turns_after,
 312                  "total_removed": self.total_turns_removed,
 313              },
 314              "averages": {
 315                  "avg_compression_ratio": round(avg_compression_ratio, 4),
 316                  "avg_tokens_saved_per_compressed": round(avg_tokens_saved, 1),
 317                  "avg_turns_removed_per_compressed": round(avg_turns_removed, 2),
 318              },
 319              "summarization": {
 320                  "total_api_calls": self.total_summarization_calls,
 321                  "total_errors": self.total_summarization_errors,
 322                  "success_rate": round(1 - (self.total_summarization_errors / max(self.total_summarization_calls, 1)), 4),
 323              },
 324              "processing": {
 325                  "start_time": self.processing_start_time,
 326                  "end_time": self.processing_end_time,
 327                  "duration_seconds": round(self.processing_duration_seconds, 2),
 328              },
 329          }
 330  
 331  
 332  class TrajectoryCompressor:
 333      """
 334      Compresses agent trajectories to fit within a target token budget.
 335      
 336      Compression strategy:
 337      1. Keep protected head turns (system, human, first gpt+tool)
 338      2. Keep protected tail turns (last N turns)
 339      3. From the compressible middle region, compress only as much as needed
 340      4. Replace compressed turns with a single human summary message
 341      5. Keep remaining middle turns intact (model continues with tools)
 342      """
 343      
 344      def __init__(self, config: CompressionConfig):
 345          """Initialize the compressor."""
 346          self.config = config
 347          self.aggregate_metrics = AggregateMetrics()
 348          
 349          # Initialize tokenizer
 350          self._init_tokenizer()
 351          
 352          # Initialize OpenRouter client
 353          self._init_summarizer()
 354          
 355          logging.basicConfig(
 356              level=logging.INFO,
 357              format='%(asctime)s - %(levelname)s - %(message)s',
 358              datefmt='%H:%M:%S'
 359          )
 360          self.logger = logging.getLogger(__name__)
 361      
 362      def _init_tokenizer(self):
 363          """Initialize HuggingFace tokenizer for token counting."""
 364          try:
 365              from transformers import AutoTokenizer
 366              self.tokenizer = AutoTokenizer.from_pretrained(
 367                  self.config.tokenizer_name,
 368                  trust_remote_code=self.config.trust_remote_code
 369              )
 370              print(f"āœ… Loaded tokenizer: {self.config.tokenizer_name}")
 371          except Exception as e:
 372              raise RuntimeError(f"Failed to load tokenizer '{self.config.tokenizer_name}': {e}")
 373      
 374      def _init_summarizer(self):
 375          """Initialize LLM routing for summarization (sync and async).
 376  
 377          Uses call_llm/async_call_llm from the centralized provider router
 378          which handles auth, headers, and provider detection internally.
 379          For custom endpoints, falls back to raw client construction.
 380          """
 381  
 382          provider = self._detect_provider()
 383          if provider:
 384              # Store provider for use in _generate_summary calls
 385              self._llm_provider = provider
 386              self._use_call_llm = True
 387              # Verify the provider is available
 388              from agent.auxiliary_client import resolve_provider_client
 389              client, _ = resolve_provider_client(
 390                  provider, model=self.config.summarization_model)
 391              if client is None:
 392                  raise RuntimeError(
 393                      f"Provider '{provider}' is not configured. "
 394                      f"Check your API key or run: hermes setup")
 395              self.client = None  # Not used directly
 396              self.async_client = None  # Not used directly
 397          else:
 398              # Custom endpoint — use config's raw base_url + api_key_env
 399              self._use_call_llm = False
 400              api_key = os.getenv(self.config.api_key_env)
 401              if not api_key:
 402                  raise RuntimeError(
 403                      f"Missing API key. Set {self.config.api_key_env} "
 404                      f"environment variable.")
 405              from openai import OpenAI
 406              from agent.auxiliary_client import _to_openai_base_url
 407              self.client = OpenAI(
 408                  api_key=api_key, base_url=_to_openai_base_url(self.config.base_url))
 409              # AsyncOpenAI is created lazily in _get_async_client() so it
 410              # binds to the current event loop — avoids "Event loop is closed"
 411              # when process_directory() is called multiple times (each call
 412              # creates a new loop via asyncio.run()).
 413              self.async_client = None
 414              self._async_client_api_key = api_key
 415  
 416          print(f"āœ… Initialized summarizer client: {self.config.summarization_model}")
 417          print(f"   Max concurrent requests: {self.config.max_concurrent_requests}")
 418  
 419      def _get_async_client(self):
 420          """Return an AsyncOpenAI client bound to the current event loop.
 421  
 422          Created lazily so that each ``asyncio.run()`` call in
 423          ``process_directory()`` gets a client tied to its own loop,
 424          avoiding "Event loop is closed" errors on repeated calls.
 425          """
 426          from openai import AsyncOpenAI
 427          from agent.auxiliary_client import _to_openai_base_url
 428          # Always create a fresh client so it binds to the running loop.
 429          self.async_client = AsyncOpenAI(
 430              api_key=self._async_client_api_key,
 431              base_url=_to_openai_base_url(self.config.base_url),
 432          )
 433          return self.async_client
 434  
 435      def _detect_provider(self) -> str:
 436          """Detect the provider name from the configured base_url."""
 437          url = self.config.base_url or ""
 438          if base_url_host_matches(url, "openrouter.ai"):
 439              return "openrouter"
 440          if base_url_host_matches(url, "nousresearch.com"):
 441              return "nous"
 442          if (
 443              base_url_hostname(url) == "chatgpt.com"
 444              and "/backend-api/codex" in url.lower()
 445          ):
 446              return "codex"
 447          if base_url_host_matches(url, "z.ai"):
 448              return "zai"
 449          if (
 450              base_url_host_matches(url, "moonshot.ai")
 451              or base_url_host_matches(url, "moonshot.cn")
 452              or base_url_host_matches(url, "api.kimi.com")
 453          ):
 454              return "kimi-coding"
 455          if base_url_host_matches(url, "arcee.ai"):
 456              return "arcee"
 457          if base_url_host_matches(url, "minimaxi.com"):
 458              return "minimax-cn"
 459          if base_url_host_matches(url, "minimax.io"):
 460              return "minimax"
 461          # Unknown base_url — not a known provider
 462          return ""
 463      
 464      def count_tokens(self, text: str) -> int:
 465          """Count tokens in text using the configured tokenizer."""
 466          if not text:
 467              return 0
 468          try:
 469              return len(self.tokenizer.encode(text))
 470          except Exception:
 471              # Fallback to character estimate
 472              return len(text) // 4
 473      
 474      def count_trajectory_tokens(self, trajectory: List[Dict[str, str]]) -> int:
 475          """Count total tokens in a trajectory."""
 476          return sum(self.count_tokens(turn.get("value", "")) for turn in trajectory)
 477      
 478      def count_turn_tokens(self, trajectory: List[Dict[str, str]]) -> List[int]:
 479          """Count tokens for each turn in a trajectory."""
 480          return [self.count_tokens(turn.get("value", "")) for turn in trajectory]
 481      
 482      def _find_protected_indices(self, trajectory: List[Dict[str, str]]) -> Tuple[set, int, int]:
 483          """
 484          Find indices of protected turns.
 485          
 486          Returns:
 487              Tuple of (protected_set, compressible_start, compressible_end)
 488          """
 489          n = len(trajectory)
 490          protected = set()
 491          
 492          # Track first occurrences
 493          first_system = first_human = first_gpt = first_tool = None
 494          
 495          for i, turn in enumerate(trajectory):
 496              role = turn.get("from", "")
 497              if role == "system" and first_system is None:
 498                  first_system = i
 499              elif role == "human" and first_human is None:
 500                  first_human = i
 501              elif role == "gpt" and first_gpt is None:
 502                  first_gpt = i
 503              elif role == "tool" and first_tool is None:
 504                  first_tool = i
 505          
 506          # Protect first turns
 507          if self.config.protect_first_system and first_system is not None:
 508              protected.add(first_system)
 509          if self.config.protect_first_human and first_human is not None:
 510              protected.add(first_human)
 511          if self.config.protect_first_gpt and first_gpt is not None:
 512              protected.add(first_gpt)
 513          if self.config.protect_first_tool and first_tool is not None:
 514              protected.add(first_tool)
 515          
 516          # Protect last N turns
 517          for i in range(max(0, n - self.config.protect_last_n_turns), n):
 518              protected.add(i)
 519          
 520          # Determine compressible region
 521          # Start after the last protected head turn
 522          head_protected = [i for i in protected if i < n // 2]
 523          tail_protected = [i for i in protected if i >= n // 2]
 524          
 525          compressible_start = max(head_protected) + 1 if head_protected else 0
 526          compressible_end = min(tail_protected) if tail_protected else n
 527          
 528          return protected, compressible_start, compressible_end
 529      
 530      def _extract_turn_content_for_summary(self, trajectory: List[Dict[str, str]], start: int, end: int) -> str:
 531          """
 532          Extract content from turns to be summarized.
 533          
 534          Args:
 535              trajectory: Full trajectory
 536              start: Start index (inclusive)
 537              end: End index (exclusive)
 538              
 539          Returns:
 540              Formatted string of turn contents for summarization
 541          """
 542          parts = []
 543          for i in range(start, end):
 544              turn = trajectory[i]
 545              role = turn.get("from", "unknown")
 546              value = turn.get("value", "")
 547              
 548              # Truncate very long values for the summary prompt
 549              if len(value) > 3000:
 550                  value = value[:1500] + "\n...[truncated]...\n" + value[-500:]
 551              
 552              parts.append(f"[Turn {i} - {role.upper()}]:\n{value}")
 553          
 554          return "\n\n".join(parts)
 555  
 556      @staticmethod
 557      def _coerce_summary_content(content: Any) -> str:
 558          """Normalize summary-model output to a safe string."""
 559          if not isinstance(content, str):
 560              content = str(content) if content else ""
 561          return content.strip()
 562  
 563      @staticmethod
 564      def _ensure_summary_prefix(summary: str) -> str:
 565          """Normalize summary text to include the expected prefix exactly once."""
 566          text = (summary or "").strip()
 567          if text.startswith("[CONTEXT SUMMARY]:"):
 568              return text
 569          return "[CONTEXT SUMMARY]:" if not text else f"[CONTEXT SUMMARY]: {text}"
 570      
 571      def _generate_summary(self, content: str, metrics: TrajectoryMetrics) -> str:
 572          """
 573          Generate a summary of the compressed turns using OpenRouter.
 574          
 575          Args:
 576              content: The content to summarize
 577              metrics: Metrics object to update
 578              
 579          Returns:
 580              Summary string
 581          """
 582          prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history.
 583  
 584  Write the summary from a neutral perspective describing what the assistant did and learned. Include:
 585  1. What actions the assistant took (tool calls, searches, file operations)
 586  2. Key information or results obtained
 587  3. Any important decisions or findings
 588  4. Relevant data, file names, values, or outputs
 589  
 590  Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens.
 591  
 592  ---
 593  TURNS TO SUMMARIZE:
 594  {content}
 595  ---
 596  
 597  Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
 598  
 599          for attempt in range(self.config.max_retries):
 600              try:
 601                  metrics.summarization_api_calls += 1
 602                  summary_temperature = _effective_temperature_for_model(
 603                      self.config.summarization_model,
 604                      self.config.temperature,
 605                      self.config.base_url,
 606                  )
 607                  
 608                  if getattr(self, '_use_call_llm', False):
 609                      from agent.auxiliary_client import call_llm
 610                      response = call_llm(
 611                          provider=self._llm_provider,
 612                          model=self.config.summarization_model,
 613                          messages=[{"role": "user", "content": prompt}],
 614                          temperature=summary_temperature,
 615                          max_tokens=self.config.summary_target_tokens * 2,
 616                      )
 617                  else:
 618                      _create_kwargs = {
 619                          "model": self.config.summarization_model,
 620                          "messages": [{"role": "user", "content": prompt}],
 621                          "max_tokens": self.config.summary_target_tokens * 2,
 622                      }
 623                      if summary_temperature is not None:
 624                          _create_kwargs["temperature"] = summary_temperature
 625                      response = self.client.chat.completions.create(**_create_kwargs)
 626                  
 627                  summary = self._coerce_summary_content(response.choices[0].message.content)
 628                  return self._ensure_summary_prefix(summary)
 629                  
 630              except Exception as e:
 631                  metrics.summarization_errors += 1
 632                  self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
 633                  
 634                  if attempt < self.config.max_retries - 1:
 635                      time.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
 636                  else:
 637                      # Fallback: create a basic summary
 638                      return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
 639      
 640      async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str:
 641          """
 642          Generate a summary of the compressed turns using OpenRouter (async version).
 643          
 644          Args:
 645              content: The content to summarize
 646              metrics: Metrics object to update
 647              
 648          Returns:
 649              Summary string
 650          """
 651          prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history.
 652  
 653  Write the summary from a neutral perspective describing what the assistant did and learned. Include:
 654  1. What actions the assistant took (tool calls, searches, file operations)
 655  2. Key information or results obtained
 656  3. Any important decisions or findings
 657  4. Relevant data, file names, values, or outputs
 658  
 659  Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens.
 660  
 661  ---
 662  TURNS TO SUMMARIZE:
 663  {content}
 664  ---
 665  
 666  Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
 667  
 668          for attempt in range(self.config.max_retries):
 669              try:
 670                  metrics.summarization_api_calls += 1
 671                  summary_temperature = _effective_temperature_for_model(
 672                      self.config.summarization_model,
 673                      self.config.temperature,
 674                      self.config.base_url,
 675                  )
 676                  
 677                  if getattr(self, '_use_call_llm', False):
 678                      from agent.auxiliary_client import async_call_llm
 679                      response = await async_call_llm(
 680                          provider=self._llm_provider,
 681                          model=self.config.summarization_model,
 682                          messages=[{"role": "user", "content": prompt}],
 683                          temperature=summary_temperature,
 684                          max_tokens=self.config.summary_target_tokens * 2,
 685                      )
 686                  else:
 687                      _create_kwargs = {
 688                          "model": self.config.summarization_model,
 689                          "messages": [{"role": "user", "content": prompt}],
 690                          "max_tokens": self.config.summary_target_tokens * 2,
 691                      }
 692                      if summary_temperature is not None:
 693                          _create_kwargs["temperature"] = summary_temperature
 694                      response = await self._get_async_client().chat.completions.create(**_create_kwargs)
 695                  
 696                  summary = self._coerce_summary_content(response.choices[0].message.content)
 697                  return self._ensure_summary_prefix(summary)
 698                  
 699              except Exception as e:
 700                  metrics.summarization_errors += 1
 701                  self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
 702                  
 703                  if attempt < self.config.max_retries - 1:
 704                      await asyncio.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
 705                  else:
 706                      # Fallback: create a basic summary
 707                      return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
 708      
 709      def compress_trajectory(
 710          self,
 711          trajectory: List[Dict[str, str]]
 712      ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]:
 713          """
 714          Compress a single trajectory to fit within target token budget.
 715          
 716          Algorithm:
 717          1. Count total tokens
 718          2. If under target, skip
 719          3. Find compressible region (between protected head and tail)
 720          4. Calculate how many tokens need to be saved
 721          5. Accumulate turns from start of compressible region until savings met
 722          6. Replace accumulated turns with single human summary
 723          7. Keep remaining turns intact
 724          
 725          Args:
 726              trajectory: List of conversation turns
 727              
 728          Returns:
 729              Tuple of (compressed_trajectory, metrics)
 730          """
 731          metrics = TrajectoryMetrics()
 732          metrics.original_turns = len(trajectory)
 733          
 734          # Count tokens per turn
 735          turn_tokens = self.count_turn_tokens(trajectory)
 736          total_tokens = sum(turn_tokens)
 737          metrics.original_tokens = total_tokens
 738          
 739          # Check if compression needed
 740          if total_tokens <= self.config.target_max_tokens:
 741              metrics.skipped_under_target = True
 742              metrics.compressed_tokens = total_tokens
 743              metrics.compressed_turns = len(trajectory)
 744              metrics.compression_ratio = 1.0
 745              return trajectory, metrics
 746          
 747          # Find protected regions
 748          protected, compress_start, compress_end = self._find_protected_indices(trajectory)
 749          
 750          # Check if there's anything to compress
 751          if compress_start >= compress_end:
 752              # Nothing to compress, return as-is
 753              metrics.compressed_tokens = total_tokens
 754              metrics.compressed_turns = len(trajectory)
 755              metrics.still_over_limit = total_tokens > self.config.target_max_tokens
 756              return trajectory, metrics
 757          
 758          # Calculate how much we need to save
 759          tokens_to_save = total_tokens - self.config.target_max_tokens
 760          
 761          # We'll replace N turns with 1 summary turn
 762          # Net savings = (sum of N turns' tokens) - summary_target_tokens
 763          # We need: net_savings >= tokens_to_save
 764          # So: sum of turns >= tokens_to_save + summary_target_tokens
 765          target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens
 766          
 767          # Accumulate turns from compress_start until we have enough savings
 768          accumulated_tokens = 0
 769          compress_until = compress_start
 770          
 771          for i in range(compress_start, compress_end):
 772              accumulated_tokens += turn_tokens[i]
 773              compress_until = i + 1  # Exclusive end
 774              
 775              # Check if we have enough savings
 776              if accumulated_tokens >= target_tokens_to_compress:
 777                  break
 778          
 779          # If we still don't have enough savings, compress the entire compressible region
 780          if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end:
 781              compress_until = compress_end
 782              accumulated_tokens = sum(turn_tokens[compress_start:compress_end])
 783          
 784          # Record compression region
 785          metrics.turns_compressed_start_idx = compress_start
 786          metrics.turns_compressed_end_idx = compress_until
 787          metrics.turns_in_compressed_region = compress_until - compress_start
 788          
 789          # Extract content for summary
 790          content_to_summarize = self._extract_turn_content_for_summary(
 791              trajectory, compress_start, compress_until
 792          )
 793          
 794          # Generate summary
 795          summary = self._generate_summary(content_to_summarize, metrics)
 796          
 797          # Build compressed trajectory
 798          compressed = []
 799          
 800          # Add head (turns before compression region)
 801          for i in range(compress_start):
 802              turn = trajectory[i].copy()
 803              # Add notice to system message
 804              if turn.get("from") == "system" and self.config.add_summary_notice:
 805                  turn["value"] = turn["value"] + self.config.summary_notice_text
 806              compressed.append(turn)
 807          
 808          # Add summary as human message
 809          compressed.append({
 810              "from": "human",
 811              "value": summary
 812          })
 813          
 814          # Add tail (turns after compression region)
 815          for i in range(compress_until, len(trajectory)):
 816              compressed.append(trajectory[i].copy())
 817          
 818          # Calculate final metrics
 819          metrics.compressed_turns = len(compressed)
 820          metrics.compressed_tokens = self.count_trajectory_tokens(compressed)
 821          metrics.turns_removed = metrics.original_turns - metrics.compressed_turns
 822          metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens
 823          metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1)
 824          metrics.was_compressed = True
 825          metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens
 826          
 827          return compressed, metrics
 828      
 829      async def compress_trajectory_async(
 830          self,
 831          trajectory: List[Dict[str, str]]
 832      ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]:
 833          """
 834          Compress a single trajectory to fit within target token budget (async version).
 835          
 836          Same algorithm as compress_trajectory but uses async API calls for summarization.
 837          """
 838          metrics = TrajectoryMetrics()
 839          metrics.original_turns = len(trajectory)
 840          
 841          # Count tokens per turn
 842          turn_tokens = self.count_turn_tokens(trajectory)
 843          total_tokens = sum(turn_tokens)
 844          metrics.original_tokens = total_tokens
 845          
 846          # Check if compression needed
 847          if total_tokens <= self.config.target_max_tokens:
 848              metrics.skipped_under_target = True
 849              metrics.compressed_tokens = total_tokens
 850              metrics.compressed_turns = len(trajectory)
 851              metrics.compression_ratio = 1.0
 852              return trajectory, metrics
 853          
 854          # Find protected regions
 855          protected, compress_start, compress_end = self._find_protected_indices(trajectory)
 856          
 857          # Check if there's anything to compress
 858          if compress_start >= compress_end:
 859              metrics.compressed_tokens = total_tokens
 860              metrics.compressed_turns = len(trajectory)
 861              metrics.still_over_limit = total_tokens > self.config.target_max_tokens
 862              return trajectory, metrics
 863          
 864          # Calculate how much we need to save
 865          tokens_to_save = total_tokens - self.config.target_max_tokens
 866          target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens
 867          
 868          # Accumulate turns from compress_start until we have enough savings
 869          accumulated_tokens = 0
 870          compress_until = compress_start
 871          
 872          for i in range(compress_start, compress_end):
 873              accumulated_tokens += turn_tokens[i]
 874              compress_until = i + 1
 875              if accumulated_tokens >= target_tokens_to_compress:
 876                  break
 877          
 878          # If we still don't have enough savings, compress the entire compressible region
 879          if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end:
 880              compress_until = compress_end
 881              accumulated_tokens = sum(turn_tokens[compress_start:compress_end])
 882          
 883          # Record compression region
 884          metrics.turns_compressed_start_idx = compress_start
 885          metrics.turns_compressed_end_idx = compress_until
 886          metrics.turns_in_compressed_region = compress_until - compress_start
 887          
 888          # Extract content for summary
 889          content_to_summarize = self._extract_turn_content_for_summary(
 890              trajectory, compress_start, compress_until
 891          )
 892          
 893          # Generate summary (ASYNC)
 894          summary = await self._generate_summary_async(content_to_summarize, metrics)
 895          
 896          # Build compressed trajectory
 897          compressed = []
 898          
 899          # Add head (turns before compression region)
 900          for i in range(compress_start):
 901              turn = trajectory[i].copy()
 902              if turn.get("from") == "system" and self.config.add_summary_notice:
 903                  turn["value"] = turn["value"] + self.config.summary_notice_text
 904              compressed.append(turn)
 905          
 906          # Add summary as human message
 907          compressed.append({
 908              "from": "human",
 909              "value": summary
 910          })
 911          
 912          # Add tail (turns after compression region)
 913          for i in range(compress_until, len(trajectory)):
 914              compressed.append(trajectory[i].copy())
 915          
 916          # Calculate final metrics
 917          metrics.compressed_turns = len(compressed)
 918          metrics.compressed_tokens = self.count_trajectory_tokens(compressed)
 919          metrics.turns_removed = metrics.original_turns - metrics.compressed_turns
 920          metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens
 921          metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1)
 922          metrics.was_compressed = True
 923          metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens
 924          
 925          return compressed, metrics
 926      
 927      async def process_entry_async(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]:
 928          """
 929          Process a single JSONL entry (async version).
 930          """
 931          if "conversations" not in entry:
 932              metrics = TrajectoryMetrics()
 933              return entry, metrics
 934          
 935          trajectory = entry["conversations"]
 936          compressed_trajectory, metrics = await self.compress_trajectory_async(trajectory)
 937          
 938          # Create new entry with compressed trajectory
 939          result = entry.copy()
 940          result["conversations"] = compressed_trajectory
 941          
 942          # Add compression metadata if enabled
 943          if self.config.metrics_per_trajectory and metrics.was_compressed:
 944              result["compression_metrics"] = metrics.to_dict()
 945          
 946          return result, metrics
 947      
 948      def process_entry(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]:
 949          """
 950          Process a single JSONL entry.
 951          
 952          Args:
 953              entry: JSONL entry containing 'conversations' field
 954              
 955          Returns:
 956              Tuple of (processed_entry, metrics)
 957          """
 958          if "conversations" not in entry:
 959              metrics = TrajectoryMetrics()
 960              return entry, metrics
 961          
 962          trajectory = entry["conversations"]
 963          compressed_trajectory, metrics = self.compress_trajectory(trajectory)
 964          
 965          # Create new entry with compressed trajectory
 966          result = entry.copy()
 967          result["conversations"] = compressed_trajectory
 968          
 969          # Add compression metadata if enabled
 970          if self.config.metrics_per_trajectory and metrics.was_compressed:
 971              result["compression_metrics"] = metrics.to_dict()
 972          
 973          return result, metrics
 974      
 975      def process_directory(self, input_dir: Path, output_dir: Path):
 976          """
 977          Process all JSONL files in a directory using async parallel processing.
 978          
 979          Args:
 980              input_dir: Input directory containing JSONL files
 981              output_dir: Output directory for compressed files
 982          """
 983          # Run the async version
 984          asyncio.run(self._process_directory_async(input_dir, output_dir))
 985      
 986      async def _process_directory_async(self, input_dir: Path, output_dir: Path):
 987          """
 988          Async implementation of directory processing with parallel API calls.
 989          """
 990          console = Console()
 991          
 992          # Record start time
 993          self.aggregate_metrics.processing_start_time = datetime.now().isoformat()
 994          start_time = time.time()
 995          
 996          # Find all JSONL files
 997          jsonl_files = sorted(input_dir.glob("*.jsonl"))
 998          
 999          if not jsonl_files:
1000              self.logger.warning(f"No JSONL files found in {input_dir}")
1001              return
1002          
1003          # Load ALL entries from all files
1004          console.print("\n[dim]Loading all entries...[/dim]")
1005          all_entries = []  # List of (file_path, entry_idx, entry)
1006          
1007          for file_path in jsonl_files:
1008              with open(file_path, 'r', encoding='utf-8') as f:
1009                  for line_num, line in enumerate(f):
1010                      line = line.strip()
1011                      if line:
1012                          try:
1013                              entry = json.loads(line)
1014                              all_entries.append((file_path, line_num, entry))
1015                          except json.JSONDecodeError as e:
1016                              self.logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e}")
1017          
1018          total_entries = len(all_entries)
1019          
1020          console.print(f"\n{'='*60}")
1021          console.print(f"šŸ“‚ Input: {input_dir}")
1022          console.print(f"šŸ“‚ Output: {output_dir}")
1023          console.print(f"šŸ“„ Files to process: {len(jsonl_files)}")
1024          console.print(f"šŸ“Š Total trajectories: {total_entries:,}")
1025          console.print(f"šŸŽÆ Target max tokens: {self.config.target_max_tokens:,}")
1026          console.print(f"šŸ“ Summary target tokens: {self.config.summary_target_tokens}")
1027          console.print(f"⚔ Max concurrent API calls: {self.config.max_concurrent_requests}")
1028          console.print(f"{'='*60}\n")
1029          
1030          # Create semaphore for rate limiting
1031          semaphore = asyncio.Semaphore(self.config.max_concurrent_requests)
1032          
1033          # Tracking for progress display (thread-safe with lock)
1034          progress_lock = asyncio.Lock()
1035          compressed_count = 0
1036          skipped_count = 0
1037          api_calls = 0
1038          in_flight = 0
1039          
1040          # Results storage: {file_path: {entry_idx: (processed_entry, metrics)}}
1041          results = {f: {} for f in jsonl_files}
1042          
1043          # Track timeouts separately
1044          timeout_count = 0
1045          
1046          async def process_single(file_path: Path, entry_idx: int, entry: Dict, 
1047                                    progress, main_task, status_task):
1048              """Process a single entry with semaphore rate limiting and timeout."""
1049              nonlocal compressed_count, skipped_count, api_calls, in_flight, timeout_count
1050              
1051              async with semaphore:
1052                  # Track in-flight
1053                  async with progress_lock:
1054                      in_flight += 1
1055                  
1056                  try:
1057                      # Apply per-trajectory timeout
1058                      processed_entry, metrics = await asyncio.wait_for(
1059                          self.process_entry_async(entry),
1060                          timeout=self.config.per_trajectory_timeout
1061                      )
1062                      results[file_path][entry_idx] = (processed_entry, metrics)
1063                      
1064                      # Update aggregate metrics (with lock for thread safety)
1065                      async with progress_lock:
1066                          self.aggregate_metrics.add_trajectory_metrics(metrics)
1067                          
1068                          # Update counters
1069                          if metrics.was_compressed:
1070                              compressed_count += 1
1071                              api_calls += metrics.summarization_api_calls
1072                          if metrics.skipped_under_target:
1073                              skipped_count += 1
1074                          
1075                          in_flight -= 1
1076                          
1077                          # Update progress
1078                          progress.advance(main_task)
1079                          progress.update(
1080                              status_task,
1081                              description=f"[dim]āœ… {compressed_count} compressed | ā­ļø {skipped_count} skipped | ā±ļø {timeout_count} timeout | šŸ”„ {api_calls} API calls | ⚔ {in_flight} in-flight[/dim]"
1082                          )
1083                  
1084                  except asyncio.TimeoutError:
1085                      self.logger.warning(f"Timeout processing entry from {file_path}:{entry_idx} (>{self.config.per_trajectory_timeout}s)")
1086                      
1087                      async with progress_lock:
1088                          self.aggregate_metrics.trajectories_failed += 1
1089                          timeout_count += 1
1090                          in_flight -= 1
1091                          progress.advance(main_task)
1092                          progress.update(
1093                              status_task,
1094                              description=f"[dim]āœ… {compressed_count} compressed | ā­ļø {skipped_count} skipped | ā±ļø {timeout_count} timeout | šŸ”„ {api_calls} API calls | ⚔ {in_flight} in-flight[/dim]"
1095                          )
1096                      
1097                      # Skip this entry entirely (don't include in output)
1098                      results[file_path][entry_idx] = None
1099                      
1100                  except Exception as e:
1101                      self.logger.error(f"Error processing entry from {file_path}:{entry_idx}: {e}")
1102                      
1103                      async with progress_lock:
1104                          self.aggregate_metrics.trajectories_failed += 1
1105                          in_flight -= 1
1106                          progress.advance(main_task)
1107                      
1108                      # Keep original entry on error
1109                      results[file_path][entry_idx] = (entry, TrajectoryMetrics())
1110          
1111          # Create progress bar
1112          with Progress(
1113              SpinnerColumn(),
1114              TextColumn("[progress.description]{task.description}"),
1115              BarColumn(),
1116              TaskProgressColumn(),
1117              TextColumn("•"),
1118              TimeElapsedColumn(),
1119              TextColumn("•"),
1120              TimeRemainingColumn(),
1121              console=console,
1122              refresh_per_second=10  # Higher refresh for async
1123          ) as progress:
1124              # Main task for overall progress
1125              main_task = progress.add_task(
1126                  f"[cyan]Compressing {total_entries:,} trajectories",
1127                  total=total_entries
1128              )
1129              
1130              # Status line task
1131              status_task = progress.add_task(
1132                  "[dim]Starting...[/dim]",
1133                  total=None
1134              )
1135              
1136              # Create all tasks
1137              tasks = [
1138                  process_single(file_path, entry_idx, entry, progress, main_task, status_task)
1139                  for file_path, entry_idx, entry in all_entries
1140              ]
1141              
1142              # Run all tasks concurrently (semaphore limits actual concurrency)
1143              await asyncio.gather(*tasks)
1144              
1145              # Remove status task
1146              progress.remove_task(status_task)
1147          
1148          # Write results to output files (preserving original order)
1149          console.print("\n[dim]Writing output files...[/dim]")
1150          output_dir.mkdir(parents=True, exist_ok=True)
1151          
1152          for file_path in jsonl_files:
1153              output_path = output_dir / file_path.name
1154              file_results = results[file_path]
1155              
1156              # Sort by original entry index to preserve order, skip None (timed out) entries
1157              sorted_entries = [
1158                  file_results[idx][0] 
1159                  for idx in sorted(file_results.keys()) 
1160                  if file_results[idx] is not None
1161              ]
1162              
1163              with open(output_path, 'w', encoding='utf-8') as f:
1164                  for entry in sorted_entries:
1165                      f.write(json.dumps(entry, ensure_ascii=False) + '\n')
1166          
1167          # Record end time
1168          self.aggregate_metrics.processing_end_time = datetime.now().isoformat()
1169          self.aggregate_metrics.processing_duration_seconds = time.time() - start_time
1170          
1171          # Print summary
1172          self._print_summary()
1173          
1174          # Save metrics
1175          if self.config.metrics_enabled:
1176              metrics_path = output_dir / self.config.metrics_output_file
1177              with open(metrics_path, 'w') as f:
1178                  json.dump(self.aggregate_metrics.to_dict(), f, indent=2)
1179              console.print(f"\nšŸ’¾ Metrics saved to {metrics_path}")
1180      
1181      def _print_summary(self):
1182          """Print comprehensive compression summary statistics."""
1183          m = self.aggregate_metrics.to_dict()
1184          
1185          # Calculate some additional stats
1186          total = m['summary']['total_trajectories']
1187          compressed = m['summary']['trajectories_compressed']
1188          skipped = m['summary']['trajectories_skipped_under_target']
1189          over_limit = m['summary']['trajectories_still_over_limit']
1190          failed = m['summary']['trajectories_failed']
1191          
1192          # Token stats
1193          tokens_before = m['tokens']['total_before']
1194          tokens_after = m['tokens']['total_after']
1195          tokens_saved = m['tokens']['total_saved']
1196          
1197          # Calculate percentages
1198          compressed_pct = (compressed / max(total, 1)) * 100
1199          skipped_pct = (skipped / max(total, 1)) * 100
1200          over_limit_pct = (over_limit / max(total, 1)) * 100
1201          
1202          print(f"\n")
1203          print(f"ā•”{'═'*70}ā•—")
1204          print(f"ā•‘{'TRAJECTORY COMPRESSION REPORT':^70}ā•‘")
1205          print(f"ā• {'═'*70}ā•£")
1206          
1207          # Trajectories section
1208          print(f"ā•‘{'':2}šŸ“ TRAJECTORIES{' '*54}ā•‘")
1209          print(f"ā•‘{'─'*70}ā•‘")
1210          print(f"ā•‘{'':4}Total Processed:        {total:>10,}{' '*32}ā•‘")
1211          print(f"ā•‘{'':4}ā”œā”€ Compressed:          {compressed:>10,}  ({compressed_pct:>5.1f}%){' '*18}ā•‘")
1212          print(f"ā•‘{'':4}ā”œā”€ Skipped (under limit):{skipped:>9,}  ({skipped_pct:>5.1f}%){' '*18}ā•‘")
1213          print(f"ā•‘{'':4}ā”œā”€ Still over limit:    {over_limit:>10,}  ({over_limit_pct:>5.1f}%){' '*18}ā•‘")
1214          print(f"ā•‘{'':4}└─ Failed:              {failed:>10,}{' '*32}ā•‘")
1215          
1216          print(f"ā• {'═'*70}ā•£")
1217          
1218          # Tokens section
1219          print(f"ā•‘{'':2}šŸ”¢ TOKENS{' '*60}ā•‘")
1220          print(f"ā•‘{'─'*70}ā•‘")
1221          print(f"ā•‘{'':4}Before Compression:     {tokens_before:>15,} tokens{' '*21}ā•‘")
1222          print(f"ā•‘{'':4}After Compression:      {tokens_after:>15,} tokens{' '*21}ā•‘")
1223          print(f"ā•‘{'':4}Total Saved:            {tokens_saved:>15,} tokens{' '*21}ā•‘")
1224          print(f"ā•‘{'':4}Overall Compression:    {m['tokens']['overall_compression_ratio']:>14.1%}{' '*28}ā•‘")
1225          
1226          if tokens_before > 0:
1227              savings_pct = (tokens_saved / tokens_before) * 100
1228              print(f"ā•‘{'':4}Space Savings:          {savings_pct:>14.1f}%{' '*28}ā•‘")
1229          
1230          print(f"ā• {'═'*70}ā•£")
1231          
1232          # Turns section
1233          print(f"ā•‘{'':2}šŸ’¬ CONVERSATION TURNS{' '*48}ā•‘")
1234          print(f"ā•‘{'─'*70}ā•‘")
1235          print(f"ā•‘{'':4}Before Compression:     {m['turns']['total_before']:>15,} turns{' '*22}ā•‘")
1236          print(f"ā•‘{'':4}After Compression:      {m['turns']['total_after']:>15,} turns{' '*22}ā•‘")
1237          print(f"ā•‘{'':4}Total Removed:          {m['turns']['total_removed']:>15,} turns{' '*22}ā•‘")
1238          
1239          print(f"ā• {'═'*70}ā•£")
1240          
1241          # Averages section (for compressed trajectories only)
1242          print(f"ā•‘{'':2}šŸ“ˆ AVERAGES (Compressed Trajectories Only){' '*27}ā•‘")
1243          print(f"ā•‘{'─'*70}ā•‘")
1244          if compressed > 0:
1245              print(f"ā•‘{'':4}Avg Compression Ratio:  {m['averages']['avg_compression_ratio']:>14.1%}{' '*28}ā•‘")
1246              print(f"ā•‘{'':4}Avg Tokens Saved:       {m['averages']['avg_tokens_saved_per_compressed']:>14,.0f}{' '*28}ā•‘")
1247              print(f"ā•‘{'':4}Avg Turns Removed:      {m['averages']['avg_turns_removed_per_compressed']:>14.1f}{' '*28}ā•‘")
1248          else:
1249              print(f"ā•‘{'':4}No trajectories were compressed{' '*38}ā•‘")
1250          
1251          print(f"ā• {'═'*70}ā•£")
1252          
1253          # Summarization API section
1254          print(f"ā•‘{'':2}šŸ¤– SUMMARIZATION API{' '*49}ā•‘")
1255          print(f"ā•‘{'─'*70}ā•‘")
1256          print(f"ā•‘{'':4}API Calls Made:         {m['summarization']['total_api_calls']:>15,}{' '*27}ā•‘")
1257          print(f"ā•‘{'':4}Errors:                 {m['summarization']['total_errors']:>15,}{' '*27}ā•‘")
1258          print(f"ā•‘{'':4}Success Rate:           {m['summarization']['success_rate']:>14.1%}{' '*28}ā•‘")
1259          
1260          print(f"ā• {'═'*70}ā•£")
1261          
1262          # Processing time section
1263          duration = m['processing']['duration_seconds']
1264          if duration > 60:
1265              time_str = f"{duration/60:.1f} minutes"
1266          else:
1267              time_str = f"{duration:.1f} seconds"
1268          
1269          throughput = total / max(duration, 0.001)
1270          
1271          print(f"ā•‘{'':2}ā±ļø  PROCESSING TIME{' '*51}ā•‘")
1272          print(f"ā•‘{'─'*70}ā•‘")
1273          print(f"ā•‘{'':4}Duration:               {time_str:>20}{' '*22}ā•‘")
1274          print(f"ā•‘{'':4}Throughput:             {throughput:>15.1f} traj/sec{' '*18}ā•‘")
1275          print(f"ā•‘{'':4}Started:                {m['processing']['start_time'][:19]:>20}{' '*22}ā•‘")
1276          print(f"ā•‘{'':4}Finished:               {m['processing']['end_time'][:19]:>20}{' '*22}ā•‘")
1277          
1278          print(f"ā•š{'═'*70}ā•")
1279          
1280          # Distribution summary if we have data
1281          if self.aggregate_metrics.compression_ratios:
1282              ratios = self.aggregate_metrics.compression_ratios
1283              tokens_saved_list = self.aggregate_metrics.tokens_saved_list
1284              
1285              print(f"\nšŸ“Š Distribution Summary:")
1286              print(f"   Compression ratios: min={min(ratios):.2%}, max={max(ratios):.2%}, median={sorted(ratios)[len(ratios)//2]:.2%}")
1287              print(f"   Tokens saved:       min={min(tokens_saved_list):,}, max={max(tokens_saved_list):,}, median={sorted(tokens_saved_list)[len(tokens_saved_list)//2]:,}")
1288  
1289  
1290  def main(
1291      input: str,
1292      output: str = None,
1293      config: str = "configs/trajectory_compression.yaml",
1294      target_max_tokens: int = None,
1295      tokenizer: str = None,
1296      sample_percent: float = None,
1297      seed: int = 42,
1298      dry_run: bool = False,
1299  ):
1300      """
1301      Compress agent trajectories to fit within a target token budget.
1302      
1303      Supports both single JSONL files and directories containing multiple JSONL files.
1304      Optionally sample a percentage of trajectories before compression.
1305      
1306      Args:
1307          input: Path to JSONL file or directory containing JSONL files
1308          output: Output path (file for file input, directory for dir input)
1309                  Default: adds "_compressed" suffix to input name
1310          config: Path to YAML configuration file
1311          target_max_tokens: Override target token count from config
1312          tokenizer: Override tokenizer name from config
1313          sample_percent: Sample this percentage of trajectories (1-100) before compression
1314          seed: Random seed for sampling reproducibility (default: 42)
1315          dry_run: Analyze without compressing (just show what would happen)
1316      
1317      Examples:
1318          # Compress a directory (original behavior)
1319          python trajectory_compressor.py --input=data/my_run
1320          
1321          # Compress a single file
1322          python trajectory_compressor.py --input=data/trajectories.jsonl
1323          
1324          # Compress 15% sample of a file
1325          python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
1326          
1327          # Compress 10% sample with custom output
1328          python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=10 --output=data/sampled_compressed.jsonl
1329      """
1330      import random
1331      import tempfile
1332      import shutil
1333      
1334      print("šŸ—œļø  Trajectory Compressor")
1335      print("=" * 60)
1336      
1337      # Load configuration
1338      config_path = Path(config)
1339      if config_path.exists():
1340          print(f"šŸ“‹ Loading config from {config}")
1341          compression_config = CompressionConfig.from_yaml(config)
1342      else:
1343          print(f"āš ļø  Config not found at {config}, using defaults")
1344          compression_config = CompressionConfig()
1345      
1346      # Apply CLI overrides
1347      if target_max_tokens:
1348          compression_config.target_max_tokens = target_max_tokens
1349      if tokenizer:
1350          compression_config.tokenizer_name = tokenizer
1351      
1352      # Validate sample_percent
1353      if sample_percent is not None:
1354          if sample_percent <= 0 or sample_percent > 100:
1355              print(f"āŒ sample_percent must be between 1 and 100, got {sample_percent}")
1356              return
1357          print(f"šŸŽ² Will sample {sample_percent}% of trajectories (seed={seed})")
1358      
1359      # Setup paths and determine input type
1360      input_path = Path(input)
1361      if not input_path.exists():
1362          print(f"āŒ Input not found: {input}")
1363          return
1364      
1365      is_file_input = input_path.is_file()
1366      
1367      if is_file_input:
1368          print(f"šŸ“„ Input mode: Single JSONL file")
1369          
1370          # For file input, default output is file with _compressed suffix
1371          if output:
1372              output_path = Path(output)
1373          else:
1374              output_path = input_path.parent / (input_path.stem + compression_config.output_suffix + ".jsonl")
1375          
1376          # Load entries from the single file
1377          entries = []
1378          with open(input_path, 'r', encoding='utf-8') as f:
1379              for line_num, line in enumerate(f, 1):
1380                  line = line.strip()
1381                  if line:
1382                      try:
1383                          entries.append(json.loads(line))
1384                      except json.JSONDecodeError as e:
1385                          print(f"āš ļø  Skipping invalid JSON at line {line_num}: {e}")
1386          
1387          total_entries = len(entries)
1388          print(f"   Loaded {total_entries:,} trajectories from {input_path.name}")
1389          
1390          # Sample if requested
1391          if sample_percent is not None:
1392              random.seed(seed)
1393              sample_size = max(1, int(total_entries * sample_percent / 100))
1394              entries = random.sample(entries, sample_size)
1395              print(f"   Sampled {len(entries):,} trajectories ({sample_percent}% of {total_entries:,})")
1396          
1397          if dry_run:
1398              print(f"\nšŸ” DRY RUN MODE - analyzing without writing")
1399              print(f"šŸ“„ Would process: {len(entries):,} trajectories")
1400              print(f"šŸ“„ Would output to: {output_path}")
1401              return
1402          
1403          # Create a temporary directory for processing
1404          with tempfile.TemporaryDirectory() as temp_dir:
1405              temp_input_dir = Path(temp_dir) / "input"
1406              temp_output_dir = Path(temp_dir) / "output"
1407              temp_input_dir.mkdir()
1408              
1409              # Write entries to temp file
1410              temp_input_file = temp_input_dir / "trajectories.jsonl"
1411              with open(temp_input_file, 'w', encoding='utf-8') as f:
1412                  for entry in entries:
1413                      f.write(json.dumps(entry, ensure_ascii=False) + '\n')
1414              
1415              # Initialize compressor and process
1416              compressor = TrajectoryCompressor(compression_config)
1417              compressor.process_directory(temp_input_dir, temp_output_dir)
1418              
1419              # Copy result to output path (merge all files in temp_output_dir)
1420              output_path.parent.mkdir(parents=True, exist_ok=True)
1421              with open(output_path, 'w', encoding='utf-8') as out_f:
1422                  for jsonl_file in sorted(temp_output_dir.glob("*.jsonl")):
1423                      with open(jsonl_file, 'r', encoding='utf-8') as in_f:
1424                          for line in in_f:
1425                              out_f.write(line)
1426              
1427              # Copy metrics file if it exists
1428              metrics_file = temp_output_dir / compression_config.metrics_output_file
1429              if metrics_file.exists():
1430                  metrics_output = output_path.parent / (output_path.stem + "_metrics.json")
1431                  shutil.copy(metrics_file, metrics_output)
1432                  print(f"šŸ’¾ Metrics saved to {metrics_output}")
1433          
1434          print(f"\nāœ… Compression complete!")
1435          print(f"šŸ“„ Output: {output_path}")
1436          
1437      else:
1438          # Directory input - original behavior
1439          print(f"šŸ“ Input mode: Directory of JSONL files")
1440          
1441          if output:
1442              output_path = Path(output)
1443          else:
1444              output_path = input_path.parent / (input_path.name + compression_config.output_suffix)
1445          
1446          # If sampling is requested for directory mode, we need to handle it differently
1447          if sample_percent is not None:
1448              print(f"\nāš ļø  Sampling from directory: will sample {sample_percent}% from each file")
1449              
1450              # Create a temp directory with sampled files
1451              with tempfile.TemporaryDirectory() as temp_dir:
1452                  temp_input_dir = Path(temp_dir) / "input"
1453                  temp_input_dir.mkdir()
1454                  
1455                  random.seed(seed)
1456                  total_original = 0
1457                  total_sampled = 0
1458                  
1459                  # Sample from each JSONL file
1460                  for jsonl_file in sorted(input_path.glob("*.jsonl")):
1461                      entries = []
1462                      with open(jsonl_file, 'r', encoding='utf-8') as f:
1463                          for line in f:
1464                              line = line.strip()
1465                              if line:
1466                                  try:
1467                                      entries.append(json.loads(line))
1468                                  except json.JSONDecodeError:
1469                                      pass
1470                      
1471                      total_original += len(entries)
1472                      sample_size = max(1, int(len(entries) * sample_percent / 100))
1473                      sampled_entries = random.sample(entries, min(sample_size, len(entries)))
1474                      total_sampled += len(sampled_entries)
1475                      
1476                      # Write sampled entries
1477                      temp_file = temp_input_dir / jsonl_file.name
1478                      with open(temp_file, 'w', encoding='utf-8') as f:
1479                          for entry in sampled_entries:
1480                              f.write(json.dumps(entry, ensure_ascii=False) + '\n')
1481                  
1482                  print(f"   Sampled {total_sampled:,} from {total_original:,} total trajectories")
1483                  
1484                  if dry_run:
1485                      print(f"\nšŸ” DRY RUN MODE - analyzing without writing")
1486                      print(f"šŸ“ Would process: {temp_input_dir}")
1487                      print(f"šŸ“ Would output to: {output_path}")
1488                      return
1489                  
1490                  # Initialize compressor and process the sampled data
1491                  compressor = TrajectoryCompressor(compression_config)
1492                  compressor.process_directory(temp_input_dir, output_path)
1493          else:
1494              if dry_run:
1495                  print(f"\nšŸ” DRY RUN MODE - analyzing without writing")
1496                  print(f"šŸ“ Would process: {input_path}")
1497                  print(f"šŸ“ Would output to: {output_path}")
1498                  return
1499              
1500              # Initialize compressor and process directly
1501              compressor = TrajectoryCompressor(compression_config)
1502              compressor.process_directory(input_path, output_path)
1503          
1504          print("\nāœ… Compression complete!")
1505  
1506  
1507  if __name__ == "__main__":
1508      fire.Fire(main)