callback.py
1 import os 2 import time 3 from dataclasses import dataclass 4 5 import torch 6 import transformers 7 from accelerate.utils.constants import FSDP_SHARDING_STRATEGY 8 from transformers import TrainerControl, TrainerState, TrainingArguments 9 10 from liger_kernel.utils import infer_device 11 12 # https://simple.wikipedia.org/wiki/Byte 13 # For memory, we use binary system 14 M_BIN_UNIT = 2**20 15 # For metrics (tflops), we use decimal system 16 T_DEC_UNIT = 10**12 17 18 19 def round_to_n_decimal(x, n): 20 return round(x, n) 21 22 23 @dataclass 24 class Precision: 25 """ 26 Precision is a dataclass to store the number of decimal points for each metric. 27 """ 28 29 n_decimal_time: int 30 n_decimal_memory: int 31 n_decimal_TPS: int 32 n_decimal_MFU: int 33 34 35 @dataclass 36 class State: 37 """ 38 State is a dataclass to store the internal state of the efficiency callback. 39 """ 40 41 n_warmup_steps: int = 0 42 total_peak_memory_allocated: float = float("-inf") 43 total_peak_memory_reserved: float = float("-inf") 44 45 step_start_time: float = 0.0 46 elapsed_time: float = 0.0 47 48 elapsed_step: int = 0 49 50 step_start_tokens_seen: int = 0 51 elapsed_tokens_seen: int = 0 52 53 step_start_flos: float = 0.0 54 elapsed_flos: float = 0.0 55 56 global_start_step: int = 0 57 58 59 @dataclass 60 class Time: 61 """ 62 Time is a dataclass to store the time-related metrics. 63 """ 64 65 step: int = 0 66 step_time_sec: float = 0.0 67 avg_step_time_sec: float = 0.0 68 time_to_completion_sec: float = 0.0 69 estimated_total_time_sec: float = 0.0 70 71 72 @dataclass 73 class Memory: 74 """ 75 Memory is a dataclass to store the memory-related metrics. 76 """ 77 78 step_peak_memory_allocated_MB: float = 0.0 79 total_peak_memory_allocated_MB: float = 0.0 80 81 82 @dataclass 83 class TPS: 84 """ 85 TPS is a dataclass to store the tokens per second metrics. 86 """ 87 88 step_tokens_per_second: float = 0.0 89 avg_tokens_per_second: float = 0.0 90 91 92 @dataclass 93 class MFU: 94 """ 95 MFU is a dataclass to store the MFU metrics. 96 """ 97 98 step_MFU: float = 0.0 99 avg_MFU: float = 0.0 100 101 102 class EfficiencyCallback(transformers.TrainerCallback): 103 """ 104 EfficiencyCallback is a callback to track the efficiency of the training process. 105 The tracked stats include: step time, memory, throughput, and MFU. 106 107 It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments. 108 109 Args: 110 n_warmup_steps: number of warmup steps 111 The stats in the first n_warmup_steps will not be added into the aggregated stats 112 This is because the first few steps might take longer due to jit compliation and other initialization overheads 113 n_decimal_time: number of decimal points for time 114 n_decimal_memory: number of decimal points for memory 115 n_decimal_TPS: number of decimal points for TPS 116 n_decimal_MFU: number of decimal points for MFU in percentage 117 """ 118 119 def __init__( 120 self, 121 n_warmup_steps=2, 122 n_decimal_time=2, 123 n_decimal_memory=2, 124 n_decimal_TPS=2, 125 n_decimal_MFU=4, 126 ): 127 self.state = State( 128 n_warmup_steps, 129 ) 130 131 self.precision = Precision( 132 n_decimal_time, 133 n_decimal_memory, 134 n_decimal_TPS, 135 n_decimal_MFU, 136 ) 137 138 self.time = Time() 139 self.memory = Memory() 140 self.tps = TPS() 141 self.mfu = MFU() 142 self.device = infer_device() 143 144 def on_init_end( 145 self, 146 args: TrainingArguments, 147 state: TrainerState, 148 control: TrainerControl, 149 **kwargs, 150 ): 151 """ 152 Event called at the end of the initialization of the [`Trainer`]. 153 """ 154 if not args.include_num_input_tokens_seen: 155 raise Exception( 156 'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second' 157 ) 158 if args.logging_steps != 1: 159 raise Exception( 160 "Please set logging_steps=1 to track the efficiency metrics accurately" 161 ) 162 163 def on_train_begin( 164 self, 165 args: TrainingArguments, 166 state: TrainerState, 167 control: TrainerControl, 168 **kwargs, 169 ): 170 # if loaded from checkpoints, global_start_step is not 1 but state.global_step 171 self.state.global_start_step = state.global_step 172 173 def on_log( 174 self, 175 args: TrainingArguments, 176 state: TrainerState, 177 control: TrainerControl, 178 logs: dict[str, float], 179 **kwargs, 180 ): 181 if state.global_step < ( 182 self.state.global_start_step + self.state.n_warmup_steps 183 ): 184 return 185 else: 186 # spread self.time, self.memory, self.tps, self.mfu to logs 187 # logs.update(self.time.__dict__) 188 logs.update(self.memory.__dict__) 189 logs.update(self.tps.__dict__) 190 # logs.update(self.mfu.__dict__) 191 192 def on_step_begin( 193 self, 194 args: TrainingArguments, 195 state: TrainerState, 196 control: TrainerControl, 197 **kwargs, 198 ): 199 """ 200 Event called at the beginning of a training step. If using gradient accumulation, one training step might take 201 several inputs. 202 """ 203 # memory 204 getattr(torch, self.device).reset_peak_memory_stats() 205 206 # time 207 self.state.step_start_time = time.perf_counter() 208 209 def on_step_end( 210 self, 211 args: TrainingArguments, 212 state: TrainerState, 213 control: TrainerControl, 214 **kwargs, 215 ): 216 if state.global_step < ( 217 self.state.global_start_step + self.state.n_warmup_steps 218 ): 219 # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration 220 221 # tokens 222 self.state.step_start_tokens_seen = state.num_input_tokens_seen 223 # flos 224 self.state.step_start_flos = state.total_flos 225 return 226 227 # time 228 current_time = time.perf_counter() 229 step_time = current_time - self.state.step_start_time 230 self.state.elapsed_time += step_time 231 232 # step 233 global_step = state.global_step 234 self.state.elapsed_step += 1 235 avg_step_time = self.state.elapsed_time / self.state.elapsed_step 236 237 self.time.step = global_step 238 self.time.step_time_sec = round_to_n_decimal( 239 step_time, self.precision.n_decimal_time 240 ) 241 self.time.avg_step_time_sec = round_to_n_decimal( 242 avg_step_time, self.precision.n_decimal_time 243 ) 244 self.time.time_to_completion_sec = round_to_n_decimal( 245 avg_step_time * (state.max_steps - global_step), 246 self.precision.n_decimal_time, 247 ) 248 self.time.estimated_total_time_sec = round_to_n_decimal( 249 avg_step_time * state.max_steps, self.precision.n_decimal_time 250 ) 251 252 # memory 253 step_peak_memory_allocated = getattr( 254 torch, self.device 255 ).memory.max_memory_allocated() 256 step_peak_memory_reserved = getattr( 257 torch, self.device 258 ).memory.max_memory_reserved() 259 260 self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( 261 step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory 262 ) 263 self.state.total_peak_memory_allocated = max( 264 self.state.total_peak_memory_allocated, step_peak_memory_allocated 265 ) 266 self.memory.total_peak_memory_allocated_MB = round_to_n_decimal( 267 self.state.total_peak_memory_allocated / M_BIN_UNIT, 268 self.precision.n_decimal_memory, 269 ) 270 271 self.memory.step_peak_memory_reserved_MB = round_to_n_decimal( 272 step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory 273 ) 274 275 self.state.total_peak_memory_reserved = max( 276 self.state.total_peak_memory_reserved, step_peak_memory_reserved 277 ) 278 279 self.memory.total_peak_memory_reserved_MB = round_to_n_decimal( 280 self.state.total_peak_memory_reserved / M_BIN_UNIT, 281 self.precision.n_decimal_memory, 282 ) 283 284 # tokens 285 step_tokens_seen = ( 286 state.num_input_tokens_seen - self.state.step_start_tokens_seen 287 ) 288 289 self.state.elapsed_tokens_seen += step_tokens_seen 290 291 self.tps.step_tokens_per_second = round_to_n_decimal( 292 step_tokens_seen / step_time, 293 self.precision.n_decimal_TPS, 294 ) 295 296 self.tps.avg_tokens_per_second = round_to_n_decimal( 297 self.state.elapsed_tokens_seen / self.state.elapsed_time, 298 self.precision.n_decimal_TPS, 299 ) 300 301 # flos 302 step_flos = state.total_flos - self.state.step_start_flos 303 self.state.elapsed_flos += step_flos 304 305 # MFU 306 # 1. Definition 307 # 308 # MFU is defined as (achieved TPS) / (theoretical maximum TPS) = (achieved floating point operations per sec) / (theoretical maximum floating point operations per sec) 309 # Crucially, the "theoretical maximum" throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization. MFU therefore allows fair comparisons 310 # between training runs on different systems, as the numerator is simply the observed tokens-per-second, and the denominator is only dependent on the model architecture and published maximum FLOPs for a given system. 311 # Ref: https://arxiv.org/pdf/2204.02311 312 # The benefit of MFU is that it 313 # 314 # 2. Implementation in huggingface 315 # 316 # current_flos = 6 * estimate_tokens(input_dict) * num_parameters() 317 # total_flos = sum(current_flos) # across all GPUs 318 # Ref: https://github.com/huggingface/transformers/blob/616bb11d487aabc231bb230b245c42214ea4b254/src/transformers/modeling_utils.py#L1196 319 # 320 # 3. Derive MFU on rank 0 321 # 322 # rank_0_flos = tatal_flos / n_gpus = measured_flos / effecitve_n_gpus 323 # rank_0_MFU = rank_0_flos / step_time 324 # 325 # For FSDP, num_parameters() is (1 / n_gpus) of the total parameters. So, the effective_n_gpus = 1 326 # For HSDP, num_parameters() is (1 / local_world_size) of the total parameters. So, the effective_n_gpus = n_nodes 327 # For no sharding and zero-2, num_parameters() is the total parameters. So, the effective_n_gpus = n_gpus 328 329 num_gpus = EfficiencyCallback._get_effective_num_gpus() 330 step_achieved_tflops = step_flos / step_time / num_gpus / T_DEC_UNIT 331 332 avg_achieved_tflops = ( 333 self.state.elapsed_flos / self.state.elapsed_time / num_gpus / T_DEC_UNIT 334 ) 335 336 precision_bits = 16 if args.bf16 or args.fp16 else 32 337 gpu_peak_tflops = EfficiencyCallback._get_gpu_peak_tflops(precision_bits) 338 339 self.mfu.step_MFU = round_to_n_decimal( 340 step_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU 341 ) 342 343 self.mfu.avg_MFU = round_to_n_decimal( 344 avg_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU 345 ) 346 347 # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration 348 349 # tokens 350 self.state.step_start_tokens_seen = state.num_input_tokens_seen 351 # flos 352 self.state.step_start_flos = state.total_flos 353 354 @staticmethod 355 def _get_effective_num_gpus(): 356 # Calculate the number of effective GPUs for the total FLOPs in order to calculate the single GPU FLOP 357 world_size = int(os.environ.get("WORLD_SIZE", "1")) 358 359 if transformers.utils.strtobool(os.environ.get("ACCELERATE_USE_FSDP", "false")): 360 sharding_strategy = os.environ.get( 361 "FSDP_SHARDING_STRATEGY", FSDP_SHARDING_STRATEGY[0] 362 ).upper() 363 364 # Either specified as string or enum number 365 if sharding_strategy in { 366 "FULL_SHARD", 367 str(FSDP_SHARDING_STRATEGY.index("FULL_SHARD") + 1), 368 }: 369 return 1 370 371 elif sharding_strategy in { 372 "HYBRID_SHARD", 373 str(FSDP_SHARDING_STRATEGY.index("HYBRID_SHARD") + 1), 374 }: 375 return world_size // int(os.environ.get("LOCAL_WORLD_SIZE", 1)) 376 else: 377 return world_size 378 379 assert ( 380 world_size != 0 381 ), "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1." 382 383 # TODO: add deepspeed support 384 return world_size 385 386 @staticmethod 387 def _get_gpu_peak_tflops(precision_bits: int = 16): 388 if precision_bits not in {16, 32}: 389 raise Exception(f"Precision bits {precision_bits} is not supported") 390 391 device_name = getattr(torch, infer_device()).get_device_name() 392 393 if "A100" in device_name: 394 # data from https://www.nvidia.com/en-us/data-center/a100/ 395 return 312 if precision_bits == 16 else 156 396 elif "H100" in device_name: 397 # data from https://www.nvidia.com/en-us/data-center/h100/ 398 # NOTE: Specifications are one-half lower without sparsity. 399 if "NVL" in device_name: 400 return 1979 if precision_bits == 16 else 989 401 elif "PCIe" in device_name: 402 return 756 if precision_bits == 16 else 378 403 else: # for SXM and other variants 404 return 989 if precision_bits == 16 else 494 405 elif "V100" in device_name: 406 if "NVL" in device_name: 407 return 125 408 else: 409 return 112 410 return None