/ trajectory_compressor.py
trajectory_compressor.py
1 #!/usr/bin/env python3 2 """ 3 Trajectory Compressor 4 5 Post-processes completed agent trajectories to compress them within a target 6 token budget while preserving training signal quality. 7 8 Compression Strategy: 9 1. Protect first turns (system, human, first gpt, first tool) 10 2. Protect last N turns (final actions and conclusions) 11 3. Compress MIDDLE turns only, starting from 2nd tool response 12 4. Compress only as much as needed to fit under target 13 5. Replace compressed region with a single human summary message 14 6. Keep remaining tool calls intact (model continues working after summary) 15 16 Usage: 17 # Compress a directory of JSONL files 18 python trajectory_compressor.py --input=data/my_run 19 20 # Compress a single JSONL file 21 python trajectory_compressor.py --input=data/trajectories.jsonl 22 23 # Compress 15% sample of a file 24 python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15 25 26 # Compress with custom output and token target 27 python trajectory_compressor.py --input=data/trajectories.jsonl --output=compressed.jsonl --target_max_tokens=16000 28 29 # Compress 10% sample from a directory 30 python trajectory_compressor.py --input=data/my_run --sample_percent=10 31 """ 32 33 import json 34 import os 35 import time 36 import yaml 37 import logging 38 import asyncio 39 from pathlib import Path 40 from typing import List, Dict, Any, Optional, Tuple 41 from dataclasses import dataclass, field 42 from datetime import datetime 43 44 from utils import base_url_host_matches, base_url_hostname 45 import fire 46 from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn 47 from rich.console import Console 48 from hermes_constants import OPENROUTER_BASE_URL, get_hermes_home 49 from agent.retry_utils import jittered_backoff 50 51 # Load .env from HERMES_HOME first, then project root as a dev fallback. 52 from hermes_cli.env_loader import load_hermes_dotenv 53 54 _hermes_home = get_hermes_home() 55 _project_env = Path(__file__).parent / ".env" 56 load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env) 57 58 59 def _effective_temperature_for_model( 60 model: str, 61 requested_temperature: float, 62 base_url: Optional[str] = None, 63 ) -> Optional[float]: 64 """Apply fixed model temperature contracts to direct client calls. 65 66 Returns ``None`` when the model manages temperature server-side (Kimi); 67 callers must omit the ``temperature`` kwarg entirely in that case. 68 """ 69 try: 70 from agent.auxiliary_client import _fixed_temperature_for_model, OMIT_TEMPERATURE 71 except Exception: 72 return requested_temperature 73 74 fixed_temperature = _fixed_temperature_for_model(model, base_url) 75 if fixed_temperature is OMIT_TEMPERATURE: 76 return None # caller must omit temperature 77 if fixed_temperature is not None: 78 return fixed_temperature 79 return requested_temperature 80 81 82 @dataclass 83 class CompressionConfig: 84 """Configuration for trajectory compression.""" 85 # Tokenizer 86 tokenizer_name: str = "moonshotai/Kimi-K2-Thinking" 87 trust_remote_code: bool = True 88 89 # Compression targets 90 target_max_tokens: int = 15250 91 summary_target_tokens: int = 750 92 93 # Protected turns 94 protect_first_system: bool = True 95 protect_first_human: bool = True 96 protect_first_gpt: bool = True 97 protect_first_tool: bool = True 98 protect_last_n_turns: int = 4 99 100 # Summarization (OpenRouter) 101 summarization_model: str = "google/gemini-3-flash-preview" 102 base_url: str = OPENROUTER_BASE_URL 103 api_key_env: str = "OPENROUTER_API_KEY" 104 temperature: float = 0.3 105 max_retries: int = 3 106 retry_delay: int = 2 107 108 # Output 109 add_summary_notice: bool = True 110 summary_notice_text: str = "\n\nSome of your previous tool responses may be summarized to preserve context." 111 output_suffix: str = "_compressed" 112 113 # Processing 114 num_workers: int = 4 115 max_concurrent_requests: int = 50 # Max concurrent API calls for summarization 116 skip_under_target: bool = True 117 save_over_limit: bool = True 118 per_trajectory_timeout: int = 300 # Timeout per trajectory in seconds (default: 5 min) 119 120 # Metrics 121 metrics_enabled: bool = True 122 metrics_per_trajectory: bool = True 123 metrics_output_file: str = "compression_metrics.json" 124 125 @classmethod 126 def from_yaml(cls, yaml_path: str) -> "CompressionConfig": 127 """Load configuration from YAML file.""" 128 with open(yaml_path, 'r') as f: 129 data = yaml.safe_load(f) 130 131 config = cls() 132 133 # Tokenizer 134 if 'tokenizer' in data: 135 config.tokenizer_name = data['tokenizer'].get('name', config.tokenizer_name) 136 config.trust_remote_code = data['tokenizer'].get('trust_remote_code', config.trust_remote_code) 137 138 # Compression 139 if 'compression' in data: 140 config.target_max_tokens = data['compression'].get('target_max_tokens', config.target_max_tokens) 141 config.summary_target_tokens = data['compression'].get('summary_target_tokens', config.summary_target_tokens) 142 143 # Protected turns 144 if 'protected_turns' in data: 145 config.protect_first_system = data['protected_turns'].get('first_system', config.protect_first_system) 146 config.protect_first_human = data['protected_turns'].get('first_human', config.protect_first_human) 147 config.protect_first_gpt = data['protected_turns'].get('first_gpt', config.protect_first_gpt) 148 config.protect_first_tool = data['protected_turns'].get('first_tool', config.protect_first_tool) 149 config.protect_last_n_turns = data['protected_turns'].get('last_n_turns', config.protect_last_n_turns) 150 151 # Summarization 152 if 'summarization' in data: 153 config.summarization_model = data['summarization'].get('model', config.summarization_model) 154 config.base_url = data['summarization'].get('base_url') or config.base_url 155 config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env) 156 config.temperature = data['summarization'].get('temperature', config.temperature) 157 config.max_retries = data['summarization'].get('max_retries', config.max_retries) 158 config.retry_delay = data['summarization'].get('retry_delay', config.retry_delay) 159 160 # Output 161 if 'output' in data: 162 config.add_summary_notice = data['output'].get('add_summary_notice', config.add_summary_notice) 163 config.summary_notice_text = data['output'].get('summary_notice_text', config.summary_notice_text) 164 config.output_suffix = data['output'].get('output_suffix', config.output_suffix) 165 166 # Processing 167 if 'processing' in data: 168 config.num_workers = data['processing'].get('num_workers', config.num_workers) 169 config.max_concurrent_requests = data['processing'].get('max_concurrent_requests', config.max_concurrent_requests) 170 config.skip_under_target = data['processing'].get('skip_under_target', config.skip_under_target) 171 config.save_over_limit = data['processing'].get('save_over_limit', config.save_over_limit) 172 173 # Metrics 174 if 'metrics' in data: 175 config.metrics_enabled = data['metrics'].get('enabled', config.metrics_enabled) 176 config.metrics_per_trajectory = data['metrics'].get('per_trajectory', config.metrics_per_trajectory) 177 config.metrics_output_file = data['metrics'].get('output_file', config.metrics_output_file) 178 179 return config 180 181 182 @dataclass 183 class TrajectoryMetrics: 184 """Metrics for a single trajectory compression.""" 185 original_tokens: int = 0 186 compressed_tokens: int = 0 187 tokens_saved: int = 0 188 compression_ratio: float = 1.0 189 190 original_turns: int = 0 191 compressed_turns: int = 0 192 turns_removed: int = 0 193 194 turns_compressed_start_idx: int = -1 195 turns_compressed_end_idx: int = -1 196 turns_in_compressed_region: int = 0 197 198 was_compressed: bool = False 199 still_over_limit: bool = False 200 skipped_under_target: bool = False 201 202 summarization_api_calls: int = 0 203 summarization_errors: int = 0 204 205 def to_dict(self) -> Dict[str, Any]: 206 return { 207 "original_tokens": self.original_tokens, 208 "compressed_tokens": self.compressed_tokens, 209 "tokens_saved": self.tokens_saved, 210 "compression_ratio": round(self.compression_ratio, 4), 211 "original_turns": self.original_turns, 212 "compressed_turns": self.compressed_turns, 213 "turns_removed": self.turns_removed, 214 "compression_region": { 215 "start_idx": self.turns_compressed_start_idx, 216 "end_idx": self.turns_compressed_end_idx, 217 "turns_count": self.turns_in_compressed_region, 218 }, 219 "was_compressed": self.was_compressed, 220 "still_over_limit": self.still_over_limit, 221 "skipped_under_target": self.skipped_under_target, 222 "summarization_api_calls": self.summarization_api_calls, 223 "summarization_errors": self.summarization_errors, 224 } 225 226 227 @dataclass 228 class AggregateMetrics: 229 """Aggregate metrics across all trajectories.""" 230 total_trajectories: int = 0 231 trajectories_compressed: int = 0 232 trajectories_skipped_under_target: int = 0 233 trajectories_still_over_limit: int = 0 234 trajectories_failed: int = 0 235 236 total_tokens_before: int = 0 237 total_tokens_after: int = 0 238 total_tokens_saved: int = 0 239 240 total_turns_before: int = 0 241 total_turns_after: int = 0 242 total_turns_removed: int = 0 243 244 total_summarization_calls: int = 0 245 total_summarization_errors: int = 0 246 247 # Distribution stats 248 compression_ratios: List[float] = field(default_factory=list) 249 tokens_saved_list: List[int] = field(default_factory=list) 250 turns_removed_list: List[int] = field(default_factory=list) 251 252 processing_start_time: str = "" 253 processing_end_time: str = "" 254 processing_duration_seconds: float = 0.0 255 256 def add_trajectory_metrics(self, metrics: TrajectoryMetrics): 257 """Add a trajectory's metrics to the aggregate.""" 258 self.total_trajectories += 1 259 self.total_tokens_before += metrics.original_tokens 260 self.total_tokens_after += metrics.compressed_tokens 261 self.total_tokens_saved += metrics.tokens_saved 262 self.total_turns_before += metrics.original_turns 263 self.total_turns_after += metrics.compressed_turns 264 self.total_turns_removed += metrics.turns_removed 265 self.total_summarization_calls += metrics.summarization_api_calls 266 self.total_summarization_errors += metrics.summarization_errors 267 268 if metrics.was_compressed: 269 self.trajectories_compressed += 1 270 self.compression_ratios.append(metrics.compression_ratio) 271 self.tokens_saved_list.append(metrics.tokens_saved) 272 self.turns_removed_list.append(metrics.turns_removed) 273 274 if metrics.skipped_under_target: 275 self.trajectories_skipped_under_target += 1 276 277 if metrics.still_over_limit: 278 self.trajectories_still_over_limit += 1 279 280 def to_dict(self) -> Dict[str, Any]: 281 avg_compression_ratio = ( 282 sum(self.compression_ratios) / len(self.compression_ratios) 283 if self.compression_ratios else 1.0 284 ) 285 avg_tokens_saved = ( 286 sum(self.tokens_saved_list) / len(self.tokens_saved_list) 287 if self.tokens_saved_list else 0 288 ) 289 avg_turns_removed = ( 290 sum(self.turns_removed_list) / len(self.turns_removed_list) 291 if self.turns_removed_list else 0 292 ) 293 294 return { 295 "summary": { 296 "total_trajectories": self.total_trajectories, 297 "trajectories_compressed": self.trajectories_compressed, 298 "trajectories_skipped_under_target": self.trajectories_skipped_under_target, 299 "trajectories_still_over_limit": self.trajectories_still_over_limit, 300 "trajectories_failed": self.trajectories_failed, 301 "compression_rate": round(self.trajectories_compressed / max(self.total_trajectories, 1), 4), 302 }, 303 "tokens": { 304 "total_before": self.total_tokens_before, 305 "total_after": self.total_tokens_after, 306 "total_saved": self.total_tokens_saved, 307 "overall_compression_ratio": round(self.total_tokens_after / max(self.total_tokens_before, 1), 4), 308 }, 309 "turns": { 310 "total_before": self.total_turns_before, 311 "total_after": self.total_turns_after, 312 "total_removed": self.total_turns_removed, 313 }, 314 "averages": { 315 "avg_compression_ratio": round(avg_compression_ratio, 4), 316 "avg_tokens_saved_per_compressed": round(avg_tokens_saved, 1), 317 "avg_turns_removed_per_compressed": round(avg_turns_removed, 2), 318 }, 319 "summarization": { 320 "total_api_calls": self.total_summarization_calls, 321 "total_errors": self.total_summarization_errors, 322 "success_rate": round(1 - (self.total_summarization_errors / max(self.total_summarization_calls, 1)), 4), 323 }, 324 "processing": { 325 "start_time": self.processing_start_time, 326 "end_time": self.processing_end_time, 327 "duration_seconds": round(self.processing_duration_seconds, 2), 328 }, 329 } 330 331 332 class TrajectoryCompressor: 333 """ 334 Compresses agent trajectories to fit within a target token budget. 335 336 Compression strategy: 337 1. Keep protected head turns (system, human, first gpt+tool) 338 2. Keep protected tail turns (last N turns) 339 3. From the compressible middle region, compress only as much as needed 340 4. Replace compressed turns with a single human summary message 341 5. Keep remaining middle turns intact (model continues with tools) 342 """ 343 344 def __init__(self, config: CompressionConfig): 345 """Initialize the compressor.""" 346 self.config = config 347 self.aggregate_metrics = AggregateMetrics() 348 349 # Initialize tokenizer 350 self._init_tokenizer() 351 352 # Initialize OpenRouter client 353 self._init_summarizer() 354 355 logging.basicConfig( 356 level=logging.INFO, 357 format='%(asctime)s - %(levelname)s - %(message)s', 358 datefmt='%H:%M:%S' 359 ) 360 self.logger = logging.getLogger(__name__) 361 362 def _init_tokenizer(self): 363 """Initialize HuggingFace tokenizer for token counting.""" 364 try: 365 from transformers import AutoTokenizer 366 self.tokenizer = AutoTokenizer.from_pretrained( 367 self.config.tokenizer_name, 368 trust_remote_code=self.config.trust_remote_code 369 ) 370 print(f"ā Loaded tokenizer: {self.config.tokenizer_name}") 371 except Exception as e: 372 raise RuntimeError(f"Failed to load tokenizer '{self.config.tokenizer_name}': {e}") 373 374 def _init_summarizer(self): 375 """Initialize LLM routing for summarization (sync and async). 376 377 Uses call_llm/async_call_llm from the centralized provider router 378 which handles auth, headers, and provider detection internally. 379 For custom endpoints, falls back to raw client construction. 380 """ 381 382 provider = self._detect_provider() 383 if provider: 384 # Store provider for use in _generate_summary calls 385 self._llm_provider = provider 386 self._use_call_llm = True 387 # Verify the provider is available 388 from agent.auxiliary_client import resolve_provider_client 389 client, _ = resolve_provider_client( 390 provider, model=self.config.summarization_model) 391 if client is None: 392 raise RuntimeError( 393 f"Provider '{provider}' is not configured. " 394 f"Check your API key or run: hermes setup") 395 self.client = None # Not used directly 396 self.async_client = None # Not used directly 397 else: 398 # Custom endpoint ā use config's raw base_url + api_key_env 399 self._use_call_llm = False 400 api_key = os.getenv(self.config.api_key_env) 401 if not api_key: 402 raise RuntimeError( 403 f"Missing API key. Set {self.config.api_key_env} " 404 f"environment variable.") 405 from openai import OpenAI 406 from agent.auxiliary_client import _to_openai_base_url 407 self.client = OpenAI( 408 api_key=api_key, base_url=_to_openai_base_url(self.config.base_url)) 409 # AsyncOpenAI is created lazily in _get_async_client() so it 410 # binds to the current event loop ā avoids "Event loop is closed" 411 # when process_directory() is called multiple times (each call 412 # creates a new loop via asyncio.run()). 413 self.async_client = None 414 self._async_client_api_key = api_key 415 416 print(f"ā Initialized summarizer client: {self.config.summarization_model}") 417 print(f" Max concurrent requests: {self.config.max_concurrent_requests}") 418 419 def _get_async_client(self): 420 """Return an AsyncOpenAI client bound to the current event loop. 421 422 Created lazily so that each ``asyncio.run()`` call in 423 ``process_directory()`` gets a client tied to its own loop, 424 avoiding "Event loop is closed" errors on repeated calls. 425 """ 426 from openai import AsyncOpenAI 427 from agent.auxiliary_client import _to_openai_base_url 428 # Always create a fresh client so it binds to the running loop. 429 self.async_client = AsyncOpenAI( 430 api_key=self._async_client_api_key, 431 base_url=_to_openai_base_url(self.config.base_url), 432 ) 433 return self.async_client 434 435 def _detect_provider(self) -> str: 436 """Detect the provider name from the configured base_url.""" 437 url = self.config.base_url or "" 438 if base_url_host_matches(url, "openrouter.ai"): 439 return "openrouter" 440 if base_url_host_matches(url, "nousresearch.com"): 441 return "nous" 442 if ( 443 base_url_hostname(url) == "chatgpt.com" 444 and "/backend-api/codex" in url.lower() 445 ): 446 return "codex" 447 if base_url_host_matches(url, "z.ai"): 448 return "zai" 449 if ( 450 base_url_host_matches(url, "moonshot.ai") 451 or base_url_host_matches(url, "moonshot.cn") 452 or base_url_host_matches(url, "api.kimi.com") 453 ): 454 return "kimi-coding" 455 if base_url_host_matches(url, "arcee.ai"): 456 return "arcee" 457 if base_url_host_matches(url, "minimaxi.com"): 458 return "minimax-cn" 459 if base_url_host_matches(url, "minimax.io"): 460 return "minimax" 461 # Unknown base_url ā not a known provider 462 return "" 463 464 def count_tokens(self, text: str) -> int: 465 """Count tokens in text using the configured tokenizer.""" 466 if not text: 467 return 0 468 try: 469 return len(self.tokenizer.encode(text)) 470 except Exception: 471 # Fallback to character estimate 472 return len(text) // 4 473 474 def count_trajectory_tokens(self, trajectory: List[Dict[str, str]]) -> int: 475 """Count total tokens in a trajectory.""" 476 return sum(self.count_tokens(turn.get("value", "")) for turn in trajectory) 477 478 def count_turn_tokens(self, trajectory: List[Dict[str, str]]) -> List[int]: 479 """Count tokens for each turn in a trajectory.""" 480 return [self.count_tokens(turn.get("value", "")) for turn in trajectory] 481 482 def _find_protected_indices(self, trajectory: List[Dict[str, str]]) -> Tuple[set, int, int]: 483 """ 484 Find indices of protected turns. 485 486 Returns: 487 Tuple of (protected_set, compressible_start, compressible_end) 488 """ 489 n = len(trajectory) 490 protected = set() 491 492 # Track first occurrences 493 first_system = first_human = first_gpt = first_tool = None 494 495 for i, turn in enumerate(trajectory): 496 role = turn.get("from", "") 497 if role == "system" and first_system is None: 498 first_system = i 499 elif role == "human" and first_human is None: 500 first_human = i 501 elif role == "gpt" and first_gpt is None: 502 first_gpt = i 503 elif role == "tool" and first_tool is None: 504 first_tool = i 505 506 # Protect first turns 507 if self.config.protect_first_system and first_system is not None: 508 protected.add(first_system) 509 if self.config.protect_first_human and first_human is not None: 510 protected.add(first_human) 511 if self.config.protect_first_gpt and first_gpt is not None: 512 protected.add(first_gpt) 513 if self.config.protect_first_tool and first_tool is not None: 514 protected.add(first_tool) 515 516 # Protect last N turns 517 for i in range(max(0, n - self.config.protect_last_n_turns), n): 518 protected.add(i) 519 520 # Determine compressible region 521 # Start after the last protected head turn 522 head_protected = [i for i in protected if i < n // 2] 523 tail_protected = [i for i in protected if i >= n // 2] 524 525 compressible_start = max(head_protected) + 1 if head_protected else 0 526 compressible_end = min(tail_protected) if tail_protected else n 527 528 return protected, compressible_start, compressible_end 529 530 def _extract_turn_content_for_summary(self, trajectory: List[Dict[str, str]], start: int, end: int) -> str: 531 """ 532 Extract content from turns to be summarized. 533 534 Args: 535 trajectory: Full trajectory 536 start: Start index (inclusive) 537 end: End index (exclusive) 538 539 Returns: 540 Formatted string of turn contents for summarization 541 """ 542 parts = [] 543 for i in range(start, end): 544 turn = trajectory[i] 545 role = turn.get("from", "unknown") 546 value = turn.get("value", "") 547 548 # Truncate very long values for the summary prompt 549 if len(value) > 3000: 550 value = value[:1500] + "\n...[truncated]...\n" + value[-500:] 551 552 parts.append(f"[Turn {i} - {role.upper()}]:\n{value}") 553 554 return "\n\n".join(parts) 555 556 @staticmethod 557 def _coerce_summary_content(content: Any) -> str: 558 """Normalize summary-model output to a safe string.""" 559 if not isinstance(content, str): 560 content = str(content) if content else "" 561 return content.strip() 562 563 @staticmethod 564 def _ensure_summary_prefix(summary: str) -> str: 565 """Normalize summary text to include the expected prefix exactly once.""" 566 text = (summary or "").strip() 567 if text.startswith("[CONTEXT SUMMARY]:"): 568 return text 569 return "[CONTEXT SUMMARY]:" if not text else f"[CONTEXT SUMMARY]: {text}" 570 571 def _generate_summary(self, content: str, metrics: TrajectoryMetrics) -> str: 572 """ 573 Generate a summary of the compressed turns using OpenRouter. 574 575 Args: 576 content: The content to summarize 577 metrics: Metrics object to update 578 579 Returns: 580 Summary string 581 """ 582 prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history. 583 584 Write the summary from a neutral perspective describing what the assistant did and learned. Include: 585 1. What actions the assistant took (tool calls, searches, file operations) 586 2. Key information or results obtained 587 3. Any important decisions or findings 588 4. Relevant data, file names, values, or outputs 589 590 Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens. 591 592 --- 593 TURNS TO SUMMARIZE: 594 {content} 595 --- 596 597 Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" 598 599 for attempt in range(self.config.max_retries): 600 try: 601 metrics.summarization_api_calls += 1 602 summary_temperature = _effective_temperature_for_model( 603 self.config.summarization_model, 604 self.config.temperature, 605 self.config.base_url, 606 ) 607 608 if getattr(self, '_use_call_llm', False): 609 from agent.auxiliary_client import call_llm 610 response = call_llm( 611 provider=self._llm_provider, 612 model=self.config.summarization_model, 613 messages=[{"role": "user", "content": prompt}], 614 temperature=summary_temperature, 615 max_tokens=self.config.summary_target_tokens * 2, 616 ) 617 else: 618 _create_kwargs = { 619 "model": self.config.summarization_model, 620 "messages": [{"role": "user", "content": prompt}], 621 "max_tokens": self.config.summary_target_tokens * 2, 622 } 623 if summary_temperature is not None: 624 _create_kwargs["temperature"] = summary_temperature 625 response = self.client.chat.completions.create(**_create_kwargs) 626 627 summary = self._coerce_summary_content(response.choices[0].message.content) 628 return self._ensure_summary_prefix(summary) 629 630 except Exception as e: 631 metrics.summarization_errors += 1 632 self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") 633 634 if attempt < self.config.max_retries - 1: 635 time.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0)) 636 else: 637 # Fallback: create a basic summary 638 return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" 639 640 async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str: 641 """ 642 Generate a summary of the compressed turns using OpenRouter (async version). 643 644 Args: 645 content: The content to summarize 646 metrics: Metrics object to update 647 648 Returns: 649 Summary string 650 """ 651 prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history. 652 653 Write the summary from a neutral perspective describing what the assistant did and learned. Include: 654 1. What actions the assistant took (tool calls, searches, file operations) 655 2. Key information or results obtained 656 3. Any important decisions or findings 657 4. Relevant data, file names, values, or outputs 658 659 Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens. 660 661 --- 662 TURNS TO SUMMARIZE: 663 {content} 664 --- 665 666 Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" 667 668 for attempt in range(self.config.max_retries): 669 try: 670 metrics.summarization_api_calls += 1 671 summary_temperature = _effective_temperature_for_model( 672 self.config.summarization_model, 673 self.config.temperature, 674 self.config.base_url, 675 ) 676 677 if getattr(self, '_use_call_llm', False): 678 from agent.auxiliary_client import async_call_llm 679 response = await async_call_llm( 680 provider=self._llm_provider, 681 model=self.config.summarization_model, 682 messages=[{"role": "user", "content": prompt}], 683 temperature=summary_temperature, 684 max_tokens=self.config.summary_target_tokens * 2, 685 ) 686 else: 687 _create_kwargs = { 688 "model": self.config.summarization_model, 689 "messages": [{"role": "user", "content": prompt}], 690 "max_tokens": self.config.summary_target_tokens * 2, 691 } 692 if summary_temperature is not None: 693 _create_kwargs["temperature"] = summary_temperature 694 response = await self._get_async_client().chat.completions.create(**_create_kwargs) 695 696 summary = self._coerce_summary_content(response.choices[0].message.content) 697 return self._ensure_summary_prefix(summary) 698 699 except Exception as e: 700 metrics.summarization_errors += 1 701 self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") 702 703 if attempt < self.config.max_retries - 1: 704 await asyncio.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0)) 705 else: 706 # Fallback: create a basic summary 707 return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" 708 709 def compress_trajectory( 710 self, 711 trajectory: List[Dict[str, str]] 712 ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]: 713 """ 714 Compress a single trajectory to fit within target token budget. 715 716 Algorithm: 717 1. Count total tokens 718 2. If under target, skip 719 3. Find compressible region (between protected head and tail) 720 4. Calculate how many tokens need to be saved 721 5. Accumulate turns from start of compressible region until savings met 722 6. Replace accumulated turns with single human summary 723 7. Keep remaining turns intact 724 725 Args: 726 trajectory: List of conversation turns 727 728 Returns: 729 Tuple of (compressed_trajectory, metrics) 730 """ 731 metrics = TrajectoryMetrics() 732 metrics.original_turns = len(trajectory) 733 734 # Count tokens per turn 735 turn_tokens = self.count_turn_tokens(trajectory) 736 total_tokens = sum(turn_tokens) 737 metrics.original_tokens = total_tokens 738 739 # Check if compression needed 740 if total_tokens <= self.config.target_max_tokens: 741 metrics.skipped_under_target = True 742 metrics.compressed_tokens = total_tokens 743 metrics.compressed_turns = len(trajectory) 744 metrics.compression_ratio = 1.0 745 return trajectory, metrics 746 747 # Find protected regions 748 protected, compress_start, compress_end = self._find_protected_indices(trajectory) 749 750 # Check if there's anything to compress 751 if compress_start >= compress_end: 752 # Nothing to compress, return as-is 753 metrics.compressed_tokens = total_tokens 754 metrics.compressed_turns = len(trajectory) 755 metrics.still_over_limit = total_tokens > self.config.target_max_tokens 756 return trajectory, metrics 757 758 # Calculate how much we need to save 759 tokens_to_save = total_tokens - self.config.target_max_tokens 760 761 # We'll replace N turns with 1 summary turn 762 # Net savings = (sum of N turns' tokens) - summary_target_tokens 763 # We need: net_savings >= tokens_to_save 764 # So: sum of turns >= tokens_to_save + summary_target_tokens 765 target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens 766 767 # Accumulate turns from compress_start until we have enough savings 768 accumulated_tokens = 0 769 compress_until = compress_start 770 771 for i in range(compress_start, compress_end): 772 accumulated_tokens += turn_tokens[i] 773 compress_until = i + 1 # Exclusive end 774 775 # Check if we have enough savings 776 if accumulated_tokens >= target_tokens_to_compress: 777 break 778 779 # If we still don't have enough savings, compress the entire compressible region 780 if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: 781 compress_until = compress_end 782 accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) 783 784 # Record compression region 785 metrics.turns_compressed_start_idx = compress_start 786 metrics.turns_compressed_end_idx = compress_until 787 metrics.turns_in_compressed_region = compress_until - compress_start 788 789 # Extract content for summary 790 content_to_summarize = self._extract_turn_content_for_summary( 791 trajectory, compress_start, compress_until 792 ) 793 794 # Generate summary 795 summary = self._generate_summary(content_to_summarize, metrics) 796 797 # Build compressed trajectory 798 compressed = [] 799 800 # Add head (turns before compression region) 801 for i in range(compress_start): 802 turn = trajectory[i].copy() 803 # Add notice to system message 804 if turn.get("from") == "system" and self.config.add_summary_notice: 805 turn["value"] = turn["value"] + self.config.summary_notice_text 806 compressed.append(turn) 807 808 # Add summary as human message 809 compressed.append({ 810 "from": "human", 811 "value": summary 812 }) 813 814 # Add tail (turns after compression region) 815 for i in range(compress_until, len(trajectory)): 816 compressed.append(trajectory[i].copy()) 817 818 # Calculate final metrics 819 metrics.compressed_turns = len(compressed) 820 metrics.compressed_tokens = self.count_trajectory_tokens(compressed) 821 metrics.turns_removed = metrics.original_turns - metrics.compressed_turns 822 metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens 823 metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1) 824 metrics.was_compressed = True 825 metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens 826 827 return compressed, metrics 828 829 async def compress_trajectory_async( 830 self, 831 trajectory: List[Dict[str, str]] 832 ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]: 833 """ 834 Compress a single trajectory to fit within target token budget (async version). 835 836 Same algorithm as compress_trajectory but uses async API calls for summarization. 837 """ 838 metrics = TrajectoryMetrics() 839 metrics.original_turns = len(trajectory) 840 841 # Count tokens per turn 842 turn_tokens = self.count_turn_tokens(trajectory) 843 total_tokens = sum(turn_tokens) 844 metrics.original_tokens = total_tokens 845 846 # Check if compression needed 847 if total_tokens <= self.config.target_max_tokens: 848 metrics.skipped_under_target = True 849 metrics.compressed_tokens = total_tokens 850 metrics.compressed_turns = len(trajectory) 851 metrics.compression_ratio = 1.0 852 return trajectory, metrics 853 854 # Find protected regions 855 protected, compress_start, compress_end = self._find_protected_indices(trajectory) 856 857 # Check if there's anything to compress 858 if compress_start >= compress_end: 859 metrics.compressed_tokens = total_tokens 860 metrics.compressed_turns = len(trajectory) 861 metrics.still_over_limit = total_tokens > self.config.target_max_tokens 862 return trajectory, metrics 863 864 # Calculate how much we need to save 865 tokens_to_save = total_tokens - self.config.target_max_tokens 866 target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens 867 868 # Accumulate turns from compress_start until we have enough savings 869 accumulated_tokens = 0 870 compress_until = compress_start 871 872 for i in range(compress_start, compress_end): 873 accumulated_tokens += turn_tokens[i] 874 compress_until = i + 1 875 if accumulated_tokens >= target_tokens_to_compress: 876 break 877 878 # If we still don't have enough savings, compress the entire compressible region 879 if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: 880 compress_until = compress_end 881 accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) 882 883 # Record compression region 884 metrics.turns_compressed_start_idx = compress_start 885 metrics.turns_compressed_end_idx = compress_until 886 metrics.turns_in_compressed_region = compress_until - compress_start 887 888 # Extract content for summary 889 content_to_summarize = self._extract_turn_content_for_summary( 890 trajectory, compress_start, compress_until 891 ) 892 893 # Generate summary (ASYNC) 894 summary = await self._generate_summary_async(content_to_summarize, metrics) 895 896 # Build compressed trajectory 897 compressed = [] 898 899 # Add head (turns before compression region) 900 for i in range(compress_start): 901 turn = trajectory[i].copy() 902 if turn.get("from") == "system" and self.config.add_summary_notice: 903 turn["value"] = turn["value"] + self.config.summary_notice_text 904 compressed.append(turn) 905 906 # Add summary as human message 907 compressed.append({ 908 "from": "human", 909 "value": summary 910 }) 911 912 # Add tail (turns after compression region) 913 for i in range(compress_until, len(trajectory)): 914 compressed.append(trajectory[i].copy()) 915 916 # Calculate final metrics 917 metrics.compressed_turns = len(compressed) 918 metrics.compressed_tokens = self.count_trajectory_tokens(compressed) 919 metrics.turns_removed = metrics.original_turns - metrics.compressed_turns 920 metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens 921 metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1) 922 metrics.was_compressed = True 923 metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens 924 925 return compressed, metrics 926 927 async def process_entry_async(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]: 928 """ 929 Process a single JSONL entry (async version). 930 """ 931 if "conversations" not in entry: 932 metrics = TrajectoryMetrics() 933 return entry, metrics 934 935 trajectory = entry["conversations"] 936 compressed_trajectory, metrics = await self.compress_trajectory_async(trajectory) 937 938 # Create new entry with compressed trajectory 939 result = entry.copy() 940 result["conversations"] = compressed_trajectory 941 942 # Add compression metadata if enabled 943 if self.config.metrics_per_trajectory and metrics.was_compressed: 944 result["compression_metrics"] = metrics.to_dict() 945 946 return result, metrics 947 948 def process_entry(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]: 949 """ 950 Process a single JSONL entry. 951 952 Args: 953 entry: JSONL entry containing 'conversations' field 954 955 Returns: 956 Tuple of (processed_entry, metrics) 957 """ 958 if "conversations" not in entry: 959 metrics = TrajectoryMetrics() 960 return entry, metrics 961 962 trajectory = entry["conversations"] 963 compressed_trajectory, metrics = self.compress_trajectory(trajectory) 964 965 # Create new entry with compressed trajectory 966 result = entry.copy() 967 result["conversations"] = compressed_trajectory 968 969 # Add compression metadata if enabled 970 if self.config.metrics_per_trajectory and metrics.was_compressed: 971 result["compression_metrics"] = metrics.to_dict() 972 973 return result, metrics 974 975 def process_directory(self, input_dir: Path, output_dir: Path): 976 """ 977 Process all JSONL files in a directory using async parallel processing. 978 979 Args: 980 input_dir: Input directory containing JSONL files 981 output_dir: Output directory for compressed files 982 """ 983 # Run the async version 984 asyncio.run(self._process_directory_async(input_dir, output_dir)) 985 986 async def _process_directory_async(self, input_dir: Path, output_dir: Path): 987 """ 988 Async implementation of directory processing with parallel API calls. 989 """ 990 console = Console() 991 992 # Record start time 993 self.aggregate_metrics.processing_start_time = datetime.now().isoformat() 994 start_time = time.time() 995 996 # Find all JSONL files 997 jsonl_files = sorted(input_dir.glob("*.jsonl")) 998 999 if not jsonl_files: 1000 self.logger.warning(f"No JSONL files found in {input_dir}") 1001 return 1002 1003 # Load ALL entries from all files 1004 console.print("\n[dim]Loading all entries...[/dim]") 1005 all_entries = [] # List of (file_path, entry_idx, entry) 1006 1007 for file_path in jsonl_files: 1008 with open(file_path, 'r', encoding='utf-8') as f: 1009 for line_num, line in enumerate(f): 1010 line = line.strip() 1011 if line: 1012 try: 1013 entry = json.loads(line) 1014 all_entries.append((file_path, line_num, entry)) 1015 except json.JSONDecodeError as e: 1016 self.logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e}") 1017 1018 total_entries = len(all_entries) 1019 1020 console.print(f"\n{'='*60}") 1021 console.print(f"š Input: {input_dir}") 1022 console.print(f"š Output: {output_dir}") 1023 console.print(f"š Files to process: {len(jsonl_files)}") 1024 console.print(f"š Total trajectories: {total_entries:,}") 1025 console.print(f"šÆ Target max tokens: {self.config.target_max_tokens:,}") 1026 console.print(f"š Summary target tokens: {self.config.summary_target_tokens}") 1027 console.print(f"ā” Max concurrent API calls: {self.config.max_concurrent_requests}") 1028 console.print(f"{'='*60}\n") 1029 1030 # Create semaphore for rate limiting 1031 semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) 1032 1033 # Tracking for progress display (thread-safe with lock) 1034 progress_lock = asyncio.Lock() 1035 compressed_count = 0 1036 skipped_count = 0 1037 api_calls = 0 1038 in_flight = 0 1039 1040 # Results storage: {file_path: {entry_idx: (processed_entry, metrics)}} 1041 results = {f: {} for f in jsonl_files} 1042 1043 # Track timeouts separately 1044 timeout_count = 0 1045 1046 async def process_single(file_path: Path, entry_idx: int, entry: Dict, 1047 progress, main_task, status_task): 1048 """Process a single entry with semaphore rate limiting and timeout.""" 1049 nonlocal compressed_count, skipped_count, api_calls, in_flight, timeout_count 1050 1051 async with semaphore: 1052 # Track in-flight 1053 async with progress_lock: 1054 in_flight += 1 1055 1056 try: 1057 # Apply per-trajectory timeout 1058 processed_entry, metrics = await asyncio.wait_for( 1059 self.process_entry_async(entry), 1060 timeout=self.config.per_trajectory_timeout 1061 ) 1062 results[file_path][entry_idx] = (processed_entry, metrics) 1063 1064 # Update aggregate metrics (with lock for thread safety) 1065 async with progress_lock: 1066 self.aggregate_metrics.add_trajectory_metrics(metrics) 1067 1068 # Update counters 1069 if metrics.was_compressed: 1070 compressed_count += 1 1071 api_calls += metrics.summarization_api_calls 1072 if metrics.skipped_under_target: 1073 skipped_count += 1 1074 1075 in_flight -= 1 1076 1077 # Update progress 1078 progress.advance(main_task) 1079 progress.update( 1080 status_task, 1081 description=f"[dim]ā {compressed_count} compressed | āļø {skipped_count} skipped | ā±ļø {timeout_count} timeout | š {api_calls} API calls | ā” {in_flight} in-flight[/dim]" 1082 ) 1083 1084 except asyncio.TimeoutError: 1085 self.logger.warning(f"Timeout processing entry from {file_path}:{entry_idx} (>{self.config.per_trajectory_timeout}s)") 1086 1087 async with progress_lock: 1088 self.aggregate_metrics.trajectories_failed += 1 1089 timeout_count += 1 1090 in_flight -= 1 1091 progress.advance(main_task) 1092 progress.update( 1093 status_task, 1094 description=f"[dim]ā {compressed_count} compressed | āļø {skipped_count} skipped | ā±ļø {timeout_count} timeout | š {api_calls} API calls | ā” {in_flight} in-flight[/dim]" 1095 ) 1096 1097 # Skip this entry entirely (don't include in output) 1098 results[file_path][entry_idx] = None 1099 1100 except Exception as e: 1101 self.logger.error(f"Error processing entry from {file_path}:{entry_idx}: {e}") 1102 1103 async with progress_lock: 1104 self.aggregate_metrics.trajectories_failed += 1 1105 in_flight -= 1 1106 progress.advance(main_task) 1107 1108 # Keep original entry on error 1109 results[file_path][entry_idx] = (entry, TrajectoryMetrics()) 1110 1111 # Create progress bar 1112 with Progress( 1113 SpinnerColumn(), 1114 TextColumn("[progress.description]{task.description}"), 1115 BarColumn(), 1116 TaskProgressColumn(), 1117 TextColumn("ā¢"), 1118 TimeElapsedColumn(), 1119 TextColumn("ā¢"), 1120 TimeRemainingColumn(), 1121 console=console, 1122 refresh_per_second=10 # Higher refresh for async 1123 ) as progress: 1124 # Main task for overall progress 1125 main_task = progress.add_task( 1126 f"[cyan]Compressing {total_entries:,} trajectories", 1127 total=total_entries 1128 ) 1129 1130 # Status line task 1131 status_task = progress.add_task( 1132 "[dim]Starting...[/dim]", 1133 total=None 1134 ) 1135 1136 # Create all tasks 1137 tasks = [ 1138 process_single(file_path, entry_idx, entry, progress, main_task, status_task) 1139 for file_path, entry_idx, entry in all_entries 1140 ] 1141 1142 # Run all tasks concurrently (semaphore limits actual concurrency) 1143 await asyncio.gather(*tasks) 1144 1145 # Remove status task 1146 progress.remove_task(status_task) 1147 1148 # Write results to output files (preserving original order) 1149 console.print("\n[dim]Writing output files...[/dim]") 1150 output_dir.mkdir(parents=True, exist_ok=True) 1151 1152 for file_path in jsonl_files: 1153 output_path = output_dir / file_path.name 1154 file_results = results[file_path] 1155 1156 # Sort by original entry index to preserve order, skip None (timed out) entries 1157 sorted_entries = [ 1158 file_results[idx][0] 1159 for idx in sorted(file_results.keys()) 1160 if file_results[idx] is not None 1161 ] 1162 1163 with open(output_path, 'w', encoding='utf-8') as f: 1164 for entry in sorted_entries: 1165 f.write(json.dumps(entry, ensure_ascii=False) + '\n') 1166 1167 # Record end time 1168 self.aggregate_metrics.processing_end_time = datetime.now().isoformat() 1169 self.aggregate_metrics.processing_duration_seconds = time.time() - start_time 1170 1171 # Print summary 1172 self._print_summary() 1173 1174 # Save metrics 1175 if self.config.metrics_enabled: 1176 metrics_path = output_dir / self.config.metrics_output_file 1177 with open(metrics_path, 'w') as f: 1178 json.dump(self.aggregate_metrics.to_dict(), f, indent=2) 1179 console.print(f"\nš¾ Metrics saved to {metrics_path}") 1180 1181 def _print_summary(self): 1182 """Print comprehensive compression summary statistics.""" 1183 m = self.aggregate_metrics.to_dict() 1184 1185 # Calculate some additional stats 1186 total = m['summary']['total_trajectories'] 1187 compressed = m['summary']['trajectories_compressed'] 1188 skipped = m['summary']['trajectories_skipped_under_target'] 1189 over_limit = m['summary']['trajectories_still_over_limit'] 1190 failed = m['summary']['trajectories_failed'] 1191 1192 # Token stats 1193 tokens_before = m['tokens']['total_before'] 1194 tokens_after = m['tokens']['total_after'] 1195 tokens_saved = m['tokens']['total_saved'] 1196 1197 # Calculate percentages 1198 compressed_pct = (compressed / max(total, 1)) * 100 1199 skipped_pct = (skipped / max(total, 1)) * 100 1200 over_limit_pct = (over_limit / max(total, 1)) * 100 1201 1202 print(f"\n") 1203 print(f"ā{'ā'*70}ā") 1204 print(f"ā{'TRAJECTORY COMPRESSION REPORT':^70}ā") 1205 print(f"ā {'ā'*70}ā£") 1206 1207 # Trajectories section 1208 print(f"ā{'':2}š TRAJECTORIES{' '*54}ā") 1209 print(f"ā{'ā'*70}ā") 1210 print(f"ā{'':4}Total Processed: {total:>10,}{' '*32}ā") 1211 print(f"ā{'':4}āā Compressed: {compressed:>10,} ({compressed_pct:>5.1f}%){' '*18}ā") 1212 print(f"ā{'':4}āā Skipped (under limit):{skipped:>9,} ({skipped_pct:>5.1f}%){' '*18}ā") 1213 print(f"ā{'':4}āā Still over limit: {over_limit:>10,} ({over_limit_pct:>5.1f}%){' '*18}ā") 1214 print(f"ā{'':4}āā Failed: {failed:>10,}{' '*32}ā") 1215 1216 print(f"ā {'ā'*70}ā£") 1217 1218 # Tokens section 1219 print(f"ā{'':2}š¢ TOKENS{' '*60}ā") 1220 print(f"ā{'ā'*70}ā") 1221 print(f"ā{'':4}Before Compression: {tokens_before:>15,} tokens{' '*21}ā") 1222 print(f"ā{'':4}After Compression: {tokens_after:>15,} tokens{' '*21}ā") 1223 print(f"ā{'':4}Total Saved: {tokens_saved:>15,} tokens{' '*21}ā") 1224 print(f"ā{'':4}Overall Compression: {m['tokens']['overall_compression_ratio']:>14.1%}{' '*28}ā") 1225 1226 if tokens_before > 0: 1227 savings_pct = (tokens_saved / tokens_before) * 100 1228 print(f"ā{'':4}Space Savings: {savings_pct:>14.1f}%{' '*28}ā") 1229 1230 print(f"ā {'ā'*70}ā£") 1231 1232 # Turns section 1233 print(f"ā{'':2}š¬ CONVERSATION TURNS{' '*48}ā") 1234 print(f"ā{'ā'*70}ā") 1235 print(f"ā{'':4}Before Compression: {m['turns']['total_before']:>15,} turns{' '*22}ā") 1236 print(f"ā{'':4}After Compression: {m['turns']['total_after']:>15,} turns{' '*22}ā") 1237 print(f"ā{'':4}Total Removed: {m['turns']['total_removed']:>15,} turns{' '*22}ā") 1238 1239 print(f"ā {'ā'*70}ā£") 1240 1241 # Averages section (for compressed trajectories only) 1242 print(f"ā{'':2}š AVERAGES (Compressed Trajectories Only){' '*27}ā") 1243 print(f"ā{'ā'*70}ā") 1244 if compressed > 0: 1245 print(f"ā{'':4}Avg Compression Ratio: {m['averages']['avg_compression_ratio']:>14.1%}{' '*28}ā") 1246 print(f"ā{'':4}Avg Tokens Saved: {m['averages']['avg_tokens_saved_per_compressed']:>14,.0f}{' '*28}ā") 1247 print(f"ā{'':4}Avg Turns Removed: {m['averages']['avg_turns_removed_per_compressed']:>14.1f}{' '*28}ā") 1248 else: 1249 print(f"ā{'':4}No trajectories were compressed{' '*38}ā") 1250 1251 print(f"ā {'ā'*70}ā£") 1252 1253 # Summarization API section 1254 print(f"ā{'':2}š¤ SUMMARIZATION API{' '*49}ā") 1255 print(f"ā{'ā'*70}ā") 1256 print(f"ā{'':4}API Calls Made: {m['summarization']['total_api_calls']:>15,}{' '*27}ā") 1257 print(f"ā{'':4}Errors: {m['summarization']['total_errors']:>15,}{' '*27}ā") 1258 print(f"ā{'':4}Success Rate: {m['summarization']['success_rate']:>14.1%}{' '*28}ā") 1259 1260 print(f"ā {'ā'*70}ā£") 1261 1262 # Processing time section 1263 duration = m['processing']['duration_seconds'] 1264 if duration > 60: 1265 time_str = f"{duration/60:.1f} minutes" 1266 else: 1267 time_str = f"{duration:.1f} seconds" 1268 1269 throughput = total / max(duration, 0.001) 1270 1271 print(f"ā{'':2}ā±ļø PROCESSING TIME{' '*51}ā") 1272 print(f"ā{'ā'*70}ā") 1273 print(f"ā{'':4}Duration: {time_str:>20}{' '*22}ā") 1274 print(f"ā{'':4}Throughput: {throughput:>15.1f} traj/sec{' '*18}ā") 1275 print(f"ā{'':4}Started: {m['processing']['start_time'][:19]:>20}{' '*22}ā") 1276 print(f"ā{'':4}Finished: {m['processing']['end_time'][:19]:>20}{' '*22}ā") 1277 1278 print(f"ā{'ā'*70}ā") 1279 1280 # Distribution summary if we have data 1281 if self.aggregate_metrics.compression_ratios: 1282 ratios = self.aggregate_metrics.compression_ratios 1283 tokens_saved_list = self.aggregate_metrics.tokens_saved_list 1284 1285 print(f"\nš Distribution Summary:") 1286 print(f" Compression ratios: min={min(ratios):.2%}, max={max(ratios):.2%}, median={sorted(ratios)[len(ratios)//2]:.2%}") 1287 print(f" Tokens saved: min={min(tokens_saved_list):,}, max={max(tokens_saved_list):,}, median={sorted(tokens_saved_list)[len(tokens_saved_list)//2]:,}") 1288 1289 1290 def main( 1291 input: str, 1292 output: str = None, 1293 config: str = "configs/trajectory_compression.yaml", 1294 target_max_tokens: int = None, 1295 tokenizer: str = None, 1296 sample_percent: float = None, 1297 seed: int = 42, 1298 dry_run: bool = False, 1299 ): 1300 """ 1301 Compress agent trajectories to fit within a target token budget. 1302 1303 Supports both single JSONL files and directories containing multiple JSONL files. 1304 Optionally sample a percentage of trajectories before compression. 1305 1306 Args: 1307 input: Path to JSONL file or directory containing JSONL files 1308 output: Output path (file for file input, directory for dir input) 1309 Default: adds "_compressed" suffix to input name 1310 config: Path to YAML configuration file 1311 target_max_tokens: Override target token count from config 1312 tokenizer: Override tokenizer name from config 1313 sample_percent: Sample this percentage of trajectories (1-100) before compression 1314 seed: Random seed for sampling reproducibility (default: 42) 1315 dry_run: Analyze without compressing (just show what would happen) 1316 1317 Examples: 1318 # Compress a directory (original behavior) 1319 python trajectory_compressor.py --input=data/my_run 1320 1321 # Compress a single file 1322 python trajectory_compressor.py --input=data/trajectories.jsonl 1323 1324 # Compress 15% sample of a file 1325 python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15 1326 1327 # Compress 10% sample with custom output 1328 python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=10 --output=data/sampled_compressed.jsonl 1329 """ 1330 import random 1331 import tempfile 1332 import shutil 1333 1334 print("šļø Trajectory Compressor") 1335 print("=" * 60) 1336 1337 # Load configuration 1338 config_path = Path(config) 1339 if config_path.exists(): 1340 print(f"š Loading config from {config}") 1341 compression_config = CompressionConfig.from_yaml(config) 1342 else: 1343 print(f"ā ļø Config not found at {config}, using defaults") 1344 compression_config = CompressionConfig() 1345 1346 # Apply CLI overrides 1347 if target_max_tokens: 1348 compression_config.target_max_tokens = target_max_tokens 1349 if tokenizer: 1350 compression_config.tokenizer_name = tokenizer 1351 1352 # Validate sample_percent 1353 if sample_percent is not None: 1354 if sample_percent <= 0 or sample_percent > 100: 1355 print(f"ā sample_percent must be between 1 and 100, got {sample_percent}") 1356 return 1357 print(f"š² Will sample {sample_percent}% of trajectories (seed={seed})") 1358 1359 # Setup paths and determine input type 1360 input_path = Path(input) 1361 if not input_path.exists(): 1362 print(f"ā Input not found: {input}") 1363 return 1364 1365 is_file_input = input_path.is_file() 1366 1367 if is_file_input: 1368 print(f"š Input mode: Single JSONL file") 1369 1370 # For file input, default output is file with _compressed suffix 1371 if output: 1372 output_path = Path(output) 1373 else: 1374 output_path = input_path.parent / (input_path.stem + compression_config.output_suffix + ".jsonl") 1375 1376 # Load entries from the single file 1377 entries = [] 1378 with open(input_path, 'r', encoding='utf-8') as f: 1379 for line_num, line in enumerate(f, 1): 1380 line = line.strip() 1381 if line: 1382 try: 1383 entries.append(json.loads(line)) 1384 except json.JSONDecodeError as e: 1385 print(f"ā ļø Skipping invalid JSON at line {line_num}: {e}") 1386 1387 total_entries = len(entries) 1388 print(f" Loaded {total_entries:,} trajectories from {input_path.name}") 1389 1390 # Sample if requested 1391 if sample_percent is not None: 1392 random.seed(seed) 1393 sample_size = max(1, int(total_entries * sample_percent / 100)) 1394 entries = random.sample(entries, sample_size) 1395 print(f" Sampled {len(entries):,} trajectories ({sample_percent}% of {total_entries:,})") 1396 1397 if dry_run: 1398 print(f"\nš DRY RUN MODE - analyzing without writing") 1399 print(f"š Would process: {len(entries):,} trajectories") 1400 print(f"š Would output to: {output_path}") 1401 return 1402 1403 # Create a temporary directory for processing 1404 with tempfile.TemporaryDirectory() as temp_dir: 1405 temp_input_dir = Path(temp_dir) / "input" 1406 temp_output_dir = Path(temp_dir) / "output" 1407 temp_input_dir.mkdir() 1408 1409 # Write entries to temp file 1410 temp_input_file = temp_input_dir / "trajectories.jsonl" 1411 with open(temp_input_file, 'w', encoding='utf-8') as f: 1412 for entry in entries: 1413 f.write(json.dumps(entry, ensure_ascii=False) + '\n') 1414 1415 # Initialize compressor and process 1416 compressor = TrajectoryCompressor(compression_config) 1417 compressor.process_directory(temp_input_dir, temp_output_dir) 1418 1419 # Copy result to output path (merge all files in temp_output_dir) 1420 output_path.parent.mkdir(parents=True, exist_ok=True) 1421 with open(output_path, 'w', encoding='utf-8') as out_f: 1422 for jsonl_file in sorted(temp_output_dir.glob("*.jsonl")): 1423 with open(jsonl_file, 'r', encoding='utf-8') as in_f: 1424 for line in in_f: 1425 out_f.write(line) 1426 1427 # Copy metrics file if it exists 1428 metrics_file = temp_output_dir / compression_config.metrics_output_file 1429 if metrics_file.exists(): 1430 metrics_output = output_path.parent / (output_path.stem + "_metrics.json") 1431 shutil.copy(metrics_file, metrics_output) 1432 print(f"š¾ Metrics saved to {metrics_output}") 1433 1434 print(f"\nā Compression complete!") 1435 print(f"š Output: {output_path}") 1436 1437 else: 1438 # Directory input - original behavior 1439 print(f"š Input mode: Directory of JSONL files") 1440 1441 if output: 1442 output_path = Path(output) 1443 else: 1444 output_path = input_path.parent / (input_path.name + compression_config.output_suffix) 1445 1446 # If sampling is requested for directory mode, we need to handle it differently 1447 if sample_percent is not None: 1448 print(f"\nā ļø Sampling from directory: will sample {sample_percent}% from each file") 1449 1450 # Create a temp directory with sampled files 1451 with tempfile.TemporaryDirectory() as temp_dir: 1452 temp_input_dir = Path(temp_dir) / "input" 1453 temp_input_dir.mkdir() 1454 1455 random.seed(seed) 1456 total_original = 0 1457 total_sampled = 0 1458 1459 # Sample from each JSONL file 1460 for jsonl_file in sorted(input_path.glob("*.jsonl")): 1461 entries = [] 1462 with open(jsonl_file, 'r', encoding='utf-8') as f: 1463 for line in f: 1464 line = line.strip() 1465 if line: 1466 try: 1467 entries.append(json.loads(line)) 1468 except json.JSONDecodeError: 1469 pass 1470 1471 total_original += len(entries) 1472 sample_size = max(1, int(len(entries) * sample_percent / 100)) 1473 sampled_entries = random.sample(entries, min(sample_size, len(entries))) 1474 total_sampled += len(sampled_entries) 1475 1476 # Write sampled entries 1477 temp_file = temp_input_dir / jsonl_file.name 1478 with open(temp_file, 'w', encoding='utf-8') as f: 1479 for entry in sampled_entries: 1480 f.write(json.dumps(entry, ensure_ascii=False) + '\n') 1481 1482 print(f" Sampled {total_sampled:,} from {total_original:,} total trajectories") 1483 1484 if dry_run: 1485 print(f"\nš DRY RUN MODE - analyzing without writing") 1486 print(f"š Would process: {temp_input_dir}") 1487 print(f"š Would output to: {output_path}") 1488 return 1489 1490 # Initialize compressor and process the sampled data 1491 compressor = TrajectoryCompressor(compression_config) 1492 compressor.process_directory(temp_input_dir, output_path) 1493 else: 1494 if dry_run: 1495 print(f"\nš DRY RUN MODE - analyzing without writing") 1496 print(f"š Would process: {input_path}") 1497 print(f"š Would output to: {output_path}") 1498 return 1499 1500 # Initialize compressor and process directly 1501 compressor = TrajectoryCompressor(compression_config) 1502 compressor.process_directory(input_path, output_path) 1503 1504 print("\nā Compression complete!") 1505 1506 1507 if __name__ == "__main__": 1508 fire.Fire(main)