/ train.py
train.py
  1  """
  2  This training script can be run both on a single gpu in debug mode,
  3  and also in a larger training run with distributed data parallel (ddp).
  4  
  5  To run on a single GPU, example:
  6  $ python train.py --batch_size=32 --compile=False
  7  
  8  To run with DDP on 4 gpus on 1 node, example:
  9  $ torchrun --standalone --nproc_per_node=4 train.py
 10  
 11  To run with DDP on 4 gpus across 2 nodes, example:
 12  - Run on the first (master) node with example IP 123.456.123.456:
 13  $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
 14  - Run on the worker node:
 15  $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
 16  (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
 17  """
 18  
 19  import os
 20  import time
 21  import math
 22  import pickle
 23  from contextlib import nullcontext
 24  
 25  import numpy as np
 26  import torch
 27  from torch.nn.parallel import DistributedDataParallel as DDP
 28  from torch.distributed import init_process_group, destroy_process_group
 29  
 30  from model import GPTConfig, GPT
 31  
 32  # -----------------------------------------------------------------------------
 33  # default config values designed to train a gpt2 (124M) on OpenWebText
 34  # I/O
 35  out_dir = 'out'
 36  eval_interval = 2000
 37  log_interval = 1
 38  eval_iters = 200
 39  eval_only = False # if True, script exits right after the first eval
 40  always_save_checkpoint = True # if True, always save a checkpoint after each eval
 41  init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
 42  # wandb logging
 43  wandb_log = False # disabled by default
 44  wandb_project = 'owt'
 45  wandb_run_name = 'gpt2' # 'run' + str(time.time())
 46  # data
 47  dataset = 'openwebtext'
 48  gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
 49  batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
 50  block_size = 1024
 51  # model
 52  n_layer = 12
 53  n_head = 12
 54  n_embd = 768
 55  dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
 56  bias = False # do we use bias inside LayerNorm and Linear layers?
 57  # adamw optimizer
 58  learning_rate = 6e-4 # max learning rate
 59  max_iters = 600000 # total number of training iterations
 60  weight_decay = 1e-1
 61  beta1 = 0.9
 62  beta2 = 0.95
 63  grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
 64  # learning rate decay settings
 65  decay_lr = True # whether to decay the learning rate
 66  warmup_iters = 2000 # how many steps to warm up for
 67  lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
 68  min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
 69  # DDP settings
 70  backend = 'nccl' # 'nccl', 'gloo', etc.
 71  # system
 72  device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
 73  dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
 74  compile = True # use PyTorch 2.0 to compile the model to be faster
 75  # -----------------------------------------------------------------------------
 76  config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
 77  exec(open('configurator.py').read()) # overrides from command line or config file
 78  config = {k: globals()[k] for k in config_keys} # will be useful for logging
 79  # -----------------------------------------------------------------------------
 80  
 81  # various inits, derived attributes, I/O setup
 82  ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
 83  if ddp:
 84      init_process_group(backend=backend)
 85      ddp_rank = int(os.environ['RANK'])
 86      ddp_local_rank = int(os.environ['LOCAL_RANK'])
 87      ddp_world_size = int(os.environ['WORLD_SIZE'])
 88      device = f'cuda:{ddp_local_rank}'
 89      torch.cuda.set_device(device)
 90      master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
 91      seed_offset = ddp_rank # each process gets a different seed
 92      # world_size number of processes will be training simultaneously, so we can scale
 93      # down the desired gradient accumulation iterations per process proportionally
 94      assert gradient_accumulation_steps % ddp_world_size == 0
 95      gradient_accumulation_steps //= ddp_world_size
 96  else:
 97      # if not ddp, we are running on a single gpu, and one process
 98      master_process = True
 99      seed_offset = 0
