/ model.py
model.py
  1  """
  2  Full definition of a GPT Language Model, all of it in this single file.
  3  References:
  4  1) the official GPT-2 TensorFlow implementation released by OpenAI:
  5  https://github.com/openai/gpt-2/blob/master/src/model.py
  6  2) huggingface/transformers PyTorch implementation:
  7  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
  8  """
  9  
 10  import math
 11  import inspect
 12  from dataclasses import dataclass
 13  
 14  import torch
 15  import torch.nn as nn
 16  from torch.nn import functional as F
 17  
 18  class LayerNorm(nn.Module):
 19      """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
 20  
 21      def __init__(self, ndim, bias):
 22          super().__init__()
 23          self.weight = nn.Parameter(torch.ones(ndim))
 24          self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
 25  
 26      def forward(self, input):
 27          return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
 28  
 29  class CausalSelfAttention(nn.Module):
 30  
 31      def __init__(self, config):
 32          super().__init__()
 33          assert config.n_embd % config.n_head == 0
 34          # key, query, value projections for all heads, but in a batch
 35          self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
 36          # output projection
 37          self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
 38          # regularization
 39          self.attn_dropout = nn.Dropout(config.dropout)
 40          self.resid_dropout = nn.Dropout(config.dropout)
 41          self.n_head = config.n_head
 42          self.n_embd = config.n_embd
 43          self.dropout = config.dropout
 44          # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
 45          self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
 46          if not self.flash:
 47              print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
 48              # causal mask to ensure that attention is only applied to the left in the input sequence
 49              self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
 50                                          .view(1, 1, config.block_size, config.block_size))
 51  
 52      def forward(self, x):
 53          B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
 54  
 55          # calculate query, key, values for all heads in batch and move head forward to be the batch dim
 56          q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
 57          k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 58          q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 59          v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 60  
 61          # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
 62          if self.flash:
 63              # efficient attention using Flash Attention CUDA kernels
 64              y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
 65          else:
 66              # manual implementation of attention
 67              att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
 68              att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
 69              att = F.softmax(att, dim=-1)
 70              att = self.attn_dropout(att)
 71              y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
 72          y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
 73  
 74          # output projection
 75          y = self.resid_dropout(self.c_proj(y))
 76          return y
 77  
 78  class MLP(nn.Module):
 79  
 80      def __init__(self, config):
 81          super().__init__()
 82          self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
 83          self.gelu    = nn.GELU()
 84          self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
 85          self.dropout = nn.Dropout(config.dropout)
 86  
 87      def forward(self, x):
 88          x = self.c_fc(x)
 89          x = self.gelu(x)
 90          x = self.c_proj(x)
 91          x = self.dropout(x)
 92          return x
 93  
 94  class Block(nn.Module):
 95  
 96      def __init__(self, config):
 97          super().__init__()
 98          self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
 99          self.attn = CausalSelfAttention(config)
100          self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101          self.mlp = MLP(config)
102  
103      def forward(self, x):
104          x = x + self.attn(self.ln_1(x))
105          x = x + self.mlp(self.ln_2(x))
106          return x
107  
108  @dataclass
109  class GPTConfig:
110      block_size: int = 1024
111      vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112      n_layer: int = 12
113      n_head: int = 12
114      n_embd: int = 768
115      dropout: float = 0.0
116      bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117  
118  class GPT(nn.Module):
119  
120      def __init__(self, config):
121          super().__init__()
122          assert config.vocab_size is not None
123          assert config.block_size is not None
124          self.config = config
125  
126          self.transformer = nn.ModuleDict(dict(
127              wte = nn.Embedding(config.vocab_size, config.n_embd),
128              wpe = nn.Embedding(config.block_size, config.n_embd),
129              drop = nn.Dropout(config.dropout),
130              h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131              ln_f = LayerNorm(config.n_embd, bias=config.bias),
132          ))
133          self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134          # with weight tying when using torch.compile() some warnings get generated:
135          # "UserWarning: functional_call was passed multiple values for tied weights.
136          # This behavior is deprecated and will be an error in future versions"
137          # not 100% sure what this is, so far seems to be harmless. TODO investigate
138          self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139  
140          # init all weights
141          self.apply(self._init_weights)
142          # apply special scaled init to the residual projections, per GPT-2 paper
143          for pn, p in self.named_parameters():
144              if pn.endswith('c_proj.weight'):
145                  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146  
147          # report number of parameters
148          print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149  
150      def get_num_params(self, non_embedding=True):
151          """
152          Return the number of parameters in the model.
153          For non-embedding count (default), the position embeddings get subtracted.
154          The token embeddings would too, except due to the parameter sharing these
155          params are actually used as weights in the final layer, so we include them.
156          """
157          n_params = sum(p.numel() for p in self.parameters())
158          if non_embedding:
159              n_params -= self.transformer.wpe.weight.numel()
160          return n_params
161  
162      def _init_weights(self, module):
163          if isinstance(module, nn.Linear):
164              torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165              if module.bias is not None:
166                  torch.nn.init.zeros_(module.bias)
167          elif isinstance(module, nn.Embedding):
168              torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169  
170      def forward(self, idx, targets=None):
171          device = idx.device
172          b, t = idx.size()
173          assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174          pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175  
176          # forward the GPT model itself
177          tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178          pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179          x = self.transformer.drop(tok_emb + pos_emb)
180          for block in self.transformer.h:
181              x = block(x)
182          x = self.transformer.ln_f(x)
183  
184          if targets is not None:
185              # if we are given some desired targets also calculate the loss
186              logits = self.lm_head(x)
187              loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188          else:
189              # inference-time mini-optimization: only forward the lm_head on the very last position
190              logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191              loss = None
192  
193          return logits, loss
194  
195      def crop_block_size(self, block_size):
196          # model surgery to decrease the block size if necessary
197          # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198          # but want to use a smaller block size for some smaller, simpler model
199          assert block_size <= self.config.block_size
200          self.config.block_size = block_size
201          self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202          for block in self.transformer.h:
203              if hasattr(block.attn, 'bias'):
204                  block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205  
206      @classmethod
207      def from_pretrained(cls, model_type, override_args=None):
208          assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209          override_args = override_args or {} # default to empty dict
210          # only dropout can be overridden see more notes below
211          assert all(k == 'dropout' for k in override_args)
212          from transformers import GPT2LMHeadModel
213          print("loading weights from pretrained gpt: %s" % model_type)
214  
215          # n_layer, n_head and n_embd are determined from model_type
216          config_args = {
217              'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
218              'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219              'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220              'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221          }[model_type]
222          print("forcing vocab_size=50257, block_size=1024, bias=True")
223          config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224          config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225          config_args['bias'] = True # always True for GPT model checkpoints
226          # we can override the dropout rate, if desired
227          if 'dropout' in override_args:
228              print(f"overriding dropout rate to {override_args['dropout']}")
229              config_args['dropout'] = override_args['dropout']
230          # create a from-scratch initialized minGPT model
231          config = GPTConfig(**config_args)
232          model = GPT(config)
233          sd = model.state_dict()
234          sd_keys = sd.keys()
235          sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236  
237          # init a huggingface/transformers model
238          model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239          sd_hf = model_hf.state_dict()
240  
241          # copy while ensuring all of the parameters are aligned and match in names and shapes
242          sd_keys_hf = sd_hf.keys()
243          sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244          sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245          transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246          # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247          # this means that we have to transpose these weights when we import them
248          assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249          for k in sd_keys_hf:
250              if any(k.endswith(w) for w in transposed):
251                  # special treatment for the Conv1D weights we need to transpose
252                  assert sd_hf[k].shape[::-1] == sd[k].shape
253                  with torch.no_grad():
254                      sd[k].copy_(sd_hf[k].t())
255              else:
256                  # vanilla copy over the other parameters
257                  assert sd_hf[k].shape == sd[k].shape
258                  with torch.no_grad():
259                      sd[k].copy_(sd_hf[k])
260  
261          return model
262  
263      def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264          # start with all of the candidate parameters
265          param_dict = {pn: p for pn, p in self.named_parameters()}
266          # filter out those that do not require grad
267          param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268          # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269          # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270          decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271          nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272          optim_groups = [
273              {'params': decay_params, 'weight_decay': weight_decay},
274              {'params': nodecay_params, 'weight_decay': 0.0}
275          ]
276          num_decay_params = sum(p.numel() for p in decay_params)
277          num_nodecay_params = sum(p.numel() for p in nodecay_params)
278          print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279          print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280          # Create AdamW optimizer and use the fused version if it is available
281          fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282          use_fused = fused_available and device_type == 'cuda'
283          extra_args = dict(fused=True) if use_fused else dict()
284          optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285          print(f"using fused AdamW: {use_fused}")
286  
287          return optimizer
288  
289      def estimate_mfu(self, fwdbwd_per_iter, dt):
290          """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291          # first estimate the number of flops we do per iteration.
292          # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293          N = self.get_num_params()
294          cfg = self.config
295          L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296          flops_per_token = 6*N + 12*L*H*Q*T
297          flops_per_fwdbwd = flops_per_token * T
298          flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299          # express our flops throughput as ratio of A100 bfloat16 peak flops
300          flops_achieved = flops_per_iter * (1.0/dt) # per second
301          flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302          mfu = flops_achieved / flops_promised
303          return mfu
304  
305      @torch.no_grad()
306      def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307          """
308          Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309          the sequence max_new_tokens times, feeding the predictions back into the model each time.
310          Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311          """
312          for _ in range(max_new_tokens):
313              # if the sequence context is growing too long we must crop it at block_size
314              idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315              # forward the model to get the logits for the index in the sequence
316              logits, _ = self(idx_cond)
317              # pluck the logits at the final step and scale by desired temperature
318              logits = logits[:, -1, :] / temperature
319              # optionally crop the logits to only the top k options
320              if top_k is not None:
321                  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322                  logits[logits < v[:, [-1]]] = -float('Inf')
323              # apply softmax to convert logits to (normalized) probabilities
324              probs = F.softmax(logits, dim=-1)
325              # sample from the distribution
326              idx_next = torch.multinomial(probs, num_samples=1)
327              # append sampled index to the running sequence and continue
328              idx = torch.cat((idx, idx_next), dim=1)
329  
330          return idx