import logging
import warnings
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from math import cos, pi, sqrt
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np
import torch
from ..config import Config, Registrable, StrEnum
from ..exceptions import OLMoConfigurationError
from .config import INITIAL_LR_FIELD, LR_FIELD
if TYPE_CHECKING:
from olmo_core.train import Trainer
log = logging.getLogger(__name__)
[docs]
class SchedulerUnits(StrEnum):
steps = "steps"
tokens = "tokens"
[docs]
@dataclass
class Scheduler(Config, Registrable, metaclass=ABCMeta):
"""
Learning rate scheduler base class.
"""
lr_field: str = LR_FIELD
initial_lr_field: str = INITIAL_LR_FIELD
units: SchedulerUnits = SchedulerUnits.steps
[docs]
@abstractmethod
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
"""
Get the learning rate given the initial/max learning rate, current step/token count, and the maximum
number of steps/tokens.
"""
raise NotImplementedError
[docs]
def set_lr(self, group: Dict[str, Any], trainer: "Trainer") -> Union[float, torch.Tensor]:
"""
Set the learning rate on an optimizer param group given a trainer's state.
"""
if (lr_field := self.lr_field) not in group and (
initial_lr_field := self.initial_lr_field
) not in group:
group_fields_list = "\n - ".join(
[f"{k}: {v}" for k, v in group.items() if k != "params"]
)
raise RuntimeError(
f"learning rate field '{lr_field}' and initial learning rate field "
f"'{initial_lr_field}' not found in optimizer param group "
f"with {len(group['params'])} parameter(s):\n"
f" - {group_fields_list}"
)
# Ensure 'initial_lr' is set.
if group.get(self.initial_lr_field) is None:
group[self.initial_lr_field] = group[self.lr_field]
# Set new LR.
if self.units == SchedulerUnits.steps:
if trainer.max_steps is None:
raise OLMoConfigurationError(
"'max_steps' must be known in the trainer for step-based scheduling."
)
new_lr = self.get_lr(
group[self.initial_lr_field],
trainer.global_step,
trainer.max_steps,
)
elif self.units == SchedulerUnits.tokens:
if trainer.max_tokens is None:
raise OLMoConfigurationError(
"'max_tokens' must be known in the trainer for token-based scheduling."
)
new_lr = self.get_lr(
group[self.initial_lr_field],
trainer.global_train_tokens_seen,
trainer.max_tokens,
)
else:
raise NotImplementedError(self.units)
if isinstance(current_lr := group.get(self.lr_field), torch.Tensor):
current_lr.fill_(new_lr)
else:
group[self.lr_field] = new_lr
return new_lr
[docs]
@Scheduler.register("constant")
@dataclass
class ConstantScheduler(Scheduler):
"""
Constant learning rate schedule, basically a no-op.
"""
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
del current, t_max
return initial_lr
[docs]
@Scheduler.register("constant_with_warmup")
@dataclass
class ConstantWithWarmup(Scheduler):
"""
Constant learning rate schedule with a warmup.
"""
warmup: Optional[int] = None
warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead.
warmup_fraction: Optional[float] = None
warmup_min_lr: float = 0.0
def __post_init__(self, *args):
del args
if self.warmup is None and self.warmup_steps is not None:
self.warmup = self.warmup_steps
self.warmup_steps = None
warnings.warn(
f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.",
DeprecationWarning,
)
if (self.warmup_fraction is None) == (self.warmup is None):
raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.")
if self.warmup_fraction is not None and (
self.warmup_fraction < 0 or self.warmup_fraction > 1
):
raise OLMoConfigurationError("'warmup_fraction' must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
if self.warmup is None:
assert self.warmup_fraction is not None
warmup = round(t_max * self.warmup_fraction)
else:
warmup = self.warmup
if current <= warmup:
return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr)
return initial_lr
[docs]
@Scheduler.register("wsd")
@dataclass
class WSD(Scheduler):
"""
Warmup-stable-decay scheduler
"""
warmup: Optional[int] = None
warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead.
warmup_fraction: Optional[float] = None
decay: Optional[int] = None
decay_steps: Optional[int] = None # deprecated, use 'decay' instead.
decay_fraction: Optional[float] = 0.1
warmup_min_lr: float = 0.0
decay_min_lr: float = 0.0
def __post_init__(self, *args):
del args
if self.warmup is None and self.warmup_steps is not None:
self.warmup = self.warmup_steps
self.warmup_steps = None
warnings.warn(
f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.",
DeprecationWarning,
)
if (self.warmup_fraction is None) == (self.warmup is None):
raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.")
if self.warmup_fraction is not None and (
self.warmup_fraction < 0 or self.warmup_fraction > 1
):
raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
if self.decay is None and self.decay_steps is not None:
self.decay = self.decay_steps
self.decay_steps = None
warnings.warn(
f"'{self.__class__.__name__}.decay_steps' is deprecated, please use '.decay' instead.",
DeprecationWarning,
)
if (self.decay_fraction is None) == (self.decay is None):
raise OLMoConfigurationError(
"Either 'decay_fraction' or 'decay' must be specified. Never both."
)
if self.decay_fraction is not None and (self.decay_fraction < 0 or self.decay_fraction > 1):
raise OLMoConfigurationError("decay_fraction must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
if self.warmup is None:
assert self.warmup_fraction is not None
warmup = round(t_max * self.warmup_fraction)
else:
warmup = self.warmup
if current <= warmup:
return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr)
if self.decay is None:
assert self.decay_fraction is not None
decay = round(t_max * self.decay_fraction)
else:
decay = self.decay
if current >= t_max - decay:
return _linear_decay(initial_lr, t_max - current, decay, self.decay_min_lr)
return initial_lr
[docs]
@Scheduler.register("linear_with_warmup")
@dataclass
class LinearWithWarmup(Scheduler):
"""
Linear learning rate schedule with a warmup.
"""
alpha_f: float = 0.1
t_max: Optional[int] = None
warmup: Optional[int] = None
warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead.
warmup_fraction: Optional[float] = None
warmup_min_lr: float = 0.0
def __post_init__(self, *args):
del args
if self.warmup is None and self.warmup_steps is not None:
self.warmup = self.warmup_steps
self.warmup_steps = None
warnings.warn(
f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.",
DeprecationWarning,
)
if (self.warmup_fraction is None) == (self.warmup is None):
raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.")
if self.warmup_fraction is not None and (
self.warmup_fraction < 0 or self.warmup_fraction > 1
):
raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
t_max = t_max if self.t_max is None else self.t_max
eta_min = initial_lr * self.alpha_f
if self.warmup is None:
assert self.warmup_fraction is not None
warmup = round(t_max * self.warmup_fraction)
else:
warmup = self.warmup
if current < warmup:
return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr)
elif current >= t_max:
return eta_min
else:
current = current - warmup
t_max = t_max - warmup
return initial_lr - (initial_lr - eta_min) * (current / t_max)
[docs]
@Scheduler.register("inv_sqrt_with_warmup")
@dataclass
class InvSqrtWithWarmup(Scheduler):
"""
Inverse square root learning rate (LR) schedule with a warmup.
"""
alpha_f: float = 0.1
warmup: Optional[int] = None
warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead.
warmup_fraction: Optional[float] = None
warmup_min_lr: float = 0.0
def __post_init__(self, *args):
del args
if self.warmup is None and self.warmup_steps is not None:
self.warmup = self.warmup_steps
self.warmup_steps = None
warnings.warn(
f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.",
DeprecationWarning,
)
if (self.warmup_fraction is None) == (self.warmup is None):
raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.")
if self.warmup_fraction is not None and (
self.warmup_fraction < 0 or self.warmup_fraction > 1
):
raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
if self.warmup is None:
assert self.warmup_fraction is not None
warmup = round(t_max * self.warmup_fraction)
else:
warmup = self.warmup
if current < warmup:
return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr)
eta_min = initial_lr * self.alpha_f
return eta_min + (initial_lr - eta_min) * sqrt(warmup / current)
[docs]
@Scheduler.register("cos_with_warmup")
@dataclass
class CosWithWarmup(Scheduler):
"""
Cosine learning rate schedule with a warmup.
"""
warmup: Optional[int] = None
warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead.
warmup_fraction: Optional[float] = None
alpha_f: float = 0.1
t_max: Optional[int] = None
warmup_min_lr: float = 0.0
def __post_init__(self, *args):
del args
if self.warmup is None and self.warmup_steps is not None:
self.warmup = self.warmup_steps
self.warmup_steps = None
warnings.warn(
f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.",
DeprecationWarning,
)
if (self.warmup_fraction is None) == (self.warmup is None):
raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.")
if self.warmup_fraction is not None and (
self.warmup_fraction < 0 or self.warmup_fraction > 1
):
raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
t_max = t_max if self.t_max is None else self.t_max
eta_min = initial_lr * self.alpha_f
if self.warmup is None:
assert self.warmup_fraction is not None
warmup = round(t_max * self.warmup_fraction)
else:
warmup = self.warmup
if current < warmup:
return _linear_warmup(initial_lr, current, warmup, self.warmup_min_lr)
elif current >= t_max:
return eta_min
else:
current = current - warmup
t_max = t_max - warmup
return eta_min + (initial_lr - eta_min) * (1 + cos(pi * current / t_max)) / 2
[docs]
@Scheduler.register("half_cos_with_warmup")
@dataclass
class HalfCosWithWarmup(Scheduler):
"""
Second half of a cosine learning rate schedule, with a warmup before that.
Note: This assumes that the peak LR set is for the full cosine schedule.
"""
warmup: Optional[int] = None
warmup_steps: Optional[int] = None # deprecated, use 'warmup' instead.
warmup_fraction: Optional[float] = None
alpha_f: float = 0.1
t_max: Optional[int] = None
warmup_min_lr: float = 0.0
def __post_init__(self, *args):
del args
if self.warmup is None and self.warmup_steps is not None:
self.warmup = self.warmup_steps
self.warmup_steps = None
warnings.warn(
f"'{self.__class__.__name__}.warmup_steps' is deprecated, please use '.warmup' instead.",
DeprecationWarning,
)
if (self.warmup_fraction is None) == (self.warmup is None):
raise OLMoConfigurationError("Either 'warmup_fraction' or 'warmup' must be specified.")
if self.warmup_fraction is not None and (
self.warmup_fraction < 0 or self.warmup_fraction > 1
):
raise OLMoConfigurationError("warmup_fraction must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
t_max = t_max if self.t_max is None else self.t_max
eta_min = initial_lr * self.alpha_f
if self.warmup is None:
assert self.warmup_fraction is not None
warmup = round(t_max * self.warmup_fraction)
else:
warmup = self.warmup
if current < warmup:
max_lr = eta_min + (initial_lr - eta_min) / 2
return _linear_warmup(max_lr, current, warmup, self.warmup_min_lr)
elif current >= t_max:
return eta_min
else:
current = current - warmup
t_max = t_max - warmup
current += t_max
t_max *= 2
return eta_min + (initial_lr - eta_min) * (1 + cos(pi * current / t_max)) / 2
[docs]
@Scheduler.register("cos_with_warmup_and_linear_decay")
@dataclass
class CosWithWarmupAndLinearDecay(CosWithWarmup):
"""
Cosine learning rate schedule with a warmup, cut short at the end and followed by a linear decay.
"""
decay: Optional[int] = None
decay_steps: Optional[int] = None # deprecated, use 'decay' instead.
decay_fraction: Optional[float] = 0.1
decay_min_lr: float = 0.0
def __post_init__(self, *args):
del args
super().__post_init__()
if self.decay is None and self.decay_steps is not None:
self.decay = self.decay_steps
self.decay_steps = None
warnings.warn(
f"'{self.__class__.__name__}.decay_steps' is deprecated, please use '.decay' instead.",
DeprecationWarning,
)
if (self.decay_fraction is None) == (self.decay is None):
raise OLMoConfigurationError("Either 'decay_fraction' or 'decay' must be specified.")
if self.decay_fraction is not None and (self.decay_fraction < 0 or self.decay_fraction > 1):
raise OLMoConfigurationError("'decay_fraction' must be between 0 and 1.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
if self.decay is None:
assert self.decay_fraction is not None
decay = round(t_max * self.decay_fraction)
else:
decay = self.decay
if current >= t_max - decay:
final_cosine_lr = super().get_lr(initial_lr, t_max - decay, t_max)
return _linear_decay(final_cosine_lr, t_max - current, decay, self.decay_min_lr)
return super().get_lr(initial_lr, current, t_max)
def _linear_warmup(
initial_lr: Union[float, torch.Tensor], current: int, warmup: int, warmup_min_lr: float = 0.0
) -> Union[float, torch.Tensor]:
if isinstance(initial_lr, float): # not worth the potential host-device sync if it's a tensor
assert 0 <= warmup_min_lr < initial_lr
return warmup_min_lr + (initial_lr - warmup_min_lr) * min(current, warmup) / warmup
def _linear_decay(
initial_lr: Union[float, torch.Tensor],
step_from_end: int,
decay: int,
decay_min_lr: float = 0.0,
) -> Union[float, torch.Tensor]:
if isinstance(initial_lr, float): # not worth the potential host-device sync if it's a tensor
assert 0 <= decay_min_lr < initial_lr
return decay_min_lr + (initial_lr - decay_min_lr) * min(step_from_end, decay) / decay
[docs]
@Scheduler.register("sequential")
@dataclass
class SequentialScheduler(Scheduler):
"""
A scheduler that calls a sequence of schedulers sequentially during the optimization
process. The initial LR of a scheduler in the sequence is set to the final LR of the
previous scheduler.
"""
schedulers: List[Scheduler] = field(default_factory=lambda: [ConstantWithWarmup()])
schedulers_max: Optional[List[int]] = None
"""
A list of the steps or token counts for which each scheduler runs.
The last scheduler is assumed to run until the end of training, so any value provided for it is ignored.
"""
schedulers_max_steps: Optional[List[int]] = None # deprecated, use 'schedulers_max' instead.
def __post_init__(self, *args):
del args
if self.schedulers_max is None and self.schedulers_max_steps is not None:
self.schedulers_max = self.schedulers_max_steps
self.schedulers_max_steps = None
warnings.warn(
f"'{self.__class__.__name__}.schedulers_max_steps' is deprecated, please use '.schedulers_max' instead.",
DeprecationWarning,
)
if self.schedulers_max is None:
raise OLMoConfigurationError("'schedulers_max' must be specified")
if len(self.schedulers_max) == len(self.schedulers):
log.info(
"Max steps are set for the last scheduler in sequential scheduling. "
"The last scheduler is assumed to run until the end of training, so this value is ignored."
)
self.schedulers_max.pop()
if len(self.schedulers_max) + 1 != len(self.schedulers):
raise OLMoConfigurationError(
f"Max steps must be set for all schedulers except the last when using '{self.__class__.__name__}'"
)
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
assert 0 <= current <= t_max
assert self.schedulers_max is not None
# Call schedulers sequentially until the current step/token count is within the max steps/token count
# of the scheduler or the last scheduler is reached.
for scheduler, scheduler_max in zip(self.schedulers[:-1], self.schedulers_max, strict=True):
if current <= scheduler_max:
return scheduler.get_lr(initial_lr, current, min(t_max, scheduler_max))
# The next scheduler's initial LR should be the final LR of the current schedule
initial_lr = scheduler.get_lr(initial_lr, scheduler_max, scheduler_max)
current -= scheduler_max
t_max -= scheduler_max
assert t_max > 0
return self.schedulers[-1].get_lr(initial_lr, current, t_max)
[docs]
@Scheduler.register("wsds")
@dataclass
class WSDS(Scheduler):
"""
Warmup–Stable–Decay—Simplified (WSD‑S) scheduler for continual pretraining.
Reference: https://arxiv.org/abs/2410.05192
"""
period_lengths: List[int] = field(default_factory=list)
period_lr_multipliers: Optional[List[float]] = None
warmup: Optional[int] = None
warmup_fraction: Optional[float] = None
decay: Optional[int] = None
decay_fraction: Optional[float] = None
warmup_min_lr: float = 0.0
decay_min_lr: float = 0.0
_cum_period_end: List[int] = field(default_factory=list, init=False, repr=False)
_warmup_steps: int = field(default=0, init=False, repr=False)
_adjusted_period_lengths: List[int] = field(default_factory=list, init=False, repr=False)
def __post_init__(self, *args):
del args
if not self.period_lengths:
raise OLMoConfigurationError("'period_lengths' must be provided and non-empty.")
if any(p <= 0 for p in self.period_lengths):
raise OLMoConfigurationError("All entries in 'period_lengths' must be > 0.")
if self.period_lr_multipliers is not None:
if len(self.period_lr_multipliers) != len(self.period_lengths):
raise OLMoConfigurationError(
"'period_lr_multipliers' length must match 'period_lengths' length."
)
if any(m <= 0.0 for m in self.period_lr_multipliers):
raise OLMoConfigurationError("All entries in 'period_lr_multipliers' must be > 0.")
# warmup validation
if (self.warmup is None) == (self.warmup_fraction is None):
raise OLMoConfigurationError(
"Exactly one of 'warmup' or 'warmup_fraction' must be specified."
)
if self.warmup_fraction is not None and not (0.0 <= self.warmup_fraction <= 1.0):
raise OLMoConfigurationError("'warmup_fraction' must be in [0, 1].")
# decay validation
if (self.decay is None) == (self.decay_fraction is None):
raise OLMoConfigurationError(
"Exactly one of 'decay' or 'decay_fraction' must be specified."
)
if self.decay_fraction is not None and not (0.0 <= self.decay_fraction <= 1.0):
raise OLMoConfigurationError("'decay_fraction' must be in [0, 1].")
if self.decay_min_lr < 0.0:
raise OLMoConfigurationError("'decay_min_lr' must be >= 0.")
# Resolve warmup based on first period length
L0 = self.period_lengths[0]
if self.warmup is not None:
self._warmup_steps = int(self.warmup)
else:
assert self.warmup_fraction is not None
self._warmup_steps = int(round(self.warmup_fraction * L0))
# Validate first period: warmup + decay <= L0
D0 = self._resolve_decay(L0)
if self._warmup_steps + D0 > L0:
raise OLMoConfigurationError(
f"First period: warmup ({self._warmup_steps}) + decay ({D0}) = "
f"{self._warmup_steps + D0} exceeds period length ({L0})."
)
# Validate remaining periods: decay <= Li
for i, Li in enumerate(self.period_lengths[1:], start=1):
Di = self._resolve_decay(Li)
if Di > Li:
raise OLMoConfigurationError(
f"Period {i}: decay ({Di}) exceeds period length ({Li})."
)
# Adjust period lengths: subtract warmup from first period
self._adjusted_period_lengths = [L0 - self._warmup_steps] + self.period_lengths[1:]
# Precompute cumulative ends based on ADJUSTED periods
self._cum_period_end = np.cumsum(self._adjusted_period_lengths).tolist()
def _resolve_decay(self, Li: int) -> int:
if self.decay is not None:
return int(self.decay)
else:
assert self.decay_fraction is not None
return int(round(self.decay_fraction * Li))
def _find_period(self, x: int) -> int:
for idx, end in enumerate(self._cum_period_end):
if x <= end:
return idx
return len(self._cum_period_end) - 1
def _get_peak_lr(
self, initial_lr: Union[float, torch.Tensor], pidx: int
) -> Union[float, torch.Tensor]:
if self.period_lr_multipliers is None:
return initial_lr
else:
return initial_lr * self.period_lr_multipliers[pidx]
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
del t_max
if current < self._warmup_steps:
return _linear_warmup(
self._get_peak_lr(initial_lr, 0), current, self._warmup_steps, self.warmup_min_lr
)
adjusted_current = current - self._warmup_steps
if adjusted_current >= self._cum_period_end[-1]:
return self.decay_min_lr
# Find current period (using adjusted boundaries)
pidx = self._find_period(adjusted_current)
start = 0 if pidx == 0 else self._cum_period_end[pidx - 1]
Li = self._adjusted_period_lengths[pidx]
pos = min(max(adjusted_current - start, 0), Li)
D = self._resolve_decay(self.period_lengths[pidx])
S = Li - D
if pos < S:
return self._get_peak_lr(initial_lr, pidx)
else:
t = pos - S
return _linear_decay(self._get_peak_lr(initial_lr, pidx), D - t, D, self.decay_min_lr)
[docs]
@Scheduler.register("exponential")
@dataclass
class ExponentialScheduler(Scheduler):
"""
Exponential learning rate schedule that increases from a minimum LR to a maximum LR. Thus:
- lr(0) = lr_min
- lr(t_max) = initial_lr
"""
lr_min: float = 1e-9
def __post_init__(self, *args):
del args
if self.lr_min <= 0:
raise OLMoConfigurationError("'lr_min' must be positive.")
[docs]
def get_lr(
self, initial_lr: Union[float, torch.Tensor], current: int, t_max: int
) -> Union[float, torch.Tensor]:
if current >= t_max:
return initial_lr
if current == 0:
return self.lr_min
# Exponential growth: lr(t) = lr_min * (lr_max / lr_min)^(t / t_max)
ratio = current / t_max
if isinstance(initial_lr, torch.Tensor):
growth_factor = torch.pow(initial_lr / self.lr_min, ratio)
else:
growth_factor = (initial_lr / self.lr_min) ** ratio
return self.lr_min * growth_factor