linear_probe.py
1 """ 2 Sample from a trained astropt model and finetune embeddings on linear probes 3 """ 4 import os 5 import pickle 6 from contextlib import nullcontext 7 import torch 8 from torch.utils.data import DataLoader 9 import tiktoken 10 from tqdm import tqdm, trange 11 import matplotlib.pyplot as plt 12 import numpy as np 13 from model import GPTConfig, GPT 14 from sklearn.linear_model import LogisticRegression, LinearRegression 15 from sklearn.neural_network import MLPRegressor 16 from datasets import load_dataset, concatenate_datasets 17 from train import GalaxyImageDataset 18 from torchvision import transforms 19 from torchvision.transforms import ToTensor 20 import functools 21 from einops import rearrange 22 import pandas as pd 23 24 # ----------------------------------------------------------------------------- 25 init_from = 'resume' 26 out_dir = './logs/9B_tokens/astropt001M/' # ignored if init_from is not 'resume' 27 refresh_cache = False # resample the embeddings 28 batch_size = 256 29 seed = 1337 30 device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 31 dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16' 32 compile = False # use PyTorch 2.0 to compile the model to be faster 33 exec(open('astroPT/configurator.py').read()) # overrides from command line or config file 34 # ----------------------------------------------------------------------------- 35 36 torch.manual_seed(seed) 37 torch.cuda.manual_seed(seed) 38 torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 39 torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 40 device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 41 ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 42 ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 43 44 # model 45 if init_from == 'resume': 46 print("loading from checkpoint") 47 # init from a model saved in a specific directory 48 ckpt_path = os.path.join(out_dir, '030000_ckpt.pt') 49 checkpoint = torch.load(ckpt_path, map_location=device) 50 # TODO remove this for latest models 51 gptconf = GPTConfig(**checkpoint['model_args']) 52 model = GPT(gptconf) 53 state_dict = checkpoint['model'] 54 unwanted_prefix = '_orig_mod.' 55 for k,v in list(state_dict.items()): 56 if k.startswith(unwanted_prefix): 57 state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 58 model.load_state_dict(state_dict) 59 60 model.eval() 61 model.to(device) 62 if compile: 63 model = torch.compile(model) # requires PyTorch 2.0 (optional) 64 65 # set up HF galaxies in test set to be processed 66 def normalise(x): 67 std, mean = torch.std_mean(x, dim=1, keepdim=True) 68 return (x - mean)/(std + 1e-8) 69 def data_transforms(): 70 transform = transforms.Compose([ 71 transforms.Lambda(normalise), 72 ]) 73 return transform 74 def _process_galaxy_wrapper(gal, func): 75 gal = ToTensor()(gal["image"]).to(torch.bfloat16) 76 patch_galaxy = func(gal) 77 return {"image": patch_galaxy} 78 galproc = GalaxyImageDataset(None, spiral=True, transform=data_transforms()) 79 ds = concatenate_datasets(( 80 load_dataset("Smith42/galaxies", split="test", streaming=True), 81 load_dataset("Smith42/galaxies", split="validation", streaming=True), 82 )) 83 ds = ds.map( 84 functools.partial(_process_galaxy_wrapper, func=galproc.process_galaxy) 85 ).with_format("torch") 86 dl = iter(DataLoader( 87 ds, batch_size=batch_size, num_workers=2, 88 )) 89 90 def train_probe(zs, ys): 91 probe = LinearRegression() 92 probe.fit(zs, ys) 93 return probe 94 95 def run_probes(xs, ids, metadata): 96 mdata_fields = metadata.keys() 97 def reject_outliers(data, m=3.): 98 d = np.abs(data - np.median(data)) 99 mdev = np.median(d) 100 s = d/mdev if mdev else np.zeros(len(d)) 101 return (s < m) 102 103 losses = [] 104 fields = [] 105 for metadatum in tqdm(mdata_fields): 106 if metadatum == "g_minus_r": 107 ys = pd.Series(metadata["mag_g_desi"].to_numpy() - metadata["mag_r_desi"].to_numpy()) 108 if metadatum == "r_minus_z": 109 ys = pd.Series(metadata["mag_r_desi"].to_numpy() - metadata["mag_z_desi"].to_numpy()) 110 if ys.dtype != "object": 111 112 nonnans = (np.isfinite(ys) & np.all(np.isfinite(xs), axis=1)) 113 xs_ = xs[nonnans] 114 ys_ = ys[nonnans] 115 116 # remove outliers above 3 sigma 117 inliers = reject_outliers(ys_) 118 xs_ = xs_[inliers] 119 ys_ = ys_[inliers] 120 121 # robust normalisation 122 ys_ = (ys_ - np.median(ys_))/(np.quantile(ys_, 0.75) - np.quantile(ys_, 0.25)) 123 if np.all(np.isfinite(ys_)): 124 halfway = len(ys_)//2 125 probe = train_probe(xs_[:halfway], ys_[:halfway]) 126 losses.append( 127 (np.abs(probe.predict(xs_[halfway:]) - ys_[halfway:])).median() 128 ) 129 else: 130 losses.append( 131 np.nan 132 ) 133 fields.append(metadatum) 134 135 return fields, losses 136 137 n_tokens = 64 138 norm = "mean" 139 if (not ( 140 os.path.isfile(os.path.join(out_dir, f"zss_{n_tokens}t_{norm}.npy")) and 141 os.path.isfile(os.path.join(out_dir, f"idss_{n_tokens}t_{norm}.npy")) and 142 os.path.isfile(os.path.join(out_dir, "metadata_processed.parquet")) 143 )) or refresh_cache: 144 # run generation 145 zss = [] 146 idss = [] 147 with torch.no_grad(): 148 with ctx: 149 tt = tqdm(unit="galz", unit_scale=True, total=87000) 150 for B in dl: 151 xs = B["image"][:, :n_tokens] 152 ids = B["dr8_id"] 153 zs = model.generate_embeddings(xs.to(device), average_type=norm) 154 zss.append(zs.detach().cpu().numpy()) 155 idss.append(ids) 156 tt.update(batch_size) 157 tt.close() 158 159 zss = np.concatenate(zss, axis=0) 160 idss = np.concatenate(idss, axis=0) 161 np.save(os.path.join(out_dir, f"zss_{n_tokens}t_{norm}.npy"), zss) 162 np.save(os.path.join(out_dir, f"idss_{n_tokens}t_{norm}.npy"), idss) 163 164 print("processing metadata file") 165 metadata = pd.read_parquet("/raid/data/metadata.parquet") 166 metadata = metadata.set_index(["dr8_id"]) 167 metadata = metadata.loc[list(idss)] 168 metadata.to_parquet(os.path.join(out_dir, "metadata_processed.parquet")) 169 else: 170 print("loading from cache") 171 metadata = pd.read_parquet(os.path.join(out_dir, "metadata_processed.parquet")) 172 zss = np.load(os.path.join(out_dir, f"zss_{n_tokens}t_{norm}.npy")) 173 idss = np.load(os.path.join(out_dir, f"idss_{n_tokens}t_{norm}.npy")) 174 175 print("probing...") 176 labels, loss_zs = run_probes(zss, idss, metadata) 177 print(loss_zs) 178 file_path = f"probe_losses.txt" 179 if (not os.path.exists(file_path)) or os.path.getsize(file_path) == 0: 180 with open(file_path, 'w') as f: 181 f.write(','.join(labels) + '\n') 182 with open(file_path, "a") as f: 183 np.savetxt(f, np.array(loss_zs)[np.newaxis], delimiter=",")