model.py
1 """ 2 Full definition of a GPT s/Language/Observation Model adapted for float inputs 3 and with a regression loss, all within this single file. 4 References: 5 0) the original nanoGPT code from Andrej Karpathy: 6 https://github.com/karpathy/nanoGPT 7 1) the official GPT-2 TensorFlow implementation released by OpenAI: 8 https://github.com/openai/gpt-2/blob/master/src/model.py 9 2) huggingface/transformers PyTorch implementation: 10 https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 11 3) Aspia Space's earthPT code: 12 https://github.com/aspiaspace/earthpt 13 """ 14 15 import math 16 import inspect 17 from dataclasses import dataclass 18 19 import torch 20 import torch.nn as nn 21 from torch.nn import functional as F 22 from einops.layers.torch import Rearrange 23 24 # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) 25 def new_gelu(x): 26 """ 27 Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 28 Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 29 """ 30 return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 31 32 class LayerNorm(nn.Module): 33 """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 34 35 def __init__(self, ndim, bias): 36 super().__init__() 37 self.weight = nn.Parameter(torch.ones(ndim)) 38 self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 39 40 def forward(self, input): 41 return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 42 43 class CausalSelfAttention(nn.Module): 44 45 def __init__(self, config): 46 super().__init__() 47 assert config.n_embd % config.n_head == 0 48 # key, query, value projections for all heads, but in a batch 49 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 50 # output projection 51 self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 52 # regularization 53 self.attn_dropout = nn.Dropout(config.dropout) 54 self.resid_dropout = nn.Dropout(config.dropout) 55 self.n_head = config.n_head 56 self.n_embd = config.n_embd 57 self.dropout = config.dropout 58 # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 59 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 60 if not self.flash: 61 print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 62 # causal mask to ensure that attention is only applied to the left in the input sequence 63 self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 64 .view(1, 1, config.block_size, config.block_size)) 65 66 def forward(self, x): 67 B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 68 69 # calculate query, key, values for all heads in batch and move head forward to be the batch dim 70 q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 71 k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 72 q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 73 v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 74 75 # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 76 if self.flash: 77 # efficient attention using Flash Attention CUDA kernels 78 y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 79 else: 80 # manual implementation of attention 81 att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 82 att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 83 att = F.softmax(att, dim=-1) 84 att = self.attn_dropout(att) 85 y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 86 y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 87 88 # output projection 89 y = self.resid_dropout(self.c_proj(y)) 90 return y 91 92 class MLP(nn.Module): 93 94 def __init__(self, config): 95 super().__init__() 96 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 97 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 98 self.dropout = nn.Dropout(config.dropout) 99 100 def forward(self, x): 101 x = self.c_fc(x) 102 x = new_gelu(x) 103 x = self.c_proj(x) 104 x = self.dropout(x) 105 return x 106 107 class Block(nn.Module): 108 109 def __init__(self, config): 110 super().__init__() 111 self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 112 self.attn = CausalSelfAttention(config) 113 self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 114 self.mlp = MLP(config) 115 116 def forward(self, x): 117 x = x + self.attn(self.ln_1(x)) 118 x = x + self.mlp(self.ln_2(x)) 119 return x 120 121 @dataclass 122 class GPTConfig: 123 block_size: int = 1024 124 n_layer: int = 12 125 n_head: int = 12 126 n_embd: int = 768 127 n_chan: int = 1 128 dropout: float = 0.0 129 patch_size: int = 16 130 bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 131 132 class GPT(nn.Module): 133 134 def __init__(self, config, master_process=True): 135 super().__init__() 136 assert config.block_size is not None 137 self.config = config 138 139 self.transformer = nn.ModuleDict(dict( 140 wte = nn.Sequential( 141 nn.Linear(config.patch_size*config.patch_size*config.n_chan, config.n_embd), 142 nn.ReLU(), 143 nn.Linear(config.n_embd, config.n_embd), 144 nn.ReLU(), 145 ), 146 wpe = nn.Embedding(config.block_size, config.n_embd), 147 drop = nn.Dropout(config.dropout), 148 h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 149 ln_f = LayerNorm(config.n_embd, bias=config.bias), 150 )) 151 self.lm_head = nn.Sequential( 152 nn.Linear(config.n_embd, config.n_embd*2), 153 nn.ReLU(), 154 nn.Linear(config.n_embd*2, config.n_embd), 155 nn.ReLU(), 156 nn.Linear(config.n_embd, config.patch_size*config.patch_size*config.n_chan), 157 ) 158 # with weight tying when using torch.compile() some warnings get generated: 159 # "UserWarning: functional_call was passed multiple values for tied weights. 160 # This behavior is deprecated and will be an error in future versions" 161 # not 100% sure what this is, so far seems to be harmless. TODO investigate 162 # TODO rethink weight tying 163 #self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 164 165 # init all weights 166 self.apply(self._init_weights) 167 # apply special scaled init to the residual projections, per GPT-2 paper 168 for pn, p in self.named_parameters(): 169 if pn.endswith('c_proj.weight'): 170 torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 171 172 # report number of parameters 173 self.master_process = master_process 174 if self.master_process: print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 175 176 def get_num_params(self, non_embedding=True): 177 """ 178 Return the number of parameters in the model. 179 For non-embedding count (default), the position embeddings get subtracted. 180 The token embeddings would too, except due to the parameter sharing these 181 params are actually used as weights in the final layer, so we include them. 182 """ 183 n_params = sum(p.numel() for p in self.parameters()) 184 if non_embedding: 185 n_params -= self.transformer.wpe.weight.numel() 186 return n_params 187 188 def _init_weights(self, module): 189 if isinstance(module, nn.Linear): 190 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 191 if module.bias is not None: 192 torch.nn.init.zeros_(module.bias) 193 elif isinstance(module, nn.Embedding): 194 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 195 196 def forward(self, idx, targets=None): 197 device = idx.device 198 b, t, c = idx.size() 199 assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 200 pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 201 202 # forward the GPT model itself 203 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 204 pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 205 x = self.transformer.drop(tok_emb + pos_emb) 206 for block in self.transformer.h: 207 x = block(x) 208 x = self.transformer.ln_f(x) 209 210 if targets is not None: 211 # if we are given some desired targets also calculate the loss 212 logits = self.lm_head(x) 213 loss = F.huber_loss(logits, targets) 214 else: 215 # inference-time mini-optimization: only forward the lm_head on the very last position 216 logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 217 loss = None 218 219 return logits, loss 220 221 def get_embeddings(self, idx): 222 device = idx.device 223 b, t, ch = idx.size() 224 assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 225 pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 226 227 # forward the GPT model itself 228 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 229 pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 230 x = self.transformer.drop(tok_emb + pos_emb) 231 for block in self.transformer.h: 232 x = block(x) 233 embeddings = x 234 235 return embeddings 236 237 def crop_block_size(self, block_size): 238 # model surgery to decrease the block size if necessary 239 assert block_size <= self.config.block_size 240 self.config.block_size = block_size 241 self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 242 for block in self.transformer.h: 243 if hasattr(block.attn, 'bias'): 244 block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 245 246 def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 247 # start with all of the candidate parameters 248 param_dict = {pn: p for pn, p in self.named_parameters()} 249 # filter out those that do not require grad 250 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 251 # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 252 # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 253 decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 254 nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 255 optim_groups = [ 256 {'params': decay_params, 'weight_decay': weight_decay}, 257 {'params': nodecay_params, 'weight_decay': 0.0} 258 ] 259 num_decay_params = sum(p.numel() for p in decay_params) 260 num_nodecay_params = sum(p.numel() for p in nodecay_params) 261 if self.master_process: 262 print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 263 print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 264 # Create AdamW optimizer and use the fused version if it is available 265 fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 266 use_fused = fused_available and device_type == 'cuda' 267 extra_args = dict(fused=True) if use_fused else dict() 268 optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 269 if self.master_process: print(f"using fused AdamW: {use_fused}") 270 271 return optimizer 272 273 def estimate_mfu(self, fwdbwd_per_iter, dt): 274 """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 275 # first estimate the number of flops we do per iteration. 276 # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 277 N = self.get_num_params() 278 cfg = self.config 279 L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 280 flops_per_token = 6*N + 12*L*H*Q*T 281 flops_per_fwdbwd = flops_per_token * T 282 flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 283 # express our flops throughput as ratio of A100 bfloat16 peak flops 284 flops_achieved = flops_per_iter * (1.0/dt) # per second 285 flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 286 mfu = flops_achieved / flops_promised 287 return mfu 288 289 @torch.no_grad() 290 def generate(self, idx, new_tokens, temperature=0.0): 291 """ 292 Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 293 the sequence new_tokens times, feeding the predictions back into the model each time. 294 Most likely you'll want to make sure to be in model.eval() mode of operation for this. 295 """ 296 for i in range(new_tokens): 297 # if the sequence context is growing too long we must crop it at block_size 298 idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 299 # forward the model to get the logits for the index in the sequence 300 logits, _ = self(idx_cond) 301 idx_next = logits + (torch.randn(logits.size())*temperature).to(logits.device) 302 # append sampled index to the running sequence and continue 303 idx = torch.cat((idx, idx_next), dim=1) 304 305 return idx 306 307 @torch.no_grad() 308 def generate_embeddings(self, idx, average_type="mean"): 309 """ 310 Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) 311 and get the embedding from the transformer model for that series. 312 313 Most likely you'll want to make sure to be in model.eval() mode of 314 operation for this. 315 """ 316 idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 317 if average_type == "mean": 318 # We only care about the average embedding 319 weighted_embeddings = torch.mean(self.get_embeddings(idx_cond), dim=1) 320 elif average_type == "exp_decay": 321 embeddings = self.get_embeddings(idx_cond) 322 weights = torch.logspace(0, -1, embeddings.shape[1], device=embeddings.device).unsqueeze(0).unsqueeze(-1) 323 weighted_embeddings = torch.sum(weights*embeddings, dim=1)/torch.sum(embeddings, dim=1) 324 elif average_type == "none": 325 weighted_embeddings = self.get_embeddings(idx_cond) 326 else: 327 raise NotImplementedError 328 329 return weighted_embeddings