[docs]classSkipStepOptimizer(Optimizer):""" A :class:`SkipStepOptimizer` is an optimizer that can skip updates when the loss or gradient norm for a step is above a certain threshold of standard deviations computed over a rolling interval. .. important:: When using a :class:`SkipStepOptimizer` you must always set :data:`latest_loss` and :data:`latest_grad_norm` to the current loss and grad norm, respectively, *before* calling :meth:`step()`. The :class:`~olmo_core.train.train_module.TransformerTrainModule` will automatically set the :data:`latest_loss` and :data:`latest_grad_norm` whenever its optimizer is a subclass of :class:`SkipStepOptimizer`. .. tip:: When implementing a :class:`SkipStepOptimizer` you should be careful to avoid host-device syncs. You can use :meth:`get_step_factor()` within your :meth:`step()` method to do this. See the implementation of :class:`SkipStepLion` for an example. """def__init__(self,params:ParamsT,defaults:Dict[str,Any],rolling_interval_length:int=128,sigma_factor:int=6,)->None:super().__init__(params,defaults)self.rolling_interval_length=rolling_interval_lengthself.sigma_factor=sigma_factorself._losses:List[torch.Tensor]=[]self._grad_norms:List[torch.Tensor]=[]self._device:Optional[torch.device]=None@propertydefdevice(self)->torch.device:ifself._deviceisNone:forgroupinself.param_groups:forpingroup["params"]:ifp.numel()>0:self._device=p.devicebreakifself._deviceisNone:self._device=get_default_device()returnself._device@propertydeflatest_loss(self)->Optional[torch.Tensor]:ifnotself._losses:returnNoneelse:returnself._losses[-1]@latest_loss.setterdeflatest_loss(self,loss:torch.Tensor):self._losses.append(loss)whilelen(self._losses)>self.rolling_interval_length+1:self._losses.pop(0)@propertydeflatest_grad_norm(self)->Optional[torch.Tensor]:ifnotself._grad_norms:returnNoneelse:returnself._grad_norms[-1]@latest_grad_norm.setterdeflatest_grad_norm(self,grad_norm:torch.Tensor):self._grad_norms.append(grad_norm)whilelen(self._grad_norms)>self.rolling_interval_length+1:self._grad_norms.pop(0)
[docs]@torch._dynamo.disable()defget_step_factor(self)->torch.Tensor:""" Returns a float tensor which will be `1.0` if the optimizer should proceed with the step and `0.0` if the optimizer should skip the step. The tensor can be used within the optimizer's step computation to essentially skip a step without a host-device sync. """iflen(self._losses)<max(2,self.rolling_interval_length//2):returnmove_to_device(torch.tensor(1.0),self.device)loss_std,loss_mean=torch.std_mean(torch.stack(self._losses[:-1]))assertself.latest_lossisnotNoneifself._grad_norms:assertself.latest_grad_normisnotNonegrad_norm_std,grad_norm_mean=torch.std_mean(torch.stack(self._grad_norms[:-1]))step_factor=torch.logical_and((self.latest_loss-loss_mean)<=self.sigma_factor*loss_std,(self.latest_grad_norm-grad_norm_mean)<=self.sigma_factor*grad_norm_std,)else:step_factor=(self.latest_loss-loss_mean)<=self.sigma_factor*loss_stdreturnstep_factor.float()
@propertydefstep_skipped(self)->torch.Tensor:""" Returns a float tensor which will be `1.0` if the step was skipped and `0.0` otherwise. """return1-self.get_step_factor()