/ src / linear_probe.py
linear_probe.py
  1  """
  2  Sample from a trained astropt model and finetune embeddings on linear probes 
  3  """
  4  import os
  5  import pickle
  6  from contextlib import nullcontext
  7  import torch
  8  from torch.utils.data import DataLoader
  9  import tiktoken
 10  from tqdm import tqdm, trange
 11  import matplotlib.pyplot as plt
 12  import numpy as np
 13  from model import GPTConfig, GPT
 14  from sklearn.linear_model import LogisticRegression, LinearRegression
 15  from sklearn.neural_network import MLPRegressor
 16  from datasets import load_dataset, concatenate_datasets
 17  from train import GalaxyImageDataset
 18  from torchvision import transforms
 19  from torchvision.transforms import ToTensor
 20  import functools
 21  from einops import rearrange
 22  import pandas as pd
 23  
 24  # -----------------------------------------------------------------------------
 25  init_from = 'resume'
 26  out_dir = './logs/9B_tokens/astropt001M/' # ignored if init_from is not 'resume'
 27  refresh_cache = False # resample the embeddings
 28  batch_size = 256
 29  seed = 1337
 30  device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
 31  dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
 32  compile = False # use PyTorch 2.0 to compile the model to be faster
 33  exec(open('astroPT/configurator.py').read()) # overrides from command line or config file
 34  # -----------------------------------------------------------------------------
 35  
 36  torch.manual_seed(seed)
 37  torch.cuda.manual_seed(seed)
 38  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
 39  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
 40  device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
 41  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
 42  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
 43  
 44  # model
 45  if init_from == 'resume':
 46      print("loading from checkpoint")
 47      # init from a model saved in a specific directory
 48      ckpt_path = os.path.join(out_dir, '030000_ckpt.pt')
 49      checkpoint = torch.load(ckpt_path, map_location=device)
 50      # TODO remove this for latest models
 51      gptconf = GPTConfig(**checkpoint['model_args'])
 52      model = GPT(gptconf)
 53      state_dict = checkpoint['model']
 54      unwanted_prefix = '_orig_mod.'
 55      for k,v in list(state_dict.items()):
 56          if k.startswith(unwanted_prefix):
 57              state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
 58      model.load_state_dict(state_dict)
 59  
 60  model.eval()
 61  model.to(device)
 62  if compile:
 63      model = torch.compile(model) # requires PyTorch 2.0 (optional)
 64  
 65  # set up HF galaxies in test set to be processed
 66  def normalise(x):
 67      std, mean = torch.std_mean(x, dim=1, keepdim=True)
 68      return (x - mean)/(std + 1e-8)
 69  def data_transforms():
 70      transform = transforms.Compose([
 71          transforms.Lambda(normalise),
 72      ])
 73      return transform
 74  def _process_galaxy_wrapper(gal, func):
 75      gal = ToTensor()(gal["image"]).to(torch.bfloat16)
 76      patch_galaxy = func(gal)
 77      return {"image": patch_galaxy}
 78  galproc = GalaxyImageDataset(None, spiral=True, transform=data_transforms())
 79  ds = concatenate_datasets(( 
 80      load_dataset("Smith42/galaxies", split="test", streaming=True),
 81      load_dataset("Smith42/galaxies", split="validation", streaming=True),
 82  ))
 83  ds = ds.map(
 84      functools.partial(_process_galaxy_wrapper, func=galproc.process_galaxy)
 85  ).with_format("torch")
 86  dl = iter(DataLoader(
 87      ds, batch_size=batch_size, num_workers=2,
 88  ))
 89  
 90  def train_probe(zs, ys):
 91      probe = LinearRegression()
 92      probe.fit(zs, ys)
 93      return probe
 94  
 95  def run_probes(xs, ids, metadata):
 96      mdata_fields = metadata.keys()
 97      def reject_outliers(data, m=3.):
 98          d = np.abs(data - np.median(data))
 99          mdev = np.median(d)
