/ demo.py
demo.py
  1  
  2  """DeMo: Decoupled Momentum Optimization
  3  
  4  This implements the DeMo fused optimizer and data parallel algorithm.
  5  It is recommended to use DeMo as the base data parallelism.
  6  In an exisiting codebase that uses PyTorch DDP, wrap your forward-backward in 
  7  `torch.distributed.DistributedDataParallel.no_sync` to disable external gradient synchronization.
  8  See https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync
  9  """
 10  
 11  import math
 12  import torch
 13  import torch.fft
 14  import torch.distributed as dist
 15  
 16  from einops import rearrange
 17  from typing import Optional, Callable
 18  
 19  class DeMo(torch.optim.SGD):
 20      def __init__(
 21          self,
 22          params,
 23          compression_decay: float = 0.999,
 24          compression_topk: int = 32,
 25          compression_chunk: int = 64,
 26          weight_decay: float = 0.0,
 27          process_group: Optional[dist.ProcessGroup] = None,
 28          **kwargs,
 29      ):
 30          super().__init__(
 31              params,
 32              foreach=False,
 33              momentum=0.0,
 34              dampening=0.0,
 35              nesterov=False,
 36              maximize=False,
 37              weight_decay=0.0,
 38              **kwargs,
 39          )
 40  
 41          self.compression_decay = compression_decay
 42          self.compression_chunk = compression_chunk
 43          self.compression_topk = compression_topk
 44          self.process_group = process_group
 45          self.weight_decay = weight_decay
 46  
 47          if self.compression_topk <= 0:
 48              raise ValueError("topk_size has to be positive")
 49          if self.compression_chunk <= 0:
 50              raise ValueError("chunk_size has to be positive")
 51          if self.compression_decay < 0:
 52              raise ValueError("Negative compression_decay is currently not supported")
 53          if self.compression_decay >= 1:
 54              raise ValueError("Values of compression_decay bigger or equal to 1.0 is currently not supported")
 55  
 56          self.demo_state = {}
 57          self._init_demo_states()
 58          self._init_opt_parameters()
 59  
 60          self.default_dtype = self._find_dtype()
 61          self.transform = TransformDCT(self.param_groups, self.compression_chunk)
 62          self.compress = CompressDCT()
 63  
 64      def _find_dtype(self):
 65          for group in self.param_groups:
 66              for p in group["params"]:
 67                  if p.requires_grad:
 68                      return p.dtype
 69          return torch.float32
 70  
 71      def _init_demo_states(self):
 72          for group in self.param_groups:
 73              for p in group["params"]:
 74                  if p.requires_grad:
 75                      self.demo_state[p] = {}
 76  
 77      def _state_parameter(self, p):
 78          if p not in self.demo_state:
 79              self.demo_state[p] = {}
 80          return self.demo_state[p]
 81  
 82      def _init_opt_parameters(self):
 83          for group in self.param_groups:
 84              for p in group["params"]:
 85                  if p.requires_grad:
 86                      state = self._state_parameter(p)
 87  
 88                      state["step"] = 0
 89                      state["delta"] = torch.zeros_like(p)
 90  
 91      def _demo_all_gather(self, sparse_idx, sparse_val):
 92          world_size = dist.get_world_size() if self.process_group is None else self.process_group.size()
 93  
 94          # Gather all the idx and vals
 95          sparse_idx_list = [torch.zeros_like(sparse_idx) for wi in range(world_size)]
 96          sparse_val_list = [torch.zeros_like(sparse_val) for wi in range(world_size)]
 97  
 98          sparse_idx_handle = dist.all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True)
 99          sparse_val_handle = dist.all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True)
