/ train.py
train.py
1 """ 2 Autoresearch pretraining script. Single-GPU, single-file. 3 Cherry-picked and simplified from nanochat. 4 Usage: uv run train.py 5 """ 6 7 import os 8 os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" 9 os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" 10 11 import gc 12 import time 13 from dataclasses import dataclass, asdict 14 15 import torch 16 import torch.nn as nn 17 import torch.nn.functional as F 18 19 from kernels import get_kernel 20 cap = torch.cuda.get_device_capability() 21 # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs 22 repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" 23 fa3 = get_kernel(repo).flash_attn_interface 24 25 from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb 26 27 # --------------------------------------------------------------------------- 28 # GPT Model 29 # --------------------------------------------------------------------------- 30 31 @dataclass 32 class GPTConfig: 33 sequence_len: int = 2048 34 vocab_size: int = 32768 35 n_layer: int = 12 36 n_head: int = 6 37 n_kv_head: int = 6 38 n_embd: int = 768 39 window_pattern: str = "SSSL" 40 41 42 def norm(x): 43 return F.rms_norm(x, (x.size(-1),)) 44 45 46 def has_ve(layer_idx, n_layer): 47 """Returns True if layer should have Value Embedding (alternating, last always included).""" 48 return layer_idx % 2 == (n_layer - 1) % 2 49 50 51 def apply_rotary_emb(x, cos, sin): 52 assert x.ndim == 4 53 d = x.shape[3] // 2 54 x1, x2 = x[..., :d], x[..., d:] 55 y1 = x1 * cos + x2 * sin 56 y2 = x1 * (-sin) + x2 * cos 57 return torch.cat([y1, y2], 3) 58 59 60 class CausalSelfAttention(nn.Module): 61 def __init__(self, config, layer_idx): 62 super().__init__() 63 self.n_head = config.n_head 64 self.n_kv_head = config.n_kv_head 65 self.n_embd = config.n_embd 66 self.head_dim = self.n_embd // self.n_head 67 assert self.n_embd % self.n_head == 0 68 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 69 self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) 70 self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) 71 self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) 72 self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) 73 self.ve_gate_channels = 32 74 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None 75 76 def forward(self, x, ve, cos_sin, window_size): 77 B, T, C = x.size() 78 q = self.c_q(x).view(B, T, self.n_head, self.head_dim) 79 k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) 80 v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) 81 82 # Value residual (ResFormer): mix in value embedding with input-dependent gate per head 83 if ve is not None: 84 ve = ve.view(B, T, self.n_kv_head, self.head_dim) 85 gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) 86 v = v + gate.unsqueeze(-1) * ve 87 88 cos, sin = cos_sin 89 q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) 90 q, k = norm(q), norm(k) 91 92 y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size) 93 y = y.contiguous().view(B, T, -1) 94 y = self.c_proj(y) 95 return y 96 97 98 class MLP(nn.Module): 99 def __init__(self, config): 100 super().__init__() 101 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) 102 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) 103 104 def forward(self, x): 105 x = self.c_fc(x) 106 x = F.relu(x).square() 107 x = self.c_proj(x) 108 return x 109 110 111 class Block(nn.Module): 112 def __init__(self, config, layer_idx): 113 super().__init__() 114 self.attn = CausalSelfAttention(config, layer_idx) 115 self.mlp = MLP(config) 116 117 def forward(self, x, ve, cos_sin, window_size): 118 x = x + self.attn(norm(x), ve, cos_sin, window_size) 119 x = x + self.mlp(norm(x)) 120 return x 121 122 123 class GPT(nn.Module): 124 def __init__(self, config): 125 super().__init__() 126 self.config = config 127 self.window_sizes = self._compute_window_sizes(config) 128 self.transformer = nn.ModuleDict({ 129 "wte": nn.Embedding(config.vocab_size, config.n_embd), 130 "h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), 131 }) 132 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 133 self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) 134 self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) 135 # Value embeddings 136 head_dim = config.n_embd // config.n_head 137 kv_dim = config.n_kv_head * head_dim 138 self.value_embeds = nn.ModuleDict({ 139 str(i): nn.Embedding(config.vocab_size, kv_dim) 140 for i in range(config.n_layer) if has_ve(i, config.n_layer) 141 }) 142 # Rotary embeddings 143 self.rotary_seq_len = config.sequence_len * 10 144 cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) 145 self.register_buffer("cos", cos, persistent=False) 146 self.register_buffer("sin", sin, persistent=False) 147 148 @torch.no_grad() 149 def init_weights(self): 150 # Embedding and unembedding 151 torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) 152 torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) 153 # Transformer blocks 154 n_embd = self.config.n_embd 155 s = 3**0.5 * n_embd**-0.5 156 for block in self.transformer.h: 157 torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) 158 torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) 159 torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) 160 torch.nn.init.zeros_(block.attn.c_proj.weight) 161 torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) 162 torch.nn.init.zeros_(block.mlp.c_proj.weight) 163 # Per-layer scalars 164 self.resid_lambdas.fill_(1.0) 165 self.x0_lambdas.fill_(0.1) 166 # Value embeddings 167 for ve in self.value_embeds.values(): 168 torch.nn.init.uniform_(ve.weight, -s, s) 169 # Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral) 170 for block in self.transformer.h: 171 if block.attn.ve_gate is not None: 172 torch.nn.init.zeros_(block.attn.ve_gate.weight) 173 # Rotary embeddings 174 head_dim = self.config.n_embd // self.config.n_head 175 cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) 176 self.cos, self.sin = cos, sin 177 # Cast embeddings to bf16 178 self.transformer.wte.to(dtype=torch.bfloat16) 179 for ve in self.value_embeds.values(): 180 ve.to(dtype=torch.bfloat16) 181 182 def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): 183 if device is None: 184 device = self.transformer.wte.weight.device 185 channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) 186 inv_freq = 1.0 / (base ** (channel_range / head_dim)) 187 t = torch.arange(seq_len, dtype=torch.float32, device=device) 188 freqs = torch.outer(t, inv_freq) 189 cos, sin = freqs.cos(), freqs.sin() 190 cos, sin = cos.bfloat16(), sin.bfloat16() 191 cos, sin = cos[None, :, None, :], sin[None, :, None, :] 192 return cos, sin 193 194 def _compute_window_sizes(self, config): 195 pattern = config.window_pattern.upper() 196 assert all(c in "SL" for c in pattern) 197 long_window = config.sequence_len 198 short_window = long_window // 2 199 char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} 200 window_sizes = [] 201 for layer_idx in range(config.n_layer): 202 char = pattern[layer_idx % len(pattern)] 203 window_sizes.append(char_to_window[char]) 204 window_sizes[-1] = (long_window, 0) 205 return window_sizes 206 207 def estimate_flops(self): 208 """Estimated FLOPs per token (forward + backward).""" 209 nparams = sum(p.numel() for p in self.parameters()) 210 value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) 211 nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + 212 self.resid_lambdas.numel() + self.x0_lambdas.numel()) 213 h = self.config.n_head 214 q = self.config.n_embd // self.config.n_head 215 t = self.config.sequence_len 216 attn_flops = 0 217 for window_size in self.window_sizes: 218 window = window_size[0] 219 effective_seq = t if window < 0 else min(window, t) 220 attn_flops += 12 * h * q * effective_seq 221 return 6 * (nparams - nparams_exclude) + attn_flops 222 223 def num_scaling_params(self): 224 wte = sum(p.numel() for p in self.transformer.wte.parameters()) 225 value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) 226 lm_head = sum(p.numel() for p in self.lm_head.parameters()) 227 transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) 228 scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() 229 total = wte + value_embeds + lm_head + transformer_matrices + scalars 230 return { 231 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, 232 'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total, 233 } 234 235 def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, 236 weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): 237 model_dim = self.config.n_embd 238 matrix_params = list(self.transformer.h.parameters()) 239 value_embeds_params = list(self.value_embeds.parameters()) 240 embedding_params = list(self.transformer.wte.parameters()) 241 lm_head_params = list(self.lm_head.parameters()) 242 resid_params = [self.resid_lambdas] 243 x0_params = [self.x0_lambdas] 244 assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + 245 len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)) 246 # Scale LR ∝ 1/√dmodel (tuned at 768 dim) 247 dmodel_lr_scale = (model_dim / 768) ** -0.5 248 print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") 249 param_groups = [ 250 dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), 251 dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), 252 dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), 253 dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), 254 dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), 255 ] 256 for shape in sorted({p.shape for p in matrix_params}): 257 group_params = [p for p in matrix_params if p.shape == shape] 258 param_groups.append(dict( 259 kind='muon', params=group_params, lr=matrix_lr, 260 momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, 261 )) 262 optimizer = MuonAdamW(param_groups) 263 for group in optimizer.param_groups: 264 group["initial_lr"] = group["lr"] 265 return optimizer 266 267 def forward(self, idx, targets=None, reduction='mean'): 268 B, T = idx.size() 269 assert T <= self.cos.size(1) 270 cos_sin = self.cos[:, :T], self.sin[:, :T] 271 272 x = self.transformer.wte(idx) 273 x = norm(x) 274 x0 = x 275 for i, block in enumerate(self.transformer.h): 276 x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 277 ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None 278 x = block(x, ve, cos_sin, self.window_sizes[i]) 279 x = norm(x) 280 281 softcap = 15 282 logits = self.lm_head(x) 283 logits = logits.float() 284 logits = softcap * torch.tanh(logits / softcap) 285 286 if targets is not None: 287 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), 288 ignore_index=-1, reduction=reduction) 289 return loss 290 return logits 291 292 # --------------------------------------------------------------------------- 293 # Optimizer (MuonAdamW, single GPU only) 294 # --------------------------------------------------------------------------- 295 296 polar_express_coeffs = [ 297 (8.156554524902461, -22.48329292557795, 15.878769915207462), 298 (4.042929935166739, -2.808917465908714, 0.5000178451051316), 299 (3.8916678022926607, -2.772484153217685, 0.5060648178503393), 300 (3.285753657755655, -2.3681294933425376, 0.46449024233003106), 301 (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), 302 ] 303 304 @torch.compile(dynamic=False, fullgraph=True) 305 def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): 306 p.mul_(1 - lr_t * wd_t) 307 exp_avg.lerp_(grad, 1 - beta1_t) 308 exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) 309 bias1 = 1 - beta1_t ** step_t 310 bias2 = 1 - beta2_t ** step_t 311 denom = (exp_avg_sq / bias2).sqrt() + eps_t 312 step_size = lr_t / bias1 313 p.add_(exp_avg / denom, alpha=-step_size) 314 315 @torch.compile(dynamic=False, fullgraph=True) 316 def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer, 317 momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim): 318 # Nesterov momentum 319 momentum = momentum_t.to(stacked_grads.dtype) 320 momentum_buffer.lerp_(stacked_grads, 1 - momentum) 321 g = stacked_grads.lerp_(momentum_buffer, momentum) 322 # Polar express orthogonalization 323 X = g.bfloat16() 324 X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) 325 if g.size(-2) > g.size(-1): 326 for a, b, c in polar_express_coeffs[:ns_steps]: 327 A = X.mT @ X 328 B = b * A + c * (A @ A) 329 X = a * X + X @ B 330 else: 331 for a, b, c in polar_express_coeffs[:ns_steps]: 332 A = X @ X.mT 333 B = b * A + c * (A @ A) 334 X = a * X + B @ X 335 g = X 336 # NorMuon variance reduction 337 beta2 = beta2_t.to(g.dtype) 338 v_mean = g.float().square().mean(dim=red_dim, keepdim=True) 339 red_dim_size = g.size(red_dim) 340 v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size 341 v_norm = v_norm_sq.sqrt() 342 second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) 343 step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() 344 scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() 345 v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() 346 final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) 347 g = g * final_scale.to(g.dtype) 348 # Cautious weight decay + parameter update 349 lr = lr_t.to(g.dtype) 350 wd = wd_t.to(g.dtype) 351 mask = (g * stacked_params) >= 0 352 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) 353 354 355 class MuonAdamW(torch.optim.Optimizer): 356 """Combined optimizer: Muon for 2D matrix params, AdamW for others.""" 357 358 def __init__(self, param_groups): 359 super().__init__(param_groups, defaults={}) 360 # 0-D CPU tensors to avoid torch.compile recompilation when values change 361 self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 362 self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 363 self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 364 self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 365 self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 366 self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 367 self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 368 self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 369 self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 370 self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 371 372 def _step_adamw(self, group): 373 for p in group['params']: 374 if p.grad is None: 375 continue 376 grad = p.grad 377 state = self.state[p] 378 if not state: 379 state['step'] = 0 380 state['exp_avg'] = torch.zeros_like(p) 381 state['exp_avg_sq'] = torch.zeros_like(p) 382 state['step'] += 1 383 self._adamw_step_t.fill_(state['step']) 384 self._adamw_lr_t.fill_(group['lr']) 385 self._adamw_beta1_t.fill_(group['betas'][0]) 386 self._adamw_beta2_t.fill_(group['betas'][1]) 387 self._adamw_eps_t.fill_(group['eps']) 388 self._adamw_wd_t.fill_(group['weight_decay']) 389 adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'], 390 self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, 391 self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t) 392 393 def _step_muon(self, group): 394 params = group['params'] 395 if not params: 396 return 397 p = params[0] 398 state = self.state[p] 399 num_params = len(params) 400 shape, device, dtype = p.shape, p.device, p.dtype 401 if "momentum_buffer" not in state: 402 state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) 403 if "second_momentum_buffer" not in state: 404 state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) 405 state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) 406 red_dim = -1 if shape[-2] >= shape[-1] else -2 407 stacked_grads = torch.stack([p.grad for p in params]) 408 stacked_params = torch.stack(params) 409 self._muon_momentum_t.fill_(group["momentum"]) 410 self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) 411 self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) 412 self._muon_wd_t.fill_(group["weight_decay"]) 413 muon_step_fused(stacked_grads, stacked_params, 414 state["momentum_buffer"], state["second_momentum_buffer"], 415 self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, 416 self._muon_beta2_t, group["ns_steps"], red_dim) 417 torch._foreach_copy_(params, list(stacked_params.unbind(0))) 418 419 @torch.no_grad() 420 def step(self): 421 for group in self.param_groups: 422 if group['kind'] == 'adamw': 423 self._step_adamw(group) 424 elif group['kind'] == 'muon': 425 self._step_muon(group) 426 427 # --------------------------------------------------------------------------- 428 # Hyperparameters (edit these directly, no CLI flags needed) 429 # --------------------------------------------------------------------------- 430 431 # Model architecture 432 ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO 433 HEAD_DIM = 128 # target head dimension for attention 434 WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context 435 436 # Optimization 437 TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step 438 EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam) 439 UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam) 440 MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon) 441 SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam) 442 WEIGHT_DECAY = 0.2 # cautious weight decay for Muon 443 ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2 444 WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup 445 WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown 446 FINAL_LR_FRAC = 0.0 # final LR as fraction of initial 447 448 # Model size 449 DEPTH = 8 # number of transformer layers 450 DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) 451 452 # --------------------------------------------------------------------------- 453 # Setup: tokenizer, model, optimizer, dataloader 454 # --------------------------------------------------------------------------- 455 456 t_start = time.time() 457 torch.manual_seed(42) 458 torch.cuda.manual_seed(42) 459 torch.set_float32_matmul_precision("high") 460 device = torch.device("cuda") 461 autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) 462 H100_BF16_PEAK_FLOPS = 989.5e12 463 464 tokenizer = Tokenizer.from_directory() 465 vocab_size = tokenizer.get_vocab_size() 466 print(f"Vocab size: {vocab_size:,}") 467 468 def build_model_config(depth): 469 base_dim = depth * ASPECT_RATIO 470 model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM 471 num_heads = model_dim // HEAD_DIM 472 return GPTConfig( 473 sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, 474 n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, 475 window_pattern=WINDOW_PATTERN, 476 ) 477 478 config = build_model_config(DEPTH) 479 print(f"Model config: {asdict(config)}") 480 481 with torch.device("meta"): 482 model = GPT(config) 483 model.to_empty(device=device) 484 model.init_weights() 485 486 param_counts = model.num_scaling_params() 487 print("Parameter counts:") 488 for key, value in param_counts.items(): 489 print(f" {key:24s}: {value:,}") 490 num_params = param_counts['total'] 491 num_flops_per_token = model.estimate_flops() 492 print(f"Estimated FLOPs per token: {num_flops_per_token:e}") 493 494 tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN 495 assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 496 grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd 497 498 optimizer = model.setup_optimizer( 499 unembedding_lr=UNEMBEDDING_LR, 500 embedding_lr=EMBEDDING_LR, 501 scalar_lr=SCALAR_LR, 502 adam_betas=ADAM_BETAS, 503 matrix_lr=MATRIX_LR, 504 weight_decay=WEIGHT_DECAY, 505 ) 506 507 model = torch.compile(model, dynamic=False) 508 509 train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train") 510 x, y, epoch = next(train_loader) # prefetch first batch 511 512 print(f"Time budget: {TIME_BUDGET}s") 513 print(f"Gradient accumulation steps: {grad_accum_steps}") 514 515 # Schedules (all based on progress = training_time / TIME_BUDGET) 516 517 def get_lr_multiplier(progress): 518 if progress < WARMUP_RATIO: 519 return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 520 elif progress < 1.0 - WARMDOWN_RATIO: 521 return 1.0 522 else: 523 cooldown = (1.0 - progress) / WARMDOWN_RATIO 524 return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC 525 526 def get_muon_momentum(step): 527 frac = min(step / 300, 1) 528 return (1 - frac) * 0.85 + frac * 0.95 529 530 def get_weight_decay(progress): 531 return WEIGHT_DECAY * (1 - progress) 532 533 # --------------------------------------------------------------------------- 534 # Training loop 535 # --------------------------------------------------------------------------- 536 537 t_start_training = time.time() 538 smooth_train_loss = 0 539 total_training_time = 0 540 step = 0 541 542 while True: 543 torch.cuda.synchronize() 544 t0 = time.time() 545 for micro_step in range(grad_accum_steps): 546 with autocast_ctx: 547 loss = model(x, y) 548 train_loss = loss.detach() 549 loss = loss / grad_accum_steps 550 loss.backward() 551 x, y, epoch = next(train_loader) 552 553 # Progress and schedules 554 progress = min(total_training_time / TIME_BUDGET, 1.0) 555 lrm = get_lr_multiplier(progress) 556 muon_momentum = get_muon_momentum(step) 557 muon_weight_decay = get_weight_decay(progress) 558 for group in optimizer.param_groups: 559 group["lr"] = group["initial_lr"] * lrm 560 if group['kind'] == 'muon': 561 group["momentum"] = muon_momentum 562 group["weight_decay"] = muon_weight_decay 563 optimizer.step() 564 model.zero_grad(set_to_none=True) 565 566 train_loss_f = train_loss.item() 567 568 # Fast fail: abort if loss is exploding 569 if train_loss_f > 100: 570 print("FAIL") 571 exit(1) 572 573 torch.cuda.synchronize() 574 t1 = time.time() 575 dt = t1 - t0 576 577 if step > 10: 578 total_training_time += dt 579 580 # Logging 581 ema_beta = 0.9 582 smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f 583 debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) 584 pct_done = 100 * progress 585 tok_per_sec = int(TOTAL_BATCH_SIZE / dt) 586 mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS 587 remaining = max(0, TIME_BUDGET - total_training_time) 588 589 print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True) 590 591 # GC management (Python's GC causes ~500ms stalls) 592 if step == 0: 593 gc.collect() 594 gc.freeze() 595 gc.disable() 596 elif (step + 1) % 5000 == 0: 597 gc.collect() 598 599 step += 1 600 601 # Time's up — but only stop after warmup steps so we don't count compilation 602 if step > 10 and total_training_time >= TIME_BUDGET: 603 break 604 605 print() # newline after \r training log 606 607 total_tokens = step * TOTAL_BATCH_SIZE 608 609 # Final eval 610 model.eval() 611 with autocast_ctx: 612 val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE) 613 614 # Final summary 615 t_end = time.time() 616 startup_time = t_start_training - t_start 617 steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0 618 peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 619 620 print("---") 621 print(f"val_bpb: {val_bpb:.6f}") 622 print(f"training_seconds: {total_training_time:.1f}") 623 print(f"total_seconds: {t_end - t_start:.1f}") 624 print(f"peak_vram_mb: {peak_vram_mb:.1f}") 625 print(f"mfu_percent: {steady_state_mfu:.2f}") 626 print(f"total_tokens_M: {total_tokens / 1e6:.1f}") 627 print(f"num_steps: {step}") 628 print(f"num_params_M: {num_params / 1e6:.1f}") 629 print(f"depth: {DEPTH}")