/ prepare.py
prepare.py
1 """ 2 One-time data preparation for autoresearch experiments. 3 Downloads data shards and trains a BPE tokenizer. 4 5 Usage: 6 python prepare.py # full prep (download + tokenizer) 7 python prepare.py --num-shards 8 # download only 8 shards (for testing) 8 9 Data and tokenizer are stored in ~/.cache/autoresearch/. 10 """ 11 12 import os 13 import sys 14 import time 15 import math 16 import argparse 17 import pickle 18 from multiprocessing import Pool 19 20 import requests 21 import pyarrow.parquet as pq 22 import rustbpe 23 import tiktoken 24 import torch 25 26 # --------------------------------------------------------------------------- 27 # Constants (fixed, do not modify) 28 # --------------------------------------------------------------------------- 29 30 MAX_SEQ_LEN = 2048 # context length 31 TIME_BUDGET = 300 # training time budget in seconds (5 minutes) 32 EVAL_TOKENS = 40 * 524288 # number of tokens for val eval 33 34 # --------------------------------------------------------------------------- 35 # Configuration 36 # --------------------------------------------------------------------------- 37 38 CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch") 39 DATA_DIR = os.path.join(CACHE_DIR, "data") 40 TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") 41 BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" 42 MAX_SHARD = 6542 # the last datashard is shard_06542.parquet 43 VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542) 44 VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet" 45 VOCAB_SIZE = 8192 46 47 # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3}) 48 SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" 49 50 SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)] 51 BOS_TOKEN = "<|reserved_0|>" 52 53 # --------------------------------------------------------------------------- 54 # Data download 55 # --------------------------------------------------------------------------- 56 57 def download_single_shard(index): 58 """Download one parquet shard with retries. Returns True on success.""" 59 filename = f"shard_{index:05d}.parquet" 60 filepath = os.path.join(DATA_DIR, filename) 61 if os.path.exists(filepath): 62 return True 63 64 url = f"{BASE_URL}/{filename}" 65 max_attempts = 5 66 for attempt in range(1, max_attempts + 1): 67 try: 68 response = requests.get(url, stream=True, timeout=30) 69 response.raise_for_status() 70 temp_path = filepath + ".tmp" 71 with open(temp_path, "wb") as f: 72 for chunk in response.iter_content(chunk_size=1024 * 1024): 73 if chunk: 74 f.write(chunk) 75 os.rename(temp_path, filepath) 76 print(f" Downloaded {filename}") 77 return True 78 except (requests.RequestException, IOError) as e: 79 print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}") 80 for path in [filepath + ".tmp", filepath]: 81 if os.path.exists(path): 82 try: 83 os.remove(path) 84 except OSError: 85 pass 86 if attempt < max_attempts: 87 time.sleep(2 ** attempt) 88 return False 89 90 91 def download_data(num_shards, download_workers=8): 92 """Download training shards + pinned validation shard.""" 93 os.makedirs(DATA_DIR, exist_ok=True) 94 num_train = min(num_shards, MAX_SHARD) 95 ids = list(range(num_train)) 96 if VAL_SHARD not in ids: 97 ids.append(VAL_SHARD) 98 99 # Count what's already downloaded 100 existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet"))) 101 if existing == len(ids): 102 print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}") 103 return 104 105 needed = len(ids) - existing 106 print(f"Data: downloading {needed} shards ({existing} already exist)...") 107 108 workers = max(1, min(download_workers, needed)) 109 with Pool(processes=workers) as pool: 110 results = pool.map(download_single_shard, ids) 111 112 ok = sum(1 for r in results if r) 113 print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}") 114 115 # --------------------------------------------------------------------------- 116 # Tokenizer training 117 # --------------------------------------------------------------------------- 118 119 def list_parquet_files(): 120 """Return sorted list of parquet file paths in the data directory.""" 121 files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp")) 122 return [os.path.join(DATA_DIR, f) for f in files] 123 124 125 def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): 126 """Yield documents from training split (all shards except pinned val shard).""" 127 parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)] 128 nchars = 0 129 for filepath in parquet_paths: 130 pf = pq.ParquetFile(filepath) 131 for rg_idx in range(pf.num_row_groups): 132 rg = pf.read_row_group(rg_idx) 133 for text in rg.column("text").to_pylist(): 134 doc = text[:doc_cap] if len(text) > doc_cap else text 135 nchars += len(doc) 136 yield doc 137 if nchars >= max_chars: 138 return 139 140 141 def train_tokenizer(): 142 """Train BPE tokenizer using rustbpe, save as tiktoken pickle.""" 143 tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl") 144 token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") 145 146 if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path): 147 print(f"Tokenizer: already trained at {TOKENIZER_DIR}") 148 return 149 150 os.makedirs(TOKENIZER_DIR, exist_ok=True) 151 152 parquet_files = list_parquet_files() 153 if len(parquet_files) < 2: 154 print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.") 155 sys.exit(1) 156 157 # --- Train with rustbpe --- 158 print("Tokenizer: training BPE tokenizer...") 159 t0 = time.time() 160 161 tokenizer = rustbpe.Tokenizer() 162 vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS) 163 tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN) 164 165 # Build tiktoken encoding from trained merges 166 pattern = tokenizer.get_pattern() 167 mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()} 168 tokens_offset = len(mergeable_ranks) 169 special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)} 170 enc = tiktoken.Encoding( 171 name="rustbpe", 172 pat_str=pattern, 173 mergeable_ranks=mergeable_ranks, 174 special_tokens=special_tokens, 175 ) 176 177 # Save tokenizer 178 with open(tokenizer_pkl, "wb") as f: 179 pickle.dump(enc, f) 180 181 t1 = time.time() 182 print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}") 183 184 # --- Build token_bytes lookup for BPB evaluation --- 185 print("Tokenizer: building token_bytes lookup...") 186 special_set = set(SPECIAL_TOKENS) 187 token_bytes_list = [] 188 for token_id in range(enc.n_vocab): 189 token_str = enc.decode([token_id]) 190 if token_str in special_set: 191 token_bytes_list.append(0) 192 else: 193 token_bytes_list.append(len(token_str.encode("utf-8"))) 194 token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32) 195 torch.save(token_bytes_tensor, token_bytes_path) 196 print(f"Tokenizer: saved token_bytes to {token_bytes_path}") 197 198 # Sanity check 199 test = "Hello world! Numbers: 123. Unicode: 你好" 200 encoded = enc.encode_ordinary(test) 201 decoded = enc.decode(encoded) 202 assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}" 203 print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})") 204 205 # --------------------------------------------------------------------------- 206 # Runtime utilities (imported by train.py) 207 # --------------------------------------------------------------------------- 208 209 class Tokenizer: 210 """Minimal tokenizer wrapper. Training is handled above.""" 211 212 def __init__(self, enc): 213 self.enc = enc 214 self.bos_token_id = enc.encode_single_token(BOS_TOKEN) 215 216 @classmethod 217 def from_directory(cls, tokenizer_dir=TOKENIZER_DIR): 218 with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f: 219 enc = pickle.load(f) 220 return cls(enc) 221 222 def get_vocab_size(self): 223 return self.enc.n_vocab 224 225 def get_bos_token_id(self): 226 return self.bos_token_id 227 228 def encode(self, text, prepend=None, num_threads=8): 229 if prepend is not None: 230 prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) 231 if isinstance(text, str): 232 ids = self.enc.encode_ordinary(text) 233 if prepend is not None: 234 ids.insert(0, prepend_id) 235 elif isinstance(text, list): 236 ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) 237 if prepend is not None: 238 for row in ids: 239 row.insert(0, prepend_id) 240 else: 241 raise ValueError(f"Invalid input type: {type(text)}") 242 return ids 243 244 def decode(self, ids): 245 return self.enc.decode(ids) 246 247 248 def get_token_bytes(device="cpu"): 249 path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") 250 with open(path, "rb") as f: 251 return torch.load(f, map_location=device) 252 253 254 def _document_batches(split, tokenizer_batch_size=128): 255 """Infinite iterator over document batches from parquet files.""" 256 parquet_paths = list_parquet_files() 257 assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." 258 val_path = os.path.join(DATA_DIR, VAL_FILENAME) 259 if split == "train": 260 parquet_paths = [p for p in parquet_paths if p != val_path] 261 else: 262 parquet_paths = [val_path] 263 epoch = 1 264 while True: 265 for filepath in parquet_paths: 266 pf = pq.ParquetFile(filepath) 267 for rg_idx in range(pf.num_row_groups): 268 rg = pf.read_row_group(rg_idx) 269 batch = rg.column('text').to_pylist() 270 for i in range(0, len(batch), tokenizer_batch_size): 271 yield batch[i:i+tokenizer_batch_size], epoch 272 epoch += 1 273 274 275 def make_dataloader(tokenizer, B, T, split, buffer_size=1000): 276 """ 277 BOS-aligned dataloader with best-fit packing. 278 Every row starts with BOS. Documents packed using best-fit to minimize cropping. 279 When no document fits remaining space, crops shortest doc to fill exactly. 280 100% utilization (no padding). 281 """ 282 assert split in ["train", "val"] 283 row_capacity = T + 1 284 batches = _document_batches(split) 285 bos_token = tokenizer.get_bos_token_id() 286 doc_buffer = [] 287 epoch = 1 288 289 def refill_buffer(): 290 nonlocal epoch 291 doc_batch, epoch = next(batches) 292 token_lists = tokenizer.encode(doc_batch, prepend=bos_token) 293 doc_buffer.extend(token_lists) 294 295 # Pre-allocate buffers: [inputs (B*T) | targets (B*T)] 296 row_buffer = torch.empty((B, row_capacity), dtype=torch.long) 297 cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True) 298 gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda") 299 cpu_inputs = cpu_buffer[:B * T].view(B, T) 300 cpu_targets = cpu_buffer[B * T:].view(B, T) 301 inputs = gpu_buffer[:B * T].view(B, T) 302 targets = gpu_buffer[B * T:].view(B, T) 303 304 while True: 305 for row_idx in range(B): 306 pos = 0 307 while pos < row_capacity: 308 while len(doc_buffer) < buffer_size: 309 refill_buffer() 310 311 remaining = row_capacity - pos 312 313 # Find largest doc that fits entirely 314 best_idx = -1 315 best_len = 0 316 for i, doc in enumerate(doc_buffer): 317 doc_len = len(doc) 318 if doc_len <= remaining and doc_len > best_len: 319 best_idx = i 320 best_len = doc_len 321 322 if best_idx >= 0: 323 doc = doc_buffer.pop(best_idx) 324 row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long) 325 pos += len(doc) 326 else: 327 # No doc fits — crop shortest to fill remaining 328 shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) 329 doc = doc_buffer.pop(shortest_idx) 330 row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) 331 pos += remaining 332 333 cpu_inputs.copy_(row_buffer[:, :-1]) 334 cpu_targets.copy_(row_buffer[:, 1:]) 335 gpu_buffer.copy_(cpu_buffer, non_blocking=True) 336 yield inputs, targets, epoch 337 338 # --------------------------------------------------------------------------- 339 # Evaluation (DO NOT CHANGE — this is the fixed metric) 340 # --------------------------------------------------------------------------- 341 342 @torch.no_grad() 343 def evaluate_bpb(model, tokenizer, batch_size): 344 """ 345 Bits per byte (BPB): vocab size-independent evaluation metric. 346 Sums per-token cross-entropy (in nats), sums target byte lengths, 347 then converts nats/byte to bits/byte. Special tokens (byte length 0) 348 are excluded from both sums. 349 Uses fixed MAX_SEQ_LEN so results are comparable across configs. 350 """ 351 token_bytes = get_token_bytes(device="cuda") 352 val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val") 353 steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN) 354 total_nats = 0.0 355 total_bytes = 0 356 for _ in range(steps): 357 x, y, _ = next(val_loader) 358 loss_flat = model(x, y, reduction='none').view(-1) 359 y_flat = y.view(-1) 360 nbytes = token_bytes[y_flat] 361 mask = nbytes > 0 362 total_nats += (loss_flat * mask).sum().item() 363 total_bytes += nbytes.sum().item() 364 return total_nats / (math.log(2) * total_bytes) 365 366 # --------------------------------------------------------------------------- 367 # Main 368 # --------------------------------------------------------------------------- 369 370 if __name__ == "__main__": 371 parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") 372 parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") 373 parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") 374 args = parser.parse_args() 375 376 num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards 377 378 print(f"Cache directory: {CACHE_DIR}") 379 print() 380 381 # Step 1: Download data 382 download_data(num_shards, download_workers=args.download_workers) 383 print() 384 385 # Step 2: Train tokenizer 386 train_tokenizer() 387 print() 388 print("Done! Ready to train.")