/ 0001-DeMo.patch
0001-DeMo.patch
  1  From 45d0a1c7ad6286078b2fbd274b31cfe05b990b9f Mon Sep 17 00:00:00 2001
  2  From: redacted <redacted@redacted.com>
  3  Date: Wed, 2 Oct 2024 01:26:11 +0000
  4  Subject: [PATCH] DeMo
  5  
  6  ---
  7   .gitignore         |   2 +
  8   olmo/config.py     |  23 ++++
  9   olmo/demo_utils.py | 286 +++++++++++++++++++++++++++++++++++++++++++++
 10   olmo/optim.py      | 187 ++++++++++++++++++++++++++++-
 11   olmo/train.py      |  10 +-
 12   scripts/train.py   |   5 +
 13   6 files changed, 509 insertions(+), 4 deletions(-)
 14   create mode 100644 olmo/demo_utils.py
 15  
 16  diff --git a/.gitignore b/.gitignore
 17  index 9b1e9978..68739e81 100644
 18  --- a/.gitignore
 19  +++ b/.gitignore
 20  @@ -56,3 +56,5 @@ site/
 21   /wandb/
 22   /scratch/
 23   core
 24  +slurm*
 25  +checkpoints/
 26  \ No newline at end of file
 27  diff --git a/olmo/config.py b/olmo/config.py
 28  index ae454bb1..64abeb0b 100644
 29  --- a/olmo/config.py
 30  +++ b/olmo/config.py
 31  @@ -498,6 +498,7 @@ class ModelConfig(BaseConfig):
 32   class OptimizerType(StrEnum):
 33       lionw = "lionw"
 34       adamw = "adamw"
 35  +    demo = "demo"
 36   
 37   
 38   @dataclass
 39  @@ -533,6 +534,20 @@ class OptimizerConfig(BaseConfig):
 40       of the update with AdamW.
 41       """
 42   
 43  +    ### DeMo parameters
 44  +    compression_decay: float = 0.999
 45  +
 46  +    compression_topk: int = 32
 47  +    """
 48  +    How many numbers of topk to transmit per chunk, if dynamic is enabled, this is the initial topk
 49  +    """
 50  +
 51  +    compression_chunk: int = 64
 52  +    """
 53  +    Size of the chunk of the gradients, note that 2D gradients are chunked in 2D, which the topk sparsity is squared compared to 1D
 54  +    """
 55  +
 56  +
 57       def __post_init__(self):
 58           self.betas = tuple(self.betas)  # type: ignore[assignment]
 59   
 60  @@ -724,6 +739,12 @@ class DDPGradSyncMode(StrEnum):
 61       set to True, to prevent errors.
 62       """
 63   
 64  +    none = "none"
 65  +    """
 66  +    Totally disable gradient synchronization within the distributed model.
 67  +    Should only be done with some explicit external synchronization (e.g. DeMo) or if you just like spinning your wheels
 68  +    """
 69  +
 70   
 71   @dataclass
 72   class DDPConfig(BaseConfig):
 73  @@ -818,6 +839,8 @@ class FSDPConfig(BaseConfig):
 74       PyTorch's default HSDP behavior matches this default behavior.
 75       """
 76   
 77  +    disable_grad_sync: bool = False
 78  +
 79   
 80   class CheckpointType(StrEnum):
 81       sharded = "sharded"
 82  diff --git a/olmo/demo_utils.py b/olmo/demo_utils.py
 83  new file mode 100644
 84  index 00000000..316586ca
 85  --- /dev/null
 86  +++ b/olmo/demo_utils.py
 87  @@ -0,0 +1,286 @@
 88  +import math
 89  +import torch
 90  +import torch.fft
 91  +import torch.distributed as dist
 92  +
 93  +from einops import rearrange
 94  +
 95  +
 96  +class TransformDCT:
 97  +    @torch.no_grad()
 98  +    def __init__(self, param_groups, target_chunk, norm="ortho"):
 99  +        self.target_chunk = target_chunk
100  +
101  +        self.shape_dict = dict()
102  +        self.f_dict = dict()
103  +        self.b_dict = dict()
104  +
105  +        # Get all variants of model tensor sizes
106  +        # Generate all possible valid DCT sizes for model tensors
107  +        for group in param_groups:
108  +            for p in group["params"]:
109  +                if not p.requires_grad:
110  +                    continue
111  +                for s in p.shape:
112  +                    # Get the closest smallest divisor to the targeted DCT size
113  +                    sc = _get_smaller_split(s, self.target_chunk)
114  +                    self.shape_dict[s] = sc
115  +
116  +                    # Pregenerate DCT basis matrices
117  +                    if sc not in self.f_dict:
118  +                        I = torch.eye(sc)
119  +                        self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device)
120  +                        self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device)
121  +
122  +    @torch.no_grad()
123  +    def einsum_2d(self, x, b, d=None):
124  +        if d is None:
125  +            return torch.einsum("...ij, jb -> ...ib", x, b)
126  +        else:
127  +            # Note: b-c axis output is transposed to chunk DCT in 2D
128  +            return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d)
129  +
130  +    @torch.no_grad()
131  +    def einsum_2d_t(self, x, b, d=None):
132  +        if d is None:
133  +            return torch.einsum("...ij, jb -> ...ib", x, b)
134  +        else:
135  +            # Note: b-c axis output is transposed to chunk DCT in 2D
136  +            return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d)
137  +
138  +    @torch.no_grad()
139  +    def encode(self, x):
140  +        if len(x.shape) > 1:  # 2D weights
141  +            n1 = self.shape_dict[x.shape[0]]
142  +            n2 = self.shape_dict[x.shape[1]]
143  +            n1w = self.f_dict[n1].to(x.device)
144  +            n2w = self.f_dict[n2].to(x.device)
145  +            self.f_dict[n1] = n1w
146  +            self.f_dict[n2] = n2w
147  +
148  +            x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2)
149  +            x = self.einsum_2d(x, n1w, n2w)
150  +
151  +        else:  # 1D weights
152  +            n1 = self.shape_dict[x.shape[0]]
153  +            n1w = self.f_dict[n1].to(x.device)
154  +            self.f_dict[n1] = n1w
155  +
156  +            x = rearrange(x, "(x w) -> x w", w=n1)
157  +            x = self.einsum_2d(x, n1w)
158  +
159  +        return x
160  +
161  +    @torch.no_grad()
162  +    def decode(self, x):
163  +        if len(x.shape) > 2:  # 2D weights
164  +            n1 = x.shape[2]
165  +            n2 = x.shape[3]
166  +            n1w = self.b_dict[n1].to(x.device)
167  +            n2w = self.b_dict[n2].to(x.device)
168  +            self.b_dict[n1] = n1w
169  +            self.b_dict[n2] = n2w
170  +
171  +            x = self.einsum_2d_t(x, n1w, n2w)
172  +            x = rearrange(x, "y h x w -> (y h) (x w)")
173  +
174  +        else:  # 1D weights
175  +            n1 = x.shape[1]
176  +            n1w = self.b_dict[n1].to(x.device)
177  +            self.b_dict[n1] = n1w
178  +
179  +            x = self.einsum_2d_t(x, n1w)
180  +            x = rearrange(x, "x w -> (x w)")
181  +
182  +        return x
183  +
184  +
185  +class CompressDCT:
186  +    @torch.no_grad()
187  +    def __init__(self):
188  +        pass
189  +
190  +    def _clamp_topk(self, x, topk):
191  +        if topk > x.shape[-1]:
192  +            topk = x.shape[-1]
193  +        if topk < 1:
194  +            topk = 1
195  +        return topk
196  +
197  +    @torch.no_grad()
198  +    def compress(self, x, topk):
199  +        xshape = x.shape
200  +        if len(x.shape) > 2:  # 2D weights
201  +            x = rearrange(x, "y x h w -> y x (h w)")
202  +
203  +        # Limit topk to max size
204  +        totalk = x.shape[-1]
205  +        topk = self._clamp_topk(x, topk)
206  +
207  +        idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False).indices
208  +        val = torch.gather(x, dim=-1, index=idx)
209  +
210  +        return idx, val, xshape, totalk
211  +
212  +    @torch.no_grad()
213  +    def decompress(self, p, idx, val, xshape, totalk):
214  +        x = torch.zeros(xshape, device=p.device, dtype=p.dtype)
215  +
216  +        if len(xshape) > 2:  # 2D weights
217  +            x = rearrange(x, "y x h w -> y x (h w)")
218  +
219  +        # TODO: Careful, this is nondeterministic across different CUDA devices! might cause errors to accumulate between nodes!
220  +        x.scatter_reduce_(dim=-1, index=idx, src=val, reduce="mean", include_self=False).reshape(xshape)
221  +
222  +        if len(x.shape) > 2:  # 2D weights
223  +            x = rearrange(x, "y x (h w) -> y x h w", h=xshape[2])
224  +
225  +        return x
226  +
227  +    @torch.no_grad()
228  +    def batch_decompress(self, p, idx, val, xshape, totalk):
229  +        idx = torch.concatenate(idx, dim=-1).to(device=p.device)
230  +        val = torch.concatenate(val, dim=-1).to(device=p.device)
231  +        return self.decompress(p, idx, val, xshape, totalk)
232  +
233  +
234  +# Code modified and sourced from https://github.com/zh217/torch-dct
235  +def _dct_fft_impl(v):
236  +    return torch.view_as_real(torch.fft.fft(v, dim=1))
237  +
238  +
239  +def _idct_irfft_impl(V):
240  +    return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
241  +
242  +
243  +def _dct(x, norm=None):
244  +    """
245  +    Discrete Cosine Transform, Type II (a.k.a. the DCT)
246  +
247  +    For the meaning of the parameter `norm`, see:
248  +    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
249  +
250  +    :param x: the input signal
251  +    :param norm: the normalization, None or 'ortho'
252  +    :return: the DCT-II of the signal over the last dimension
253  +    """
254  +    x_shape = x.shape
255  +    N = x_shape[-1]
256  +    x = x.contiguous().view(-1, N)
257  +
258  +    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
259  +
260  +    Vc = _dct_fft_impl(v)
261  +
262  +    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N)
263  +    W_r = torch.cos(k)
264  +    W_i = torch.sin(k)
265  +
266  +    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
267  +
268  +    if norm == "ortho":
269  +        V[:, 0] /= math.sqrt(N) * 2
270  +        V[:, 1:] /= math.sqrt(N / 2) * 2
271  +
272  +    V = 2 * V.view(*x_shape)
273  +
274  +    return V
275  +
276  +
277  +def _idct(X, norm=None):
278  +    """
279  +    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
280  +
281  +    Our definition of idct is that idct(dct(x)) == x
282  +
283  +    For the meaning of the parameter `norm`, see:
284  +    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
285  +
286  +    :param X: the input signal
287  +    :param norm: the normalization, None or 'ortho'
288  +    :return: the inverse DCT-II of the signal over the last dimension
289  +    """
290  +
291  +    x_shape = X.shape
292  +    N = x_shape[-1]
293  +
294  +    X_v = X.contiguous().view(-1, x_shape[-1]) / 2
295  +
296  +    if norm == "ortho":
297  +        X_v[:, 0] *= math.sqrt(N) * 2
298  +        X_v[:, 1:] *= math.sqrt(N / 2) * 2
299  +
300  +    k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * math.pi / (2 * N)
301  +    W_r = torch.cos(k)
302  +    W_i = torch.sin(k)
303  +
304  +    V_t_r = X_v
305  +    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
306  +
307  +    V_r = V_t_r * W_r - V_t_i * W_i
308  +    V_i = V_t_r * W_i + V_t_i * W_r
309  +
310  +    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
311  +
312  +    v = _idct_irfft_impl(V)
313  +    x = v.new_zeros(v.shape)
314  +    x[:, ::2] += v[:, : N - (N // 2)]
315  +    x[:, 1::2] += v.flip([1])[:, : N // 2]
316  +
317  +    return x.view(*x_shape)
318  +
319  +
320  +def _get_prime_divisors(n):
321  +    divisors = []
322  +    while n % 2 == 0:
323  +        divisors.append(2)
324  +        n //= 2
325  +    while n % 3 == 0:
326  +        divisors.append(3)
327  +        n //= 3
328  +    i = 5
329  +    while i * i <= n:
330  +        for k in (i, i + 2):
331  +            while n % k == 0:
332  +                divisors.append(k)
333  +                n //= k
334  +        i += 6
335  +    if n > 1:
336  +        divisors.append(n)
337  +    return divisors
338  +
339  +
340  +def _get_divisors(n):
341  +    divisors = []
342  +    if n == 1:
343  +        divisors.append(1)
344  +    elif n > 1:
345  +        prime_factors = _get_prime_divisors(n)
346  +        divisors = [1]
347  +        last_prime = 0
348  +        factor = 0
349  +        slice_len = 0
350  +        # Find all the products that are divisors of n
351  +        for prime in prime_factors:
352  +            if last_prime != prime:
353  +                slice_len = len(divisors)
354  +                factor = prime
355  +            else:
356  +                factor *= prime
357  +            for i in range(slice_len):
358  +                divisors.append(divisors[i] * factor)
359  +            last_prime = prime
360  +        divisors.sort()
361  +    return divisors
362  +
363  +
364  +def _get_smaller_split(n, close_to):
365  +    all_divisors = _get_divisors(n)
366  +    for ix, val in enumerate(all_divisors):
367  +        if val == close_to:
368  +            return val
369  +        if val > close_to:
370  +            if ix == 0:
371  +                return val
372  +            return all_divisors[ix - 1]
373  +    return n
374  diff --git a/olmo/optim.py b/olmo/optim.py
375  index 5460ccee..bbbb102a 100644
376  --- a/olmo/optim.py
377  +++ b/olmo/optim.py
378  @@ -1,8 +1,9 @@
379  +import math
380   import logging
381   from abc import ABCMeta, abstractmethod
382   from dataclasses import dataclass, replace
383   from math import cos, pi, sqrt
384  -from typing import Any, Dict, List, Optional, Tuple, Union
385  +from typing import Any, Dict, List, Optional, Tuple, Union, Callable
386   
387   import torch
388   import torch.distributed as dist
389  @@ -14,11 +15,13 @@ from torch.optim.optimizer import Optimizer as OptimizerBase
390   from . import LayerNormBase
391   from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
392   from .torch_util import get_default_device, is_distributed
393  +from .demo_utils import TransformDCT, CompressDCT
394   
395   __all__ = [
396       "Optimizer",
397       "LionW",
398       "AdamW",
399  +    "DeMo",
400       "Scheduler",
401       "CosWithWarmup",
402       "LinearWithWarmup",
403  @@ -647,6 +650,177 @@ class AdamW(torch.optim.AdamW, Optimizer):
404               return metrics
405   
406   
407  +class DeMo(torch.optim.SGD, Optimizer):
408  +    def __init__(
409  +        self,
410  +        params,
411  +        compression_decay: float = 0.999,
412  +        compression_topk: int = 32,
413  +        compression_chunk: int = 64,
414  +        weight_decay: float = 0.0,
415  +        process_group: Optional[dist.ProcessGroup] = None,
416  +        record_update_metrics: bool = False,
417  +        selective_updates: bool = False,
418  +        **kwargs,
419  +    ):
420  +        super().__init__(
421  +            params,
422  +            foreach=False,
423  +            momentum=0.0,
424  +            dampening=0.0,
425  +            nesterov=False,
426  +            maximize=False,
427  +            weight_decay=0.0,
428  +            **kwargs,
429  +        )
430  +
431  +        # Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
432  +        # won't be called.
433  +        self._record_update_metrics = record_update_metrics
434  +        self._collecting_metrics = False
435  +        self._selective_updates = selective_updates
436  +
437  +        self.compression_decay = compression_decay
438  +        self.compression_chunk = compression_chunk
439  +        self.compression_topk = compression_topk
440  +        self.process_group = process_group
441  +        self.weight_decay = weight_decay
442  +
443  +        if self.compression_topk <= 0:
444  +            raise ValueError("topk_size has to be positive")
445  +        if self.compression_chunk <= 0:
446  +            raise ValueError("chunk_size has to be positive")
447  +        if self.compression_decay < 0:
448  +            raise ValueError("Negative compression_decay is currently not supported")
449  +        if self.compression_decay >= 1:
450  +            raise ValueError("Values of compression_decay bigger or equal to 1.0 is currently not supported")
451  +
452  +        self.demo_state = {}
453  +        self._init_demo_states()
454  +        self._init_opt_parameters()
455  +
456  +        self.default_dtype = self._find_dtype()
457  +        self.transform = TransformDCT(self.param_groups, self.compression_chunk)
458  +        self.compress = CompressDCT()
459  +
460  +    def _find_dtype(self):
461  +        for group in self.param_groups:
462  +            for p in group["params"]:
463  +                if p.requires_grad:
464  +                    return p.dtype
465  +        return torch.float32
466  +
467  +    def _init_demo_states(self):
468  +        for group in self.param_groups:
469  +            for p in group["params"]:
470  +                if p.requires_grad:
471  +                    self.demo_state[p] = {}
472  +
473  +    def _state_parameter(self, p):
474  +        if p not in self.demo_state:
475  +            self.demo_state[p] = {}
476  +        return self.demo_state[p]
477  +
478  +    def _init_opt_parameters(self):
479  +        for group in self.param_groups:
480  +            for p in group["params"]:
481  +                if p.requires_grad:
482  +                    state = self._state_parameter(p)
483  +
484  +                    state["step"] = 0
485  +                    state["delta"] = torch.zeros_like(p)
486  +
487  +    def _demo_all_gather(self, sparse_idx, sparse_val):
488  +        world_size = dist.get_world_size() if self.process_group is None else self.process_group.size()
489  +
490  +        # Gather all the idx and vals
491  +        sparse_idx_list = [torch.zeros_like(sparse_idx) for wi in range(world_size)]
492  +        sparse_val_list = [torch.zeros_like(sparse_val) for wi in range(world_size)]
493  +
494  +        sparse_idx_handle = dist.all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True)
495  +        sparse_val_handle = dist.all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True)
496  +
497  +        sparse_idx_handle.wait()
498  +        sparse_val_handle.wait()
499  +
500  +        return sparse_idx_list, sparse_val_list
501  +
502  +
503  +    @torch.no_grad()
504  +    def step(self, closure: Callable | None = None):
505  +
506  +        self.data_transmit = 0
507  +        self.data_receive = 0
508  +
509  +        for group in self.param_groups:
510  +            lr = group["lr"]
511  +            for p in group["params"]:
512  +                if not p.requires_grad:
513  +                    continue
514  +                state = self._state_parameter(p)
515  +
516  +                # Update step
517  +                state["step"] += 1
518  +
519  +                # Step-Weight decay
520  +                if self.weight_decay != 0.0:
521  +                    p.data.mul_(1.0 - lr * self.weight_decay)
522  +
523  +                # Decay delta
524  +                if self.compression_decay != 1:
525  +                    state["delta"].mul_(self.compression_decay)
526  +
527  +                # Add delta to new gradient
528  +                state["delta"].add_(p.grad, alpha=lr)
529  +
530  +                # Compress delta
531  +                sparse_idx, sparse_val, xshape, totalk = self.compress.compress(
532  +                    self.transform.encode(state["delta"]), self.compression_topk
533  +                )
534  +
535  +                # Estimate transmitted delta
536  +                transmit_grad = self.transform.decode(
537  +                    self.compress.decompress(p, sparse_idx, sparse_val, xshape, totalk)
538  +                )
539  +
540  +                # Remove transmitted from delta
541  +                state["delta"].sub_(transmit_grad)
542  +
543  +                # All-gather
544  +                sparse_idx_gather, sparse_val_gather = self._demo_all_gather(sparse_idx, sparse_val)
545  +
546  +                # Log I/O data size
547  +                self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes
548  +                for si, v in zip(sparse_idx_gather, sparse_val_gather):
549  +                    self.data_receive += si.nbytes + v.nbytes
550  +
551  +                # Decode grad from all nodes
552  +                new_grad = self.transform.decode(
553  +                    self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, xshape, totalk)
554  +                )
555  +
556  +                # Set grad to values
557  +                if p.grad is None:
558  +                    p.grad = new_grad
559  +                else:
560  +                    p.grad.copy_(new_grad)
561  +
562  +                # Sign-SGD
563  +                p.grad.sign_()
564  +
565  +        # SGD step
566  +        return super().step(closure)
567  +
568  +
569  +    def get_post_step_metrics(
570  +        self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
571  +    ) -> Dict[str, torch.Tensor]:
572  +        return {
573  +            "data_receive": torch.tensor(self.data_receive, device=get_default_device()),
574  +            "data_transmit": torch.tensor(self.data_transmit, device=get_default_device()),
575  +        }
576  +
577  +
578   @dataclass
579   class Scheduler(metaclass=ABCMeta):
580       # NOTE: these fields are not given default values because otherwise dataclasses complains
581  @@ -950,6 +1124,17 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
582               selective_updates=cfg.optimizer.selective_updates,
583               eps=cfg.optimizer.eps,
584           )
585  +    elif cfg.optimizer.name == OptimizerType.demo:
586  +        return DeMo(
587  +            param_groups,
588  +            compression_decay=cfg.optimizer.compression_decay,
589  +            compression_topk=cfg.optimizer.compression_topk,
590  +            compression_chunk=cfg.optimizer.compression_chunk,
591  +            weight_decay=cfg.optimizer.weight_decay,
592  +            process_group=None,  # TODO: fix for hybrid sharding
593  +            record_update_metrics=cfg.optimizer.record_update_metrics,
594  +            selective_updates=cfg.optimizer.selective_updates,
595  +        )
596       else:
597           raise NotImplementedError
598   
599  diff --git a/olmo/train.py b/olmo/train.py
600  index 34105500..77e758b9 100644
601  --- a/olmo/train.py
602  +++ b/olmo/train.py
603  @@ -35,6 +35,7 @@ from .config import (
604       CheckpointType,
605       DDPGradSyncMode,
606       DistributedStrategy,
607  +    OptimizerType,
608       SchedulerUnits,
609       ShardedCheckpointerType,
610       SpeedMonitorConfig,
611  @@ -44,7 +45,7 @@ from .data import IterableDataset
612   from .eval import Evaluator
613   from .exceptions import OLMoConfigurationError
614   from .model import OLMo
615  -from .optim import Optimizer, Scheduler
616  +from .optim import DeMo, Optimizer, Scheduler
617   from .torch_util import (
618       barrier,
619       gc_cuda,
620  @@ -785,10 +786,13 @@ class Trainer:
621               if (
622                   self.cfg.distributed_strategy == DistributedStrategy.ddp
623                   and self.cfg.ddp is not None
624  -                and self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch
625  +                and self.cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch
626               ):
627  -                if micro_batch_idx != num_micro_batches - 1:
628  +                if (self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch and micro_batch_idx != num_micro_batches - 1) \
629  +                    or self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.none:
630                       grad_sync_context = self.dist_model.no_sync
631  +            elif self.cfg.distributed_strategy == DistributedStrategy.fsdp and self.cfg.fsdp is not None and self.cfg.fsdp.disable_grad_sync:
632  +                grad_sync_context = self.dist_model.no_sync
633   
634               # Register output hooks
635               output_hooks: List[torch.utils.hooks.RemovableHandle] = []
636  diff --git a/scripts/train.py b/scripts/train.py
637  index 1f735309..d20d0092 100644
638  --- a/scripts/train.py
639  +++ b/scripts/train.py
640  @@ -20,6 +20,7 @@ from olmo.config import (
641       CheckpointType,
642       DDPGradSyncMode,
643       DistributedStrategy,
644  +    OptimizerType,
645       TrainConfig,
646   )
647   from olmo.data import build_train_dataloader
648  @@ -138,6 +139,8 @@ def main(cfg: TrainConfig) -> None:
649       if cfg.distributed_strategy == DistributedStrategy.ddp:
650           log.info("Wrapping model with DDP...")
651           assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!"
652  +        if cfg.optimizer.name == OptimizerType.demo and cfg.ddp.grad_sync_mode != DDPGradSyncMode.none:
653  +            raise OLMoConfigurationError("DeMo requires that `ddp.grad_sync_mode` be set to `none`.")
654   
655           if cfg.model.init_device != "cuda":
656               raise OLMoConfigurationError("DDP does not work with init_device set to anything other than `cuda`.")
657  @@ -155,6 +158,8 @@ def main(cfg: TrainConfig) -> None:
658           # Wrap the model in FSDP.
659           log.info("Wrapping model with FSDP...")
660           assert cfg.fsdp is not None, "DistributedStrategy fsdp needs cfg.fsdp to be set!"
661  +        if cfg.optimizer.name == OptimizerType.demo and not cfg.fsdp.disable_grad_sync:
662  +            raise OLMoConfigurationError("DeMo requires that `fsdp.disable_grad_sync` be set to `true`.")
663           wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
664   
665           if version.parse(torch.__version__) >= version.parse("2.1.0"):
666  -- 
667  2.34.1
668