/ src / model.py
model.py
  1  """
  2  Full definition of a GPT s/Language/Observation Model adapted for float inputs
  3  and with a regression loss, all within this single file.
  4  References:
  5  0) the original nanoGPT code from Andrej Karpathy:
  6  https://github.com/karpathy/nanoGPT
  7  1) the official GPT-2 TensorFlow implementation released by OpenAI:
  8  https://github.com/openai/gpt-2/blob/master/src/model.py
  9  2) huggingface/transformers PyTorch implementation:
 10  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
 11  3) Aspia Space's earthPT code:
 12  https://github.com/aspiaspace/earthpt
 13  """
 14  
 15  import math
 16  import inspect
 17  from dataclasses import dataclass
 18  
 19  import torch
 20  import torch.nn as nn
 21  from torch.nn import functional as F
 22  from einops.layers.torch import Rearrange
 23  
 24  # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
 25  def new_gelu(x):
 26      """
 27      Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
 28      Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
 29      """
 30      return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
 31  
 32  class LayerNorm(nn.Module):
 33      """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
 34  
 35      def __init__(self, ndim, bias):
 36          super().__init__()
 37          self.weight = nn.Parameter(torch.ones(ndim))
 38          self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
 39  
 40      def forward(self, input):
 41          return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
 42  
 43  class CausalSelfAttention(nn.Module):
 44  
 45      def __init__(self, config):
 46          super().__init__()
 47          assert config.n_embd % config.n_head == 0
 48          # key, query, value projections for all heads, but in a batch
 49          self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
 50          # output projection
 51          self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
 52          # regularization
 53          self.attn_dropout = nn.Dropout(config.dropout)
 54          self.resid_dropout = nn.Dropout(config.dropout)
 55          self.n_head = config.n_head
 56          self.n_embd = config.n_embd
 57          self.dropout = config.dropout
 58          # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
 59          self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
 60          if not self.flash:
 61              print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
 62              # causal mask to ensure that attention is only applied to the left in the input sequence
 63              self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
 64                                          .view(1, 1, config.block_size, config.block_size))
 65  
 66      def forward(self, x):
 67          B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
 68  
 69          # calculate query, key, values for all heads in batch and move head forward to be the batch dim
 70          q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
 71          k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 72          q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 73          v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
 74  
 75          # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
 76          if self.flash:
 77              # efficient attention using Flash Attention CUDA kernels
 78              y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
 79          else:
 80              # manual implementation of attention
 81              att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
 82              att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
 83              att = F.softmax(att, dim=-1)
 84              att = self.attn_dropout(att)
 85              y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
 86          y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
 87  
 88          # output projection
 89          y = self.resid_dropout(self.c_proj(y))
 90          return y
 91  
 92  class MLP(nn.Module):
 93  
 94      def __init__(self, config):
 95          super().__init__()
 96          self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
 97          self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
 98          self.dropout = nn.Dropout(config.dropout)
 99  
100      def forward(self, x):
101          x = self.c_fc(x)
102          x = new_gelu(x)
103          x = self.c_proj(x)
104          x = self.dropout(x)
105          return x
106  
107  class Block(nn.Module):
108  
109      def __init__(self, config):
110          super().__init__()
111          self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
112          self.attn = CausalSelfAttention(config)
113          self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
114          self.mlp = MLP(config)
115  
116      def forward(self, x):
117          x = x + self.attn(self.ln_1(x))
118          x = x + self.mlp(self.ln_2(x))
119          return x
120  
121  @dataclass
122  class GPTConfig:
123      block_size: int = 1024
124      n_layer: int = 12
125      n_head: int = 12
126      n_embd: int = 768
127      n_chan: int = 1
128      dropout: float = 0.0
129      patch_size: int = 16
130      bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
131  
132  class GPT(nn.Module):
133  
134      def __init__(self, config, master_process=True):
135          super().__init__()
136          assert config.block_size is not None
137          self.config = config
138  
139          self.transformer = nn.ModuleDict(dict(
140              wte = nn.Sequential(
141                   nn.Linear(config.patch_size*config.patch_size*config.n_chan, config.n_embd),
142                   nn.ReLU(),
143                   nn.Linear(config.n_embd, config.n_embd),
144                   nn.ReLU(),
145              ),
146              wpe = nn.Embedding(config.block_size, config.n_embd),
147              drop = nn.Dropout(config.dropout),
148              h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
149              ln_f = LayerNorm(config.n_embd, bias=config.bias),
150          ))
151          self.lm_head = nn.Sequential(
152              nn.Linear(config.n_embd, config.n_embd*2),
153              nn.ReLU(),
154              nn.Linear(config.n_embd*2, config.n_embd),
155              nn.ReLU(),
156              nn.Linear(config.n_embd, config.patch_size*config.patch_size*config.n_chan),
157          )
158          # with weight tying when using torch.compile() some warnings get generated:
159          # "UserWarning: functional_call was passed multiple values for tied weights.
160          # This behavior is deprecated and will be an error in future versions"
161          # not 100% sure what this is, so far seems to be harmless. TODO investigate
162          # TODO rethink weight tying
163          #self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
164  
165          # init all weights
166          self.apply(self._init_weights)
167          # apply special scaled init to the residual projections, per GPT-2 paper
168          for pn, p in self.named_parameters():
169              if pn.endswith('c_proj.weight'):
170                  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
171  
172          # report number of parameters
173          self.master_process = master_process
174          if self.master_process: print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
175  
176      def get_num_params(self, non_embedding=True):
177          """
178          Return the number of parameters in the model.
179          For non-embedding count (default), the position embeddings get subtracted.
180          The token embeddings would too, except due to the parameter sharing these
181          params are actually used as weights in the final layer, so we include them.
182          """
183          n_params = sum(p.numel() for p in self.parameters())
184          if non_embedding:
185              n_params -= self.transformer.wpe.weight.numel()
186          return n_params
187  
188      def _init_weights(self, module):
189          if isinstance(module, nn.Linear):
190              torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
191              if module.bias is not None:
192                  torch.nn.init.zeros_(module.bias)
193          elif isinstance(module, nn.Embedding):
194              torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
195  
196      def forward(self, idx, targets=None):
197          device = idx.device
198          b, t, c = idx.size()
199          assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
200          pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
201  
202          # forward the GPT model itself
203          tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
204          pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
205          x = self.transformer.drop(tok_emb + pos_emb)
206          for block in self.transformer.h:
207              x = block(x)
208          x = self.transformer.ln_f(x)
209  
210          if targets is not None:
211              # if we are given some desired targets also calculate the loss
212              logits = self.lm_head(x)
213              loss = F.huber_loss(logits, targets)
214          else:
215              # inference-time mini-optimization: only forward the lm_head on the very last position
216              logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
217              loss = None
218  
219          return logits, loss
220  
221      def get_embeddings(self, idx):
222          device = idx.device
223          b, t, ch = idx.size()
224          assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
225          pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
226  
227          # forward the GPT model itself
228          tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
229          pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
230          x = self.transformer.drop(tok_emb + pos_emb)
231          for block in self.transformer.h:
232              x = block(x)
233          embeddings = x
234  
235          return embeddings
236  
237      def crop_block_size(self, block_size):
238          # model surgery to decrease the block size if necessary
239          assert block_size <= self.config.block_size
240          self.config.block_size = block_size
241          self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
242          for block in self.transformer.h:
243              if hasattr(block.attn, 'bias'):
244                  block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
245  
246      def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
247          # start with all of the candidate parameters
248          param_dict = {pn: p for pn, p in self.named_parameters()}
249          # filter out those that do not require grad
250          param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
251          # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
252          # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
253          decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
254          nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
255          optim_groups = [
256              {'params': decay_params, 'weight_decay': weight_decay},
257              {'params': nodecay_params, 'weight_decay': 0.0}
258          ]
259          num_decay_params = sum(p.numel() for p in decay_params)
260          num_nodecay_params = sum(p.numel() for p in nodecay_params)
261          if self.master_process:
262              print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
263              print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
264          # Create AdamW optimizer and use the fused version if it is available
265          fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
266          use_fused = fused_available and device_type == 'cuda'
267          extra_args = dict(fused=True) if use_fused else dict()
268          optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
269          if self.master_process: print(f"using fused AdamW: {use_fused}")
270  
271          return optimizer
272  
273      def estimate_mfu(self, fwdbwd_per_iter, dt):
274          """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
275          # first estimate the number of flops we do per iteration.
276          # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
277          N = self.get_num_params()
278          cfg = self.config
279          L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
280          flops_per_token = 6*N + 12*L*H*Q*T
281          flops_per_fwdbwd = flops_per_token * T
282          flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
283          # express our flops throughput as ratio of A100 bfloat16 peak flops
284          flops_achieved = flops_per_iter * (1.0/dt) # per second
285          flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
286          mfu = flops_achieved / flops_promised
287          return mfu
288  
289      @torch.no_grad()
290      def generate(self, idx, new_tokens, temperature=0.0):
291          """
292          Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
293          the sequence new_tokens times, feeding the predictions back into the model each time.
294          Most likely you'll want to make sure to be in model.eval() mode of operation for this.
295          """
296          for i in range(new_tokens):
297              # if the sequence context is growing too long we must crop it at block_size
298              idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
299              # forward the model to get the logits for the index in the sequence
300              logits, _ = self(idx_cond)
301              idx_next = logits + (torch.randn(logits.size())*temperature).to(logits.device)
302              # append sampled index to the running sequence and continue
303              idx = torch.cat((idx, idx_next), dim=1)
304  
305          return idx
306  
307      @torch.no_grad()
308      def generate_embeddings(self, idx, average_type="mean"):
309          """
310          Take a conditioning sequence of indices idx (LongTensor of shape (b,t))
311          and get the embedding from the transformer model for that series.
312  
313          Most likely you'll want to make sure to be in model.eval() mode of
314          operation for this.
315          """
316          idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
317          if average_type == "mean":
318              # We only care about the average embedding
319              weighted_embeddings = torch.mean(self.get_embeddings(idx_cond), dim=1)
320          elif average_type == "exp_decay":
321              embeddings = self.get_embeddings(idx_cond)
322              weights = torch.logspace(0, -1, embeddings.shape[1], device=embeddings.device).unsqueeze(0).unsqueeze(-1)
323              weighted_embeddings = torch.sum(weights*embeddings, dim=1)/torch.sum(embeddings, dim=1)
324          elif average_type == "none":
325              weighted_embeddings = self.get_embeddings(idx_cond)
326          else:
327              raise NotImplementedError
328  
329          return weighted_embeddings