/ train.py
train.py
1 """ 2 This training script can be run both on a single gpu in debug mode, 3 and also in a larger training run with distributed data parallel (ddp). 4 5 To run on a single GPU, example: 6 $ python train.py --batch_size=32 --compile=False 7 8 To run with DDP on 4 gpus on 1 node, example: 9 $ torchrun --standalone --nproc_per_node=4 train.py 10 11 To run with DDP on 4 gpus across 2 nodes, example: 12 - Run on the first (master) node with example IP 123.456.123.456: 13 $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 - Run on the worker node: 15 $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 """ 18 19 import os 20 import time 21 import math 22 import pickle 23 from contextlib import nullcontext 24 25 import numpy as np 26 import torch 27 from torch.nn.parallel import DistributedDataParallel as DDP 28 from torch.distributed import init_process_group, destroy_process_group 29 30 from model import GPTConfig, GPT 31 32 # ----------------------------------------------------------------------------- 33 # default config values designed to train a gpt2 (124M) on OpenWebText 34 # I/O 35 out_dir = 'out' 36 eval_interval = 2000 37 log_interval = 1 38 eval_iters = 200 39 eval_only = False # if True, script exits right after the first eval 40 always_save_checkpoint = True # if True, always save a checkpoint after each eval 41 init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 42 # wandb logging 43 wandb_log = False # disabled by default 44 wandb_project = 'owt' 45 wandb_run_name = 'gpt2' # 'run' + str(time.time()) 46 # data 47 dataset = 'openwebtext' 48 gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 49 batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 50 block_size = 1024 51 # model 52 n_layer = 12 53 n_head = 12 54 n_embd = 768 55 dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 56 bias = False # do we use bias inside LayerNorm and Linear layers? 57 # adamw optimizer 58 learning_rate = 6e-4 # max learning rate 59 max_iters = 600000 # total number of training iterations 60 weight_decay = 1e-1 61 beta1 = 0.9 62 beta2 = 0.95 63 grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 64 # learning rate decay settings 65 decay_lr = True # whether to decay the learning rate 66 warmup_iters = 2000 # how many steps to warm up for 67 lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 68 min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 69 # DDP settings 70 backend = 'nccl' # 'nccl', 'gloo', etc. 71 # system 72 device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 73 dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 74 compile = True # use PyTorch 2.0 to compile the model to be faster 75 # ----------------------------------------------------------------------------- 76 config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 77 exec(open('configurator.py').read()) # overrides from command line or config file 78 config = {k: globals()[k] for k in config_keys} # will be useful for logging 79 # ----------------------------------------------------------------------------- 80 81 # various inits, derived attributes, I/O setup 82 ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 83 if ddp: 84 init_process_group(backend=backend) 85 ddp_rank = int(os.environ['RANK']) 86 ddp_local_rank = int(os.environ['LOCAL_RANK']) 87 ddp_world_size = int(os.environ['WORLD_SIZE']) 88 device = f'cuda:{ddp_local_rank}' 89 torch.cuda.set_device(device) 90 master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 91 seed_offset = ddp_rank # each process gets a different seed 92 # world_size number of processes will be training simultaneously, so we can scale 93 # down the desired gradient accumulation iterations per process proportionally 94 assert gradient_accumulation_steps % ddp_world_size == 0 95 gradient_accumulation_steps //= ddp_world_size 96 else: 97 # if not ddp, we are running on a single gpu, and one process 98 master_process = True 99 seed_offset = 0 100 ddp_world_size = 1 101 tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 102 print(f"tokens per iteration will be: {tokens_per_iter:,}") 103 104 if master_process: 105 os.makedirs(out_dir, exist_ok=True) 106 torch.manual_seed(1337 + seed_offset) 107 torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 108 torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 109 device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 110 # note: float16 data type will automatically use a GradScaler 111 ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 112 ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 113 114 # poor man's data loader 115 data_dir = os.path.join('data', dataset) 116 def get_batch(split): 117 # We recreate np.memmap every batch to avoid a memory leak, as per 118 # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 119 if split == 'train': 120 data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 121 else: 122 data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 123 ix = torch.randint(len(data) - block_size, (batch_size,)) 124 x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 125 y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 126 if device_type == 'cuda': 127 # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 128 x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 129 else: 130 x, y = x.to(device), y.to(device) 131 return x, y 132 133 # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 134 iter_num = 0 135 best_val_loss = 1e9 136 137 # attempt to derive vocab_size from the dataset 138 meta_path = os.path.join(data_dir, 'meta.pkl') 139 meta_vocab_size = None 140 if os.path.exists(meta_path): 141 with open(meta_path, 'rb') as f: 142 meta = pickle.load(f) 143 meta_vocab_size = meta['vocab_size'] 144 print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 145 146 # model init 147 model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 148 bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 149 if init_from == 'scratch': 150 # init a new model from scratch 151 print("Initializing a new model from scratch") 152 # determine the vocab size we'll use for from-scratch training 153 if meta_vocab_size is None: 154 print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 155 model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 156 gptconf = GPTConfig(**model_args) 157 model = GPT(gptconf) 158 elif init_from == 'resume': 159 print(f"Resuming training from {out_dir}") 160 # resume training from a checkpoint. 161 ckpt_path = os.path.join(out_dir, 'ckpt.pt') 162 checkpoint = torch.load(ckpt_path, map_location=device) 163 checkpoint_model_args = checkpoint['model_args'] 164 # force these config attributes to be equal otherwise we can't even resume training 165 # the rest of the attributes (e.g. dropout) can stay as desired from command line 166 for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 167 model_args[k] = checkpoint_model_args[k] 168 # create the model 169 gptconf = GPTConfig(**model_args) 170 model = GPT(gptconf) 171 state_dict = checkpoint['model'] 172 # fix the keys of the state dictionary :( 173 # honestly no idea how checkpoints sometimes get this prefix, have to debug more 174 unwanted_prefix = '_orig_mod.' 175 for k,v in list(state_dict.items()): 176 if k.startswith(unwanted_prefix): 177 state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 178 model.load_state_dict(state_dict) 179 iter_num = checkpoint['iter_num'] 180 best_val_loss = checkpoint['best_val_loss'] 181 elif init_from.startswith('gpt2'): 182 print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 183 # initialize from OpenAI GPT-2 weights 184 override_args = dict(dropout=dropout) 185 model = GPT.from_pretrained(init_from, override_args) 186 # read off the created config params, so we can store them into checkpoint correctly 187 for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 188 model_args[k] = getattr(model.config, k) 189 # crop down the model block size if desired, using model surgery 190 if block_size < model.config.block_size: 191 model.crop_block_size(block_size) 192 model_args['block_size'] = block_size # so that the checkpoint will have the right value 193 model.to(device) 194 195 # initialize a GradScaler. If enabled=False scaler is a no-op 196 scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 197 198 # optimizer 199 optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 200 if init_from == 'resume': 201 optimizer.load_state_dict(checkpoint['optimizer']) 202 checkpoint = None # free up memory 203 204 # compile the model 205 if compile: 206 print("compiling the model... (takes a ~minute)") 207 unoptimized_model = model 208 model = torch.compile(model) # requires PyTorch 2.0 209 210 # wrap model into DDP container 211 if ddp: 212 model = DDP(model, device_ids=[ddp_local_rank]) 213 214 # helps estimate an arbitrarily accurate loss over either split using many batches 215 @torch.no_grad() 216 def estimate_loss(): 217 out = {} 218 model.eval() 219 for split in ['train', 'val']: 220 losses = torch.zeros(eval_iters) 221 for k in range(eval_iters): 222 X, Y = get_batch(split) 223 with ctx: 224 logits, loss = model(X, Y) 225 losses[k] = loss.item() 226 out[split] = losses.mean() 227 model.train() 228 return out 229 230 # learning rate decay scheduler (cosine with warmup) 231 def get_lr(it): 232 # 1) linear warmup for warmup_iters steps 233 if it < warmup_iters: 234 return learning_rate * (it + 1) / (warmup_iters + 1) 235 # 2) if it > lr_decay_iters, return min learning rate 236 if it > lr_decay_iters: 237 return min_lr 238 # 3) in between, use cosine decay down to min learning rate 239 decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 240 assert 0 <= decay_ratio <= 1 241 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 242 return min_lr + coeff * (learning_rate - min_lr) 243 244 # logging 245 if wandb_log and master_process: 246 import wandb 247 wandb.init(project=wandb_project, name=wandb_run_name, config=config) 248 249 # training loop 250 X, Y = get_batch('train') # fetch the very first batch 251 t0 = time.time() 252 local_iter_num = 0 # number of iterations in the lifetime of this process 253 raw_model = model.module if ddp else model # unwrap DDP container if needed 254 running_mfu = -1.0 255 while True: 256 257 # determine and set the learning rate for this iteration 258 lr = get_lr(iter_num) if decay_lr else learning_rate 259 for param_group in optimizer.param_groups: 260 param_group['lr'] = lr 261 262 # evaluate the loss on train/val sets and write checkpoints 263 if iter_num % eval_interval == 0 and master_process: 264 losses = estimate_loss() 265 print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 266 if wandb_log: 267 wandb.log({ 268 "iter": iter_num, 269 "train/loss": losses['train'], 270 "val/loss": losses['val'], 271 "lr": lr, 272 "mfu": running_mfu*100, # convert to percentage 273 }) 274 if losses['val'] < best_val_loss or always_save_checkpoint: 275 best_val_loss = losses['val'] 276 if iter_num > 0: 277 checkpoint = { 278 'model': raw_model.state_dict(), 279 'optimizer': optimizer.state_dict(), 280 'model_args': model_args, 281 'iter_num': iter_num, 282 'best_val_loss': best_val_loss, 283 'config': config, 284 } 285 print(f"saving checkpoint to {out_dir}") 286 torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 287 if iter_num == 0 and eval_only: 288 break 289 290 # forward backward update, with optional gradient accumulation to simulate larger batch size 291 # and using the GradScaler if data type is float16 292 for micro_step in range(gradient_accumulation_steps): 293 if ddp: 294 # in DDP training we only need to sync gradients at the last micro step. 295 # the official way to do this is with model.no_sync() context manager, but 296 # I really dislike that this bloats the code and forces us to repeat code 297 # looking at the source of that context manager, it just toggles this variable 298 model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 299 with ctx: 300 logits, loss = model(X, Y) 301 loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 302 # immediately async prefetch next batch while model is doing the forward pass on the GPU 303 X, Y = get_batch('train') 304 # backward pass, with gradient scaling if training in fp16 305 scaler.scale(loss).backward() 306 # clip the gradient 307 if grad_clip != 0.0: 308 scaler.unscale_(optimizer) 309 torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 310 # step the optimizer and scaler if training in fp16 311 scaler.step(optimizer) 312 scaler.update() 313 # flush the gradients as soon as we can, no need for this memory anymore 314 optimizer.zero_grad(set_to_none=True) 315 316 # timing and logging 317 t1 = time.time() 318 dt = t1 - t0 319 t0 = t1 320 if iter_num % log_interval == 0 and master_process: 321 # get loss as float. note: this is a CPU-GPU sync point 322 # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 323 lossf = loss.item() * gradient_accumulation_steps 324 if local_iter_num >= 5: # let the training loop settle a bit 325 mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 326 running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 327 print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 328 iter_num += 1 329 local_iter_num += 1 330 331 # termination conditions 332 if iter_num > max_iters: 333 break 334 335 if ddp: 336 destroy_process_group()