/ prepare.py
prepare.py
  1  """
  2  One-time data preparation for autoresearch experiments.
  3  Downloads data shards and trains a BPE tokenizer.
  4  
  5  Usage:
  6      python prepare.py                  # full prep (download + tokenizer)
  7      python prepare.py --num-shards 8   # download only 8 shards (for testing)
  8  
  9  Data and tokenizer are stored in ~/.cache/autoresearch/.
 10  """
 11  
 12  import os
 13  import sys
 14  import time
 15  import math
 16  import argparse
 17  import pickle
 18  from multiprocessing import Pool
 19  
 20  import requests
 21  import pyarrow.parquet as pq
 22  import rustbpe
 23  import tiktoken
 24  import torch
 25  
 26  # ---------------------------------------------------------------------------
 27  # Constants (fixed, do not modify)
 28  # ---------------------------------------------------------------------------
 29  
 30  MAX_SEQ_LEN = 2048       # context length
 31  TIME_BUDGET = 300        # training time budget in seconds (5 minutes)
 32  EVAL_TOKENS = 40 * 524288  # number of tokens for val eval
 33  
 34  # ---------------------------------------------------------------------------
 35  # Configuration
 36  # ---------------------------------------------------------------------------
 37  
 38  CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
 39  DATA_DIR = os.path.join(CACHE_DIR, "data")
 40  TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
 41  BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
 42  MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
 43  VAL_SHARD = MAX_SHARD  # pinned validation shard (shard_06542)
 44  VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
 45  VOCAB_SIZE = 8192
 46  
 47  # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
 48  SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
 49  
 50  SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
 51  BOS_TOKEN = "<|reserved_0|>"
 52  
 53  # ---------------------------------------------------------------------------
 54  # Data download
 55  # ---------------------------------------------------------------------------
 56  
 57  def download_single_shard(index):
 58      """Download one parquet shard with retries. Returns True on success."""
 59      filename = f"shard_{index:05d}.parquet"
 60      filepath = os.path.join(DATA_DIR, filename)
 61      if os.path.exists(filepath):
 62          return True
 63  
 64      url = f"{BASE_URL}/{filename}"
 65      max_attempts = 5
 66      for attempt in range(1, max_attempts + 1):
 67          try:
 68              response = requests.get(url, stream=True, timeout=30)
 69              response.raise_for_status()
 70              temp_path = filepath + ".tmp"
 71              with open(temp_path, "wb") as f:
 72                  for chunk in response.iter_content(chunk_size=1024 * 1024):
 73                      if chunk:
 74                          f.write(chunk)
 75              os.rename(temp_path, filepath)
 76              print(f"  Downloaded {filename}")
 77              return True
 78          except (requests.RequestException, IOError) as e:
 79              print(f"  Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
 80              for path in [filepath + ".tmp", filepath]:
 81                  if os.path.exists(path):
 82                      try:
 83                          os.remove(path)
 84                      except OSError:
 85                          pass
 86              if attempt < max_attempts:
 87                  time.sleep(2 ** attempt)
 88      return False
 89  
 90  
 91  def download_data(num_shards, download_workers=8):
 92      """Download training shards + pinned validation shard."""
 93      os.makedirs(DATA_DIR, exist_ok=True)
 94      num_train = min(num_shards, MAX_SHARD)
 95      ids = list(range(num_train))
 96      if VAL_SHARD not in ids:
 97          ids.append(VAL_SHARD)
 98  
 99      # Count what's already downloaded
100      existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
101      if existing == len(ids):
102          print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
103          return
104  
105      needed = len(ids) - existing
106      print(f"Data: downloading {needed} shards ({existing} already exist)...")
107  
108      workers = max(1, min(download_workers, needed))
109      with Pool(processes=workers) as pool:
110          results = pool.map(download_single_shard, ids)
111  
112      ok = sum(1 for r in results if r)
113      print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
114  
115  # ---------------------------------------------------------------------------
116  # Tokenizer training
117  # ---------------------------------------------------------------------------
118  
119  def list_parquet_files():
120      """Return sorted list of parquet file paths in the data directory."""
121      files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
122      return [os.path.join(DATA_DIR, f) for f in files]
123  
124  
125  def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
126      """Yield documents from training split (all shards except pinned val shard)."""
127      parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
128      nchars = 0
129      for filepath in parquet_paths:
130          pf = pq.ParquetFile(filepath)
131          for rg_idx in range(pf.num_row_groups):
132              rg = pf.read_row_group(rg_idx)
133              for text in rg.column("text").to_pylist():
134                  doc = text[:doc_cap] if len(text) > doc_cap else text
135                  nchars += len(doc)
136                  yield doc
137                  if nchars >= max_chars:
138                      return
139  
140  
141  def train_tokenizer():
142      """Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
143      tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
144      token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
145  
146      if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
147          print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
148          return
149  
150      os.makedirs(TOKENIZER_DIR, exist_ok=True)
151  
152      parquet_files = list_parquet_files()
153      if len(parquet_files) < 2:
154          print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
155          sys.exit(1)
156  
157      # --- Train with rustbpe ---
158      print("Tokenizer: training BPE tokenizer...")
159      t0 = time.time()
160  
161      tokenizer = rustbpe.Tokenizer()
162      vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
163      tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
164  
165      # Build tiktoken encoding from trained merges
166      pattern = tokenizer.get_pattern()
167      mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
168      tokens_offset = len(mergeable_ranks)
169      special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
170      enc = tiktoken.Encoding(
171          name="rustbpe",
172          pat_str=pattern,
173          mergeable_ranks=mergeable_ranks,
174          special_tokens=special_tokens,
175      )
176  
177      # Save tokenizer
178      with open(tokenizer_pkl, "wb") as f:
179          pickle.dump(enc, f)
180  
181      t1 = time.time()
182      print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
183  
184      # --- Build token_bytes lookup for BPB evaluation ---
185      print("Tokenizer: building token_bytes lookup...")
186      special_set = set(SPECIAL_TOKENS)
187      token_bytes_list = []
188      for token_id in range(enc.n_vocab):
189          token_str = enc.decode([token_id])
190          if token_str in special_set:
191              token_bytes_list.append(0)
192          else:
193              token_bytes_list.append(len(token_str.encode("utf-8")))
194      token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
195      torch.save(token_bytes_tensor, token_bytes_path)
196      print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
197  
198      # Sanity check
199      test = "Hello world! Numbers: 123. Unicode: 你好"
200      encoded = enc.encode_ordinary(test)
201      decoded = enc.decode(encoded)
202      assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
203      print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
204  
205  # ---------------------------------------------------------------------------
206  # Runtime utilities (imported by train.py)
207  # ---------------------------------------------------------------------------
208  
209  class Tokenizer:
210      """Minimal tokenizer wrapper. Training is handled above."""
211  
212      def __init__(self, enc):
213          self.enc = enc
214          self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
215  
216      @classmethod
217      def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
218          with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
219              enc = pickle.load(f)
220          return cls(enc)
221  
222      def get_vocab_size(self):
223          return self.enc.n_vocab
224  
225      def get_bos_token_id(self):
226          return self.bos_token_id
227  
228      def encode(self, text, prepend=None, num_threads=8):
229          if prepend is not None:
230              prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
231          if isinstance(text, str):
232              ids = self.enc.encode_ordinary(text)
233              if prepend is not None:
234                  ids.insert(0, prepend_id)
235          elif isinstance(text, list):
236              ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
237              if prepend is not None:
238                  for row in ids:
239                      row.insert(0, prepend_id)
240          else:
241              raise ValueError(f"Invalid input type: {type(text)}")
242          return ids
243  
244      def decode(self, ids):
245          return self.enc.decode(ids)
246  
247  
248  def get_token_bytes(device="cpu"):
249      path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
250      with open(path, "rb") as f:
251          return torch.load(f, map_location=device)
252  
253  
254  def _document_batches(split, tokenizer_batch_size=128):
255      """Infinite iterator over document batches from parquet files."""
256      parquet_paths = list_parquet_files()
257      assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
258      val_path = os.path.join(DATA_DIR, VAL_FILENAME)
259      if split == "train":
260          parquet_paths = [p for p in parquet_paths if p != val_path]
261          assert len(parquet_paths) > 0, "No training shards found."
262      else:
263          parquet_paths = [val_path]
264      epoch = 1
265      while True:
266          for filepath in parquet_paths:
267              pf = pq.ParquetFile(filepath)
268              for rg_idx in range(pf.num_row_groups):
269                  rg = pf.read_row_group(rg_idx)
270                  batch = rg.column('text').to_pylist()
271                  for i in range(0, len(batch), tokenizer_batch_size):
272                      yield batch[i:i+tokenizer_batch_size], epoch
273          epoch += 1
274  
275  
276  def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
277      """
278      BOS-aligned dataloader with best-fit packing.
279      Every row starts with BOS. Documents packed using best-fit to minimize cropping.
280      When no document fits remaining space, crops shortest doc to fill exactly.
281      100% utilization (no padding).
282      """
283      assert split in ["train", "val"]
284      row_capacity = T + 1
285      batches = _document_batches(split)
286      bos_token = tokenizer.get_bos_token_id()
287      doc_buffer = []
288      epoch = 1
289  
290      def refill_buffer():
291          nonlocal epoch
292          doc_batch, epoch = next(batches)
293          token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
294          doc_buffer.extend(token_lists)
295  
296      # Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
297      row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
298      cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
299      gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
300      cpu_inputs = cpu_buffer[:B * T].view(B, T)
301      cpu_targets = cpu_buffer[B * T:].view(B, T)
302      inputs = gpu_buffer[:B * T].view(B, T)
303      targets = gpu_buffer[B * T:].view(B, T)
304  
305      while True:
306          for row_idx in range(B):
307              pos = 0
308              while pos < row_capacity:
309                  while len(doc_buffer) < buffer_size:
310                      refill_buffer()
311  
312                  remaining = row_capacity - pos
313  
314                  # Find largest doc that fits entirely
315                  best_idx = -1
316                  best_len = 0
317                  for i, doc in enumerate(doc_buffer):
318                      doc_len = len(doc)
319                      if doc_len <= remaining and doc_len > best_len:
320                          best_idx = i
321                          best_len = doc_len
322  
323                  if best_idx >= 0:
324                      doc = doc_buffer.pop(best_idx)
325                      row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
326                      pos += len(doc)
327                  else:
328                      # No doc fits — crop shortest to fill remaining
329                      shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
330                      doc = doc_buffer.pop(shortest_idx)
331                      row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
332                      pos += remaining
333  
334          cpu_inputs.copy_(row_buffer[:, :-1])
335          cpu_targets.copy_(row_buffer[:, 1:])
336          gpu_buffer.copy_(cpu_buffer, non_blocking=True)
337          yield inputs, targets, epoch
338  
339  # ---------------------------------------------------------------------------
340  # Evaluation (DO NOT CHANGE — this is the fixed metric)
341  # ---------------------------------------------------------------------------
342  
343  @torch.no_grad()
344  def evaluate_bpb(model, tokenizer, batch_size):
345      """
346      Bits per byte (BPB): vocab size-independent evaluation metric.
347      Sums per-token cross-entropy (in nats), sums target byte lengths,
348      then converts nats/byte to bits/byte. Special tokens (byte length 0)
349      are excluded from both sums.
350      Uses fixed MAX_SEQ_LEN so results are comparable across configs.
351      """
352      token_bytes = get_token_bytes(device="cuda")
353      val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
354      steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
355      total_nats = 0.0
356      total_bytes = 0
357      for _ in range(steps):
358          x, y, _ = next(val_loader)
359          loss_flat = model(x, y, reduction='none').view(-1)
360          y_flat = y.view(-1)
361          nbytes = token_bytes[y_flat]
362          mask = nbytes > 0
363          total_nats += (loss_flat * mask).sum().item()
364          total_bytes += nbytes.sum().item()
365      return total_nats / (math.log(2) * total_bytes)
366  
367  # ---------------------------------------------------------------------------
368  # Main
369  # ---------------------------------------------------------------------------
370  
371  if __name__ == "__main__":
372      parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
373      parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
374      parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
375      args = parser.parse_args()
376  
377      num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
378  
379      print(f"Cache directory: {CACHE_DIR}")
380      print()
381  
382      # Step 1: Download data
383      download_data(num_shards, download_workers=args.download_workers)
384      print()
385  
386      # Step 2: Train tokenizer
387      train_tokenizer()
388      print()
389      print("Done! Ready to train.")