100      ddp_world_size = 1
101  tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
102  print(f"tokens per iteration will be: {tokens_per_iter:,}")
103  
104  if master_process:
105      os.makedirs(out_dir, exist_ok=True)
106  torch.manual_seed(1337 + seed_offset)
107  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
108  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
109  device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
110  # note: float16 data type will automatically use a GradScaler
111  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
112  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
113  
114  # poor man's data loader
115  data_dir = os.path.join('data', dataset)
116  def get_batch(split):
117      # We recreate np.memmap every batch to avoid a memory leak, as per
118      # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
119      if split == 'train':
120          data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
121      else:
122          data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
123      ix = torch.randint(len(data) - block_size, (batch_size,))
124      x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
125      y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
126      if device_type == 'cuda':
127          # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
128          x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
129      else:
130          x, y = x.to(device), y.to(device)
131      return x, y
132  
133  # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
134  iter_num = 0
135  best_val_loss = 1e9
136  
137  # attempt to derive vocab_size from the dataset
138  meta_path = os.path.join(data_dir, 'meta.pkl')
139  meta_vocab_size = None
140  if os.path.exists(meta_path):
141      with open(meta_path, 'rb') as f:
142          meta = pickle.load(f)
143      meta_vocab_size = meta['vocab_size']
144      print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
145  
146  # model init
147  model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
148                    bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
149  if init_from == 'scratch':
150      # init a new model from scratch
151      print("Initializing a new model from scratch")
152      # determine the vocab size we'll use for from-scratch training
153      if meta_vocab_size is None:
154          print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
155      model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
156      gptconf = GPTConfig(**model_args)
157      model = GPT(gptconf)
158  elif init_from == 'resume':
159      print(f"Resuming training from {out_dir}")
160      # resume training from a checkpoint.
161      ckpt_path = os.path.join(out_dir, 'ckpt.pt')
162      checkpoint = torch.load(ckpt_path, map_location=device)
163      checkpoint_model_args = checkpoint['model_args']
164      # force these config attributes to be equal otherwise we can't even resume training
165      # the rest of the attributes (e.g. dropout) can stay as desired from command line
166      for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
167          model_args[k] = checkpoint_model_args[k]
168      # create the model
169      gptconf = GPTConfig(**model_args)
170      model = GPT(gptconf)
171      state_dict = checkpoint['model']
172      # fix the keys of the state dictionary :(
173      # honestly no idea how checkpoints sometimes get this prefix, have to debug more
174      unwanted_prefix = '_orig_mod.'
175      for k,v in list(state_dict.items()):
176          if k.startswith(unwanted_prefix):
177              state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
178      model.load_state_dict(state_dict)
179      iter_num = checkpoint['iter_num']
180      best_val_loss = checkpoint['best_val_loss']
181  elif init_from.startswith('gpt2'):
182      print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
183      # initialize from OpenAI GPT-2 weights
184      override_args = dict(dropout=dropout)
185      model = GPT.from_pretrained(init_from, override_args)
186      # read off the created config params, so we can store them into checkpoint correctly
187      for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
188          model_args[k] = getattr(model.config, k)
189  # crop down the model block size if desired, using model surgery
190  if block_size < model.config.block_size:
191      model.crop_block_size(block_size)
192      model_args['block_size'] = block_size # so that the checkpoint will have the right value
193  model.to(device)
194  
195  # initialize a GradScaler. If enabled=False scaler is a no-op
196  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
197  
198  # optimizer
199  optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
200  if init_from == 'resume':
201      optimizer.load_state_dict(checkpoint['optimizer'])
202  checkpoint = None # free up memory
203  
204  # compile the model
205  if compile:
206      print("compiling the model... (takes a ~minute)")
207      unoptimized_model = model
208      model = torch.compile(model) # requires PyTorch 2.0
209  
210  # wrap model into DDP container
211  if ddp:
212      model = DDP(model, device_ids=[ddp_local_rank])
213  
214  # helps estimate an arbitrarily accurate loss over either split using many batches
215  @torch.no_grad()
216  def estimate_loss():
217      out = {}
218      model.eval()
219      for split in ['train', 'val']:
220          losses = torch.zeros(eval_iters)
221          for k in range(eval_iters):
222              X, Y = get_batch(split)
223              with ctx:
224                  logits, loss = model(X, Y)
225              losses[k] = loss.item()
226          out[split] = losses.mean()
227      model.train()
228      return out
229  
230  # learning rate decay scheduler (cosine with warmup)
231  def get_lr(it):
232      # 1) linear warmup for warmup_iters steps
233      if it < warmup_iters:
234          return learning_rate * (it + 1) / (warmup_iters + 1)
235      # 2) if it > lr_decay_iters, return min learning rate
236      if it > lr_decay_iters:
237          return min_lr
238      # 3) in between, use cosine decay down to min learning rate
239      decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
240      assert 0 <= decay_ratio <= 1
241      coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
242      return min_lr + coeff * (learning_rate - min_lr)
243  
244  # logging
245  if wandb_log and master_process:
246      import wandb
247      wandb.init(project=wandb_project, name=wandb_run_name, config=config)
248  
249  # training loop
250  X, Y = get_batch('train') # fetch the very first batch
251  t0 = time.time()
252  local_iter_num = 0 # number of iterations in the lifetime of this process
253  raw_model = model.module if ddp else model # unwrap DDP container if needed
254  running_mfu = -1.0
255  while True:
256  
257      # determine and set the learning rate for this iteration
258      lr = get_lr(iter_num) if decay_lr else learning_rate
259      for param_group in optimizer.param_groups:
260          param_group['lr'] = lr
261  
262      # evaluate the loss on train/val sets and write checkpoints
263      if iter_num % eval_interval == 0 and master_process:
264          losses = estimate_loss()
265          print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
266          if wandb_log:
267              wandb.log({
268                  "iter": iter_num,
269                  "train/loss": losses['train'],
270                  "val/loss": losses['val'],
271                  "lr": lr,
272                  "mfu": running_mfu*100, # convert to percentage
273              })
274          if losses['val'] < best_val_loss or always_save_checkpoint:
275              best_val_loss = losses['val']
276              if iter_num > 0:
277                  checkpoint = {
278                      'model': raw_model.state_dict(),
279                      'optimizer': optimizer.state_dict(),
280                      'model_args': model_args,
281                      'iter_num': iter_num,
282                      'best_val_loss': best_val_loss,
283                      'config': config,
284                  }
285                  print(f"saving checkpoint to {out_dir}")
286                  torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
287      if iter_num == 0 and eval_only:
288          break
289  
290      # forward backward update, with optional gradient accumulation to simulate larger batch size
291      # and using the GradScaler if data type is float16
292      for micro_step in range(gradient_accumulation_steps):
293          if ddp:
294              # in DDP training we only need to sync gradients at the last micro step.
295              # the official way to do this is with model.no_sync() context manager, but
296              # I really dislike that this bloats the code and forces us to repeat code
297              # looking at the source of that context manager, it just toggles this variable
298              model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
299          with ctx:
300              logits, loss = model(X, Y)
301              loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
302          # immediately async prefetch next batch while model is doing the forward pass on the GPU
303          X, Y = get_batch('train')
304          # backward pass, with gradient scaling if training in fp16
305          scaler.scale(loss).backward()
306      # clip the gradient
307      if grad_clip != 0.0:
308          scaler.unscale_(optimizer)
309          torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
310      # step the optimizer and scaler if training in fp16
311      scaler.step(optimizer)
312      scaler.update()
313      # flush the gradients as soon as we can, no need for this memory anymore
314      optimizer.zero_grad(set_to_none=True)
315  
316      # timing and logging
317      t1 = time.time()
318      dt = t1 - t0
319      t0 = t1
320      if iter_num % log_interval == 0 and master_process:
321          # get loss as float. note: this is a CPU-GPU sync point
322          # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
323          lossf = loss.item() * gradient_accumulation_steps
324          if local_iter_num >= 5: # let the training loop settle a bit
325              mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
326              running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
327          print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
328      iter_num += 1
329      local_iter_num += 1
330  
331      # termination conditions
332      if iter_num > max_iters:
333          break
334  
335  if ddp:
336      destroy_process_group()