[docs]classNoOpOptimizer(SkipStepOptimizer):""" A no-op optimizer that performs no parameter updates but maintains all step skipping logic. This optimizer is useful for gathering statistics from training without actually modifying the model parameters. It tracks losses and gradient norms, computes step factors based on rolling statistics, but does not apply any updates to the model. """def__init__(self,params,lr:float=1e-3,rolling_interval_length:int=128,sigma_factor:int=6,)->None:defaults=dict(lr=lr)super().__init__(params,defaults,rolling_interval_length=rolling_interval_length,sigma_factor=sigma_factor,)self._step_skipped:Optional[torch.Tensor]=None@propertydefstep_skipped(self)->torch.Tensor:ifself._step_skippedisnotNone:returnself._step_skippedelse:returntorch.tensor(0.0)
[docs]@torch.no_grad()defstep(self,closure=None)->None:ifclosureisnotNone:withtorch.enable_grad():closure()# Compute step factor to maintain step skipping logicstep_factor=self.get_step_factor()self._step_skipped=1-step_factor# Iterate through parameters to maintain optimizer structure# but perform no updatesforgroupinself.param_groups:forpingroup["params"]:ifp.gradisNone:continue# Initialize state if needed (for consistency)state=self.state[p]iflen(state)==0:state["step"]=torch.zeros((),dtype=torch.float32,device=p.device)# Increment step counterstate["step"]+=step_factor
[docs]@OptimConfig.register("noop")@dataclassclassNoOpConfig(OptimConfig[NoOpOptimizer]):""" Configuration class for building a :class:`NoOpOptimizer`. This optimizer performs no parameter updates but maintains step skipping logic for gathering statistics during training. """lr:float=1e-3"""Learning rate (not used for updates, but maintained for compatibility)."""rolling_interval_length:int=128""" The length of the rolling interval to use for computing the mean and standard deviation of the loss and gradient norm. """sigma_factor:int=6""" The number of standard deviations above the mean loss/grad norm to skip a step. """