/ examples / medusa / callback.py
callback.py
  1  import os
  2  import time
  3  from dataclasses import dataclass
  4  
  5  import torch
  6  import transformers
  7  from accelerate.utils.constants import FSDP_SHARDING_STRATEGY
  8  from transformers import TrainerControl, TrainerState, TrainingArguments
  9  
 10  from liger_kernel.utils import infer_device
 11  
 12  # https://simple.wikipedia.org/wiki/Byte
 13  # For memory, we use binary system
 14  M_BIN_UNIT = 2**20
 15  # For metrics (tflops), we use decimal system
 16  T_DEC_UNIT = 10**12
 17  
 18  
 19  def round_to_n_decimal(x, n):
 20      return round(x, n)
 21  
 22  
 23  @dataclass
 24  class Precision:
 25      """
 26      Precision is a dataclass to store the number of decimal points for each metric.
 27      """
 28  
 29      n_decimal_time: int
 30      n_decimal_memory: int
 31      n_decimal_TPS: int
 32      n_decimal_MFU: int
 33  
 34  
 35  @dataclass
 36  class State:
 37      """
 38      State is a dataclass to store the internal state of the efficiency callback.
 39      """
 40  
 41      n_warmup_steps: int = 0
 42      total_peak_memory_allocated: float = float("-inf")
 43      total_peak_memory_reserved: float = float("-inf")
 44  
 45      step_start_time: float = 0.0
 46      elapsed_time: float = 0.0
 47  
 48      elapsed_step: int = 0
 49  
 50      step_start_tokens_seen: int = 0
 51      elapsed_tokens_seen: int = 0
 52  
 53      step_start_flos: float = 0.0
 54      elapsed_flos: float = 0.0
 55  
 56      global_start_step: int = 0
 57  
 58  
 59  @dataclass
 60  class Time:
 61      """
 62      Time is a dataclass to store the time-related metrics.
 63      """
 64  
 65      step: int = 0
 66      step_time_sec: float = 0.0
 67      avg_step_time_sec: float = 0.0
 68      time_to_completion_sec: float = 0.0
 69      estimated_total_time_sec: float = 0.0
 70  
 71  
 72  @dataclass
 73  class Memory:
 74      """
 75      Memory is a dataclass to store the memory-related metrics.
 76      """
 77  
 78      step_peak_memory_allocated_MB: float = 0.0
 79      total_peak_memory_allocated_MB: float = 0.0
 80  
 81  
 82  @dataclass
 83  class TPS:
 84      """
 85      TPS is a dataclass to store the tokens per second metrics.
 86      """
 87  
 88      step_tokens_per_second: float = 0.0
 89      avg_tokens_per_second: float = 0.0
 90  
 91  
 92  @dataclass
 93  class MFU:
 94      """
 95      MFU is a dataclass to store the MFU metrics.
 96      """
 97  
 98      step_MFU: float = 0.0
 99      avg_MFU: float = 0.0
