/ scripts / sample_and_compress.py
sample_and_compress.py
  1  #!/usr/bin/env python3
  2  """
  3  Sample and Compress HuggingFace Datasets
  4  
  5  Downloads trajectories from multiple HuggingFace datasets, randomly samples them,
  6  and runs trajectory compression to fit within a target token budget.
  7  
  8  Usage:
  9      python scripts/sample_and_compress.py
 10      
 11      # Custom sample size
 12      python scripts/sample_and_compress.py --total_samples=5000
 13      
 14      # Custom output name
 15      python scripts/sample_and_compress.py --output_name=compressed_16k
 16  """
 17  
 18  import json
 19  import random
 20  from pathlib import Path
 21  from typing import List, Dict, Any, Tuple
 22  import fire
 23  
 24  # Load environment variables
 25  from dotenv import load_dotenv
 26  load_dotenv()
 27  
 28  
 29  # Default datasets to sample from
 30  DEFAULT_DATASETS = [
 31      "NousResearch/swe-terminus-agent-glm-kimi-minimax",
 32      "NousResearch/hermes-agent-megascience-sft1",
 33      "NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT2",
 34      "NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT1",
 35      "NousResearch/terminal-tasks-glm-hermes-agent"
 36  ]
 37  
 38  
 39  def load_dataset_from_hf(dataset_name: str) -> List[Dict[str, Any]]:
 40      """
 41      Load a dataset from HuggingFace.
 42      
 43      Args:
 44          dataset_name: HuggingFace dataset name (e.g., "NousResearch/dataset-name")
 45          
 46      Returns:
 47          List of trajectory entries
 48      """
 49      from datasets import load_dataset
 50      
 51      print(f"   Loading {dataset_name}...")
 52      
 53      try:
 54          # Try loading with default config
 55          ds = load_dataset(dataset_name, split="train")
 56      except Exception as e:
 57          print(f"   āš ļø  Error loading {dataset_name}: {e}")
 58          return []
 59      
 60      # Convert to list of dicts
 61      entries = []
 62      for item in ds:
 63          # Handle different possible formats
 64          if "conversations" in item:
 65              entries.append({"conversations": item["conversations"]})
 66          elif "messages" in item:
 67              # Convert messages format to conversations format if needed
 68              entries.append({"conversations": item["messages"]})
 69          else:
 70              # Assume the whole item is the entry
 71              entries.append(dict(item))
 72      
 73      print(f"   āœ… Loaded {len(entries):,} entries from {dataset_name}")
 74      return entries
 75  
 76  
 77  # Global tokenizer for multiprocessing (set in worker init)
 78  _TOKENIZER = None
 79  
 80  
 81  def _init_tokenizer_worker(tokenizer_name: str):
 82      """Initialize tokenizer in worker process."""
 83      global _TOKENIZER
 84      from transformers import AutoTokenizer
 85      _TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
 86  
 87  
 88  def _count_tokens_for_entry(entry: Dict) -> Tuple[Dict, int]:
 89      """
 90      Count tokens for a single entry (used in parallel processing).
 91      
 92      Args:
 93          entry: Trajectory entry with 'conversations' field
 94          
 95      Returns:
 96          Tuple of (entry, token_count)
 97      """
 98      global _TOKENIZER
 99      
