Source code for olmo_core.optim.scheduler

import logging
import warnings
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from math import cos, pi, sqrt
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
import torch

from ..config import Config, Registrable, StrEnum
from ..exceptions import OLMoConfigurationError
from .config import INITIAL_LR_FIELD, LR_FIELD

if TYPE_CHECKING:
    from olmo_core.train import Trainer

log = logging.getLogger(__name__)


[docs] class SchedulerUnits(StrEnum): steps = "steps" tokens = "tokens"
[docs] @dataclass class Scheduler(Config, Registrable, metaclass=ABCMeta): """ Learning rate scheduler base class. """ lr_field: str = LR_FIELD initial_lr_field: str = INITIAL_LR_FIELD units: SchedulerUnits = SchedulerUnits.steps
[docs] @abstractmethod def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: """ Get the learning rate given the initial/max learning rate, current step/token count, and the maximum number of steps/tokens. """ raise NotImplementedError
[docs] def set_lr(self, group: Dict[str, Any], trainer: "Trainer") -> Union[float, torch.Tensor]: """ Set the learning rate on an optimizer param group given a trainer's state. """ if (lr_field := self.lr_field) not in group and ( initial_lr_field := self.initial_lr_field ) not in group: group_fields_list = "\n - ".join( [f"{k}: {v}" for k, v in group.items() if k != "params"] ) raise RuntimeError( f"learning rate field '{lr_field}' and initial learning rate field " f"'{initial_lr_field}' not found in optimizer param group " f"with {len(group['params'])} parameter(s):\n" f" - {group_fields_list}" ) # Ensure 'initial_lr' is set. if group.get(self.initial_lr_field) is None: group[self.initial_lr_field] = group[self.lr_field] # Set new LR. if self.units == SchedulerUnits.steps: if trainer.max_steps is None: raise OLMoConfigurationError( "'max_steps' must be known in the trainer for step-based scheduling." ) new_lr = self.get_lr( group[self.initial_lr_field], trainer.global_step, trainer.max_steps, ) elif self.units == SchedulerUnits.tokens: if trainer.max_tokens is None: raise OLMoConfigurationError( "'max_tokens' must be known in the trainer for token-based scheduling." ) new_lr = self.get_lr( group[self.initial_lr_field], trainer.global_train_tokens_seen, trainer.max_tokens, ) else: raise NotImplementedError(self.units) if isinstance(current_lr := group.get(self.lr_field), torch.Tensor): current_lr.fill_(new_lr) else: group[self.lr_field] = new_lr return new_lr
[docs] @Scheduler.register("constant") @dataclass class ConstantScheduler(Scheduler): """ Constant learning rate schedule, basically a no-op. """
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: del current, t_max return initial_lr
[docs] @Scheduler.register("constant_with_warmup") @dataclass class ConstantWithWarmup(Scheduler): """ Constant learning rate schedule with a warmup. """ warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead. warmup_fraction: Optional[float] = None warmup_min_lr: float = 0.0 def __post_init__(self, *args): del args if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and ( self.warmup_fraction < 0 or self.warmup_fraction > 1 ): raise OLMoConfigurationError("'warmup_fraction' must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: if self.warmup is None: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) else: warmup = self.warmup if current <= warmup: return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr) return initial_lr
[docs] @Scheduler.register("wsd") @dataclass class WSD(Scheduler): """ Warmup-stable-decay scheduler """ warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead. warmup_fraction: Optional[float] = None decay: Optional[int] = None decay_steps: Optional[int] = None # deprecated, use 'decay' instead. decay_fraction: Optional[float] = 0.1 warmup_min_lr: float = 0.0 decay_min_lr: float = 0.0 def __post_init__(self, *args): del args if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and ( self.warmup_fraction < 0 or self.warmup_fraction > 1 ): raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.") if self.decay is None and self.decay_steps is not None: self.decay = self.decay_steps self.decay_steps = None warnings.warn( f"'{self.__class__.__name__}.decay_steps' is deprecated, please use '.decay' instead.", DeprecationWarning, ) if (self.decay_fraction is None) == (self.decay is None): raise OLMoConfigurationError( "Either 'decay_fraction' or 'decay' must be specified. Never both." ) if self.decay_fraction is not None and (self.decay_fraction < 0 or self.decay_fraction > 1): raise OLMoConfigurationError("decay_fraction must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: if self.warmup is None: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) else: warmup = self.warmup if current <= warmup: return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr) if self.decay is None: assert self.decay_fraction is not None decay = round(t_max * self.decay_fraction) else: decay = self.decay if current >= t_max - decay: return _linear_decay(initial_lr, t_max - current, decay, self.decay_min_lr) return initial_lr
[docs] @Scheduler.register("power_lr") @dataclass class PowerLR(Scheduler): """ Power learning‑rate schedule with 1. **Linear warm‑up** to a reference peak LR (`initial_lr`) during the first `warmup` steps/tokens. 2. **Power phase** where the LR decays following a power‑law `lr = initial_lr * (current / warmup) ** b`. This makes the LR independent of the eventual training horizon. 3. **Optional linear decay tail** during the last `decay` steps/tokens to smoothly anneal to `decay_min_lr`. Notes ----- * `b` should be *negative* (e.g. ‑0.51); magnitude controls how quickly the LR decays in the power phase. * If both `warmup` and `warmup_fraction` (or both `decay` and `decay_fraction`) are specified, an `OLMoConfigurationError` is raised to mirror the behaviour of other schedulers in this file. """ b: float = -0.51 # power‑law exponent (negative) warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated alias warmup_fraction: Optional[float] = None warmup_min_lr: float = 0.0 decay: Optional[int] = None decay_steps: Optional[int] = None # deprecated alias decay_fraction: Optional[float] = 0.1 decay_min_lr: float = 0.0 def __post_init__(self, *args): del args # --- handle deprecated aliases ------------------------------------------------- if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if self.decay is None and self.decay_steps is not None: self.decay = self.decay_steps self.decay_steps = None warnings.warn( f"'{self.__class__.__name__}.decay_steps' is deprecated, please use '.decay' instead.", DeprecationWarning, ) # --- sanity checks ------------------------------------------------------------- if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and not (0 <= self.warmup_fraction <= 1): raise OLMoConfigurationError("'warmup_fraction' must be between 0 and 1.") if (self.decay_fraction is None) == (self.decay is None): raise OLMoConfigurationError( "Either 'decay_fraction' or 'decay' must be specified. Never both." ) if self.decay_fraction is not None and not (0 <= self.decay_fraction <= 1): raise OLMoConfigurationError("'decay_fraction' must be between 0 and 1.") if self.b >= 0: raise OLMoConfigurationError("'b' must be negative for a decaying power‑law.") # ------------------------------------------------------------------------- # Core scheduling logic # -------------------------------------------------------------------------
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: """ Compute the learning rate for the given *current* step/token count. *Linear warm‑up*: lr = warmup_min_lr + (initial_lr - warmup_min_lr) * current / warmup *Power phase*: lr = initial_lr * (current / warmup) ** b *Linear decay tail* (last ``decay`` steps/tokens): lr is linearly annealed from the power‑phase value at the start of the tail to ``decay_min_lr``. """ # --- warm‑up and decay extents ------------------------------------------------ if self.warmup is not None: warmup = self.warmup else: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) if self.decay is not None: decay = self.decay else: assert self.decay_fraction is not None decay = round(t_max * self.decay_fraction) # --- phase 1: warm‑up --------------------------------------------------------- if current <= warmup: return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr) # --- phase 3: linear decay tail ---------------------------------------------- if current >= t_max - decay: # lr at the beginning of decay‑tail lr_start_tail = initial_lr * ((t_max - decay) / warmup) ** self.b return _linear_decay(lr_start_tail, t_max - current, decay, self.decay_min_lr) # --- phase 2: power‑law region ----------------------------------------------- lr = initial_lr * (current / warmup) ** self.b return lr
[docs] @Scheduler.register("linear_with_warmup") @dataclass class LinearWithWarmup(Scheduler): """ Linear learning rate schedule with a warmup. """ alpha_f: float = 0.1 t_max: Optional[int] = None warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead. warmup_fraction: Optional[float] = None warmup_min_lr: float = 0.0 def __post_init__(self, *args): del args if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and ( self.warmup_fraction < 0 or self.warmup_fraction > 1 ): raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: t_max = t_max if self.t_max is None else self.t_max eta_min = initial_lr * self.alpha_f if self.warmup is None: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) else: warmup = self.warmup if current < warmup: return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr) elif current >= t_max: return eta_min else: current = current - warmup t_max = t_max - warmup return initial_lr - (initial_lr - eta_min) * (current / t_max)
[docs] @Scheduler.register("inv_sqrt_with_warmup") @dataclass class InvSqrtWithWarmup(Scheduler): """ Inverse square root learning rate (LR) schedule with a warmup. """ alpha_f: float = 0.1 warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead. warmup_fraction: Optional[float] = None warmup_min_lr: float = 0.0 def __post_init__(self, *args): del args if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and ( self.warmup_fraction < 0 or self.warmup_fraction > 1 ): raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: if self.warmup is None: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) else: warmup = self.warmup if current < warmup: return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr) eta_min = initial_lr * self.alpha_f return eta_min + (initial_lr - eta_min) * sqrt(warmup / current)
[docs] @Scheduler.register("cos_with_warmup") @dataclass class CosWithWarmup(Scheduler): """ Cosine learning rate schedule with a warmup. """ warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead. warmup_fraction: Optional[float] = None alpha_f: float = 0.1 t_max: Optional[int] = None warmup_min_lr: float = 0.0 def __post_init__(self, *args): del args if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and ( self.warmup_fraction < 0 or self.warmup_fraction > 1 ): raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: t_max = t_max if self.t_max is None else self.t_max eta_min = initial_lr * self.alpha_f if self.warmup is None: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) else: warmup = self.warmup if current < warmup: return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr) elif current >= t_max: return eta_min else: current = current - warmup t_max = t_max - warmup return eta_min + (initial_lr - eta_min) * (1 + cos(pi * current / t_max)) / 2
[docs] @Scheduler.register("half_cos_with_warmup") @dataclass class HalfCosWithWarmup(Scheduler): """ Second half of a cosine learning rate schedule, with a warmup before that. Note: This assumes that the peak LR set is for the full cosine schedule. """ warmup: Optional[int] = None warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead. warmup_fraction: Optional[float] = None alpha_f: float = 0.1 t_max: Optional[int] = None warmup_min_lr: float = 0.0 def __post_init__(self, *args): del args if self.warmup is None and self.warmup_steps is not None: self.warmup = self.warmup_steps self.warmup_steps = None warnings.warn( f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.", DeprecationWarning, ) if (self.warmup_fraction is None) == (self.warmup is None): raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.") if self.warmup_fraction is not None and ( self.warmup_fraction < 0 or self.warmup_fraction > 1 ): raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: t_max = t_max if self.t_max is None else self.t_max eta_min = initial_lr * self.alpha_f if self.warmup is None: assert self.warmup_fraction is not None warmup = round(t_max * self.warmup_fraction) else: warmup = self.warmup if current < warmup: max_lr = eta_min + (initial_lr - eta_min) / 2 return _linear_warmup(max_lr, current, warmup, self.warmup_min_lr) elif current >= t_max: return eta_min else: current = current - warmup t_max = t_max - warmup current += t_max t_max *= 2 return eta_min + (initial_lr - eta_min) * (1 + cos(pi * current / t_max)) / 2
[docs] @Scheduler.register("cos_with_warmup_and_linear_decay") @dataclass class CosWithWarmupAndLinearDecay(CosWithWarmup): """ Cosine learning rate schedule with a warmup, cut short at the end and followed by a linear decay. """ decay: Optional[int] = None decay_steps: Optional[int] = None # deprecated, use 'decay' instead. decay_fraction: Optional[float] = 0.1 decay_min_lr: float = 0.0 def __post_init__(self, *args): del args super().__post_init__() if self.decay is None and self.decay_steps is not None: self.decay = self.decay_steps self.decay_steps = None warnings.warn( f"'{self.__class__.__name__}.decay_steps' is deprecated, please use '.decay' instead.", DeprecationWarning, ) if (self.decay_fraction is None) == (self.decay is None): raise OLMoConfigurationError("Either 'decay_fraction' or 'decay' must be specified.") if self.decay_fraction is not None and (self.decay_fraction < 0 or self.decay_fraction > 1): raise OLMoConfigurationError("'decay_fraction' must be between 0 and 1.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: if self.decay is None: assert self.decay_fraction is not None decay = round(t_max * self.decay_fraction) else: decay = self.decay if current >= t_max - decay: final_cosine_lr = super().get_lr(initial_lr, t_max - decay, t_max) return _linear_decay(final_cosine_lr, t_max - current, decay, self.decay_min_lr) return super().get_lr(initial_lr, current, t_max)
class ComposableSchedulerStageType(StrEnum): linear = "linear" cosine = "cosine" @dataclass class ComposableSchedulerStage(Config): """ A single stage in :class:`ComposableScheduler`. """ duration: int shape: ComposableSchedulerStageType = ComposableSchedulerStageType.linear start_lr: Optional[float] = None start_lr_fraction: Optional[float] = None end_lr: Optional[float] = None end_lr_fraction: Optional[float] = None def __post_init__(self): if self.duration <= 0: raise OLMoConfigurationError("'duration' must be > 0 for every stage.") if self.start_lr is not None and self.start_lr_fraction is not None: raise OLMoConfigurationError( "Specify at most one of 'start_lr' or 'start_lr_fraction' for a stage." ) if self.start_lr is not None and self.start_lr < 0: raise OLMoConfigurationError("'start_lr' must be >= 0.") if self.start_lr_fraction is not None and self.start_lr_fraction < 0: raise OLMoConfigurationError("'start_lr_fraction' must be >= 0.") if (self.end_lr is None) == (self.end_lr_fraction is None): raise OLMoConfigurationError( "Specify exactly one of 'end_lr' or 'end_lr_fraction' for each stage." ) if self.end_lr is not None and self.end_lr < 0: raise OLMoConfigurationError("'end_lr' must be >= 0.") if self.end_lr_fraction is not None and self.end_lr_fraction < 0: raise OLMoConfigurationError("'end_lr_fraction' must be >= 0.") @dataclass class OverrideDecay(Config): """ Optional decay override for :class:`ComposableScheduler`. When active, the main stage schedule is ignored from ``start`` onward. The override starts from the LR that the main schedule would have produced at ``start``, and then decays to ``end_lr`` (or ``end_lr_fraction`` of ``initial_lr``) over ``duration``. """ start: int duration: int shape: ComposableSchedulerStageType = ComposableSchedulerStageType.linear end_lr: Optional[float] = None end_lr_fraction: Optional[float] = None def __post_init__(self): if self.start < 0: raise OLMoConfigurationError("'start' must be >= 0 for override decay.") if self.duration <= 0: raise OLMoConfigurationError("'duration' must be > 0 for override decay.") if (self.end_lr is None) == (self.end_lr_fraction is None): raise OLMoConfigurationError( "Specify exactly one of 'end_lr' or 'end_lr_fraction' for override decay." ) if self.end_lr is not None and self.end_lr < 0: raise OLMoConfigurationError("'end_lr' must be >= 0 for override decay.") if self.end_lr_fraction is not None and self.end_lr_fraction < 0: raise OLMoConfigurationError("'end_lr_fraction' must be >= 0 for override decay.") @Scheduler.register("composable") @dataclass class ComposableScheduler(Scheduler): """ Piecewise LR schedule composed of multiple stages. - Each stage has a duration and interpolation shape (linear/cosine). - Stage start LR defaults to the previous stage's end LR. - Stage end LR is required and can be absolute or as a fraction of ``initial_lr``. - After all stages are exhausted, LR stays constant at the last stage's end LR. - The ``t_max`` argument passed to :meth:`get_lr` is **ignored**: the schedule is defined absolutely by the stage durations (and the optional :data:`override_decay`), so it does not rescale to fit the trainer's max horizon. """ stages: List[ComposableSchedulerStage] = field(default_factory=list) override_decay: Optional[OverrideDecay] = None _warned_t_max_ignored: bool = field(default=False, init=False, repr=False) def __post_init__(self, *args): del args if len(self.stages) == 0: raise OLMoConfigurationError("'stages' must be specified and non-empty.") def _resolve_stage_start( self, stage: ComposableSchedulerStage, initial_lr: Union[float, torch.Tensor], previous_end_lr: Union[float, torch.Tensor], ) -> Union[float, torch.Tensor]: if stage.start_lr is None and stage.start_lr_fraction is None: return previous_end_lr return _resolve_lr_from_initial(initial_lr, stage.start_lr, stage.start_lr_fraction) def _main_schedule_lr( self, initial_lr: Union[float, torch.Tensor], current: int ) -> Union[float, torch.Tensor]: current = max(current, 0) stage_start = 0 previous_end_lr: Union[float, torch.Tensor] = initial_lr for stage in self.stages: start_lr = self._resolve_stage_start(stage, initial_lr, previous_end_lr) end_lr = _resolve_lr_from_initial(initial_lr, stage.end_lr, stage.end_lr_fraction) stage_end = stage_start + stage.duration if current < stage_end: stage_current = current - stage_start return _interpolate_lr( shape=stage.shape, start_lr=start_lr, end_lr=end_lr, current=stage_current, duration=stage.duration, ) previous_end_lr = end_lr stage_start = stage_end return previous_end_lr def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: """ Compute the LR at ``current``. .. note:: ``t_max`` is ignored. The schedule is defined absolutely by the stage durations (and the optional :data:`override_decay`), independent of the trainer's max horizon. """ if not self._warned_t_max_ignored: total_duration = sum(stage.duration for stage in self.stages) warnings.warn( f"'{self.__class__.__name__}' ignores 't_max'; the schedule is defined " f"absolutely by stage durations (total={total_duration}, " f"t_max={t_max}). The LR will not rescale to fit the trainer's horizon.", UserWarning, stacklevel=2, ) self._warned_t_max_ignored = True del t_max current = max(current, 0) if self.override_decay is None or current < self.override_decay.start: return self._main_schedule_lr(initial_lr, current) override_decay = self.override_decay override_start_lr = self._main_schedule_lr(initial_lr, override_decay.start) override_end_lr = _resolve_override_decay_end(override_decay, initial_lr) override_current = current - override_decay.start if override_current < override_decay.duration: return _interpolate_lr( shape=override_decay.shape, start_lr=override_start_lr, end_lr=override_end_lr, current=override_current, duration=override_decay.duration, ) return override_end_lr def _resolve_lr_from_initial( initial_lr: Union[float, torch.Tensor], value: Optional[float], fraction: Optional[float], ) -> Union[float, torch.Tensor]: if value is not None: return value assert fraction is not None return initial_lr * fraction def _resolve_override_decay_end( override_decay: OverrideDecay, initial_lr: Union[float, torch.Tensor], ) -> Union[float, torch.Tensor]: return _resolve_lr_from_initial( initial_lr, override_decay.end_lr, override_decay.end_lr_fraction ) def _interpolate_lr( shape: ComposableSchedulerStageType, start_lr: Union[float, torch.Tensor], end_lr: Union[float, torch.Tensor], current: int, duration: int, ) -> Union[float, torch.Tensor]: if shape == ComposableSchedulerStageType.linear: return start_lr + (end_lr - start_lr) * current / duration elif shape == ComposableSchedulerStageType.cosine: return end_lr + (start_lr - end_lr) * (1 + cos(pi * current / duration)) / 2 else: raise NotImplementedError(shape) def _linear_warmup( initial_lr: Union[float, torch.Tensor], current: int, warmup: int, warmup_min_lr: float = 0.0 ) -> Union[float, torch.Tensor]: if isinstance(initial_lr, float): # not worth the potential host-device sync if it's a tensor assert 0 <= warmup_min_lr < initial_lr return warmup_min_lr + (initial_lr - warmup_min_lr) * min(current, warmup) / warmup def _linear_decay( initial_lr: Union[float, torch.Tensor], step_from_end: int, decay: int, decay_min_lr: float = 0.0, ) -> Union[float, torch.Tensor]: if isinstance(initial_lr, float): # not worth the potential host-device sync if it's a tensor assert 0 <= decay_min_lr < initial_lr return decay_min_lr + (initial_lr - decay_min_lr) * min(step_from_end, decay) / decay
[docs] @Scheduler.register("sequential") @dataclass class SequentialScheduler(Scheduler): """ A scheduler that calls a sequence of schedulers sequentially during the optimization process. The initial LR of a scheduler in the sequence is set to the final LR of the previous scheduler. """ schedulers: List[Scheduler] = field(default_factory=lambda: [ConstantWithWarmup()]) schedulers_max: Optional[List[int]] = None """ A list of the steps or token counts for which each scheduler runs. The last scheduler is assumed to run until the end of training, so any value provided for it is ignored. """ schedulers_max_steps: Optional[List[int]] = None # deprecated, use 'schedulers_max' instead. override_decay: Optional[OverrideDecay] = None """ Optional late-stage override. When ``current >= override_decay.start``, the sub-scheduler sequence is bypassed and the LR decays from "whatever the main sequence would have produced at ``start``" to the override's target over ``duration`` (linear or cosine). After ``start + duration``, the LR is held at the override's end LR. .. note:: While the override is active, ``t_max`` is ignored — the override is defined absolutely by ``start`` and ``duration``. """ _warned_t_max_ignored: bool = field(default=False, init=False, repr=False) def __post_init__(self, *args): del args if self.schedulers_max is None and self.schedulers_max_steps is not None: self.schedulers_max = self.schedulers_max_steps self.schedulers_max_steps = None warnings.warn( f"'{self.__class__.__name__}.schedulers_max_steps' is deprecated, please use '.schedulers_max' instead.", DeprecationWarning, ) if self.schedulers_max is None: raise OLMoConfigurationError("'schedulers_max' must be specified") if len(self.schedulers_max) == len(self.schedulers): log.info( "Max steps are set for the last scheduler in sequential scheduling. " "The last scheduler is assumed to run until the end of training, so this value is ignored." ) self.schedulers_max.pop() if len(self.schedulers_max) + 1 != len(self.schedulers): raise OLMoConfigurationError( f"Max steps must be set for all schedulers except the last when using '{self.__class__.__name__}'" ) def _sequential_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: assert self.schedulers_max is not None # Call schedulers sequentially until the current step/token count is within the max steps/token count # of the scheduler or the last scheduler is reached. for scheduler, scheduler_max in zip(self.schedulers[:-1], self.schedulers_max, strict=True): if current <= scheduler_max: return scheduler.get_lr(initial_lr, current, min(t_max, scheduler_max)) # The next scheduler's initial LR should be the final LR of the current schedule initial_lr = scheduler.get_lr(initial_lr, scheduler_max, scheduler_max) current -= scheduler_max t_max -= scheduler_max assert t_max > 0 return self.schedulers[-1].get_lr(initial_lr, current, t_max)
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: assert 0 <= current <= t_max if self.override_decay is None or current < self.override_decay.start: return self._sequential_lr(initial_lr, current, t_max) if not self._warned_t_max_ignored: warnings.warn( f"'{self.__class__.__name__}' ignores 't_max' once 'override_decay' is active; " f"the override is defined absolutely by 'start' ({self.override_decay.start}) " f"and 'duration' ({self.override_decay.duration}) (t_max={t_max}).", UserWarning, stacklevel=2, ) self._warned_t_max_ignored = True override_decay = self.override_decay override_start_lr = self._sequential_lr(initial_lr, override_decay.start, t_max) override_end_lr = _resolve_override_decay_end(override_decay, initial_lr) override_current = current - override_decay.start if override_current < override_decay.duration: return _interpolate_lr( shape=override_decay.shape, start_lr=override_start_lr, end_lr=override_end_lr, current=override_current, duration=override_decay.duration, ) return override_end_lr
[docs] @Scheduler.register("wsds") @dataclass class WSDS(Scheduler): """ Warmup–Stable–Decay—Simplified (WSD‑S) scheduler for continual pretraining. Reference: https://arxiv.org/abs/2410.05192 """ period_lengths: List[int] = field(default_factory=list) period_lr_multipliers: Optional[List[float]] = None warmup: Optional[int] = None warmup_fraction: Optional[float] = None decay: Optional[int] = None decay_fraction: Optional[float] = None warmup_min_lr: float = 0.0 decay_min_lr: float = 0.0 _cum_period_end: List[int] = field(default_factory=list, init=False, repr=False) _warmup_steps: int = field(default=0, init=False, repr=False) _adjusted_period_lengths: List[int] = field(default_factory=list, init=False, repr=False) def __post_init__(self, *args): del args if not self.period_lengths: raise OLMoConfigurationError("'period_lengths' must be provided and non-empty.") if any(p <= 0 for p in self.period_lengths): raise OLMoConfigurationError("All entries in 'period_lengths' must be > 0.") if self.period_lr_multipliers is not None: if len(self.period_lr_multipliers) != len(self.period_lengths): raise OLMoConfigurationError( "'period_lr_multipliers' length must match 'period_lengths' length." ) if any(m <= 0.0 for m in self.period_lr_multipliers): raise OLMoConfigurationError("All entries in 'period_lr_multipliers' must be > 0.") # warmup validation if (self.warmup is None) == (self.warmup_fraction is None): raise OLMoConfigurationError( "Exactly one of 'warmup' or 'warmup_fraction' must be specified." ) if self.warmup_fraction is not None and not (0.0 <= self.warmup_fraction <= 1.0): raise OLMoConfigurationError("'warmup_fraction' must be in [0, 1].") # decay validation if (self.decay is None) == (self.decay_fraction is None): raise OLMoConfigurationError( "Exactly one of 'decay' or 'decay_fraction' must be specified." ) if self.decay_fraction is not None and not (0.0 <= self.decay_fraction <= 1.0): raise OLMoConfigurationError("'decay_fraction' must be in [0, 1].") if self.decay_min_lr < 0.0: raise OLMoConfigurationError("'decay_min_lr' must be >= 0.") # Resolve warmup based on first period length L0 = self.period_lengths[0] if self.warmup is not None: self._warmup_steps = int(self.warmup) else: assert self.warmup_fraction is not None self._warmup_steps = int(round(self.warmup_fraction * L0)) # Validate first period: warmup + decay <= L0 D0 = self._resolve_decay(L0) if self._warmup_steps + D0 > L0: raise OLMoConfigurationError( f"First period: warmup ({self._warmup_steps}) + decay ({D0}) = " f"{self._warmup_steps + D0} exceeds period length ({L0})." ) # Validate remaining periods: decay <= Li for i, Li in enumerate(self.period_lengths[1:], start=1): Di = self._resolve_decay(Li) if Di > Li: raise OLMoConfigurationError( f"Period {i}: decay ({Di}) exceeds period length ({Li})." ) # Adjust period lengths: subtract warmup from first period self._adjusted_period_lengths = [L0 - self._warmup_steps] + self.period_lengths[1:] # Precompute cumulative ends based on ADJUSTED periods self._cum_period_end = np.cumsum(self._adjusted_period_lengths).tolist() def _resolve_decay(self, Li: int) -> int: if self.decay is not None: return int(self.decay) else: assert self.decay_fraction is not None return int(round(self.decay_fraction * Li)) def _find_period(self, x: int) -> int: for idx, end in enumerate(self._cum_period_end): if x <= end: return idx return len(self._cum_period_end) - 1 def _get_peak_lr( self, initial_lr: Union[float, torch.Tensor], pidx: int ) -> Union[float, torch.Tensor]: if self.period_lr_multipliers is None: return initial_lr else: return initial_lr * self.period_lr_multipliers[pidx]
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: del t_max if current < self._warmup_steps: return _linear_warmup( self._get_peak_lr(initial_lr, 0), current, self._warmup_steps, self.warmup_min_lr ) adjusted_current = current - self._warmup_steps if adjusted_current >= self._cum_period_end[-1]: return self.decay_min_lr # Find current period (using adjusted boundaries) pidx = self._find_period(adjusted_current) start = 0 if pidx == 0 else self._cum_period_end[pidx - 1] Li = self._adjusted_period_lengths[pidx] pos = min(max(adjusted_current - start, 0), Li) D = self._resolve_decay(self.period_lengths[pidx]) S = Li - D if pos < S: return self._get_peak_lr(initial_lr, pidx) else: t = pos - S return _linear_decay(self._get_peak_lr(initial_lr, pidx), D - t, D, self.decay_min_lr)
[docs] @Scheduler.register("exponential") @dataclass class ExponentialScheduler(Scheduler): """ Exponential learning rate schedule that increases from a minimum LR to a maximum LR. Thus: - lr(0) = lr_min - lr(t_max) = initial_lr """ lr_min: float = 1e-9 def __post_init__(self, *args): del args if self.lr_min <= 0: raise OLMoConfigurationError("'lr_min' must be positive.")
[docs] def get_lr( self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int ) -> Union[float, torch.Tensor]: if current >= t_max: return initial_lr if current == 0: return self.lr_min # Exponential growth: lr(t) = lr_min * (lr_max / lr_min)^(t / t_max) ratio = current / t_max if isinstance(initial_lr, torch.Tensor): growth_factor = torch.pow(initial_lr / self.lr_min, ratio) else: growth_factor = (initial_lr / self.lr_min) ** ratio return self.lr_min * growth_factor