/ prepare.py
prepare.py
1 """ 2 One-time data preparation for autoresearch experiments. 3 Downloads data shards and trains a BPE tokenizer. 4 5 Usage: 6 python prepare.py # full prep (download + tokenizer) 7 python prepare.py --num-shards 8 # download only 8 shards (for testing) 8 9 Data and tokenizer are stored in ~/.cache/autoresearch/. 10 """ 11 12 import os 13 import sys 14 import time 15 import math 16 import argparse 17 import pickle 18 from multiprocessing import Pool 19 20 import requests 21 import pyarrow.parquet as pq 22 import rustbpe 23 import tiktoken 24 import torch 25 26 # --------------------------------------------------------------------------- 27 # Constants (fixed, do not modify) 28 # --------------------------------------------------------------------------- 29 30 MAX_SEQ_LEN = 2048 # context length 31 TIME_BUDGET = 300 # training time budget in seconds (5 minutes) 32 EVAL_TOKENS = 40 * 524288 # number of tokens for val eval 33 34 # --------------------------------------------------------------------------- 35 # Configuration 36 # --------------------------------------------------------------------------- 37 38 CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch") 39 DATA_DIR = os.path.join(CACHE_DIR, "data") 40 TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") 41 BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" 42 MAX_SHARD = 6542 # the last datashard is shard_06542.parquet 43 VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542) 44 VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet" 45 VOCAB_SIZE = 8192 46 47 # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3}) 48 SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" 49 50 SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)] 51 BOS_TOKEN = "<|reserved_0|>" 52 53 # --------------------------------------------------------------------------- 54 # Data download 55 # --------------------------------------------------------------------------- 56 57 def download_single_shard(index): 58 """Download one parquet shard with retries. Returns True on success.""" 59 filename = f"shard_{index:05d}.parquet" 60 filepath = os.path.join(DATA_DIR, filename) 61 if os.path.exists(filepath): 62 return True 63 64 url = f"{BASE_URL}/{filename}" 65 max_attempts = 5 66 for attempt in range(1, max_attempts + 1): 67 try: 68 response = requests.get(url, stream=True, timeout=30) 69 response.raise_for_status() 70 temp_path = filepath + ".tmp" 71 with open(temp_path, "wb") as f: 72 for chunk in response.iter_content(chunk_size=1024 * 1024): 73 if chunk: 74 f.write(chunk) 75 os.rename(temp_path, filepath) 76 print(f" Downloaded {filename}") 77 return True 78 except (requests.RequestException, IOError) as e: 79 print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}") 80 for path in [filepath + ".tmp", filepath]: 81 if os.path.exists(path): 82 try: 83 os.remove(path) 84 except OSError: 85 pass 86 if attempt < max_attempts: 87 time.sleep(2 ** attempt) 88 return False 89 90 91 def download_data(num_shards, download_workers=8): 92 """Download training shards + pinned validation shard.""" 93 os.makedirs(DATA_DIR, exist_ok=True) 94 num_train = min(num_shards, MAX_SHARD) 95 ids = list(range(num_train)) 96 if VAL_SHARD not in ids: 97 ids.append(VAL_SHARD) 98 99 # Count what's already downloaded 100 existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet"))) 101 if existing == len(ids): 102 print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}") 103 return 104 105 needed = len(ids) - existing 106 print(f"Data: downloading {needed} shards ({existing} already exist)...") 107 108 workers = max(1, min(download_workers, needed)) 109 with Pool(processes=workers) as pool: 110 results = pool.map(download_single_shard, ids) 111 112 ok = sum(1 for r in results if r) 113 print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}") 114 115 # --------------------------------------------------------------------------- 116 # Tokenizer training 117 # --------------------------------------------------------------------------- 118 119 def list_parquet_files(): 120 """Return sorted list of parquet file paths in the data directory.""" 121 files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp")) 122 return [os.path.join(DATA_DIR, f) for f in files] 123 124 125 def text_iterator(max_chars=1_000_000_000, doc_cap=10_000): 126 """Yield documents from training split (all shards except pinned val shard).""" 127 parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)] 128 nchars = 0 129 for filepath in parquet_paths: 130 pf = pq.ParquetFile(filepath) 131 for rg_idx in range(pf.num_row_groups): 132 rg = pf.read_row_group(rg_idx) 133 for text in rg.column("text").to_pylist(): 134 doc = text[:doc_cap] if len(text) > doc_cap else text 135 nchars += len(doc) 136 yield doc 137 if nchars >= max_chars: 138 return 139 140 141 def train_tokenizer(): 142 """Train BPE tokenizer using rustbpe, save as tiktoken pickle.""" 143 tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl") 144 token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") 145 146 if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path): 147 print(f"Tokenizer: already trained at {TOKENIZER_DIR}") 148 return 149 150 os.makedirs(TOKENIZER_DIR, exist_ok=True) 151 152 parquet_files = list_parquet_files() 153 if len(parquet_files) < 2: 154 print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.") 155 sys.exit(1) 156 157 # --- Train with rustbpe --- 158 print("Tokenizer: training BPE tokenizer...") 159 t0 = time.time() 160 161 tokenizer = rustbpe.Tokenizer() 162 vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS) 163 tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN) 164 165 # Build tiktoken encoding from trained merges 166 pattern = tokenizer.get_pattern() 167 mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()} 168 tokens_offset = len(mergeable_ranks) 169 special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)} 170 enc = tiktoken.Encoding( 171 name="rustbpe", 172 pat_str=pattern, 173 mergeable_ranks=mergeable_ranks, 174 special_tokens=special_tokens, 175 ) 176 177 # Save tokenizer 178 with open(tokenizer_pkl, "wb") as f: 179 pickle.dump(enc, f) 180 181 t1 = time.time() 182 print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}") 183 184 # --- Build token_bytes lookup for BPB evaluation --- 185 print("Tokenizer: building token_bytes lookup...") 186 special_set = set(SPECIAL_TOKENS) 187 token_bytes_list = [] 188 for token_id in range(enc.n_vocab): 189 token_str = enc.decode([token_id]) 190 if token_str in special_set: 191 token_bytes_list.append(0) 192 else: 193 token_bytes_list.append(len(token_str.encode("utf-8"))) 194 token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32) 195 torch.save(token_bytes_tensor, token_bytes_path) 196 print(f"Tokenizer: saved token_bytes to {token_bytes_path}") 197 198 # Sanity check 199 test = "Hello world! Numbers: 123. Unicode: 你好" 200 encoded = enc.encode_ordinary(test) 201 decoded = enc.decode(encoded) 202 assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}" 203 print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})") 204 205 # --------------------------------------------------------------------------- 206 # Runtime utilities (imported by train.py) 207 # --------------------------------------------------------------------------- 208 209 class Tokenizer: 210 """Minimal tokenizer wrapper. Training is handled above.""" 211 212 def __init__(self, enc): 213 self.enc = enc 214 self.bos_token_id = enc.encode_single_token(BOS_TOKEN) 215 216 @classmethod 217 def from_directory(cls, tokenizer_dir=TOKENIZER_DIR): 218 with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f: 219 enc = pickle.load(f) 220 return cls(enc) 221 222 def get_vocab_size(self): 223 return self.enc.n_vocab 224 225 def get_bos_token_id(self): 226 return self.bos_token_id 227 228 def encode(self, text, prepend=None, num_threads=8): 229 if prepend is not None: 230 prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) 231 if isinstance(text, str): 232 ids = self.enc.encode_ordinary(text) 233 if prepend is not None: 234 ids.insert(0, prepend_id) 235 elif isinstance(text, list): 236 ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) 237 if prepend is not None: 238 for row in ids: 239 row.insert(0, prepend_id) 240 else: 241 raise ValueError(f"Invalid input type: {type(text)}") 242 return ids 243 244 def decode(self, ids): 245 return self.enc.decode(ids) 246 247 248 def get_token_bytes(device="cpu"): 249 path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") 250 with open(path, "rb") as f: 251 return torch.load(f, map_location=device) 252 253 254 def _document_batches(split, tokenizer_batch_size=128): 255 """Infinite iterator over document batches from parquet files.""" 256 parquet_paths = list_parquet_files() 257 assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." 258 val_path = os.path.join(DATA_DIR, VAL_FILENAME) 259 if split == "train": 260 parquet_paths = [p for p in parquet_paths if p != val_path] 261 assert len(parquet_paths) > 0, "No training shards found." 262 else: 263 parquet_paths = [val_path] 264 epoch = 1 265 while True: 266 for filepath in parquet_paths: 267 pf = pq.ParquetFile(filepath) 268 for rg_idx in range(pf.num_row_groups): 269 rg = pf.read_row_group(rg_idx) 270 batch = rg.column('text').to_pylist() 271 for i in range(0, len(batch), tokenizer_batch_size): 272 yield batch[i:i+tokenizer_batch_size], epoch 273 epoch += 1 274 275 276 def make_dataloader(tokenizer, B, T, split, buffer_size=1000): 277 """ 278 BOS-aligned dataloader with best-fit packing. 279 Every row starts with BOS. Documents packed using best-fit to minimize cropping. 280 When no document fits remaining space, crops shortest doc to fill exactly. 281 100% utilization (no padding). 282 """ 283 assert split in ["train", "val"] 284 row_capacity = T + 1 285 batches = _document_batches(split) 286 bos_token = tokenizer.get_bos_token_id() 287 doc_buffer = [] 288 epoch = 1 289 290 def refill_buffer(): 291 nonlocal epoch 292 doc_batch, epoch = next(batches) 293 token_lists = tokenizer.encode(doc_batch, prepend=bos_token) 294 doc_buffer.extend(token_lists) 295 296 # Pre-allocate buffers: [inputs (B*T) | targets (B*T)] 297 row_buffer = torch.empty((B, row_capacity), dtype=torch.long) 298 cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True) 299 gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda") 300 cpu_inputs = cpu_buffer[:B * T].view(B, T) 301 cpu_targets = cpu_buffer[B * T:].view(B, T) 302 inputs = gpu_buffer[:B * T].view(B, T) 303 targets = gpu_buffer[B * T:].view(B, T) 304 305 while True: 306 for row_idx in range(B): 307 pos = 0 308 while pos < row_capacity: 309 while len(doc_buffer) < buffer_size: 310 refill_buffer() 311 312 remaining = row_capacity - pos 313 314 # Find largest doc that fits entirely 315 best_idx = -1 316 best_len = 0 317 for i, doc in enumerate(doc_buffer): 318 doc_len = len(doc) 319 if doc_len <= remaining and doc_len > best_len: 320 best_idx = i 321 best_len = doc_len 322 323 if best_idx >= 0: 324 doc = doc_buffer.pop(best_idx) 325 row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long) 326 pos += len(doc) 327 else: 328 # No doc fits — crop shortest to fill remaining 329 shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) 330 doc = doc_buffer.pop(shortest_idx) 331 row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) 332 pos += remaining 333 334 cpu_inputs.copy_(row_buffer[:, :-1]) 335 cpu_targets.copy_(row_buffer[:, 1:]) 336 gpu_buffer.copy_(cpu_buffer, non_blocking=True) 337 yield inputs, targets, epoch 338 339 # --------------------------------------------------------------------------- 340 # Evaluation (DO NOT CHANGE — this is the fixed metric) 341 # --------------------------------------------------------------------------- 342 343 @torch.no_grad() 344 def evaluate_bpb(model, tokenizer, batch_size): 345 """ 346 Bits per byte (BPB): vocab size-independent evaluation metric. 347 Sums per-token cross-entropy (in nats), sums target byte lengths, 348 then converts nats/byte to bits/byte. Special tokens (byte length 0) 349 are excluded from both sums. 350 Uses fixed MAX_SEQ_LEN so results are comparable across configs. 351 """ 352 token_bytes = get_token_bytes(device="cuda") 353 val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val") 354 steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN) 355 total_nats = 0.0 356 total_bytes = 0 357 for _ in range(steps): 358 x, y, _ = next(val_loader) 359 loss_flat = model(x, y, reduction='none').view(-1) 360 y_flat = y.view(-1) 361 nbytes = token_bytes[y_flat] 362 mask = nbytes > 0 363 total_nats += (loss_flat * mask).sum().item() 364 total_bytes += nbytes.sum().item() 365 return total_nats / (math.log(2) * total_bytes) 366 367 # --------------------------------------------------------------------------- 368 # Main 369 # --------------------------------------------------------------------------- 370 371 if __name__ == "__main__": 372 parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") 373 parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") 374 parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") 375 args = parser.parse_args() 376 377 num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards 378 379 print(f"Cache directory: {CACHE_DIR}") 380 print() 381 382 # Step 1: Download data 383 download_data(num_shards, download_workers=args.download_workers) 384 print() 385 386 # Step 2: Train tokenizer 387 train_tokenizer() 388 print() 389 print("Done! Ready to train.")