[docs]@dataclassclassSpeedMonitorCallback(Callback):""" Monitors throughput. .. important:: This callback gets added automatically if you don't explicitly configure it. If you want to override this callback you should subclass it. """priority:ClassVar[int]=-2num_flops_per_token:Optional[int]=Nonenum_params:Optional[int]=Nonedevice_peak_flops_per_second:Optional[int]=None_total_steps:int=0_total_tokens:int=0_total_flops:int=0_start_time:float=0.0_first_step:bool=True_step_last_logged:float=0.0_batch_load_start:float=0.0_batch_load_time:float=0.0_step_tokens:int=0_step_seq_len:int=0_step_flops:int=0_parallel_degree:int=1_bps_avg:Optional[float]=None_tps_avg:Optional[float]=None_mfu_avg:Optional[float]=Nonedefreset(self):self._first_step=Trueself._bps_avg=None@propertydefbps_avg(self)->Optional[float]:returnself._bps_avg@propertydeftps_avg(self)->Optional[float]:returnself._tps_avg@propertydefmfu_avg(self)->Optional[float]:returnself._mfu_avgdef_get_num_flops_per_token(self,seq_len:int)->Optional[int]:ifself.num_flops_per_tokenisnotNone:returnself.num_flops_per_tokenelifisinstance(self.trainer.train_module,TransformerTrainModule):returnself.trainer.train_module.num_flops_per_token(seq_len)else:returnNonedefpre_train(self):self._first_step=Trueifself.trainer.dp_process_groupisnotNone:self._parallel_degree=get_world_size()//get_world_size(self.trainer.dp_process_group)ifself.num_paramsisNoneandisinstance(self.trainer.train_module,TransformerTrainModule):self.num_params=self.trainer.train_module.model.num_non_embedding_paramsif(self.device_peak_flops_per_secondisNoneandself.trainer.device.type=="cuda"andisinstance(self.trainer.train_module,TransformerTrainModule)):device_name=torch.cuda.get_device_name(self.trainer.device)tm=self.trainer.train_moduleusing_half_precision=tm.autocast_precision==torch.bfloat16or(tm.dp_configisnotNoneandtm.dp_config.param_dtype==DType.bfloat16)ifusing_half_precision:dense_correction=0.5# listed specs are one-half lower without sparsityif"H100"indevice_name:# data from https://www.nvidia.com/en-us/data-center/h100/if"NVL"indevice_name:self.device_peak_flops_per_second=int(1671e12*dense_correction)elif"PCIe"indevice_name:self.device_peak_flops_per_second=int(1513e12*dense_correction)else:# for SXM and other variantsself.device_peak_flops_per_second=int(1979e12*dense_correction)elif"B200"indevice_name:# data from https://www.nvidia.com/en-us/data-center/hgx/self.device_peak_flops_per_second=int(4.5e15*dense_correction)else:# for other GPU types, assume A100# data from https://www.nvidia.com/en-us/data-center/a100/self.device_peak_flops_per_second=int(624e12*dense_correction)log.info(f"Device: {device_name}, Device peak Flops/s: {self.device_peak_flops_per_second}")defpre_load_batch(self):self._batch_load_start=time.perf_counter()defpre_step(self,batch:Dict[str,Any]):self._batch_load_time=time.perf_counter()-self._batch_load_startifself._first_step:# We don't record the first batch since the first one tends to take# unusually long.returnself._total_steps+=1if"input_ids"inbatch:tokens_in_batch=batch["input_ids"].numel()self._step_tokens=tokens_in_batch//self._parallel_degreeself._step_seq_len=batch["input_ids"].shape[1]self._total_tokens+=self._step_tokensself._step_flops=0if(num_flops_per_token:=self._get_num_flops_per_token(self._step_seq_len))isnotNone:self._step_flops=num_flops_per_token*self._step_tokensself._total_flops+=self._step_flopsdefpost_step(self):counter=time.perf_counter()self.trainer.record_metric("throughput/device/data loading (s)",self._batch_load_time,reduce_type=ReduceType.max)ifself._first_step:# Now we can start recording.self._total_steps=0self._total_tokens=0self._total_flops=0self._start_time=counterself._first_step=Falseself._step_last_logged=counterreturnstep_time=counter-self._step_last_loggedtotal_time=counter-self._start_timeself._step_last_logged=counterifself._step_tokensandself._total_tokens:tps=self._step_tokens/step_timetps_avg=self._total_tokens/total_timeself._tps_avg=tps_avgself.trainer.record_metric("throughput/device/TPS",tps)self.trainer.record_metric("throughput/device/TPS (actual avg)",tps_avg)ifself.trainer.global_train_tokens_seenisnotNone:self.trainer.record_metric("throughput/total tokens",self.trainer.global_train_tokens_seen)ifself.num_paramsisnotNone:self.trainer.record_metric("throughput/chinchilla multiple",self.trainer.global_train_tokens_seen/(20*self.num_params),)flops_ps:Optional[float]=Noneflops_ps_avg:Optional[float]=Noneifself._step_flopsandself._total_flops:flops_ps=self._step_flops/step_timeflops_ps_avg=self._total_flops/total_timeself.trainer.record_metric("throughput/device/flopsPS",flops_ps)self.trainer.record_metric("throughput/device/flopsPS (actual avg)",flops_ps_avg)self.trainer.record_metric("throughput/total petaflops",self.trainer.global_train_petaflops)bps=1/step_timebps_avg=self._total_steps/total_timeself._bps_avg=bps_avgself.trainer.record_metric("throughput/device/BPS",bps)self.trainer.record_metric("throughput/device/BPS (actual avg)",bps_avg)data_pct=100*self._batch_load_time/step_timeself.trainer.record_metric("throughput/device/data loading (%)",data_pct,reduce_type=ReduceType.max)if(self.device_peak_flops_per_secondisnotNoneandflops_psisnotNoneandflops_ps_avgisnotNone):# model FLOPS utilization# For its definition and calculation, please refer to the PaLM paper:# https://arxiv.org/abs/2204.02311# MFU is computed from FLOPs/sec. This stays correct even if sequence length changes.mfu=100*flops_ps/self.device_peak_flops_per_secondmfu_avg=100*flops_ps_avg/self.device_peak_flops_per_secondself._mfu_avg=mfu_avgself.trainer.record_metric("throughput/device/MFU",mfu)self.trainer.record_metric("throughput/device/MFU (actual avg)",mfu_avg)