/ demo.py
demo.py
1 2 """DeMo: Decoupled Momentum Optimization 3 4 This implements the DeMo fused optimizer and data parallel algorithm. 5 It is recommended to use DeMo as the base data parallelism. 6 In an exisiting codebase that uses PyTorch DDP, wrap your forward-backward in 7 `torch.distributed.DistributedDataParallel.no_sync` to disable external gradient synchronization. 8 See https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync 9 """ 10 11 import math 12 import torch 13 import torch.fft 14 import torch.distributed as dist 15 16 from einops import rearrange 17 from typing import Optional, Callable 18 19 class DeMo(torch.optim.SGD): 20 def __init__( 21 self, 22 params, 23 compression_decay: float = 0.999, 24 compression_topk: int = 32, 25 compression_chunk: int = 64, 26 weight_decay: float = 0.0, 27 process_group: Optional[dist.ProcessGroup] = None, 28 **kwargs, 29 ): 30 super().__init__( 31 params, 32 foreach=False, 33 momentum=0.0, 34 dampening=0.0, 35 nesterov=False, 36 maximize=False, 37 weight_decay=0.0, 38 **kwargs, 39 ) 40 41 self.compression_decay = compression_decay 42 self.compression_chunk = compression_chunk 43 self.compression_topk = compression_topk 44 self.process_group = process_group 45 self.weight_decay = weight_decay 46 47 if self.compression_topk <= 0: 48 raise ValueError("topk_size has to be positive") 49 if self.compression_chunk <= 0: 50 raise ValueError("chunk_size has to be positive") 51 if self.compression_decay < 0: 52 raise ValueError("Negative compression_decay is currently not supported") 53 if self.compression_decay >= 1: 54 raise ValueError("Values of compression_decay bigger or equal to 1.0 is currently not supported") 55 56 self.demo_state = {} 57 self._init_demo_states() 58 self._init_opt_parameters() 59 60 self.default_dtype = self._find_dtype() 61 self.transform = TransformDCT(self.param_groups, self.compression_chunk) 62 self.compress = CompressDCT() 63 64 def _find_dtype(self): 65 for group in self.param_groups: 66 for p in group["params"]: 67 if p.requires_grad: 68 return p.dtype 69 return torch.float32 70 71 def _init_demo_states(self): 72 for group in self.param_groups: 73 for p in group["params"]: 74 if p.requires_grad: 75 self.demo_state[p] = {} 76 77 def _state_parameter(self, p): 78 if p not in self.demo_state: 79 self.demo_state[p] = {} 80 return self.demo_state[p] 81 82 def _init_opt_parameters(self): 83 for group in self.param_groups: 84 for p in group["params"]: 85 if p.requires_grad: 86 state = self._state_parameter(p) 87 88 state["step"] = 0 89 state["delta"] = torch.zeros_like(p) 90 91 def _demo_all_gather(self, sparse_idx, sparse_val): 92 world_size = dist.get_world_size() if self.process_group is None else self.process_group.size() 93 94 # Gather all the idx and vals 95 sparse_idx_list = [torch.zeros_like(sparse_idx) for wi in range(world_size)] 96 sparse_val_list = [torch.zeros_like(sparse_val) for wi in range(world_size)] 97 98 sparse_idx_handle = dist.all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True) 99 sparse_val_handle = dist.all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True) 100 101 sparse_idx_handle.wait() 102 sparse_val_handle.wait() 103 104 return sparse_idx_list, sparse_val_list 105 106 107 @torch.no_grad() 108 def step(self, closure: Callable | None = None): 109 110 self.data_transmit = 0 111 self.data_receive = 0 112 113 for group in self.param_groups: 114 lr = group["lr"] 115 for p in group["params"]: 116 if not p.requires_grad: 117 continue 118 state = self._state_parameter(p) 119 120 # Update step 121 state["step"] += 1 122 123 # Step-Weight decay 124 if self.weight_decay != 0.0: 125 p.data.mul_(1.0 - lr * self.weight_decay) 126 127 # Decay delta 128 if self.compression_decay != 1: 129 state["delta"].mul_(self.compression_decay) 130 131 # Add delta to new gradient 132 state["delta"].add_(p.grad, alpha=lr) 133 134 # Compress delta 135 sparse_idx, sparse_val, xshape, totalk = self.compress.compress( 136 self.transform.encode(state["delta"]), self.compression_topk 137 ) 138 139 # Estimate transmitted delta 140 transmit_grad = self.transform.decode( 141 self.compress.decompress(p, sparse_idx, sparse_val, xshape, totalk) 142 ) 143 144 # Remove transmitted from delta 145 state["delta"].sub_(transmit_grad) 146 147 # All-gather 148 sparse_idx_gather, sparse_val_gather = self._demo_all_gather(sparse_idx, sparse_val) 149 150 # Log I/O data size 151 self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes 152 for si, v in zip(sparse_idx_gather, sparse_val_gather): 153 self.data_receive += si.nbytes + v.nbytes 154 155 # Decode grad from all nodes 156 new_grad = self.transform.decode( 157 self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, xshape, totalk) 158 ) 159 160 # Set grad to values 161 if p.grad is None: 162 p.grad = new_grad 163 else: 164 p.grad.copy_(new_grad) 165 166 # Sign-SGD 167 p.grad.sign_() 168 169 # SGD step 170 return super().step(closure) 171 172 class TransformDCT: 173 @torch.no_grad() 174 def __init__(self, param_groups, target_chunk, norm="ortho"): 175 self.target_chunk = target_chunk 176 177 self.shape_dict = dict() 178 self.f_dict = dict() 179 self.b_dict = dict() 180 181 # Get all variants of model tensor sizes 182 # Generate all possible valid DCT sizes for model tensors 183 for group in param_groups: 184 for p in group["params"]: 185 if not p.requires_grad: 186 continue 187 for s in p.shape: 188 # Get the closest smallest divisor to the targeted DCT size 189 sc = _get_smaller_split(s, self.target_chunk) 190 self.shape_dict[s] = sc 191 192 # Pregenerate DCT basis matrices 193 if sc not in self.f_dict: 194 I = torch.eye(sc) 195 self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device) 196 self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device) 197 198 @torch.no_grad() 199 def einsum_2d(self, x, b, d=None): 200 if d is None: 201 return torch.einsum("...ij, jb -> ...ib", x, b) 202 else: 203 # Note: b-c axis output is transposed to chunk DCT in 2D 204 return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d) 205 206 @torch.no_grad() 207 def einsum_2d_t(self, x, b, d=None): 208 if d is None: 209 return torch.einsum("...ij, jb -> ...ib", x, b) 210 else: 211 # Note: b-c axis output is transposed to chunk DCT in 2D 212 return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d) 213 214 @torch.no_grad() 215 def encode(self, x): 216 if len(x.shape) > 1: # 2D weights 217 n1 = self.shape_dict[x.shape[0]] 218 n2 = self.shape_dict[x.shape[1]] 219 n1w = self.f_dict[n1].to(x.device) 220 n2w = self.f_dict[n2].to(x.device) 221 self.f_dict[n1] = n1w 222 self.f_dict[n2] = n2w 223 224 x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2) 225 x = self.einsum_2d(x, n1w, n2w) 226 227 else: # 1D weights 228 n1 = self.shape_dict[x.shape[0]] 229 n1w = self.f_dict[n1].to(x.device) 230 self.f_dict[n1] = n1w 231 232 x = rearrange(x, "(x w) -> x w", w=n1) 233 x = self.einsum_2d(x, n1w) 234 235 return x 236 237 @torch.no_grad() 238 def decode(self, x): 239 if len(x.shape) > 2: # 2D weights 240 n1 = x.shape[2] 241 n2 = x.shape[3] 242 n1w = self.b_dict[n1].to(x.device) 243 n2w = self.b_dict[n2].to(x.device) 244 self.b_dict[n1] = n1w 245 self.b_dict[n2] = n2w 246 247 x = self.einsum_2d_t(x, n1w, n2w) 248 x = rearrange(x, "y h x w -> (y h) (x w)") 249 250 else: # 1D weights 251 n1 = x.shape[1] 252 n1w = self.b_dict[n1].to(x.device) 253 self.b_dict[n1] = n1w 254 255 x = self.einsum_2d_t(x, n1w) 256 x = rearrange(x, "x w -> (x w)") 257 258 return x 259 260 261 class CompressDCT: 262 @torch.no_grad() 263 def __init__(self): 264 pass 265 266 def _clamp_topk(self, x, topk): 267 if topk > x.shape[-1]: 268 topk = x.shape[-1] 269 if topk < 1: 270 topk = 1 271 return topk 272 273 @torch.no_grad() 274 def compress(self, x, topk): 275 xshape = x.shape 276 if len(x.shape) > 2: # 2D weights 277 x = rearrange(x, "y x h w -> y x (h w)") 278 279 # Limit topk to max size 280 totalk = x.shape[-1] 281 topk = self._clamp_topk(x, topk) 282 283 idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False).indices 284 val = torch.gather(x, dim=-1, index=idx) 285 286 return idx, val, xshape, totalk 287 288 @torch.no_grad() 289 def decompress(self, p, idx, val, xshape, totalk): 290 x = torch.zeros(xshape, device=p.device, dtype=p.dtype) 291 292 if len(xshape) > 2: # 2D weights 293 x = rearrange(x, "y x h w -> y x (h w)") 294 295 # TODO: Careful, this is nondeterministic across different CUDA devices! might cause errors to accumulate between nodes! 296 x.scatter_reduce_(dim=-1, index=idx, src=val, reduce="mean", include_self=False).reshape(xshape) 297 298 if len(x.shape) > 2: # 2D weights 299 x = rearrange(x, "y x (h w) -> y x h w", h=xshape[2]) 300 301 return x 302 303 @torch.no_grad() 304 def batch_decompress(self, p, idx, val, xshape, totalk): 305 idx = torch.concatenate(idx, dim=-1).to(device=p.device) 306 val = torch.concatenate(val, dim=-1).to(device=p.device) 307 return self.decompress(p, idx, val, xshape, totalk) 308 309 310 # Code modified and sourced from https://github.com/zh217/torch-dct 311 def _dct_fft_impl(v): 312 return torch.view_as_real(torch.fft.fft(v, dim=1)) 313 314 315 def _idct_irfft_impl(V): 316 return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 317 318 319 def _dct(x, norm=None): 320 """ 321 Discrete Cosine Transform, Type II (a.k.a. the DCT) 322 323 For the meaning of the parameter `norm`, see: 324 https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 325 326 :param x: the input signal 327 :param norm: the normalization, None or 'ortho' 328 :return: the DCT-II of the signal over the last dimension 329 """ 330 x_shape = x.shape 331 N = x_shape[-1] 332 x = x.contiguous().view(-1, N) 333 334 v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 335 336 Vc = _dct_fft_impl(v) 337 338 k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N) 339 W_r = torch.cos(k) 340 W_i = torch.sin(k) 341 342 V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 343 344 if norm == "ortho": 345 V[:, 0] /= math.sqrt(N) * 2 346 V[:, 1:] /= math.sqrt(N / 2) * 2 347 348 V = 2 * V.view(*x_shape) 349 350 return V 351 352 353 def _idct(X, norm=None): 354 """ 355 The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 356 357 Our definition of idct is that idct(dct(x)) == x 358 359 For the meaning of the parameter `norm`, see: 360 https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 361 362 :param X: the input signal 363 :param norm: the normalization, None or 'ortho' 364 :return: the inverse DCT-II of the signal over the last dimension 365 """ 366 367 x_shape = X.shape 368 N = x_shape[-1] 369 370 X_v = X.contiguous().view(-1, x_shape[-1]) / 2 371 372 if norm == "ortho": 373 X_v[:, 0] *= math.sqrt(N) * 2 374 X_v[:, 1:] *= math.sqrt(N / 2) * 2 375 376 k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * math.pi / (2 * N) 377 W_r = torch.cos(k) 378 W_i = torch.sin(k) 379 380 V_t_r = X_v 381 V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 382 383 V_r = V_t_r * W_r - V_t_i * W_i 384 V_i = V_t_r * W_i + V_t_i * W_r 385 386 V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 387 388 v = _idct_irfft_impl(V) 389 x = v.new_zeros(v.shape) 390 x[:, ::2] += v[:, : N - (N // 2)] 391 x[:, 1::2] += v.flip([1])[:, : N // 2] 392 393 return x.view(*x_shape) 394 395 396 def _get_prime_divisors(n): 397 divisors = [] 398 while n % 2 == 0: 399 divisors.append(2) 400 n //= 2 401 while n % 3 == 0: 402 divisors.append(3) 403 n //= 3 404 i = 5 405 while i * i <= n: 406 for k in (i, i + 2): 407 while n % k == 0: 408 divisors.append(k) 409 n //= k 410 i += 6 411 if n > 1: 412 divisors.append(n) 413 return divisors 414 415 416 def _get_divisors(n): 417 divisors = [] 418 if n == 1: 419 divisors.append(1) 420 elif n > 1: 421 prime_factors = _get_prime_divisors(n) 422 divisors = [1] 423 last_prime = 0 424 factor = 0 425 slice_len = 0 426 # Find all the products that are divisors of n 427 for prime in prime_factors: 428 if last_prime != prime: 429 slice_len = len(divisors) 430 factor = prime 431 else: 432 factor *= prime 433 for i in range(slice_len): 434 divisors.append(divisors[i] * factor) 435 last_prime = prime 436 divisors.sort() 437 return divisors 438 439 440 def _get_smaller_split(n, close_to): 441 all_divisors = _get_divisors(n) 442 for ix, val in enumerate(all_divisors): 443 if val == close_to: 444 return val 445 if val > close_to: 446 if ix == 0: 447 return val 448 return all_divisors[ix - 1] 449 return n