[docs]@dataclassclassCallback(Stateful):""" Trainer callback base class. Callbacks can be used to modify and extend the behavior of the trainer loop. This module contains a number of useful :class:`Callback` implementations, but you can always add your own. """priority:ClassVar[int]=0""" Priority of the callback. Determines the order in which callbacks run relative to each other. The higher the priority, the earlier a callback runs. """# NOTE: omegaconf can't use this annotation# _trainer: Optional["Trainer"] = None_trainer=None@propertydeftrainer(self)->"Trainer":assertself._trainerisnotNonereturnself._trainer@trainer.setterdeftrainer(self,trainer:"Trainer"):self._trainer=trainer@propertydefstep(self)->int:returnself.trainer.global_step
[docs]defstate_dict(self)->Dict[str,Any]:""" Get the state dict to save. """return{}
[docs]defload_state_dict(self,state_dict:Dict[str,Any]):""" Load a state dict. """delstate_dict
[docs]defblock_ephemeral_checkpoints(self):"""Register this callback as blocking ephemeral checkpoint saves. Ephemeral saves are blocked as long as at least one callback is registered."""name=self.trainer.get_callback_name(self)self.trainer._blocking_ephemeral_checkpoints.add(name)
[docs]defunblock_ephemeral_checkpoints(self):"""Unregister this callback from blocking ephemeral checkpoint saves."""name=self.trainer.get_callback_name(self)ifnameinself.trainer._blocking_ephemeral_checkpoints:self.trainer._blocking_ephemeral_checkpoints.remove(name)
[docs]defpost_attach(self):""" Called right after the callback is attached to the :class:`~olmo_core.train.Trainer`. """pass
[docs]defpost_checkpoint_loaded(self,path:PathOrStr):""" Called when a checkpoint is successfully loaded. :param path: The path/URL to the checkpoint. """delpath
[docs]defpre_train(self):""" Runs before the training loop starts. """pass
[docs]defpre_epoch(self):""" Runs before the start of a new epoch. """pass
[docs]defpre_load_batch(self):""" Runs right before the next batch is fetched from the data loader. """pass
[docs]defpre_step(self,batch:Dict[str,Any]):""" Runs right before a training batch is processed. """delbatch
[docs]defpre_optim_step(self):""" Runs right after the forward-backward passes, right before the optimizer step. """pass
[docs]defpost_train_batch(self):""" Runs after a training batch is processed. """pass
[docs]defpost_step(self):""" Runs after a complete step (potentially including evals and checkpointing). """pass
[docs]defpost_checkpoint_saved(self,path:PathOrStr):""" Called when a checkpoint is successfully saved. :param path: The path/URL to the checkpoint. """delpath
[docs]defpre_log_metrics(self,step:int,metrics:Dict[str,float]):""" Called when metrics have been gathered for a given step (possibly a previous step), but right before :meth:`log_metrics()`. This can used to modify, add, or remove metrics by updating the ``metrics`` dict in-place. """delstep,metrics
[docs]deflog_metrics(self,step:int,metrics:Dict[str,float]):""" Called when metrics have been gathered for a given step (possibly a previous step). """delstep,metrics
[docs]defpost_epoch(self):""" Runs at the end of a complete epoch. """pass
[docs]defpost_train(self):""" Runs after the training loop successfully completes. """pass
[docs]defon_error(self,exc:BaseException):""" Called when the training loop exits with an error. """delexc
[docs]defclose(self):""" Always called right before `Trainer.fit()` exits, even on an error. """pass
[docs]@dataclassclassCallbackConfig(Callback,Config):""" An alternative way to define callbacks when the callback class itself can't be serialized. """
[docs]@abstractmethoddefbuild(self,trainer:"Trainer")->Optional[Callback]:""" Build the actual :class:`Callback`. """raiseNotImplementedError