[docs]@dataclassclassProfilerCallback(Callback):""" Enables profiling/tracing of training steps using :mod:`torch.profiler`. Saved the results to a subdirectory of the save folder named "profiler". """skip_first:int=0""" Ignore this many steps before profiling cycles. """wait:int=1""" Idle for this many steps before activating. """warmup:int=5""" Start tracing, but discard the results, for this many steps. """active:int=3""" Actively trace this many steps. """repeat:int=1""" Repeat the cycle start at ``wait`` steps. """with_stack:bool=True""" Whether to record source information (file and line number) for the ops. """profile_memory:bool=False""" Whether to track tensor memory allocation/deallocation """enable_cuda_sync_events:bool=False""" Whether to enable recording of CUDA sync events. Useful for critical-path analysis with https://hta.readthedocs.io/en/latest/source/features/lightweight_critical_path_analysis.html """enabled:bool=True""" Set to ``False`` to disable profiling. """ranks:str|None=None""" Ranks to profile. Can be: - ``None``: Only rank 0 is profiled - String shortcuts: - ``"dp"``: Profile one rank (local rank 0) in each data parallel group - ``"tp"``: Profile one rank (local rank 0) in each tensor parallel group - ``"cp"``: Profile one rank (local rank 0) in each context parallel group - ``"pp"``: Profile one rank (local rank 0) in each pipeline parallel group - ``"ep"``: Profile one rank (local rank 0) in each expert parallel group - ``"all"``: Profile all ranks Useful in conjunction with https://github.com/facebookresearch/HolisticTraceAnalysis to analyze traces from a distributed training job. """_exit_stack=None_profiler=None_first_batch:bool=Truedef_should_profile_rank(self)->bool:current_rank=get_rank()ifself.ranksisNone:returncurrent_rank==0elifisinstance(self.ranks,str):# Handle string shortcuts for parallel groupsworld_mesh=get_world_mesh()ifworld_meshisNone:ifself.ranks!="all":log.warning("No world mesh available, falling back to rank 0 only")returncurrent_rank==0try:ifself.ranks=="dp":dp_mesh=get_dp_mesh(world_mesh)returndp_mesh.get_local_rank()==0elifself.ranks=="tp":tp_mesh=get_tp_mesh(world_mesh)returntp_mesh.get_local_rank()==0elifself.ranks=="cp":cp_mesh=get_cp_mesh(world_mesh)returncp_mesh.get_local_rank()==0elifself.ranks=="pp":pp_mesh=get_pp_mesh(world_mesh)returnpp_mesh.get_local_rank()==0elifself.ranks=="ep":ep_mesh=get_ep_mesh(world_mesh)returnep_mesh.get_local_rank()==0elifself.ranks=="all":returnTrueelse:raiseValueError(f"Unknown rank shortcut '{self.ranks}'")exceptRuntimeErrorase:log.warning(f"Failed to determine parallel mesh for '{self.ranks}': {e}, falling back to rank 0 only")returncurrent_rank==0else:raiseTypeError(f"Invalid ranks specification: {self.ranks}")defpre_train(self):ifnotself.enabledornotself._should_profile_rank():returnfromtorch.profilerimport(ProfilerActivity,_ExperimentalConfig,profile,schedule,)profiling_schedule=schedule(wait=self.wait,warmup=self.warmup,active=self.active,repeat=self.repeat,skip_first=self.skip_first,)activities=[ProfilerActivity.CPU]ifself.trainer.device.type=="cuda":activities.append(ProfilerActivity.CUDA)experimental_config=Noneifself.enable_cuda_sync_events:experimental_config=_ExperimentalConfig(enable_cuda_sync_events=True)self._exit_stack=ExitStack()self._profiler=self._exit_stack.enter_context(profile(activities=activities,record_shapes=False,profile_memory=self.profile_memory,with_stack=self.with_stack,schedule=profiling_schedule,on_trace_ready=self._on_trace_ready,experimental_config=experimental_config,))self._first_batch=Truedefpre_load_batch(self):ifnotself.enabledornotself._should_profile_rank():returnifself._first_batch:self._first_batch=Falseelse:assertself._profilerisnotNoneself._profiler.step()def_on_trace_ready(self,prof):assertself._profilerisnotNoneoutput=self._profiler.key_averages().table(sort_by="self_cuda_time_total",row_limit=32)log.info(f"Profile by total GPU time at step {self._profiler.step_num}:\n{output}")output=self._profiler.key_averages().table(sort_by="self_cpu_time_total",row_limit=32)log.info(f"Profile by total CPU time at step {self._profiler.step_num}:\n{output}")log.info("Saving chrome trace from profiler...")output_dir=self.trainer.work_dir/"profiler"output_dir.mkdir(exist_ok=True,parents=True)trace_path=output_dir/f"rank-{get_rank()}-step-{prof.step_num}.chrome_trace.json.gz"prof.export_chrome_trace(str(trace_path))final_path=self.trainer.persist_working_file(trace_path)log.info(f"Chrome trace saved to '{final_path}'")