100          s = d/mdev if mdev else np.zeros(len(d))
101          return (s < m)
102  
103      losses = []
104      fields = []
105      for metadatum in tqdm(mdata_fields):
106          if metadatum == "g_minus_r":
107              ys = pd.Series(metadata["mag_g_desi"].to_numpy() - metadata["mag_r_desi"].to_numpy())
108          if metadatum == "r_minus_z":
109              ys = pd.Series(metadata["mag_r_desi"].to_numpy() - metadata["mag_z_desi"].to_numpy())
110          if ys.dtype != "object":
111  
112              nonnans = (np.isfinite(ys) & np.all(np.isfinite(xs), axis=1))
113              xs_ = xs[nonnans]
114              ys_ = ys[nonnans]
115  
116              # remove outliers above 3 sigma
117              inliers = reject_outliers(ys_)
118              xs_ = xs_[inliers]
119              ys_ = ys_[inliers]
120  
121              # robust normalisation
122              ys_ = (ys_ - np.median(ys_))/(np.quantile(ys_, 0.75) - np.quantile(ys_, 0.25))
123              if np.all(np.isfinite(ys_)):
124                  halfway = len(ys_)//2
125                  probe = train_probe(xs_[:halfway], ys_[:halfway])
126                  losses.append(
127                    (np.abs(probe.predict(xs_[halfway:]) - ys_[halfway:])).median()
128                  )
129              else:
130                  losses.append(
131                      np.nan
132                  )
133              fields.append(metadatum)
134  
135      return fields, losses
136  
137  n_tokens = 64
138  norm = "mean"
139  if (not (
140          os.path.isfile(os.path.join(out_dir, f"zss_{n_tokens}t_{norm}.npy")) and 
141          os.path.isfile(os.path.join(out_dir, f"idss_{n_tokens}t_{norm}.npy")) and
142          os.path.isfile(os.path.join(out_dir, "metadata_processed.parquet"))
143     )) or refresh_cache:
144      # run generation
145      zss = []
146      idss = []
147      with torch.no_grad():
148          with ctx:
149              tt = tqdm(unit="galz", unit_scale=True, total=87000)
150              for B in dl:
151                  xs = B["image"][:, :n_tokens]
152                  ids = B["dr8_id"]
153                  zs = model.generate_embeddings(xs.to(device), average_type=norm)
154                  zss.append(zs.detach().cpu().numpy())
155                  idss.append(ids)
156                  tt.update(batch_size)
157              tt.close()
158  
159      zss = np.concatenate(zss, axis=0)
160      idss = np.concatenate(idss, axis=0)
161      np.save(os.path.join(out_dir, f"zss_{n_tokens}t_{norm}.npy"), zss)
162      np.save(os.path.join(out_dir, f"idss_{n_tokens}t_{norm}.npy"), idss)
163  
164      print("processing metadata file")
165      metadata = pd.read_parquet("/raid/data/metadata.parquet")
166      metadata = metadata.set_index(["dr8_id"])
167      metadata = metadata.loc[list(idss)]
168      metadata.to_parquet(os.path.join(out_dir, "metadata_processed.parquet"))
169  else:
170      print("loading from cache")
171      metadata = pd.read_parquet(os.path.join(out_dir, "metadata_processed.parquet"))
172      zss = np.load(os.path.join(out_dir, f"zss_{n_tokens}t_{norm}.npy"))
173      idss = np.load(os.path.join(out_dir, f"idss_{n_tokens}t_{norm}.npy"))
174  
175  print("probing...")
176  labels, loss_zs = run_probes(zss, idss, metadata)
177  print(loss_zs)
178  file_path = f"probe_losses.txt"
179  if (not os.path.exists(file_path)) or os.path.getsize(file_path) == 0:
180      with open(file_path, 'w') as f:
181          f.write(','.join(labels) + '\n')
182  with open(file_path, "a") as f:
183      np.savetxt(f, np.array(loss_zs)[np.newaxis], delimiter=",")