100  
101  
102  class EfficiencyCallback(transformers.TrainerCallback):
103      """
104      EfficiencyCallback is a callback to track the efficiency of the training process.
105      The tracked stats include: step time, memory, throughput, and MFU.
106  
107      It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments.
108  
109      Args:
110          n_warmup_steps: number of warmup steps
111              The stats in the first n_warmup_steps will not be added into the aggregated stats
112              This is because the first few steps might take longer due to jit compliation and other initialization overheads
113          n_decimal_time: number of decimal points for time
114          n_decimal_memory: number of decimal points for memory
115          n_decimal_TPS: number of decimal points for TPS
116          n_decimal_MFU: number of decimal points for MFU in percentage
117      """
118  
119      def __init__(
120          self,
121          n_warmup_steps=2,
122          n_decimal_time=2,
123          n_decimal_memory=2,
124          n_decimal_TPS=2,
125          n_decimal_MFU=4,
126      ):
127          self.state = State(
128              n_warmup_steps,
129          )
130  
131          self.precision = Precision(
132              n_decimal_time,
133              n_decimal_memory,
134              n_decimal_TPS,
135              n_decimal_MFU,
136          )
137  
138          self.time = Time()
139          self.memory = Memory()
140          self.tps = TPS()
141          self.mfu = MFU()
142          self.device = infer_device()
143  
144      def on_init_end(
145          self,
146          args: TrainingArguments,
147          state: TrainerState,
148          control: TrainerControl,
149          **kwargs,
150      ):
151          """
152          Event called at the end of the initialization of the [`Trainer`].
153          """
154          if not args.include_num_input_tokens_seen:
155              raise Exception(
156                  'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second'
157              )
158          if args.logging_steps != 1:
159              raise Exception(
160                  "Please set logging_steps=1 to track the efficiency metrics accurately"
161              )
162  
163      def on_train_begin(
164          self,
165          args: TrainingArguments,
166          state: TrainerState,
167          control: TrainerControl,
168          **kwargs,
169      ):
170          # if loaded from checkpoints, global_start_step is not 1 but state.global_step
171          self.state.global_start_step = state.global_step
172  
173      def on_log(
174          self,
175          args: TrainingArguments,
176          state: TrainerState,
177          control: TrainerControl,
178          logs: dict[str, float],
179          **kwargs,
180      ):
181          if state.global_step < (
182              self.state.global_start_step + self.state.n_warmup_steps
183          ):
184              return
185          else:
186              # spread self.time, self.memory, self.tps, self.mfu to logs
187              # logs.update(self.time.__dict__)
188              logs.update(self.memory.__dict__)
189              logs.update(self.tps.__dict__)
190              # logs.update(self.mfu.__dict__)
191  
192      def on_step_begin(
193          self,
194          args: TrainingArguments,
195          state: TrainerState,
196          control: TrainerControl,
197          **kwargs,
198      ):
199          """
200          Event called at the beginning of a training step. If using gradient accumulation, one training step might take
201          several inputs.
202          """
203          # memory
204          getattr(torch, self.device).reset_peak_memory_stats()
205  
206          # time
207          self.state.step_start_time = time.perf_counter()
208  
209      def on_step_end(
210          self,
211          args: TrainingArguments,
212          state: TrainerState,
213          control: TrainerControl,
214          **kwargs,
215      ):
216          if state.global_step < (
217              self.state.global_start_step + self.state.n_warmup_steps
218          ):
219              # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration
220  
221              # tokens
222              self.state.step_start_tokens_seen = state.num_input_tokens_seen
223              # flos
224              self.state.step_start_flos = state.total_flos
225              return
226  
227          # time
228          current_time = time.perf_counter()
229          step_time = current_time - self.state.step_start_time
230          self.state.elapsed_time += step_time
231  
232          # step
233          global_step = state.global_step
234          self.state.elapsed_step += 1
235          avg_step_time = self.state.elapsed_time / self.state.elapsed_step
236  
237          self.time.step = global_step
238          self.time.step_time_sec = round_to_n_decimal(
239              step_time, self.precision.n_decimal_time
240          )
241          self.time.avg_step_time_sec = round_to_n_decimal(
242              avg_step_time, self.precision.n_decimal_time
243          )
244          self.time.time_to_completion_sec = round_to_n_decimal(
245              avg_step_time * (state.max_steps - global_step),
246              self.precision.n_decimal_time,
247          )
248          self.time.estimated_total_time_sec = round_to_n_decimal(
249              avg_step_time * state.max_steps, self.precision.n_decimal_time
250          )
251  
252          # memory
253          step_peak_memory_allocated = getattr(
254              torch, self.device
255          ).memory.max_memory_allocated()
256          step_peak_memory_reserved = getattr(
257              torch, self.device
258          ).memory.max_memory_reserved()
259  
260          self.memory.step_peak_memory_allocated_MB = round_to_n_decimal(
261              step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory
262          )
263          self.state.total_peak_memory_allocated = max(
264              self.state.total_peak_memory_allocated, step_peak_memory_allocated
265          )
266          self.memory.total_peak_memory_allocated_MB = round_to_n_decimal(
267              self.state.total_peak_memory_allocated / M_BIN_UNIT,
268              self.precision.n_decimal_memory,
269          )
270  
271          self.memory.step_peak_memory_reserved_MB = round_to_n_decimal(
272              step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory
273          )
274  
275          self.state.total_peak_memory_reserved = max(
276              self.state.total_peak_memory_reserved, step_peak_memory_reserved
277          )
278  
279          self.memory.total_peak_memory_reserved_MB = round_to_n_decimal(
280              self.state.total_peak_memory_reserved / M_BIN_UNIT,
281              self.precision.n_decimal_memory,
282          )
283  
284          # tokens
285          step_tokens_seen = (
286              state.num_input_tokens_seen - self.state.step_start_tokens_seen
287          )
288  
289          self.state.elapsed_tokens_seen += step_tokens_seen
290  
291          self.tps.step_tokens_per_second = round_to_n_decimal(
292              step_tokens_seen / step_time,
293              self.precision.n_decimal_TPS,
294          )
295  
296          self.tps.avg_tokens_per_second = round_to_n_decimal(
297              self.state.elapsed_tokens_seen / self.state.elapsed_time,
298              self.precision.n_decimal_TPS,
299          )
300  
301          # flos
302          step_flos = state.total_flos - self.state.step_start_flos
303          self.state.elapsed_flos += step_flos
304  
305          # MFU
306          # 1. Definition
307          #
308          # MFU is defined as (achieved TPS) / (theoretical maximum TPS) = (achieved floating point operations per sec) / (theoretical maximum floating point operations per sec)
309          # Crucially, the "theoretical maximum" throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization. MFU therefore allows fair comparisons
310          # between training runs on different systems, as the numerator is simply the observed tokens-per-second, and the denominator is only dependent on the model architecture and published maximum FLOPs for a given system.
311          # Ref: https://arxiv.org/pdf/2204.02311
312          # The benefit of MFU is that it
313          #
314          # 2. Implementation in huggingface
315          #
316          # current_flos = 6 * estimate_tokens(input_dict) * num_parameters()
317          # total_flos = sum(current_flos) # across all GPUs
318          # Ref: https://github.com/huggingface/transformers/blob/616bb11d487aabc231bb230b245c42214ea4b254/src/transformers/modeling_utils.py#L1196
319          #
320          # 3. Derive MFU on rank 0
321          #
322          # rank_0_flos = tatal_flos / n_gpus = measured_flos / effecitve_n_gpus
323          # rank_0_MFU = rank_0_flos / step_time
324          #
325          # For FSDP, num_parameters() is (1 / n_gpus) of the total parameters. So, the effective_n_gpus = 1
326          # For HSDP, num_parameters() is (1 / local_world_size) of the total parameters. So, the effective_n_gpus = n_nodes
327          # For no sharding and zero-2, num_parameters() is the total parameters. So, the effective_n_gpus = n_gpus
328  
329          num_gpus = EfficiencyCallback._get_effective_num_gpus()
330          step_achieved_tflops = step_flos / step_time / num_gpus / T_DEC_UNIT
331  
332          avg_achieved_tflops = (
333              self.state.elapsed_flos / self.state.elapsed_time / num_gpus / T_DEC_UNIT
334          )
335  
336          precision_bits = 16 if args.bf16 or args.fp16 else 32
337          gpu_peak_tflops = EfficiencyCallback._get_gpu_peak_tflops(precision_bits)
338  
339          self.mfu.step_MFU = round_to_n_decimal(
340              step_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU
341          )
342  
343          self.mfu.avg_MFU = round_to_n_decimal(
344              avg_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU
345          )
346  
347          # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration
348  
349          # tokens
350          self.state.step_start_tokens_seen = state.num_input_tokens_seen
351          # flos
352          self.state.step_start_flos = state.total_flos
353  
354      @staticmethod
355      def _get_effective_num_gpus():
356          # Calculate the number of effective GPUs for the total FLOPs in order to calculate the single GPU FLOP
357          world_size = int(os.environ.get("WORLD_SIZE", "1"))
358  
359          if transformers.utils.strtobool(os.environ.get("ACCELERATE_USE_FSDP", "false")):
360              sharding_strategy = os.environ.get(
361                  "FSDP_SHARDING_STRATEGY", FSDP_SHARDING_STRATEGY[0]
362              ).upper()
363  
364              # Either specified as string or enum number
365              if sharding_strategy in {
366                  "FULL_SHARD",
367                  str(FSDP_SHARDING_STRATEGY.index("FULL_SHARD") + 1),
368              }:
369                  return 1
370  
371              elif sharding_strategy in {
372                  "HYBRID_SHARD",
373                  str(FSDP_SHARDING_STRATEGY.index("HYBRID_SHARD") + 1),
374              }:
375                  return world_size // int(os.environ.get("LOCAL_WORLD_SIZE", 1))
376              else:
377                  return world_size
378  
379          assert (
380              world_size != 0
381          ), "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
382  
383          # TODO: add deepspeed support
384          return world_size
385  
386      @staticmethod
387      def _get_gpu_peak_tflops(precision_bits: int = 16):
388          if precision_bits not in {16, 32}:
389              raise Exception(f"Precision bits {precision_bits} is not supported")
390  
391          device_name = getattr(torch, infer_device()).get_device_name()
392  
393          if "A100" in device_name:
394              # data from https://www.nvidia.com/en-us/data-center/a100/
395              return 312 if precision_bits == 16 else 156
396          elif "H100" in device_name:
397              # data from https://www.nvidia.com/en-us/data-center/h100/
398              # NOTE: Specifications are one-half lower without sparsity.
399              if "NVL" in device_name:
400                  return 1979 if precision_bits == 16 else 989
401              elif "PCIe" in device_name:
402                  return 756 if precision_bits == 16 else 378
403              else:  # for SXM and other variants
404                  return 989 if precision_bits == 16 else 494
405          elif "V100" in device_name:
406              if "NVL" in device_name:
407                  return 125
408              else:
409                  return 112
410          return None