100      conversations = entry.get("conversations", [])
101      if not conversations:
102          return entry, 0
103      
104      total = 0
105      for turn in conversations:
106          value = turn.get("value", "")
107          if value:
108              try:
109                  total += len(_TOKENIZER.encode(value))
110              except Exception:
111                  # Fallback to character estimate
112                  total += len(value) // 4
113      
114      return entry, total
115  
116  
117  def sample_from_datasets(
118      datasets: List[str],
119      total_samples: int,
120      min_tokens: int = 16000,
121      tokenizer_name: str = "moonshotai/Kimi-K2-Thinking",
122      seed: int = 42,
123      num_proc: int = 8
124  ) -> List[Dict[str, Any]]:
125      """
126      Load all datasets, filter by token count, then randomly sample from combined pool.
127      
128      Args:
129          datasets: List of HuggingFace dataset names
130          total_samples: Total number of samples to collect
131          min_tokens: Minimum token count to include (only sample trajectories >= this)
132          tokenizer_name: HuggingFace tokenizer for counting tokens
133          seed: Random seed for reproducibility
134          num_proc: Number of parallel processes for tokenization
135          
136      Returns:
137          List of sampled trajectory entries
138      """
139      from multiprocessing import Pool
140      
141      random.seed(seed)
142      
143      print(f"\nšŸ“„ Loading {len(datasets)} datasets...")
144      print(f"   Minimum tokens: {min_tokens:,} (filtering smaller trajectories)")
145      print(f"   Parallel workers: {num_proc}")
146      print()
147      
148      # Load ALL entries from all datasets into one pool
149      all_entries = []
150      
151      for dataset_name in datasets:
152          entries = load_dataset_from_hf(dataset_name)
153          
154          if not entries:
155              print(f"   āš ļø  Skipping {dataset_name} (no entries loaded)")
156              continue
157          
158          # Add source metadata to each entry
159          for entry in entries:
160              entry["_source_dataset"] = dataset_name
161          
162          all_entries.extend(entries)
163      
164      print(f"\nšŸ“Š Total entries loaded: {len(all_entries):,}")
165      
166      # Filter by token count using parallel processing
167      print(f"\nšŸ” Filtering trajectories with >= {min_tokens:,} tokens (using {num_proc} workers)...")
168      
169      filtered_entries = []
170      token_counts = []
171      
172      # Use multiprocessing for token counting
173      with Pool(
174          processes=num_proc,
175          initializer=_init_tokenizer_worker,
176          initargs=(tokenizer_name,)
177      ) as pool:
178          # Process in chunks and show progress
179          chunk_size = 1000
180          processed = 0
181          
182          for result in pool.imap_unordered(_count_tokens_for_entry, all_entries, chunksize=100):
183              entry, token_count = result
184              processed += 1
185              
186              if processed % chunk_size == 0:
187                  print(f"   Processed {processed:,}/{len(all_entries):,}...", end="\r")
188              
189              if token_count >= min_tokens:
190                  entry["_original_tokens"] = token_count
191                  filtered_entries.append(entry)
192                  token_counts.append(token_count)
193      
194      print(f"\n   āœ… Found {len(filtered_entries):,} trajectories >= {min_tokens:,} tokens")
195      
196      if token_counts:
197          avg_tokens = sum(token_counts) / len(token_counts)
198          print(f"   šŸ“ˆ Token stats: min={min(token_counts):,}, max={max(token_counts):,}, avg={avg_tokens:,.0f}")
199      
200      # Random sample from the filtered pool
201      if len(filtered_entries) <= total_samples:
202          print(f"\nāš ļø  Only {len(filtered_entries):,} trajectories available, using all of them")
203          sampled = filtered_entries
204      else:
205          sampled = random.sample(filtered_entries, total_samples)
206          print(f"\nāœ… Randomly sampled {len(sampled):,} trajectories from pool of {len(filtered_entries):,}")
207      
208      # Show source distribution
209      source_counts = {}
210      for entry in sampled:
211          source = entry.get("_source_dataset", "unknown").split("/")[-1]
212          source_counts[source] = source_counts.get(source, 0) + 1
213      
214      print(f"\nšŸ“Œ Sample distribution by source:")
215      for source, count in sorted(source_counts.items()):
216          print(f"      {source}: {count:,}")
217      
218      # Shuffle
219      random.shuffle(sampled)
220      
221      return sampled
222  
223  
224  def save_samples_for_compression(
225      samples: List[Dict[str, Any]],
226      output_dir: Path,
227      batch_size: int = 100
228  ):
229      """
230      Save samples to JSONL files for trajectory compression.
231      
232      Args:
233          samples: List of trajectory entries
234          output_dir: Directory to save JSONL files
235          batch_size: Number of entries per file
236      """
237      output_dir.mkdir(parents=True, exist_ok=True)
238      
239      # Split into batches
240      num_batches = (len(samples) + batch_size - 1) // batch_size
241      
242      print(f"\nšŸ’¾ Saving {len(samples)} samples to {output_dir}")
243      print(f"   Batch size: {batch_size}, Total batches: {num_batches}")
244      
245      for i in range(num_batches):
246          start_idx = i * batch_size
247          end_idx = min((i + 1) * batch_size, len(samples))
248          batch = samples[start_idx:end_idx]
249          
250          output_file = output_dir / f"batch_{i}.jsonl"
251          with open(output_file, 'w', encoding='utf-8') as f:
252              for entry in batch:
253                  f.write(json.dumps(entry, ensure_ascii=False) + '\n')
254      
255      print(f"   āœ… Saved {num_batches} batch files")
256  
257  
258  def run_compression(input_dir: Path, output_dir: Path, config_path: str):
259      """
260      Run trajectory compression on the sampled data.
261      
262      Args:
263          input_dir: Directory containing JSONL files to compress
264          output_dir: Directory for compressed output
265          config_path: Path to compression config YAML
266      """
267      # Import the compressor
268      import sys
269      sys.path.insert(0, str(Path(__file__).parent.parent))
270      from trajectory_compressor import TrajectoryCompressor, CompressionConfig
271      
272      print(f"\nšŸ—œļø  Running trajectory compression...")
273      print(f"   Input: {input_dir}")
274      print(f"   Output: {output_dir}")
275      print(f"   Config: {config_path}")
276      
277      # Load config
278      config = CompressionConfig.from_yaml(config_path)
279      
280      # Initialize compressor
281      compressor = TrajectoryCompressor(config)
282      
283      # Run compression
284      compressor.process_directory(input_dir, output_dir)
285  
286  
287  def merge_output_to_single_jsonl(input_dir: Path, output_file: Path):
288      """
289      Merge all JSONL files in a directory into a single JSONL file.
290      
291      Args:
292          input_dir: Directory containing JSONL files
293          output_file: Output JSONL file path
294      """
295      print(f"\nšŸ“¦ Merging output files into {output_file.name}...")
296      
297      all_entries = []
298      for jsonl_file in sorted(input_dir.glob("*.jsonl")):
299          if jsonl_file.name == output_file.name:
300              continue
301          with open(jsonl_file, 'r', encoding='utf-8') as f:
302              for line in f:
303                  line = line.strip()
304                  if line:
305                      all_entries.append(json.loads(line))
306      
307      # Write merged file
308      with open(output_file, 'w', encoding='utf-8') as f:
309          for entry in all_entries:
310              f.write(json.dumps(entry, ensure_ascii=False) + '\n')
311      
312      print(f"   āœ… Merged {len(all_entries):,} entries into {output_file.name}")
313      return output_file
314  
315  
316  def main(
317      total_samples: int = 2500,
318      output_name: str = "compressed_agentic",
319      datasets: str = None,
320      config: str = "configs/trajectory_compression.yaml",
321      seed: int = 42,
322      batch_size: int = 100,
323      min_tokens: int = 16000,
324      num_proc: int = 8,
325      skip_download: bool = False,
326  ):
327      """
328      Sample trajectories from HuggingFace datasets and run compression.
329      
330      Args:
331          total_samples: Total number of samples to collect (default: 2500)
332          output_name: Name for output directory/file (default: "compressed_agentic")
333          datasets: Comma-separated list of dataset names (uses defaults if not provided)
334          config: Path to compression config YAML
335          seed: Random seed for reproducibility
336          batch_size: Number of entries per JSONL file during processing
337          min_tokens: Minimum token count to filter trajectories (default: 16000)
338          num_proc: Number of parallel workers for tokenization (default: 8)
339          skip_download: Skip download and use existing sampled data
340      """
341      print("=" * 70)
342      print("šŸ“Š TRAJECTORY SAMPLING AND COMPRESSION")
343      print("=" * 70)
344      
345      # Parse datasets
346      if datasets:
347          dataset_list = [d.strip() for d in datasets.split(",")]
348      else:
349          dataset_list = DEFAULT_DATASETS
350      
351      print(f"\nšŸ“‹ Configuration:")
352      print(f"   Total samples: {total_samples:,}")
353      print(f"   Min tokens filter: {min_tokens:,}")
354      print(f"   Parallel workers: {num_proc}")
355      print(f"   Datasets: {len(dataset_list)}")
356      for ds in dataset_list:
357          print(f"      - {ds}")
358      print(f"   Output name: {output_name}")
359      print(f"   Config: {config}")
360      print(f"   Seed: {seed}")
361      
362      # Setup paths
363      base_dir = Path(__file__).parent.parent
364      sampled_dir = base_dir / "data" / f"{output_name}_raw"
365      compressed_dir = base_dir / "data" / f"{output_name}_batches"
366      final_output = base_dir / "data" / f"{output_name}.jsonl"
367      
368      if not skip_download:
369          # Step 1: Download, filter by token count, and sample from combined pool
370          samples = sample_from_datasets(
371              dataset_list, 
372              total_samples, 
373              min_tokens=min_tokens,
374              seed=seed,
375              num_proc=num_proc
376          )
377          
378          if not samples:
379              print("āŒ No samples collected. Exiting.")
380              return
381          
382          # Step 2: Save to JSONL files
383          save_samples_for_compression(samples, sampled_dir, batch_size)
384      else:
385          print(f"\nā­ļø  Skipping download, using existing data in {sampled_dir}")
386      
387      # Step 3: Run compression
388      config_path = base_dir / config
389      if not config_path.exists():
390          print(f"āŒ Config not found: {config_path}")
391          return
392      
393      run_compression(sampled_dir, compressed_dir, str(config_path))
394      
395      # Step 4: Merge into single JSONL file
396      merge_output_to_single_jsonl(compressed_dir, final_output)
397      
398      print("\n" + "=" * 70)
399      print("āœ… COMPLETE!")
400      print("=" * 70)
401      print(f"\nšŸ“ Raw samples:        {sampled_dir}")
402      print(f"šŸ“ Compressed batches: {compressed_dir}")
403      print(f"šŸ“ Final output:       {final_output}")
404      print(f"\nTo upload to HuggingFace:")
405      print(f"   huggingface-cli upload NousResearch/{output_name} {final_output}")
406  
407  
408  if __name__ == "__main__":
409      fire.Fire(main)