/ train.py
train.py
1 """ 2 Autoresearch pretraining script. Single-GPU, single-file. 3 Cherry-picked and simplified from nanochat. 4 Usage: uv run train.py 5 """ 6 7 import os 8 os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" 9 os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" 10 11 import gc 12 import math 13 import time 14 from dataclasses import dataclass, asdict 15 16 import torch 17 import torch.nn as nn 18 import torch.nn.functional as F 19 20 from kernels import get_kernel 21 cap = torch.cuda.get_device_capability() 22 # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs 23 repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" 24 fa3 = get_kernel(repo).flash_attn_interface 25 26 from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb 27 28 # --------------------------------------------------------------------------- 29 # GPT Model 30 # --------------------------------------------------------------------------- 31 32 @dataclass 33 class GPTConfig: 34 sequence_len: int = 2048 35 vocab_size: int = 32768 36 n_layer: int = 12 37 n_head: int = 6 38 n_kv_head: int = 6 39 n_embd: int = 768 40 window_pattern: str = "SSSL" 41 42 43 def norm(x): 44 return F.rms_norm(x, (x.size(-1),)) 45 46 47 def has_ve(layer_idx, n_layer): 48 """Returns True if layer should have Value Embedding (alternating, last always included).""" 49 return layer_idx % 2 == (n_layer - 1) % 2 50 51 52 def apply_rotary_emb(x, cos, sin): 53 assert x.ndim == 4 54 d = x.shape[3] // 2 55 x1, x2 = x[..., :d], x[..., d:] 56 y1 = x1 * cos + x2 * sin 57 y2 = x1 * (-sin) + x2 * cos 58 return torch.cat([y1, y2], 3) 59 60 61 class CausalSelfAttention(nn.Module): 62 def __init__(self, config, layer_idx): 63 super().__init__() 64 self.n_head = config.n_head 65 self.n_kv_head = config.n_kv_head 66 self.n_embd = config.n_embd 67 self.head_dim = self.n_embd // self.n_head 68 assert self.n_embd % self.n_head == 0 69 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 70 self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) 71 self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) 72 self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) 73 self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) 74 self.ve_gate_channels = 32 75 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None 76 77 def forward(self, x, ve, cos_sin, window_size): 78 B, T, C = x.size() 79 q = self.c_q(x).view(B, T, self.n_head, self.head_dim) 80 k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) 81 v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) 82 83 # Value residual (ResFormer): mix in value embedding with input-dependent gate per head 84 if ve is not None: 85 ve = ve.view(B, T, self.n_kv_head, self.head_dim) 86 gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) 87 v = v + gate.unsqueeze(-1) * ve 88 89 cos, sin = cos_sin 90 q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) 91 q, k = norm(q), norm(k) 92 93 y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size) 94 y = y.contiguous().view(B, T, -1) 95 y = self.c_proj(y) 96 return y 97 98 99 class MLP(nn.Module): 100 def __init__(self, config): 101 super().__init__() 102 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) 103 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) 104 105 def forward(self, x): 106 x = self.c_fc(x) 107 x = F.relu(x).square() 108 x = self.c_proj(x) 109 return x 110 111 112 class Block(nn.Module): 113 def __init__(self, config, layer_idx): 114 super().__init__() 115 self.attn = CausalSelfAttention(config, layer_idx) 116 self.mlp = MLP(config) 117 118 def forward(self, x, ve, cos_sin, window_size): 119 x = x + self.attn(norm(x), ve, cos_sin, window_size) 120 x = x + self.mlp(norm(x)) 121 return x 122 123 124 class GPT(nn.Module): 125 def __init__(self, config): 126 super().__init__() 127 self.config = config 128 self.window_sizes = self._compute_window_sizes(config) 129 self.transformer = nn.ModuleDict({ 130 "wte": nn.Embedding(config.vocab_size, config.n_embd), 131 "h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), 132 }) 133 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) 135 self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) 136 # Value embeddings 137 head_dim = config.n_embd // config.n_head 138 kv_dim = config.n_kv_head * head_dim 139 self.value_embeds = nn.ModuleDict({ 140 str(i): nn.Embedding(config.vocab_size, kv_dim) 141 for i in range(config.n_layer) if has_ve(i, config.n_layer) 142 }) 143 # Rotary embeddings 144 self.rotary_seq_len = config.sequence_len * 10 145 cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) 146 self.register_buffer("cos", cos, persistent=False) 147 self.register_buffer("sin", sin, persistent=False) 148 149 @torch.no_grad() 150 def init_weights(self): 151 # Embedding and unembedding 152 torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) 153 torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) 154 # Transformer blocks 155 n_embd = self.config.n_embd 156 s = 3**0.5 * n_embd**-0.5 157 for block in self.transformer.h: 158 torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) 159 torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) 160 torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) 161 torch.nn.init.zeros_(block.attn.c_proj.weight) 162 torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) 163 torch.nn.init.zeros_(block.mlp.c_proj.weight) 164 # Per-layer scalars 165 self.resid_lambdas.fill_(1.0) 166 self.x0_lambdas.fill_(0.1) 167 # Value embeddings 168 for ve in self.value_embeds.values(): 169 torch.nn.init.uniform_(ve.weight, -s, s) 170 # Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral) 171 for block in self.transformer.h: 172 if block.attn.ve_gate is not None: 173 torch.nn.init.zeros_(block.attn.ve_gate.weight) 174 # Rotary embeddings 175 head_dim = self.config.n_embd // self.config.n_head 176 cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) 177 self.cos, self.sin = cos, sin 178 # Cast embeddings to bf16 179 self.transformer.wte.to(dtype=torch.bfloat16) 180 for ve in self.value_embeds.values(): 181 ve.to(dtype=torch.bfloat16) 182 183 def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): 184 if device is None: 185 device = self.transformer.wte.weight.device 186 channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) 187 inv_freq = 1.0 / (base ** (channel_range / head_dim)) 188 t = torch.arange(seq_len, dtype=torch.float32, device=device) 189 freqs = torch.outer(t, inv_freq) 190 cos, sin = freqs.cos(), freqs.sin() 191 cos, sin = cos.bfloat16(), sin.bfloat16() 192 cos, sin = cos[None, :, None, :], sin[None, :, None, :] 193 return cos, sin 194 195 def _compute_window_sizes(self, config): 196 pattern = config.window_pattern.upper() 197 assert all(c in "SL" for c in pattern) 198 long_window = config.sequence_len 199 short_window = long_window // 2 200 char_to_window = {"L": (long_window, 0), "S": (short_window, 0)} 201 window_sizes = [] 202 for layer_idx in range(config.n_layer): 203 char = pattern[layer_idx % len(pattern)] 204 window_sizes.append(char_to_window[char]) 205 window_sizes[-1] = (long_window, 0) 206 return window_sizes 207 208 def estimate_flops(self): 209 """Estimated FLOPs per token (forward + backward).""" 210 nparams = sum(p.numel() for p in self.parameters()) 211 value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) 212 nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + 213 self.resid_lambdas.numel() + self.x0_lambdas.numel()) 214 h = self.config.n_head 215 q = self.config.n_embd // self.config.n_head 216 t = self.config.sequence_len 217 attn_flops = 0 218 for window_size in self.window_sizes: 219 window = window_size[0] 220 effective_seq = t if window < 0 else min(window, t) 221 attn_flops += 12 * h * q * effective_seq 222 return 6 * (nparams - nparams_exclude) + attn_flops 223 224 def num_scaling_params(self): 225 wte = sum(p.numel() for p in self.transformer.wte.parameters()) 226 value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) 227 lm_head = sum(p.numel() for p in self.lm_head.parameters()) 228 transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) 229 scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() 230 total = wte + value_embeds + lm_head + transformer_matrices + scalars 231 return { 232 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, 233 'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total, 234 } 235 236 def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, 237 weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): 238 model_dim = self.config.n_embd 239 matrix_params = list(self.transformer.h.parameters()) 240 value_embeds_params = list(self.value_embeds.parameters()) 241 embedding_params = list(self.transformer.wte.parameters()) 242 lm_head_params = list(self.lm_head.parameters()) 243 resid_params = [self.resid_lambdas] 244 x0_params = [self.x0_lambdas] 245 assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + 246 len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)) 247 # Scale LR ∝ 1/√dmodel (tuned at 768 dim) 248 dmodel_lr_scale = (model_dim / 768) ** -0.5 249 print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") 250 param_groups = [ 251 dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), 252 dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), 253 dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), 254 dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), 255 dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), 256 ] 257 for shape in sorted({p.shape for p in matrix_params}): 258 group_params = [p for p in matrix_params if p.shape == shape] 259 param_groups.append(dict( 260 kind='muon', params=group_params, lr=matrix_lr, 261 momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, 262 )) 263 optimizer = MuonAdamW(param_groups) 264 for group in optimizer.param_groups: 265 group["initial_lr"] = group["lr"] 266 return optimizer 267 268 def forward(self, idx, targets=None, reduction='mean'): 269 B, T = idx.size() 270 assert T <= self.cos.size(1) 271 cos_sin = self.cos[:, :T], self.sin[:, :T] 272 273 x = self.transformer.wte(idx) 274 x = norm(x) 275 x0 = x 276 for i, block in enumerate(self.transformer.h): 277 x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 278 ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None 279 x = block(x, ve, cos_sin, self.window_sizes[i]) 280 x = norm(x) 281 282 softcap = 15 283 logits = self.lm_head(x) 284 logits = logits.float() 285 logits = softcap * torch.tanh(logits / softcap) 286 287 if targets is not None: 288 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), 289 ignore_index=-1, reduction=reduction) 290 return loss 291 return logits 292 293 # --------------------------------------------------------------------------- 294 # Optimizer (MuonAdamW, single GPU only) 295 # --------------------------------------------------------------------------- 296 297 polar_express_coeffs = [ 298 (8.156554524902461, -22.48329292557795, 15.878769915207462), 299 (4.042929935166739, -2.808917465908714, 0.5000178451051316), 300 (3.8916678022926607, -2.772484153217685, 0.5060648178503393), 301 (3.285753657755655, -2.3681294933425376, 0.46449024233003106), 302 (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), 303 ] 304 305 @torch.compile(dynamic=False, fullgraph=True) 306 def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): 307 p.mul_(1 - lr_t * wd_t) 308 exp_avg.lerp_(grad, 1 - beta1_t) 309 exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) 310 bias1 = 1 - beta1_t ** step_t 311 bias2 = 1 - beta2_t ** step_t 312 denom = (exp_avg_sq / bias2).sqrt() + eps_t 313 step_size = lr_t / bias1 314 p.add_(exp_avg / denom, alpha=-step_size) 315 316 @torch.compile(dynamic=False, fullgraph=True) 317 def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer, 318 momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim): 319 # Nesterov momentum 320 momentum = momentum_t.to(stacked_grads.dtype) 321 momentum_buffer.lerp_(stacked_grads, 1 - momentum) 322 g = stacked_grads.lerp_(momentum_buffer, momentum) 323 # Polar express orthogonalization 324 X = g.bfloat16() 325 X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) 326 if g.size(-2) > g.size(-1): 327 for a, b, c in polar_express_coeffs[:ns_steps]: 328 A = X.mT @ X 329 B = b * A + c * (A @ A) 330 X = a * X + X @ B 331 else: 332 for a, b, c in polar_express_coeffs[:ns_steps]: 333 A = X @ X.mT 334 B = b * A + c * (A @ A) 335 X = a * X + B @ X 336 g = X 337 # NorMuon variance reduction 338 beta2 = beta2_t.to(g.dtype) 339 v_mean = g.float().square().mean(dim=red_dim, keepdim=True) 340 red_dim_size = g.size(red_dim) 341 v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size 342 v_norm = v_norm_sq.sqrt() 343 second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) 344 step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() 345 scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() 346 v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() 347 final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) 348 g = g * final_scale.to(g.dtype) 349 # Cautious weight decay + parameter update 350 lr = lr_t.to(g.dtype) 351 wd = wd_t.to(g.dtype) 352 mask = (g * stacked_params) >= 0 353 stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) 354 355 356 class MuonAdamW(torch.optim.Optimizer): 357 """Combined optimizer: Muon for 2D matrix params, AdamW for others.""" 358 359 def __init__(self, param_groups): 360 super().__init__(param_groups, defaults={}) 361 # 0-D CPU tensors to avoid torch.compile recompilation when values change 362 self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 363 self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 364 self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 365 self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 366 self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 367 self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 368 self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 369 self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 370 self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 371 self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") 372 373 def _step_adamw(self, group): 374 for p in group['params']: 375 if p.grad is None: 376 continue 377 grad = p.grad 378 state = self.state[p] 379 if not state: 380 state['step'] = 0 381 state['exp_avg'] = torch.zeros_like(p) 382 state['exp_avg_sq'] = torch.zeros_like(p) 383 state['step'] += 1 384 self._adamw_step_t.fill_(state['step']) 385 self._adamw_lr_t.fill_(group['lr']) 386 self._adamw_beta1_t.fill_(group['betas'][0]) 387 self._adamw_beta2_t.fill_(group['betas'][1]) 388 self._adamw_eps_t.fill_(group['eps']) 389 self._adamw_wd_t.fill_(group['weight_decay']) 390 adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'], 391 self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, 392 self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t) 393 394 def _step_muon(self, group): 395 params = group['params'] 396 if not params: 397 return 398 p = params[0] 399 state = self.state[p] 400 num_params = len(params) 401 shape, device, dtype = p.shape, p.device, p.dtype 402 if "momentum_buffer" not in state: 403 state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) 404 if "second_momentum_buffer" not in state: 405 state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) 406 state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) 407 red_dim = -1 if shape[-2] >= shape[-1] else -2 408 stacked_grads = torch.stack([p.grad for p in params]) 409 stacked_params = torch.stack(params) 410 self._muon_momentum_t.fill_(group["momentum"]) 411 self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) 412 self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) 413 self._muon_wd_t.fill_(group["weight_decay"]) 414 muon_step_fused(stacked_grads, stacked_params, 415 state["momentum_buffer"], state["second_momentum_buffer"], 416 self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, 417 self._muon_beta2_t, group["ns_steps"], red_dim) 418 torch._foreach_copy_(params, list(stacked_params.unbind(0))) 419 420 @torch.no_grad() 421 def step(self): 422 for group in self.param_groups: 423 if group['kind'] == 'adamw': 424 self._step_adamw(group) 425 elif group['kind'] == 'muon': 426 self._step_muon(group) 427 428 # --------------------------------------------------------------------------- 429 # Hyperparameters (edit these directly, no CLI flags needed) 430 # --------------------------------------------------------------------------- 431 432 # Model architecture 433 ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO 434 HEAD_DIM = 128 # target head dimension for attention 435 WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context 436 437 # Optimization 438 TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step 439 EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam) 440 UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam) 441 MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon) 442 SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam) 443 WEIGHT_DECAY = 0.2 # cautious weight decay for Muon 444 ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2 445 WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup 446 WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown 447 FINAL_LR_FRAC = 0.0 # final LR as fraction of initial 448 449 # Model size 450 DEPTH = 8 # number of transformer layers 451 DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) 452 453 # --------------------------------------------------------------------------- 454 # Setup: tokenizer, model, optimizer, dataloader 455 # --------------------------------------------------------------------------- 456 457 t_start = time.time() 458 torch.manual_seed(42) 459 torch.cuda.manual_seed(42) 460 torch.set_float32_matmul_precision("high") 461 device = torch.device("cuda") 462 autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) 463 H100_BF16_PEAK_FLOPS = 989.5e12 464 465 tokenizer = Tokenizer.from_directory() 466 vocab_size = tokenizer.get_vocab_size() 467 print(f"Vocab size: {vocab_size:,}") 468 469 def build_model_config(depth): 470 base_dim = depth * ASPECT_RATIO 471 model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM 472 num_heads = model_dim // HEAD_DIM 473 return GPTConfig( 474 sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, 475 n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, 476 window_pattern=WINDOW_PATTERN, 477 ) 478 479 config = build_model_config(DEPTH) 480 print(f"Model config: {asdict(config)}") 481 482 with torch.device("meta"): 483 model = GPT(config) 484 model.to_empty(device=device) 485 model.init_weights() 486 487 param_counts = model.num_scaling_params() 488 print("Parameter counts:") 489 for key, value in param_counts.items(): 490 print(f" {key:24s}: {value:,}") 491 num_params = param_counts['total'] 492 num_flops_per_token = model.estimate_flops() 493 print(f"Estimated FLOPs per token: {num_flops_per_token:e}") 494 495 tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN 496 assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 497 grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd 498 499 optimizer = model.setup_optimizer( 500 unembedding_lr=UNEMBEDDING_LR, 501 embedding_lr=EMBEDDING_LR, 502 scalar_lr=SCALAR_LR, 503 adam_betas=ADAM_BETAS, 504 matrix_lr=MATRIX_LR, 505 weight_decay=WEIGHT_DECAY, 506 ) 507 508 model = torch.compile(model, dynamic=False) 509 510 train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train") 511 x, y, epoch = next(train_loader) # prefetch first batch 512 513 print(f"Time budget: {TIME_BUDGET}s") 514 print(f"Gradient accumulation steps: {grad_accum_steps}") 515 516 # Schedules (all based on progress = training_time / TIME_BUDGET) 517 518 def get_lr_multiplier(progress): 519 if progress < WARMUP_RATIO: 520 return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 521 elif progress < 1.0 - WARMDOWN_RATIO: 522 return 1.0 523 else: 524 cooldown = (1.0 - progress) / WARMDOWN_RATIO 525 return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC 526 527 def get_muon_momentum(step): 528 frac = min(step / 300, 1) 529 return (1 - frac) * 0.85 + frac * 0.95 530 531 def get_weight_decay(progress): 532 return WEIGHT_DECAY * (1 - progress) 533 534 # --------------------------------------------------------------------------- 535 # Training loop 536 # --------------------------------------------------------------------------- 537 538 t_start_training = time.time() 539 smooth_train_loss = 0 540 total_training_time = 0 541 step = 0 542 543 while True: 544 torch.cuda.synchronize() 545 t0 = time.time() 546 for micro_step in range(grad_accum_steps): 547 with autocast_ctx: 548 loss = model(x, y) 549 train_loss = loss.detach() 550 loss = loss / grad_accum_steps 551 loss.backward() 552 x, y, epoch = next(train_loader) 553 554 # Progress and schedules 555 progress = min(total_training_time / TIME_BUDGET, 1.0) 556 lrm = get_lr_multiplier(progress) 557 muon_momentum = get_muon_momentum(step) 558 muon_weight_decay = get_weight_decay(progress) 559 for group in optimizer.param_groups: 560 group["lr"] = group["initial_lr"] * lrm 561 if group['kind'] == 'muon': 562 group["momentum"] = muon_momentum 563 group["weight_decay"] = muon_weight_decay 564 optimizer.step() 565 model.zero_grad(set_to_none=True) 566 567 train_loss_f = train_loss.item() 568 569 # Fast fail: abort if loss is exploding or NaN 570 if math.isnan(train_loss_f) or train_loss_f > 100: 571 print("FAIL") 572 exit(1) 573 574 torch.cuda.synchronize() 575 t1 = time.time() 576 dt = t1 - t0 577 578 if step > 10: 579 total_training_time += dt 580 581 # Logging 582 ema_beta = 0.9 583 smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f 584 debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) 585 pct_done = 100 * progress 586 tok_per_sec = int(TOTAL_BATCH_SIZE / dt) 587 mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS 588 remaining = max(0, TIME_BUDGET - total_training_time) 589 590 print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True) 591 592 # GC management (Python's GC causes ~500ms stalls) 593 if step == 0: 594 gc.collect() 595 gc.freeze() 596 gc.disable() 597 elif (step + 1) % 5000 == 0: 598 gc.collect() 599 600 step += 1 601 602 # Time's up — but only stop after warmup steps so we don't count compilation 603 if step > 10 and total_training_time >= TIME_BUDGET: 604 break 605 606 print() # newline after \r training log 607 608 total_tokens = step * TOTAL_BATCH_SIZE 609 610 # Final eval 611 model.eval() 612 with autocast_ctx: 613 val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE) 614 615 # Final summary 616 t_end = time.time() 617 startup_time = t_start_training - t_start 618 steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0 619 peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 620 621 print("---") 622 print(f"val_bpb: {val_bpb:.6f}") 623 print(f"training_seconds: {total_training_time:.1f}") 624 print(f"total_seconds: {t_end - t_start:.1f}") 625 print(f"peak_vram_mb: {peak_vram_mb:.1f}") 626 print(f"mfu_percent: {steady_state_mfu:.2f}") 627 print(f"total_tokens_M: {total_tokens / 1e6:.1f}") 628 print(f"num_steps: {step}") 629 print(f"num_params_M: {num_params / 1e6:.1f}") 630 print(f"depth: {DEPTH}")