[docs]classCheckpointRemovalStrategy(StrEnum):""" An enumeration of the different strategies for removing old checkpoints found in the save folder. """ephemeral_only="ephemeral_only"""" Only remove checkpoints that were saved at the :data:`CheckpointerCallback.ephemeral_save_interval`. """all_non_permanent="all_non_permanent"""" Remove all non-permanent checkpoints found, including ephemeral checkpoints and also any other checkpoints that were not saved at the :data:`CheckpointerCallback.save_interval`. """never="never"""" Never remove any old checkpoints found in the save folder. """
[docs]@dataclassclassCheckpointerCallback(Callback):""" Manages checkpointing during training, including writing checkpoints at set intervals determined by :data:`save_interval` and :data:`ephemeral_save_interval`, as well as removing old checkpoints found in the save folder as determined by the :data:`remove` setting. .. important:: This callback gets added automatically if you don't explicitly configure it. If you want to override this callback you should subclass it. """priority:ClassVar[int]=1save_interval:Optional[int]=250""" The interval, in steps, with which to save permanent checkoints. """ephemeral_save_interval:Optional[int]=None""" The interval, in steps, with which to save temporary checkpoints. These checkpoints are removed each time a new checkpoint is saved. It can be useful to set this to a relatively frequent interval for preemptible jobs. """pre_train_checkpoint:Optional[bool]=None""" Save a pretrain checkpoint. Defaults to ``True`` unless the trainer resumes from a checkpoint. """save_async:Optional[bool]=None""" Save checkpoints asynchronously. Requires a separate CPU-only backend. Defaults to ``True`` if there is one. """remove:CheckpointRemovalStrategy=CheckpointRemovalStrategy.ephemeral_only""" The strategy for removing old checkpoints found in the save folder. """ephemeral_cooldown:Optional[int]=None""" The number of steps to wait after saving a checkpoint before saving another ephemeral one is allowed. """fixed_steps:Optional[List[int]]=None""" A list of fixed steps at which to save additional permanent checkpoints. """enabled:bool=True# Bookkeeping# NOTE: can't use type annotation here, omegaconf doesn't like it# _future: Optional[Future] = None_future=None_latest_checkpoint_step:int=-1_latest_checkpoint_path:str=""_checkpoints:List[str]=field(default_factory=list)_ephemeral_checkpoints:List[str]=field(default_factory=list)_checkpoints_to_remove:List[str]=field(default_factory=list)def__post_init__(self):ifself.save_intervalisnotNoneandself.save_interval<1:raiseOLMoConfigurationError("'save_interval' must be at least 1")ifself.ephemeral_save_intervalisnotNone:ifself.ephemeral_save_interval<1:raiseOLMoConfigurationError("'ephemeral_save_interval' must be at least 1")if(self.save_intervalisnotNoneandself.ephemeral_save_interval>=self.save_interval):raiseOLMoConfigurationError("'ephemeral_save_interval' must be less than 'save_interval'")@propertydefcheckpointer(self)->Checkpointer:returnself.trainer.checkpointer@propertydefsave_folder(self)->str:returnself.trainer.save_folder@propertydefcheckpoint_pending(self)->bool:returnself._futureisnotNonedef_await_last_checkpoint(self,blocking:bool=True)->Optional[Future]:if(fut:=self._future)isnotNone:# Wait for last async checkpoint to finish.ifblockingorfut.done():fut.result()ifget_rank()==0:# Just to be safe, make sure the checkpointer has finalized the checkpoint.wait_for(lambda:self.checkpointer.dir_is_checkpoint(self._latest_checkpoint_path),"waiting to finalize checkpoint",)self._future=NonereturnfutreturnNonedef_save_checkpoint(self,save_async:Optional[bool]=None,ephemeral:bool=False)->str:save_async=save_asyncifsave_asyncisnotNoneelseself.save_asyncself._await_last_checkpoint()ifsave_async:path,self._future=self.trainer.save_checkpoint_async(ephemeral=ephemeral)else:path=self.trainer.save_checkpoint(ephemeral=ephemeral)self._latest_checkpoint_step=self.stepself._latest_checkpoint_path=str(path)returnstr(path)def_remove_checkpoint(self,path:str):log.info(f"Removing old checkpoint at '{path}'...")# Remove metadata file first to invalidate the checkpoint.ifget_rank()==0:try:remove_file(join_path(path,self.trainer.checkpointer.METADATA_FNAME))exceptFileNotFoundError:passifis_url(path):ifget_rank()==0:self.trainer.run_bookkeeping_op(clear_directory,path,op_name=f"clear_directory {path}",distributed=False,soft_timeout=180,# this can take a while on GCS, for example)elifget_fs_local_rank()==0:self.trainer.run_bookkeeping_op(clear_directory,path,op_name=f"clear_directory {path}",distributed=False)def_schedule_for_removal(self,path:str):self._checkpoints_to_remove.append(path)def_remove_old_checkpoints(self):forpathinself._checkpoints_to_remove:self._remove_checkpoint(path)self._checkpoints_to_remove.clear()defpre_train(self):ifnotself.enabled:returnifself.save_asyncisNone:self.save_async=backend_supports_cpu()# Maybe create a new process group for async checkpointing.ifis_distributed()andself.save_asyncandself.checkpointer.process_groupisNone:ifnotbackend_supports_cpu():raiseRuntimeError("a CPU-capable backend is required for async checkpointing")log.info("Creating new process group for checkpointing (needed for async checkpointing)")self.checkpointer.process_group=dist.new_group(timeout=timedelta(minutes=30))# Maybe save a pre-train checkpoint.ifself.step==0and(self.pre_train_checkpointor(self.pre_train_checkpointisNoneandnotself.trainer.checkpoint_loaded)):self._checkpoints.append(self._save_checkpoint())# Collect existing ephemeral checkpoints from previous runs.ifself.remove!=CheckpointRemovalStrategy.never:ephemeral_checkpoints:List[Tuple[int,str]]=[]# Only search from rank 0 to avoid hammering remote file stores with requests.ifget_rank()==0:try:forstep_num,pathinself.checkpointer.find_checkpoints(self.save_folder):if(step_num==0orstep_num>self.stepor(self.fixed_stepsisnotNoneandstep_numinself.fixed_steps)or(self.save_intervalisnotNoneandstep_num%self.save_interval==0)):continueelif(self.remove==CheckpointRemovalStrategy.ephemeral_onlyandself.ephemeral_save_intervalisnotNoneandstep_num%self.ephemeral_save_interval==0)or(self.remove==CheckpointRemovalStrategy.all_non_permanent):ephemeral_checkpoints.append((step_num,path))exceptFileNotFoundError:passephemeral_checkpoints=broadcast_object(ephemeral_checkpoints)# TODO: handle this if we ever restore callback state.assertnotself._ephemeral_checkpointsself._ephemeral_checkpoints=[pathfor_,pathinsorted(ephemeral_checkpoints,key=lambdax:x[0])]forpathinself._ephemeral_checkpoints:log.info(f"Found existing ephemeral checkpoint at '{path}' which will ""be removed when the next checkpoint is saved")defpost_train_batch(self):ifnotself.enabled:returnself._await_last_checkpoint(blocking=False)ifnotself.checkpoint_pending:self._remove_old_checkpoints()ifself.fixed_stepsisnotNoneandself.stepinself.fixed_steps:# Save permanent checkpoint.self._checkpoints.append(self._save_checkpoint())elifself.save_intervalisnotNoneandself.step%self.save_interval==0:# Save permanent checkpoint.self._checkpoints.append(self._save_checkpoint())elif(self.ephemeral_save_intervalisnotNoneandself.step%self.ephemeral_save_interval==0andnotself.trainer.block_ephemeral_checkpoints):# Maybe save ephemeral checkpoint.if(self.ephemeral_cooldownisnotNoneand(self.step-self._latest_checkpoint_step)<self.ephemeral_cooldown):returnself._ephemeral_checkpoints.append(self._save_checkpoint(ephemeral=True))# Remove old ephemeral checkpoints.whilelen(self._ephemeral_checkpoints)>1:oldest_path=self._ephemeral_checkpoints.pop(0)self._schedule_for_removal(oldest_path)defpost_train(self):ifnotself.enabled:returnifself.step>self._latest_checkpoint_step:self._checkpoints.append(self._save_checkpoint(save_async=False))self._await_last_checkpoint()