/ prepare.py
prepare.py
  1  """
  2  One-time data preparation for autoresearch experiments.
  3  Downloads data shards and trains a BPE tokenizer.
  4  
  5  Usage:
  6      python prepare.py                  # full prep (download + tokenizer)
  7      python prepare.py --num-shards 8   # download only 8 shards (for testing)
  8  
  9  Data and tokenizer are stored in ~/.cache/autoresearch/.
 10  """
 11  
 12  import os
 13  import sys
 14  import time
 15  import math
 16  import argparse
 17  import pickle
 18  from multiprocessing import Pool
 19  
 20  import requests
 21  import pyarrow.parquet as pq
 22  import rustbpe
 23  import tiktoken
 24  import torch
 25  
 26  # ---------------------------------------------------------------------------
 27  # Constants (fixed, do not modify)
 28  # ---------------------------------------------------------------------------
 29  
 30  MAX_SEQ_LEN = 2048       # context length
 31  TIME_BUDGET = 300        # training time budget in seconds (5 minutes)
 32  EVAL_TOKENS = 40 * 524288  # number of tokens for val eval
 33  
 34  # ---------------------------------------------------------------------------
 35  # Configuration
 36  # ---------------------------------------------------------------------------
 37  
 38  CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
 39  DATA_DIR = os.path.join(CACHE_DIR, "data")
 40  TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
 41  BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
 42  MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
 43  VAL_SHARD = MAX_SHARD  # pinned validation shard (shard_06542)
 44  VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
 45  VOCAB_SIZE = 8192
 46  
 47  # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
 48  SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
 49  
 50  SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
 51  BOS_TOKEN = "<|reserved_0|>"
 52  
 53  # ---------------------------------------------------------------------------
 54  # Data download
 55  # ---------------------------------------------------------------------------
 56  
 57  def download_single_shard(index):
 58      """Download one parquet shard with retries. Returns True on success."""
 59      filename = f"shard_{index:05d}.parquet"
 60      filepath = os.path.join(DATA_DIR, filename)
 61      if os.path.exists(filepath):
 62          return True
 63  
 64      url = f"{BASE_URL}/{filename}"
 65      max_attempts = 5
 66      for attempt in range(1, max_attempts + 1):
 67          try:
 68              response = requests.get(url, stream=True, timeout=30)
 69              response.raise_for_status()
 70              temp_path = filepath + ".tmp"
 71              with open(temp_path, "wb") as f:
 72                  for chunk in response.iter_content(chunk_size=1024 * 1024):
 73                      if chunk:
 74                          f.write(chunk)
 75              os.rename(temp_path, filepath)
 76              print(f"  Downloaded {filename}")
 77              return True
 78          except (requests.RequestException, IOError) as e:
 79              print(f"  Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
 80              for path in [filepath + ".tmp", filepath]:
 81                  if os.path.exists(path):
 82                      try:
 83                          os.remove(path)
 84                      except OSError:
 85                          pass
 86              if attempt < max_attempts:
 87                  time.sleep(2 ** attempt)
 88      return False
 89  
 90  
 91  def download_data(num_shards, download_workers=8):
 92      """Download training shards + pinned validation shard."""
 93      os.makedirs(DATA_DIR, exist_ok=True)
 94      num_train = min(num_shards, MAX_SHARD)
 95      ids = list(range(num_train))
 96      if VAL_SHARD not in ids:
 97          ids.append(VAL_SHARD)
 98  
 99      # Count what's already downloaded
100      existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
101      if existing == len(ids):
102          print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
103          return
104  
105      needed = len(ids) - existing
106      print(f"Data: downloading {needed} shards ({existing} already exist)...")
107  
108      workers = max(1, min(download_workers, needed))
109      with Pool(processes=workers) as pool:
110          results = pool.map(download_single_shard, ids)
111  
112      ok = sum(1 for r in results if r)
113      print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
114  
115  # ---------------------------------------------------------------------------
116  # Tokenizer training
117  # ---------------------------------------------------------------------------
118  
119  def list_parquet_files():
120      """Return sorted list of parquet file paths in the data directory."""
121      files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
122      return [os.path.join(DATA_DIR, f) for f in files]
123  
124  
125  def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
126      """Yield documents from training split (all shards except pinned val shard)."""
127      parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
128      nchars = 0
129      for filepath in parquet_paths:
130          pf = pq.ParquetFile(filepath)
131          for rg_idx in range(pf.num_row_groups):
132              rg = pf.read_row_group(rg_idx)
133              for text in rg.column("text").to_pylist():
134                  doc = text[:doc_cap] if len(text) > doc_cap else text
135                  nchars += len(doc)
136                  yield doc
137                  if nchars >= max_chars:
138                      return
139  
140  
141  def train_tokenizer():
142      """Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
143      tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
144      token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
145  
146      if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
147          print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
148          return
149  
150      os.makedirs(TOKENIZER_DIR, exist_ok=True)
151  
152      parquet_files = list_parquet_files()
153      if len(parquet_files) < 2:
154          print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
155          sys.exit(1)
156  
157      # --- Train with rustbpe ---
158      print("Tokenizer: training BPE tokenizer...")
159      t0 = time.time()
160  
161      tokenizer = rustbpe.Tokenizer()
162      vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
163      tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
164  
165      # Build tiktoken encoding from trained merges
166      pattern = tokenizer.get_pattern()
167      mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
168      tokens_offset = len(mergeable_ranks)
169      special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
170      enc = tiktoken.Encoding(
171          name="rustbpe",
172          pat_str=pattern,
173          mergeable_ranks=mergeable_ranks,
174          special_tokens=special_tokens,
175      )
176  
177      # Save tokenizer
178      with open(tokenizer_pkl, "wb") as f:
179          pickle.dump(enc, f)
180  
181      t1 = time.time()
182      print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
183  
184      # --- Build token_bytes lookup for BPB evaluation ---
185      print("Tokenizer: building token_bytes lookup...")
186      special_set = set(SPECIAL_TOKENS)
187      token_bytes_list = []
188      for token_id in range(enc.n_vocab):
189          token_str = enc.decode([token_id])
190          if token_str in special_set:
191              token_bytes_list.append(0)
192          else:
193              token_bytes_list.append(len(token_str.encode("utf-8")))
194      token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
195      torch.save(token_bytes_tensor, token_bytes_path)
196      print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
197  
198      # Sanity check
199      test = "Hello world! Numbers: 123. Unicode: 你好"
200      encoded = enc.encode_ordinary(test)
201      decoded = enc.decode(encoded)
202      assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
203      print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
204  
205  # ---------------------------------------------------------------------------
206  # Runtime utilities (imported by train.py)
207  # ---------------------------------------------------------------------------
208  
209  class Tokenizer:
210      """Minimal tokenizer wrapper. Training is handled above."""
211  
212      def __init__(self, enc):
213          self.enc = enc
214          self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
215  
216      @classmethod
217      def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
218          with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
219              enc = pickle.load(f)
220          return cls(enc)
221  
222      def get_vocab_size(self):
223          return self.enc.n_vocab
224  
225      def get_bos_token_id(self):
226          return self.bos_token_id
227  
228      def encode(self, text, prepend=None, num_threads=8):
229          if prepend is not None:
230              prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
231          if isinstance(text, str):
232              ids = self.enc.encode_ordinary(text)
233              if prepend is not None:
234                  ids.insert(0, prepend_id)
235          elif isinstance(text, list):
236              ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
237              if prepend is not None:
238                  for row in ids:
239                      row.insert(0, prepend_id)
240          else:
241              raise ValueError(f"Invalid input type: {type(text)}")
242          return ids
243  
244      def decode(self, ids):
245          return self.enc.decode(ids)
246  
247  
248  def get_token_bytes(device="cpu"):
249      path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
250      with open(path, "rb") as f:
251          return torch.load(f, map_location=device)
252  
253  
254  def _document_batches(split, tokenizer_batch_size=128):
255      """Infinite iterator over document batches from parquet files."""
256      parquet_paths = list_parquet_files()
257      assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
258      val_path = os.path.join(DATA_DIR, VAL_FILENAME)
259      if split == "train":
260          parquet_paths = [p for p in parquet_paths if p != val_path]
261      else:
262          parquet_paths = [val_path]
263      epoch = 1
264      while True:
265          for filepath in parquet_paths:
266              pf = pq.ParquetFile(filepath)
267              for rg_idx in range(pf.num_row_groups):
268                  rg = pf.read_row_group(rg_idx)
269                  batch = rg.column('text').to_pylist()
270                  for i in range(0, len(batch), tokenizer_batch_size):
271                      yield batch[i:i+tokenizer_batch_size], epoch
272          epoch += 1
273  
274  
275  def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
276      """
277      BOS-aligned dataloader with best-fit packing.
278      Every row starts with BOS. Documents packed using best-fit to minimize cropping.
279      When no document fits remaining space, crops shortest doc to fill exactly.
280      100% utilization (no padding).
281      """
282      assert split in ["train", "val"]
283      row_capacity = T + 1
284      batches = _document_batches(split)
285      bos_token = tokenizer.get_bos_token_id()
286      doc_buffer = []
287      epoch = 1
288  
289      def refill_buffer():
290          nonlocal epoch
291          doc_batch, epoch = next(batches)
292          token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
293          doc_buffer.extend(token_lists)
294  
295      # Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
296      row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
297      cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
298      gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
299      cpu_inputs = cpu_buffer[:B * T].view(B, T)
300      cpu_targets = cpu_buffer[B * T:].view(B, T)
301      inputs = gpu_buffer[:B * T].view(B, T)
302      targets = gpu_buffer[B * T:].view(B, T)
303  
304      while True:
305          for row_idx in range(B):
306              pos = 0
307              while pos < row_capacity:
308                  while len(doc_buffer) < buffer_size:
309                      refill_buffer()
310  
311                  remaining = row_capacity - pos
312  
313                  # Find largest doc that fits entirely
314                  best_idx = -1
315                  best_len = 0
316                  for i, doc in enumerate(doc_buffer):
317                      doc_len = len(doc)
318                      if doc_len <= remaining and doc_len > best_len:
319                          best_idx = i
320                          best_len = doc_len
321  
322                  if best_idx >= 0:
323                      doc = doc_buffer.pop(best_idx)
324                      row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
325                      pos += len(doc)
326                  else:
327                      # No doc fits — crop shortest to fill remaining
328                      shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
329                      doc = doc_buffer.pop(shortest_idx)
330                      row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
331                      pos += remaining
332  
333          cpu_inputs.copy_(row_buffer[:, :-1])
334          cpu_targets.copy_(row_buffer[:, 1:])
335          gpu_buffer.copy_(cpu_buffer, non_blocking=True)
336          yield inputs, targets, epoch
337  
338  # ---------------------------------------------------------------------------
339  # Evaluation (DO NOT CHANGE — this is the fixed metric)
340  # ---------------------------------------------------------------------------
341  
342  @torch.no_grad()
343  def evaluate_bpb(model, tokenizer, batch_size):
344      """
345      Bits per byte (BPB): vocab size-independent evaluation metric.
346      Sums per-token cross-entropy (in nats), sums target byte lengths,
347      then converts nats/byte to bits/byte. Special tokens (byte length 0)
348      are excluded from both sums.
349      Uses fixed MAX_SEQ_LEN so results are comparable across configs.
350      """
351      token_bytes = get_token_bytes(device="cuda")
352      val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
353      steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
354      total_nats = 0.0
355      total_bytes = 0
356      for _ in range(steps):
357          x, y, _ = next(val_loader)
358          loss_flat = model(x, y, reduction='none').view(-1)
359          y_flat = y.view(-1)
360          nbytes = token_bytes[y_flat]
361          mask = nbytes > 0
362          total_nats += (loss_flat * mask).sum().item()
363          total_bytes += nbytes.sum().item()
364      return total_nats / (math.log(2) * total_bytes)
365  
366  # ---------------------------------------------------------------------------
367  # Main
368  # ---------------------------------------------------------------------------
369  
370  if __name__ == "__main__":
371      parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
372      parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
373      parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
374      args = parser.parse_args()
375  
376      num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
377  
378      print(f"Cache directory: {CACHE_DIR}")
379      print()
380  
381      # Step 1: Download data
382      download_data(num_shards, download_workers=args.download_workers)
383      print()
384  
385      # Step 2: Train tokenizer
386      train_tokenizer()
387      print()
388      print("Done! Ready to train.")