100  
101          sparse_idx_handle.wait()
102          sparse_val_handle.wait()
103  
104          return sparse_idx_list, sparse_val_list
105  
106  
107      @torch.no_grad()
108      def step(self, closure: Callable | None = None):
109  
110          self.data_transmit = 0
111          self.data_receive = 0
112  
113          for group in self.param_groups:
114              lr = group["lr"]
115              for p in group["params"]:
116                  if not p.requires_grad:
117                      continue
118                  state = self._state_parameter(p)
119  
120                  # Update step
121                  state["step"] += 1
122  
123                  # Step-Weight decay
124                  if self.weight_decay != 0.0:
125                      p.data.mul_(1.0 - lr * self.weight_decay)
126  
127                  # Decay delta
128                  if self.compression_decay != 1:
129                      state["delta"].mul_(self.compression_decay)
130  
131                  # Add delta to new gradient
132                  state["delta"].add_(p.grad, alpha=lr)
133  
134                  # Compress delta
135                  sparse_idx, sparse_val, xshape, totalk = self.compress.compress(
136                      self.transform.encode(state["delta"]), self.compression_topk
137                  )
138  
139                  # Estimate transmitted delta
140                  transmit_grad = self.transform.decode(
141                      self.compress.decompress(p, sparse_idx, sparse_val, xshape, totalk)
142                  )
143  
144                  # Remove transmitted from delta
145                  state["delta"].sub_(transmit_grad)
146  
147                  # All-gather
148                  sparse_idx_gather, sparse_val_gather = self._demo_all_gather(sparse_idx, sparse_val)
149  
150                  # Log I/O data size
151                  self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes
152                  for si, v in zip(sparse_idx_gather, sparse_val_gather):
153                      self.data_receive += si.nbytes + v.nbytes
154  
155                  # Decode grad from all nodes
156                  new_grad = self.transform.decode(
157                      self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, xshape, totalk)
158                  )
159  
160                  # Set grad to values
161                  if p.grad is None:
162                      p.grad = new_grad
163                  else:
164                      p.grad.copy_(new_grad)
165  
166                  # Sign-SGD
167                  p.grad.sign_()
168  
169          # SGD step
170          return super().step(closure)
171      
172  class TransformDCT:
173      @torch.no_grad()
174      def __init__(self, param_groups, target_chunk, norm="ortho"):
175          self.target_chunk = target_chunk
176  
177          self.shape_dict = dict()
178          self.f_dict = dict()
179          self.b_dict = dict()
180  
181          # Get all variants of model tensor sizes
182          # Generate all possible valid DCT sizes for model tensors
183          for group in param_groups:
184              for p in group["params"]:
185                  if not p.requires_grad:
186                      continue
187                  for s in p.shape:
188                      # Get the closest smallest divisor to the targeted DCT size
189                      sc = _get_smaller_split(s, self.target_chunk)
190                      self.shape_dict[s] = sc
191  
192                      # Pregenerate DCT basis matrices
193                      if sc not in self.f_dict:
194                          I = torch.eye(sc)
195                          self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device)
196                          self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device)
197  
198      @torch.no_grad()
199      def einsum_2d(self, x, b, d=None):
200          if d is None:
201              return torch.einsum("...ij, jb -> ...ib", x, b)
202          else:
203              # Note: b-c axis output is transposed to chunk DCT in 2D
204              return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d)
205  
206      @torch.no_grad()
207      def einsum_2d_t(self, x, b, d=None):
208          if d is None:
209              return torch.einsum("...ij, jb -> ...ib", x, b)
210          else:
211              # Note: b-c axis output is transposed to chunk DCT in 2D
212              return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d)
213  
214      @torch.no_grad()
215      def encode(self, x):
216          if len(x.shape) > 1:  # 2D weights
217              n1 = self.shape_dict[x.shape[0]]
218              n2 = self.shape_dict[x.shape[1]]
219              n1w = self.f_dict[n1].to(x.device)
220              n2w = self.f_dict[n2].to(x.device)
221              self.f_dict[n1] = n1w
222              self.f_dict[n2] = n2w
223  
224              x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2)
225              x = self.einsum_2d(x, n1w, n2w)
226  
227          else:  # 1D weights
228              n1 = self.shape_dict[x.shape[0]]
229              n1w = self.f_dict[n1].to(x.device)
230              self.f_dict[n1] = n1w
231  
232              x = rearrange(x, "(x w) -> x w", w=n1)
233              x = self.einsum_2d(x, n1w)
234  
235          return x
236  
237      @torch.no_grad()
238      def decode(self, x):
239          if len(x.shape) > 2:  # 2D weights
240              n1 = x.shape[2]
241              n2 = x.shape[3]
242              n1w = self.b_dict[n1].to(x.device)
243              n2w = self.b_dict[n2].to(x.device)
244              self.b_dict[n1] = n1w
245              self.b_dict[n2] = n2w
246  
247              x = self.einsum_2d_t(x, n1w, n2w)
248              x = rearrange(x, "y h x w -> (y h) (x w)")
249  
250          else:  # 1D weights
251              n1 = x.shape[1]
252              n1w = self.b_dict[n1].to(x.device)
253              self.b_dict[n1] = n1w
254  
255              x = self.einsum_2d_t(x, n1w)
256              x = rearrange(x, "x w -> (x w)")
257  
258          return x
259  
260  
261  class CompressDCT:
262      @torch.no_grad()
263      def __init__(self):
264          pass
265  
266      def _clamp_topk(self, x, topk):
267          if topk > x.shape[-1]:
268              topk = x.shape[-1]
269          if topk < 1:
270              topk = 1
271          return topk
272  
273      @torch.no_grad()
274      def compress(self, x, topk):
275          xshape = x.shape
276          if len(x.shape) > 2:  # 2D weights
277              x = rearrange(x, "y x h w -> y x (h w)")
278  
279          # Limit topk to max size
280          totalk = x.shape[-1]
281          topk = self._clamp_topk(x, topk)
282  
283          idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False).indices
284          val = torch.gather(x, dim=-1, index=idx)
285  
286          return idx, val, xshape, totalk
287  
288      @torch.no_grad()
289      def decompress(self, p, idx, val, xshape, totalk):
290          x = torch.zeros(xshape, device=p.device, dtype=p.dtype)
291  
292          if len(xshape) > 2:  # 2D weights
293              x = rearrange(x, "y x h w -> y x (h w)")
294  
295          # TODO: Careful, this is nondeterministic across different CUDA devices! might cause errors to accumulate between nodes!
296          x.scatter_reduce_(dim=-1, index=idx, src=val, reduce="mean", include_self=False).reshape(xshape)
297  
298          if len(x.shape) > 2:  # 2D weights
299              x = rearrange(x, "y x (h w) -> y x h w", h=xshape[2])
300  
301          return x
302  
303      @torch.no_grad()
304      def batch_decompress(self, p, idx, val, xshape, totalk):
305          idx = torch.concatenate(idx, dim=-1).to(device=p.device)
306          val = torch.concatenate(val, dim=-1).to(device=p.device)
307          return self.decompress(p, idx, val, xshape, totalk)
308  
309  
310  # Code modified and sourced from https://github.com/zh217/torch-dct
311  def _dct_fft_impl(v):
312      return torch.view_as_real(torch.fft.fft(v, dim=1))
313  
314  
315  def _idct_irfft_impl(V):
316      return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
317  
318  
319  def _dct(x, norm=None):
320      """
321      Discrete Cosine Transform, Type II (a.k.a. the DCT)
322  
323      For the meaning of the parameter `norm`, see:
324      https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
325  
326      :param x: the input signal
327      :param norm: the normalization, None or 'ortho'
328      :return: the DCT-II of the signal over the last dimension
329      """
330      x_shape = x.shape
331      N = x_shape[-1]
332      x = x.contiguous().view(-1, N)
333  
334      v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
335  
336      Vc = _dct_fft_impl(v)
337  
338      k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N)
339      W_r = torch.cos(k)
340      W_i = torch.sin(k)
341  
342      V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
343  
344      if norm == "ortho":
345          V[:, 0] /= math.sqrt(N) * 2
346          V[:, 1:] /= math.sqrt(N / 2) * 2
347  
348      V = 2 * V.view(*x_shape)
349  
350      return V
351  
352  
353  def _idct(X, norm=None):
354      """
355      The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
356  
357      Our definition of idct is that idct(dct(x)) == x
358  
359      For the meaning of the parameter `norm`, see:
360      https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
361  
362      :param X: the input signal
363      :param norm: the normalization, None or 'ortho'
364      :return: the inverse DCT-II of the signal over the last dimension
365      """
366  
367      x_shape = X.shape
368      N = x_shape[-1]
369  
370      X_v = X.contiguous().view(-1, x_shape[-1]) / 2
371  
372      if norm == "ortho":
373          X_v[:, 0] *= math.sqrt(N) * 2
374          X_v[:, 1:] *= math.sqrt(N / 2) * 2
375  
376      k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * math.pi / (2 * N)
377      W_r = torch.cos(k)
378      W_i = torch.sin(k)
379  
380      V_t_r = X_v
381      V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
382  
383      V_r = V_t_r * W_r - V_t_i * W_i
384      V_i = V_t_r * W_i + V_t_i * W_r
385  
386      V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
387  
388      v = _idct_irfft_impl(V)
389      x = v.new_zeros(v.shape)
390      x[:, ::2] += v[:, : N - (N // 2)]
391      x[:, 1::2] += v.flip([1])[:, : N // 2]
392  
393      return x.view(*x_shape)
394  
395  
396  def _get_prime_divisors(n):
397      divisors = []
398      while n % 2 == 0:
399          divisors.append(2)
400          n //= 2
401      while n % 3 == 0:
402          divisors.append(3)
403          n //= 3
404      i = 5
405      while i * i <= n:
406          for k in (i, i + 2):
407              while n % k == 0:
408                  divisors.append(k)
409                  n //= k
410          i += 6
411      if n > 1:
412          divisors.append(n)
413      return divisors
414  
415  
416  def _get_divisors(n):
417      divisors = []
418      if n == 1:
419          divisors.append(1)
420      elif n > 1:
421          prime_factors = _get_prime_divisors(n)
422          divisors = [1]
423          last_prime = 0
424          factor = 0
425          slice_len = 0
426          # Find all the products that are divisors of n
427          for prime in prime_factors:
428              if last_prime != prime:
429                  slice_len = len(divisors)
430                  factor = prime
431              else:
432                  factor *= prime
433              for i in range(slice_len):
434                  divisors.append(divisors[i] * factor)
435              last_prime = prime
436          divisors.sort()
437      return divisors
438  
439  
440  def _get_smaller_split(n, close_to):
441      all_divisors = _get_divisors(n)
442      for ix, val in enumerate(all_divisors):
443          if val == close_to:
444              return val
445          if val > close_to:
446              if ix == 0:
447                  return val
448              return all_divisors[ix - 1]
449      return n