/ 0001-DeMo.patch
0001-DeMo.patch
1 From 45d0a1c7ad6286078b2fbd274b31cfe05b990b9f Mon Sep 17 00:00:00 2001 2 From: redacted <redacted@redacted.com> 3 Date: Wed, 2 Oct 2024 01:26:11 +0000 4 Subject: [PATCH] DeMo 5 6 --- 7 .gitignore | 2 + 8 olmo/config.py | 23 ++++ 9 olmo/demo_utils.py | 286 +++++++++++++++++++++++++++++++++++++++++++++ 10 olmo/optim.py | 187 ++++++++++++++++++++++++++++- 11 olmo/train.py | 10 +- 12 scripts/train.py | 5 + 13 6 files changed, 509 insertions(+), 4 deletions(-) 14 create mode 100644 olmo/demo_utils.py 15 16 diff --git a/.gitignore b/.gitignore 17 index 9b1e9978..68739e81 100644 18 --- a/.gitignore 19 +++ b/.gitignore 20 @@ -56,3 +56,5 @@ site/ 21 /wandb/ 22 /scratch/ 23 core 24 +slurm* 25 +checkpoints/ 26 \ No newline at end of file 27 diff --git a/olmo/config.py b/olmo/config.py 28 index ae454bb1..64abeb0b 100644 29 --- a/olmo/config.py 30 +++ b/olmo/config.py 31 @@ -498,6 +498,7 @@ class ModelConfig(BaseConfig): 32 class OptimizerType(StrEnum): 33 lionw = "lionw" 34 adamw = "adamw" 35 + demo = "demo" 36 37 38 @dataclass 39 @@ -533,6 +534,20 @@ class OptimizerConfig(BaseConfig): 40 of the update with AdamW. 41 """ 42 43 + ### DeMo parameters 44 + compression_decay: float = 0.999 45 + 46 + compression_topk: int = 32 47 + """ 48 + How many numbers of topk to transmit per chunk, if dynamic is enabled, this is the initial topk 49 + """ 50 + 51 + compression_chunk: int = 64 52 + """ 53 + Size of the chunk of the gradients, note that 2D gradients are chunked in 2D, which the topk sparsity is squared compared to 1D 54 + """ 55 + 56 + 57 def __post_init__(self): 58 self.betas = tuple(self.betas) # type: ignore[assignment] 59 60 @@ -724,6 +739,12 @@ class DDPGradSyncMode(StrEnum): 61 set to True, to prevent errors. 62 """ 63 64 + none = "none" 65 + """ 66 + Totally disable gradient synchronization within the distributed model. 67 + Should only be done with some explicit external synchronization (e.g. DeMo) or if you just like spinning your wheels 68 + """ 69 + 70 71 @dataclass 72 class DDPConfig(BaseConfig): 73 @@ -818,6 +839,8 @@ class FSDPConfig(BaseConfig): 74 PyTorch's default HSDP behavior matches this default behavior. 75 """ 76 77 + disable_grad_sync: bool = False 78 + 79 80 class CheckpointType(StrEnum): 81 sharded = "sharded" 82 diff --git a/olmo/demo_utils.py b/olmo/demo_utils.py 83 new file mode 100644 84 index 00000000..316586ca 85 --- /dev/null 86 +++ b/olmo/demo_utils.py 87 @@ -0,0 +1,286 @@ 88 +import math 89 +import torch 90 +import torch.fft 91 +import torch.distributed as dist 92 + 93 +from einops import rearrange 94 + 95 + 96 +class TransformDCT: 97 + @torch.no_grad() 98 + def __init__(self, param_groups, target_chunk, norm="ortho"): 99 + self.target_chunk = target_chunk 100 + 101 + self.shape_dict = dict() 102 + self.f_dict = dict() 103 + self.b_dict = dict() 104 + 105 + # Get all variants of model tensor sizes 106 + # Generate all possible valid DCT sizes for model tensors 107 + for group in param_groups: 108 + for p in group["params"]: 109 + if not p.requires_grad: 110 + continue 111 + for s in p.shape: 112 + # Get the closest smallest divisor to the targeted DCT size 113 + sc = _get_smaller_split(s, self.target_chunk) 114 + self.shape_dict[s] = sc 115 + 116 + # Pregenerate DCT basis matrices 117 + if sc not in self.f_dict: 118 + I = torch.eye(sc) 119 + self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device) 120 + self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device) 121 + 122 + @torch.no_grad() 123 + def einsum_2d(self, x, b, d=None): 124 + if d is None: 125 + return torch.einsum("...ij, jb -> ...ib", x, b) 126 + else: 127 + # Note: b-c axis output is transposed to chunk DCT in 2D 128 + return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d) 129 + 130 + @torch.no_grad() 131 + def einsum_2d_t(self, x, b, d=None): 132 + if d is None: 133 + return torch.einsum("...ij, jb -> ...ib", x, b) 134 + else: 135 + # Note: b-c axis output is transposed to chunk DCT in 2D 136 + return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d) 137 + 138 + @torch.no_grad() 139 + def encode(self, x): 140 + if len(x.shape) > 1: # 2D weights 141 + n1 = self.shape_dict[x.shape[0]] 142 + n2 = self.shape_dict[x.shape[1]] 143 + n1w = self.f_dict[n1].to(x.device) 144 + n2w = self.f_dict[n2].to(x.device) 145 + self.f_dict[n1] = n1w 146 + self.f_dict[n2] = n2w 147 + 148 + x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2) 149 + x = self.einsum_2d(x, n1w, n2w) 150 + 151 + else: # 1D weights 152 + n1 = self.shape_dict[x.shape[0]] 153 + n1w = self.f_dict[n1].to(x.device) 154 + self.f_dict[n1] = n1w 155 + 156 + x = rearrange(x, "(x w) -> x w", w=n1) 157 + x = self.einsum_2d(x, n1w) 158 + 159 + return x 160 + 161 + @torch.no_grad() 162 + def decode(self, x): 163 + if len(x.shape) > 2: # 2D weights 164 + n1 = x.shape[2] 165 + n2 = x.shape[3] 166 + n1w = self.b_dict[n1].to(x.device) 167 + n2w = self.b_dict[n2].to(x.device) 168 + self.b_dict[n1] = n1w 169 + self.b_dict[n2] = n2w 170 + 171 + x = self.einsum_2d_t(x, n1w, n2w) 172 + x = rearrange(x, "y h x w -> (y h) (x w)") 173 + 174 + else: # 1D weights 175 + n1 = x.shape[1] 176 + n1w = self.b_dict[n1].to(x.device) 177 + self.b_dict[n1] = n1w 178 + 179 + x = self.einsum_2d_t(x, n1w) 180 + x = rearrange(x, "x w -> (x w)") 181 + 182 + return x 183 + 184 + 185 +class CompressDCT: 186 + @torch.no_grad() 187 + def __init__(self): 188 + pass 189 + 190 + def _clamp_topk(self, x, topk): 191 + if topk > x.shape[-1]: 192 + topk = x.shape[-1] 193 + if topk < 1: 194 + topk = 1 195 + return topk 196 + 197 + @torch.no_grad() 198 + def compress(self, x, topk): 199 + xshape = x.shape 200 + if len(x.shape) > 2: # 2D weights 201 + x = rearrange(x, "y x h w -> y x (h w)") 202 + 203 + # Limit topk to max size 204 + totalk = x.shape[-1] 205 + topk = self._clamp_topk(x, topk) 206 + 207 + idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False).indices 208 + val = torch.gather(x, dim=-1, index=idx) 209 + 210 + return idx, val, xshape, totalk 211 + 212 + @torch.no_grad() 213 + def decompress(self, p, idx, val, xshape, totalk): 214 + x = torch.zeros(xshape, device=p.device, dtype=p.dtype) 215 + 216 + if len(xshape) > 2: # 2D weights 217 + x = rearrange(x, "y x h w -> y x (h w)") 218 + 219 + # TODO: Careful, this is nondeterministic across different CUDA devices! might cause errors to accumulate between nodes! 220 + x.scatter_reduce_(dim=-1, index=idx, src=val, reduce="mean", include_self=False).reshape(xshape) 221 + 222 + if len(x.shape) > 2: # 2D weights 223 + x = rearrange(x, "y x (h w) -> y x h w", h=xshape[2]) 224 + 225 + return x 226 + 227 + @torch.no_grad() 228 + def batch_decompress(self, p, idx, val, xshape, totalk): 229 + idx = torch.concatenate(idx, dim=-1).to(device=p.device) 230 + val = torch.concatenate(val, dim=-1).to(device=p.device) 231 + return self.decompress(p, idx, val, xshape, totalk) 232 + 233 + 234 +# Code modified and sourced from https://github.com/zh217/torch-dct 235 +def _dct_fft_impl(v): 236 + return torch.view_as_real(torch.fft.fft(v, dim=1)) 237 + 238 + 239 +def _idct_irfft_impl(V): 240 + return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 241 + 242 + 243 +def _dct(x, norm=None): 244 + """ 245 + Discrete Cosine Transform, Type II (a.k.a. the DCT) 246 + 247 + For the meaning of the parameter `norm`, see: 248 + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 249 + 250 + :param x: the input signal 251 + :param norm: the normalization, None or 'ortho' 252 + :return: the DCT-II of the signal over the last dimension 253 + """ 254 + x_shape = x.shape 255 + N = x_shape[-1] 256 + x = x.contiguous().view(-1, N) 257 + 258 + v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 259 + 260 + Vc = _dct_fft_impl(v) 261 + 262 + k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N) 263 + W_r = torch.cos(k) 264 + W_i = torch.sin(k) 265 + 266 + V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 267 + 268 + if norm == "ortho": 269 + V[:, 0] /= math.sqrt(N) * 2 270 + V[:, 1:] /= math.sqrt(N / 2) * 2 271 + 272 + V = 2 * V.view(*x_shape) 273 + 274 + return V 275 + 276 + 277 +def _idct(X, norm=None): 278 + """ 279 + The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 280 + 281 + Our definition of idct is that idct(dct(x)) == x 282 + 283 + For the meaning of the parameter `norm`, see: 284 + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 285 + 286 + :param X: the input signal 287 + :param norm: the normalization, None or 'ortho' 288 + :return: the inverse DCT-II of the signal over the last dimension 289 + """ 290 + 291 + x_shape = X.shape 292 + N = x_shape[-1] 293 + 294 + X_v = X.contiguous().view(-1, x_shape[-1]) / 2 295 + 296 + if norm == "ortho": 297 + X_v[:, 0] *= math.sqrt(N) * 2 298 + X_v[:, 1:] *= math.sqrt(N / 2) * 2 299 + 300 + k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * math.pi / (2 * N) 301 + W_r = torch.cos(k) 302 + W_i = torch.sin(k) 303 + 304 + V_t_r = X_v 305 + V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 306 + 307 + V_r = V_t_r * W_r - V_t_i * W_i 308 + V_i = V_t_r * W_i + V_t_i * W_r 309 + 310 + V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 311 + 312 + v = _idct_irfft_impl(V) 313 + x = v.new_zeros(v.shape) 314 + x[:, ::2] += v[:, : N - (N // 2)] 315 + x[:, 1::2] += v.flip([1])[:, : N // 2] 316 + 317 + return x.view(*x_shape) 318 + 319 + 320 +def _get_prime_divisors(n): 321 + divisors = [] 322 + while n % 2 == 0: 323 + divisors.append(2) 324 + n //= 2 325 + while n % 3 == 0: 326 + divisors.append(3) 327 + n //= 3 328 + i = 5 329 + while i * i <= n: 330 + for k in (i, i + 2): 331 + while n % k == 0: 332 + divisors.append(k) 333 + n //= k 334 + i += 6 335 + if n > 1: 336 + divisors.append(n) 337 + return divisors 338 + 339 + 340 +def _get_divisors(n): 341 + divisors = [] 342 + if n == 1: 343 + divisors.append(1) 344 + elif n > 1: 345 + prime_factors = _get_prime_divisors(n) 346 + divisors = [1] 347 + last_prime = 0 348 + factor = 0 349 + slice_len = 0 350 + # Find all the products that are divisors of n 351 + for prime in prime_factors: 352 + if last_prime != prime: 353 + slice_len = len(divisors) 354 + factor = prime 355 + else: 356 + factor *= prime 357 + for i in range(slice_len): 358 + divisors.append(divisors[i] * factor) 359 + last_prime = prime 360 + divisors.sort() 361 + return divisors 362 + 363 + 364 +def _get_smaller_split(n, close_to): 365 + all_divisors = _get_divisors(n) 366 + for ix, val in enumerate(all_divisors): 367 + if val == close_to: 368 + return val 369 + if val > close_to: 370 + if ix == 0: 371 + return val 372 + return all_divisors[ix - 1] 373 + return n 374 diff --git a/olmo/optim.py b/olmo/optim.py 375 index 5460ccee..bbbb102a 100644 376 --- a/olmo/optim.py 377 +++ b/olmo/optim.py 378 @@ -1,8 +1,9 @@ 379 +import math 380 import logging 381 from abc import ABCMeta, abstractmethod 382 from dataclasses import dataclass, replace 383 from math import cos, pi, sqrt 384 -from typing import Any, Dict, List, Optional, Tuple, Union 385 +from typing import Any, Dict, List, Optional, Tuple, Union, Callable 386 387 import torch 388 import torch.distributed as dist 389 @@ -14,11 +15,13 @@ from torch.optim.optimizer import Optimizer as OptimizerBase 390 from . import LayerNormBase 391 from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig 392 from .torch_util import get_default_device, is_distributed 393 +from .demo_utils import TransformDCT, CompressDCT 394 395 __all__ = [ 396 "Optimizer", 397 "LionW", 398 "AdamW", 399 + "DeMo", 400 "Scheduler", 401 "CosWithWarmup", 402 "LinearWithWarmup", 403 @@ -647,6 +650,177 @@ class AdamW(torch.optim.AdamW, Optimizer): 404 return metrics 405 406 407 +class DeMo(torch.optim.SGD, Optimizer): 408 + def __init__( 409 + self, 410 + params, 411 + compression_decay: float = 0.999, 412 + compression_topk: int = 32, 413 + compression_chunk: int = 64, 414 + weight_decay: float = 0.0, 415 + process_group: Optional[dist.ProcessGroup] = None, 416 + record_update_metrics: bool = False, 417 + selective_updates: bool = False, 418 + **kwargs, 419 + ): 420 + super().__init__( 421 + params, 422 + foreach=False, 423 + momentum=0.0, 424 + dampening=0.0, 425 + nesterov=False, 426 + maximize=False, 427 + weight_decay=0.0, 428 + **kwargs, 429 + ) 430 + 431 + # Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__` 432 + # won't be called. 433 + self._record_update_metrics = record_update_metrics 434 + self._collecting_metrics = False 435 + self._selective_updates = selective_updates 436 + 437 + self.compression_decay = compression_decay 438 + self.compression_chunk = compression_chunk 439 + self.compression_topk = compression_topk 440 + self.process_group = process_group 441 + self.weight_decay = weight_decay 442 + 443 + if self.compression_topk <= 0: 444 + raise ValueError("topk_size has to be positive") 445 + if self.compression_chunk <= 0: 446 + raise ValueError("chunk_size has to be positive") 447 + if self.compression_decay < 0: 448 + raise ValueError("Negative compression_decay is currently not supported") 449 + if self.compression_decay >= 1: 450 + raise ValueError("Values of compression_decay bigger or equal to 1.0 is currently not supported") 451 + 452 + self.demo_state = {} 453 + self._init_demo_states() 454 + self._init_opt_parameters() 455 + 456 + self.default_dtype = self._find_dtype() 457 + self.transform = TransformDCT(self.param_groups, self.compression_chunk) 458 + self.compress = CompressDCT() 459 + 460 + def _find_dtype(self): 461 + for group in self.param_groups: 462 + for p in group["params"]: 463 + if p.requires_grad: 464 + return p.dtype 465 + return torch.float32 466 + 467 + def _init_demo_states(self): 468 + for group in self.param_groups: 469 + for p in group["params"]: 470 + if p.requires_grad: 471 + self.demo_state[p] = {} 472 + 473 + def _state_parameter(self, p): 474 + if p not in self.demo_state: 475 + self.demo_state[p] = {} 476 + return self.demo_state[p] 477 + 478 + def _init_opt_parameters(self): 479 + for group in self.param_groups: 480 + for p in group["params"]: 481 + if p.requires_grad: 482 + state = self._state_parameter(p) 483 + 484 + state["step"] = 0 485 + state["delta"] = torch.zeros_like(p) 486 + 487 + def _demo_all_gather(self, sparse_idx, sparse_val): 488 + world_size = dist.get_world_size() if self.process_group is None else self.process_group.size() 489 + 490 + # Gather all the idx and vals 491 + sparse_idx_list = [torch.zeros_like(sparse_idx) for wi in range(world_size)] 492 + sparse_val_list = [torch.zeros_like(sparse_val) for wi in range(world_size)] 493 + 494 + sparse_idx_handle = dist.all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True) 495 + sparse_val_handle = dist.all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True) 496 + 497 + sparse_idx_handle.wait() 498 + sparse_val_handle.wait() 499 + 500 + return sparse_idx_list, sparse_val_list 501 + 502 + 503 + @torch.no_grad() 504 + def step(self, closure: Callable | None = None): 505 + 506 + self.data_transmit = 0 507 + self.data_receive = 0 508 + 509 + for group in self.param_groups: 510 + lr = group["lr"] 511 + for p in group["params"]: 512 + if not p.requires_grad: 513 + continue 514 + state = self._state_parameter(p) 515 + 516 + # Update step 517 + state["step"] += 1 518 + 519 + # Step-Weight decay 520 + if self.weight_decay != 0.0: 521 + p.data.mul_(1.0 - lr * self.weight_decay) 522 + 523 + # Decay delta 524 + if self.compression_decay != 1: 525 + state["delta"].mul_(self.compression_decay) 526 + 527 + # Add delta to new gradient 528 + state["delta"].add_(p.grad, alpha=lr) 529 + 530 + # Compress delta 531 + sparse_idx, sparse_val, xshape, totalk = self.compress.compress( 532 + self.transform.encode(state["delta"]), self.compression_topk 533 + ) 534 + 535 + # Estimate transmitted delta 536 + transmit_grad = self.transform.decode( 537 + self.compress.decompress(p, sparse_idx, sparse_val, xshape, totalk) 538 + ) 539 + 540 + # Remove transmitted from delta 541 + state["delta"].sub_(transmit_grad) 542 + 543 + # All-gather 544 + sparse_idx_gather, sparse_val_gather = self._demo_all_gather(sparse_idx, sparse_val) 545 + 546 + # Log I/O data size 547 + self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes 548 + for si, v in zip(sparse_idx_gather, sparse_val_gather): 549 + self.data_receive += si.nbytes + v.nbytes 550 + 551 + # Decode grad from all nodes 552 + new_grad = self.transform.decode( 553 + self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, xshape, totalk) 554 + ) 555 + 556 + # Set grad to values 557 + if p.grad is None: 558 + p.grad = new_grad 559 + else: 560 + p.grad.copy_(new_grad) 561 + 562 + # Sign-SGD 563 + p.grad.sign_() 564 + 565 + # SGD step 566 + return super().step(closure) 567 + 568 + 569 + def get_post_step_metrics( 570 + self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None 571 + ) -> Dict[str, torch.Tensor]: 572 + return { 573 + "data_receive": torch.tensor(self.data_receive, device=get_default_device()), 574 + "data_transmit": torch.tensor(self.data_transmit, device=get_default_device()), 575 + } 576 + 577 + 578 @dataclass 579 class Scheduler(metaclass=ABCMeta): 580 # NOTE: these fields are not given default values because otherwise dataclasses complains 581 @@ -950,6 +1124,17 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer: 582 selective_updates=cfg.optimizer.selective_updates, 583 eps=cfg.optimizer.eps, 584 ) 585 + elif cfg.optimizer.name == OptimizerType.demo: 586 + return DeMo( 587 + param_groups, 588 + compression_decay=cfg.optimizer.compression_decay, 589 + compression_topk=cfg.optimizer.compression_topk, 590 + compression_chunk=cfg.optimizer.compression_chunk, 591 + weight_decay=cfg.optimizer.weight_decay, 592 + process_group=None, # TODO: fix for hybrid sharding 593 + record_update_metrics=cfg.optimizer.record_update_metrics, 594 + selective_updates=cfg.optimizer.selective_updates, 595 + ) 596 else: 597 raise NotImplementedError 598 599 diff --git a/olmo/train.py b/olmo/train.py 600 index 34105500..77e758b9 100644 601 --- a/olmo/train.py 602 +++ b/olmo/train.py 603 @@ -35,6 +35,7 @@ from .config import ( 604 CheckpointType, 605 DDPGradSyncMode, 606 DistributedStrategy, 607 + OptimizerType, 608 SchedulerUnits, 609 ShardedCheckpointerType, 610 SpeedMonitorConfig, 611 @@ -44,7 +45,7 @@ from .data import IterableDataset 612 from .eval import Evaluator 613 from .exceptions import OLMoConfigurationError 614 from .model import OLMo 615 -from .optim import Optimizer, Scheduler 616 +from .optim import DeMo, Optimizer, Scheduler 617 from .torch_util import ( 618 barrier, 619 gc_cuda, 620 @@ -785,10 +786,13 @@ class Trainer: 621 if ( 622 self.cfg.distributed_strategy == DistributedStrategy.ddp 623 and self.cfg.ddp is not None 624 - and self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch 625 + and self.cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch 626 ): 627 - if micro_batch_idx != num_micro_batches - 1: 628 + if (self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch and micro_batch_idx != num_micro_batches - 1) \ 629 + or self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.none: 630 grad_sync_context = self.dist_model.no_sync 631 + elif self.cfg.distributed_strategy == DistributedStrategy.fsdp and self.cfg.fsdp is not None and self.cfg.fsdp.disable_grad_sync: 632 + grad_sync_context = self.dist_model.no_sync 633 634 # Register output hooks 635 output_hooks: List[torch.utils.hooks.RemovableHandle] = [] 636 diff --git a/scripts/train.py b/scripts/train.py 637 index 1f735309..d20d0092 100644 638 --- a/scripts/train.py 639 +++ b/scripts/train.py 640 @@ -20,6 +20,7 @@ from olmo.config import ( 641 CheckpointType, 642 DDPGradSyncMode, 643 DistributedStrategy, 644 + OptimizerType, 645 TrainConfig, 646 ) 647 from olmo.data import build_train_dataloader 648 @@ -138,6 +139,8 @@ def main(cfg: TrainConfig) -> None: 649 if cfg.distributed_strategy == DistributedStrategy.ddp: 650 log.info("Wrapping model with DDP...") 651 assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!" 652 + if cfg.optimizer.name == OptimizerType.demo and cfg.ddp.grad_sync_mode != DDPGradSyncMode.none: 653 + raise OLMoConfigurationError("DeMo requires that `ddp.grad_sync_mode` be set to `none`.") 654 655 if cfg.model.init_device != "cuda": 656 raise OLMoConfigurationError("DDP does not work with init_device set to anything other than `cuda`.") 657 @@ -155,6 +158,8 @@ def main(cfg: TrainConfig) -> None: 658 # Wrap the model in FSDP. 659 log.info("Wrapping model with FSDP...") 660 assert cfg.fsdp is not None, "DistributedStrategy fsdp needs cfg.fsdp to be set!" 661 + if cfg.optimizer.name == OptimizerType.demo and not cfg.fsdp.disable_grad_sync: 662 + raise OLMoConfigurationError("DeMo requires that `fsdp.disable_grad_sync` be set to `true`.") 663 wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy) 664 665 if version.parse(torch.__version__) >= version.parse("2.1.0"): 666 -- 667 2.34.1 668