/ model.py
model.py
1 """ 2 Full definition of a GPT Language Model, all of it in this single file. 3 References: 4 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 https://github.com/openai/gpt-2/blob/master/src/model.py 6 2) huggingface/transformers PyTorch implementation: 7 https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 """ 9 10 import math 11 import inspect 12 from dataclasses import dataclass 13 14 import torch 15 import torch.nn as nn 16 from torch.nn import functional as F 17 18 class LayerNorm(nn.Module): 19 """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 21 def __init__(self, ndim, bias): 22 super().__init__() 23 self.weight = nn.Parameter(torch.ones(ndim)) 24 self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 26 def forward(self, input): 27 return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 29 class CausalSelfAttention(nn.Module): 30 31 def __init__(self, config): 32 super().__init__() 33 assert config.n_embd % config.n_head == 0 34 # key, query, value projections for all heads, but in a batch 35 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 # output projection 37 self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 # regularization 39 self.attn_dropout = nn.Dropout(config.dropout) 40 self.resid_dropout = nn.Dropout(config.dropout) 41 self.n_head = config.n_head 42 self.n_embd = config.n_embd 43 self.dropout = config.dropout 44 # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 46 if not self.flash: 47 print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 48 # causal mask to ensure that attention is only applied to the left in the input sequence 49 self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 50 .view(1, 1, config.block_size, config.block_size)) 51 52 def forward(self, x): 53 B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 54 55 # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 57 k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 61 # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 62 if self.flash: 63 # efficient attention using Flash Attention CUDA kernels 64 y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 65 else: 66 # manual implementation of attention 67 att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 68 att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 69 att = F.softmax(att, dim=-1) 70 att = self.attn_dropout(att) 71 y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 72 y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 73 74 # output projection 75 y = self.resid_dropout(self.c_proj(y)) 76 return y 77 78 class MLP(nn.Module): 79 80 def __init__(self, config): 81 super().__init__() 82 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 83 self.gelu = nn.GELU() 84 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 85 self.dropout = nn.Dropout(config.dropout) 86 87 def forward(self, x): 88 x = self.c_fc(x) 89 x = self.gelu(x) 90 x = self.c_proj(x) 91 x = self.dropout(x) 92 return x 93 94 class Block(nn.Module): 95 96 def __init__(self, config): 97 super().__init__() 98 self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 99 self.attn = CausalSelfAttention(config) 100 self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 101 self.mlp = MLP(config) 102 103 def forward(self, x): 104 x = x + self.attn(self.ln_1(x)) 105 x = x + self.mlp(self.ln_2(x)) 106 return x 107 108 @dataclass 109 class GPTConfig: 110 block_size: int = 1024 111 vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 112 n_layer: int = 12 113 n_head: int = 12 114 n_embd: int = 768 115 dropout: float = 0.0 116 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 117 118 class GPT(nn.Module): 119 120 def __init__(self, config): 121 super().__init__() 122 assert config.vocab_size is not None 123 assert config.block_size is not None 124 self.config = config 125 126 self.transformer = nn.ModuleDict(dict( 127 wte = nn.Embedding(config.vocab_size, config.n_embd), 128 wpe = nn.Embedding(config.block_size, config.n_embd), 129 drop = nn.Dropout(config.dropout), 130 h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 131 ln_f = LayerNorm(config.n_embd, bias=config.bias), 132 )) 133 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 # with weight tying when using torch.compile() some warnings get generated: 135 # "UserWarning: functional_call was passed multiple values for tied weights. 136 # This behavior is deprecated and will be an error in future versions" 137 # not 100% sure what this is, so far seems to be harmless. TODO investigate 138 self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 139 140 # init all weights 141 self.apply(self._init_weights) 142 # apply special scaled init to the residual projections, per GPT-2 paper 143 for pn, p in self.named_parameters(): 144 if pn.endswith('c_proj.weight'): 145 torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 146 147 # report number of parameters 148 print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 149 150 def get_num_params(self, non_embedding=True): 151 """ 152 Return the number of parameters in the model. 153 For non-embedding count (default), the position embeddings get subtracted. 154 The token embeddings would too, except due to the parameter sharing these 155 params are actually used as weights in the final layer, so we include them. 156 """ 157 n_params = sum(p.numel() for p in self.parameters()) 158 if non_embedding: 159 n_params -= self.transformer.wpe.weight.numel() 160 return n_params 161 162 def _init_weights(self, module): 163 if isinstance(module, nn.Linear): 164 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 165 if module.bias is not None: 166 torch.nn.init.zeros_(module.bias) 167 elif isinstance(module, nn.Embedding): 168 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 169 170 def forward(self, idx, targets=None): 171 device = idx.device 172 b, t = idx.size() 173 assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 174 pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 175 176 # forward the GPT model itself 177 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 178 pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 179 x = self.transformer.drop(tok_emb + pos_emb) 180 for block in self.transformer.h: 181 x = block(x) 182 x = self.transformer.ln_f(x) 183 184 if targets is not None: 185 # if we are given some desired targets also calculate the loss 186 logits = self.lm_head(x) 187 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 188 else: 189 # inference-time mini-optimization: only forward the lm_head on the very last position 190 logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 191 loss = None 192 193 return logits, loss 194 195 def crop_block_size(self, block_size): 196 # model surgery to decrease the block size if necessary 197 # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 198 # but want to use a smaller block size for some smaller, simpler model 199 assert block_size <= self.config.block_size 200 self.config.block_size = block_size 201 self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 202 for block in self.transformer.h: 203 if hasattr(block.attn, 'bias'): 204 block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 205 206 @classmethod 207 def from_pretrained(cls, model_type, override_args=None): 208 assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 209 override_args = override_args or {} # default to empty dict 210 # only dropout can be overridden see more notes below 211 assert all(k == 'dropout' for k in override_args) 212 from transformers import GPT2LMHeadModel 213 print("loading weights from pretrained gpt: %s" % model_type) 214 215 # n_layer, n_head and n_embd are determined from model_type 216 config_args = { 217 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 218 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 219 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 220 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 221 }[model_type] 222 print("forcing vocab_size=50257, block_size=1024, bias=True") 223 config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 224 config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 225 config_args['bias'] = True # always True for GPT model checkpoints 226 # we can override the dropout rate, if desired 227 if 'dropout' in override_args: 228 print(f"overriding dropout rate to {override_args['dropout']}") 229 config_args['dropout'] = override_args['dropout'] 230 # create a from-scratch initialized minGPT model 231 config = GPTConfig(**config_args) 232 model = GPT(config) 233 sd = model.state_dict() 234 sd_keys = sd.keys() 235 sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 236 237 # init a huggingface/transformers model 238 model_hf = GPT2LMHeadModel.from_pretrained(model_type) 239 sd_hf = model_hf.state_dict() 240 241 # copy while ensuring all of the parameters are aligned and match in names and shapes 242 sd_keys_hf = sd_hf.keys() 243 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 244 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 245 transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 246 # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 247 # this means that we have to transpose these weights when we import them 248 assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 249 for k in sd_keys_hf: 250 if any(k.endswith(w) for w in transposed): 251 # special treatment for the Conv1D weights we need to transpose 252 assert sd_hf[k].shape[::-1] == sd[k].shape 253 with torch.no_grad(): 254 sd[k].copy_(sd_hf[k].t()) 255 else: 256 # vanilla copy over the other parameters 257 assert sd_hf[k].shape == sd[k].shape 258 with torch.no_grad(): 259 sd[k].copy_(sd_hf[k]) 260 261 return model 262 263 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 264 # start with all of the candidate parameters 265 param_dict = {pn: p for pn, p in self.named_parameters()} 266 # filter out those that do not require grad 267 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 268 # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 269 # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 270 decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 271 nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 272 optim_groups = [ 273 {'params': decay_params, 'weight_decay': weight_decay}, 274 {'params': nodecay_params, 'weight_decay': 0.0} 275 ] 276 num_decay_params = sum(p.numel() for p in decay_params) 277 num_nodecay_params = sum(p.numel() for p in nodecay_params) 278 print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 279 print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 280 # Create AdamW optimizer and use the fused version if it is available 281 fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 282 use_fused = fused_available and device_type == 'cuda' 283 extra_args = dict(fused=True) if use_fused else dict() 284 optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 285 print(f"using fused AdamW: {use_fused}") 286 287 return optimizer 288 289 def estimate_mfu(self, fwdbwd_per_iter, dt): 290 """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 291 # first estimate the number of flops we do per iteration. 292 # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 293 N = self.get_num_params() 294 cfg = self.config 295 L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 296 flops_per_token = 6*N + 12*L*H*Q*T 297 flops_per_fwdbwd = flops_per_token * T 298 flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 299 # express our flops throughput as ratio of A100 bfloat16 peak flops 300 flops_achieved = flops_per_iter * (1.0/dt) # per second 301 flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 302 mfu = flops_achieved / flops_promised 303 return mfu 304 305 @torch.no_grad() 306 def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 307 """ 308 Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 309 the sequence max_new_tokens times, feeding the predictions back into the model each time. 310 Most likely you'll want to make sure to be in model.eval() mode of operation for this. 311 """ 312 for _ in range(max_new_tokens): 313 # if the sequence context is growing too long we must crop it at block_size 314 idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 315 # forward the model to get the logits for the index in the sequence 316 logits, _ = self(idx_cond) 317 # pluck the logits at the final step and scale by desired temperature 318 logits = logits[:, -1, :] / temperature 319 # optionally crop the logits to only the top k options 320 if top_k is not None: 321 v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 322 logits[logits < v[:, [-1]]] = -float('Inf') 323 # apply softmax to convert logits to (normalized) probabilities 324 probs = F.softmax(logits, dim=-1) 325 # sample from the distribution 326 idx_next = torch.multinomial(probs, num_samples=1) 327 # append sampled index to the running sequence and continue 328 idx = torch.cat((idx, idx_next), dim=1) 329 330 return idx