sample_and_compress.py
1 #!/usr/bin/env python3 2 """ 3 Sample and Compress HuggingFace Datasets 4 5 Downloads trajectories from multiple HuggingFace datasets, randomly samples them, 6 and runs trajectory compression to fit within a target token budget. 7 8 Usage: 9 python scripts/sample_and_compress.py 10 11 # Custom sample size 12 python scripts/sample_and_compress.py --total_samples=5000 13 14 # Custom output name 15 python scripts/sample_and_compress.py --output_name=compressed_16k 16 """ 17 18 import json 19 import random 20 from pathlib import Path 21 from typing import List, Dict, Any, Tuple 22 import fire 23 24 # Load environment variables 25 from dotenv import load_dotenv 26 load_dotenv() 27 28 29 # Default datasets to sample from 30 DEFAULT_DATASETS = [ 31 "NousResearch/swe-terminus-agent-glm-kimi-minimax", 32 "NousResearch/hermes-agent-megascience-sft1", 33 "NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT2", 34 "NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT1", 35 "NousResearch/terminal-tasks-glm-hermes-agent" 36 ] 37 38 39 def load_dataset_from_hf(dataset_name: str) -> List[Dict[str, Any]]: 40 """ 41 Load a dataset from HuggingFace. 42 43 Args: 44 dataset_name: HuggingFace dataset name (e.g., "NousResearch/dataset-name") 45 46 Returns: 47 List of trajectory entries 48 """ 49 from datasets import load_dataset 50 51 print(f" Loading {dataset_name}...") 52 53 try: 54 # Try loading with default config 55 ds = load_dataset(dataset_name, split="train") 56 except Exception as e: 57 print(f" ā ļø Error loading {dataset_name}: {e}") 58 return [] 59 60 # Convert to list of dicts 61 entries = [] 62 for item in ds: 63 # Handle different possible formats 64 if "conversations" in item: 65 entries.append({"conversations": item["conversations"]}) 66 elif "messages" in item: 67 # Convert messages format to conversations format if needed 68 entries.append({"conversations": item["messages"]}) 69 else: 70 # Assume the whole item is the entry 71 entries.append(dict(item)) 72 73 print(f" ā Loaded {len(entries):,} entries from {dataset_name}") 74 return entries 75 76 77 # Global tokenizer for multiprocessing (set in worker init) 78 _TOKENIZER = None 79 80 81 def _init_tokenizer_worker(tokenizer_name: str): 82 """Initialize tokenizer in worker process.""" 83 global _TOKENIZER 84 from transformers import AutoTokenizer 85 _TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) 86 87 88 def _count_tokens_for_entry(entry: Dict) -> Tuple[Dict, int]: 89 """ 90 Count tokens for a single entry (used in parallel processing). 91 92 Args: 93 entry: Trajectory entry with 'conversations' field 94 95 Returns: 96 Tuple of (entry, token_count) 97 """ 98 global _TOKENIZER 99 100 conversations = entry.get("conversations", []) 101 if not conversations: 102 return entry, 0 103 104 total = 0 105 for turn in conversations: 106 value = turn.get("value", "") 107 if value: 108 try: 109 total += len(_TOKENIZER.encode(value)) 110 except Exception: 111 # Fallback to character estimate 112 total += len(value) // 4 113 114 return entry, total 115 116 117 def sample_from_datasets( 118 datasets: List[str], 119 total_samples: int, 120 min_tokens: int = 16000, 121 tokenizer_name: str = "moonshotai/Kimi-K2-Thinking", 122 seed: int = 42, 123 num_proc: int = 8 124 ) -> List[Dict[str, Any]]: 125 """ 126 Load all datasets, filter by token count, then randomly sample from combined pool. 127 128 Args: 129 datasets: List of HuggingFace dataset names 130 total_samples: Total number of samples to collect 131 min_tokens: Minimum token count to include (only sample trajectories >= this) 132 tokenizer_name: HuggingFace tokenizer for counting tokens 133 seed: Random seed for reproducibility 134 num_proc: Number of parallel processes for tokenization 135 136 Returns: 137 List of sampled trajectory entries 138 """ 139 from multiprocessing import Pool 140 141 random.seed(seed) 142 143 print(f"\nš„ Loading {len(datasets)} datasets...") 144 print(f" Minimum tokens: {min_tokens:,} (filtering smaller trajectories)") 145 print(f" Parallel workers: {num_proc}") 146 print() 147 148 # Load ALL entries from all datasets into one pool 149 all_entries = [] 150 151 for dataset_name in datasets: 152 entries = load_dataset_from_hf(dataset_name) 153 154 if not entries: 155 print(f" ā ļø Skipping {dataset_name} (no entries loaded)") 156 continue 157 158 # Add source metadata to each entry 159 for entry in entries: 160 entry["_source_dataset"] = dataset_name 161 162 all_entries.extend(entries) 163 164 print(f"\nš Total entries loaded: {len(all_entries):,}") 165 166 # Filter by token count using parallel processing 167 print(f"\nš Filtering trajectories with >= {min_tokens:,} tokens (using {num_proc} workers)...") 168 169 filtered_entries = [] 170 token_counts = [] 171 172 # Use multiprocessing for token counting 173 with Pool( 174 processes=num_proc, 175 initializer=_init_tokenizer_worker, 176 initargs=(tokenizer_name,) 177 ) as pool: 178 # Process in chunks and show progress 179 chunk_size = 1000 180 processed = 0 181 182 for result in pool.imap_unordered(_count_tokens_for_entry, all_entries, chunksize=100): 183 entry, token_count = result 184 processed += 1 185 186 if processed % chunk_size == 0: 187 print(f" Processed {processed:,}/{len(all_entries):,}...", end="\r") 188 189 if token_count >= min_tokens: 190 entry["_original_tokens"] = token_count 191 filtered_entries.append(entry) 192 token_counts.append(token_count) 193 194 print(f"\n ā Found {len(filtered_entries):,} trajectories >= {min_tokens:,} tokens") 195 196 if token_counts: 197 avg_tokens = sum(token_counts) / len(token_counts) 198 print(f" š Token stats: min={min(token_counts):,}, max={max(token_counts):,}, avg={avg_tokens:,.0f}") 199 200 # Random sample from the filtered pool 201 if len(filtered_entries) <= total_samples: 202 print(f"\nā ļø Only {len(filtered_entries):,} trajectories available, using all of them") 203 sampled = filtered_entries 204 else: 205 sampled = random.sample(filtered_entries, total_samples) 206 print(f"\nā Randomly sampled {len(sampled):,} trajectories from pool of {len(filtered_entries):,}") 207 208 # Show source distribution 209 source_counts = {} 210 for entry in sampled: 211 source = entry.get("_source_dataset", "unknown").split("/")[-1] 212 source_counts[source] = source_counts.get(source, 0) + 1 213 214 print(f"\nš Sample distribution by source:") 215 for source, count in sorted(source_counts.items()): 216 print(f" {source}: {count:,}") 217 218 # Shuffle 219 random.shuffle(sampled) 220 221 return sampled 222 223 224 def save_samples_for_compression( 225 samples: List[Dict[str, Any]], 226 output_dir: Path, 227 batch_size: int = 100 228 ): 229 """ 230 Save samples to JSONL files for trajectory compression. 231 232 Args: 233 samples: List of trajectory entries 234 output_dir: Directory to save JSONL files 235 batch_size: Number of entries per file 236 """ 237 output_dir.mkdir(parents=True, exist_ok=True) 238 239 # Split into batches 240 num_batches = (len(samples) + batch_size - 1) // batch_size 241 242 print(f"\nš¾ Saving {len(samples)} samples to {output_dir}") 243 print(f" Batch size: {batch_size}, Total batches: {num_batches}") 244 245 for i in range(num_batches): 246 start_idx = i * batch_size 247 end_idx = min((i + 1) * batch_size, len(samples)) 248 batch = samples[start_idx:end_idx] 249 250 output_file = output_dir / f"batch_{i}.jsonl" 251 with open(output_file, 'w', encoding='utf-8') as f: 252 for entry in batch: 253 f.write(json.dumps(entry, ensure_ascii=False) + '\n') 254 255 print(f" ā Saved {num_batches} batch files") 256 257 258 def run_compression(input_dir: Path, output_dir: Path, config_path: str): 259 """ 260 Run trajectory compression on the sampled data. 261 262 Args: 263 input_dir: Directory containing JSONL files to compress 264 output_dir: Directory for compressed output 265 config_path: Path to compression config YAML 266 """ 267 # Import the compressor 268 import sys 269 sys.path.insert(0, str(Path(__file__).parent.parent)) 270 from trajectory_compressor import TrajectoryCompressor, CompressionConfig 271 272 print(f"\nšļø Running trajectory compression...") 273 print(f" Input: {input_dir}") 274 print(f" Output: {output_dir}") 275 print(f" Config: {config_path}") 276 277 # Load config 278 config = CompressionConfig.from_yaml(config_path) 279 280 # Initialize compressor 281 compressor = TrajectoryCompressor(config) 282 283 # Run compression 284 compressor.process_directory(input_dir, output_dir) 285 286 287 def merge_output_to_single_jsonl(input_dir: Path, output_file: Path): 288 """ 289 Merge all JSONL files in a directory into a single JSONL file. 290 291 Args: 292 input_dir: Directory containing JSONL files 293 output_file: Output JSONL file path 294 """ 295 print(f"\nš¦ Merging output files into {output_file.name}...") 296 297 all_entries = [] 298 for jsonl_file in sorted(input_dir.glob("*.jsonl")): 299 if jsonl_file.name == output_file.name: 300 continue 301 with open(jsonl_file, 'r', encoding='utf-8') as f: 302 for line in f: 303 line = line.strip() 304 if line: 305 all_entries.append(json.loads(line)) 306 307 # Write merged file 308 with open(output_file, 'w', encoding='utf-8') as f: 309 for entry in all_entries: 310 f.write(json.dumps(entry, ensure_ascii=False) + '\n') 311 312 print(f" ā Merged {len(all_entries):,} entries into {output_file.name}") 313 return output_file 314 315 316 def main( 317 total_samples: int = 2500, 318 output_name: str = "compressed_agentic", 319 datasets: str = None, 320 config: str = "configs/trajectory_compression.yaml", 321 seed: int = 42, 322 batch_size: int = 100, 323 min_tokens: int = 16000, 324 num_proc: int = 8, 325 skip_download: bool = False, 326 ): 327 """ 328 Sample trajectories from HuggingFace datasets and run compression. 329 330 Args: 331 total_samples: Total number of samples to collect (default: 2500) 332 output_name: Name for output directory/file (default: "compressed_agentic") 333 datasets: Comma-separated list of dataset names (uses defaults if not provided) 334 config: Path to compression config YAML 335 seed: Random seed for reproducibility 336 batch_size: Number of entries per JSONL file during processing 337 min_tokens: Minimum token count to filter trajectories (default: 16000) 338 num_proc: Number of parallel workers for tokenization (default: 8) 339 skip_download: Skip download and use existing sampled data 340 """ 341 print("=" * 70) 342 print("š TRAJECTORY SAMPLING AND COMPRESSION") 343 print("=" * 70) 344 345 # Parse datasets 346 if datasets: 347 dataset_list = [d.strip() for d in datasets.split(",")] 348 else: 349 dataset_list = DEFAULT_DATASETS 350 351 print(f"\nš Configuration:") 352 print(f" Total samples: {total_samples:,}") 353 print(f" Min tokens filter: {min_tokens:,}") 354 print(f" Parallel workers: {num_proc}") 355 print(f" Datasets: {len(dataset_list)}") 356 for ds in dataset_list: 357 print(f" - {ds}") 358 print(f" Output name: {output_name}") 359 print(f" Config: {config}") 360 print(f" Seed: {seed}") 361 362 # Setup paths 363 base_dir = Path(__file__).parent.parent 364 sampled_dir = base_dir / "data" / f"{output_name}_raw" 365 compressed_dir = base_dir / "data" / f"{output_name}_batches" 366 final_output = base_dir / "data" / f"{output_name}.jsonl" 367 368 if not skip_download: 369 # Step 1: Download, filter by token count, and sample from combined pool 370 samples = sample_from_datasets( 371 dataset_list, 372 total_samples, 373 min_tokens=min_tokens, 374 seed=seed, 375 num_proc=num_proc 376 ) 377 378 if not samples: 379 print("ā No samples collected. Exiting.") 380 return 381 382 # Step 2: Save to JSONL files 383 save_samples_for_compression(samples, sampled_dir, batch_size) 384 else: 385 print(f"\nāļø Skipping download, using existing data in {sampled_dir}") 386 387 # Step 3: Run compression 388 config_path = base_dir / config 389 if not config_path.exists(): 390 print(f"ā Config not found: {config_path}") 391 return 392 393 run_compression(sampled_dir, compressed_dir, str(config_path)) 394 395 # Step 4: Merge into single JSONL file 396 merge_output_to_single_jsonl(compressed_dir, final_output) 397 398 print("\n" + "=" * 70) 399 print("ā COMPLETE!") 400 print("=" * 70) 401 print(f"\nš Raw samples: {sampled_dir}") 402 print(f"š Compressed batches: {compressed_dir}") 403 print(f"š Final output: {final_output}") 404 print(f"\nTo upload to HuggingFace:") 405 print(f" huggingface-cli upload NousResearch/{output_name} {final_output}") 406 407 408 if __name__ == "__main__": 409 fire.Fire(main)