Source code for olmo_core.train.common

from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Dict, Optional, Tuple

import torch

from ..config import StrEnum
from ..data.utils import get_labels
from ..utils import format_float, format_int, format_timedelta

TRAIN_CE_LOSS_METRIC = "train/CE loss"
TRAIN_PPL_METRIC = "train/PPL"
TRAIN_Z_LOSS_METRIC = "train/Z loss"
OPTIM_GRAD_NORM_METRIC = "optim/total grad norm"


[docs] class DurationUnit(StrEnum): """ Units that can be used to define a :class:`Duration`. """ steps = "steps" """ Steps (batches). """ epochs = "epochs" """ Epochs. """ tokens = "tokens" """ Tokens. """
[docs] @dataclass class Duration: value: int """ The value of the duration. """ unit: DurationUnit """ The unit associated with the :data:`value`. """
[docs] @classmethod def steps(cls, steps: int) -> "Duration": """ Define a duration from a number of steps. """ return cls(value=steps, unit=DurationUnit.steps)
[docs] @classmethod def epochs(cls, epochs: int) -> "Duration": """ Define a duration from a number of epochs. """ return cls(value=epochs, unit=DurationUnit.epochs)
[docs] @classmethod def tokens(cls, tokens: int) -> "Duration": """ Define a duration from a number of tokens. """ return cls(value=tokens, unit=DurationUnit.tokens)
[docs] @classmethod def chinchilla_tokens( cls, multiple: float, *, model_params: int, _tok_per_param: int = 20 ) -> "Duration": """ Define a duration based on a multiple of the Chinchilla-optimal number of tokens. The rule of thumb for Chinchilla compute optimality is 20 tokens-per-parameter for decoder-only natural language models trained with AdamW on dataset mixtures similar to the Pile. Chinchilla optimality refers to training-time compute only, and does not account for inference-time compute. In practice, models are often trained with more tokens than the Chinchilla optimal value ("overtrained") to improve inference-time performance. Chinchilla: https://arxiv.org/abs/2203.15556 Chinchilla replication: https://arxiv.org/abs/2404.10102 :param multiple: The Chinchilla multiplier. 1.0 is the Chinchilla optimal value. Values less than 1.0 will undertrain relative to Chinchilla, and values greater than 1.0 will overtrain relative to Chinchilla. :param model_params: The number of *active, non-embedding* parameters in the target model. """ tokens = int(_tok_per_param * model_params * multiple) return Duration.tokens(tokens)
[docs] def due(self, *, step: int, tokens: int, epoch: int) -> bool: """ Check if the duration is due. """ if self.unit == DurationUnit.steps: return step >= self.value elif self.unit == DurationUnit.tokens: return tokens >= self.value elif self.unit == DurationUnit.epochs: return epoch > self.value else: raise NotImplementedError
[docs] class LoadStrategy(StrEnum): """ Determines the strategy for loading checkpoints prior to training. """ if_available = "if_available" """ The trainer will attempt to load a checkpoint from the save folder or load path (in that order) but will train from scratch if no checkoint is found. """ always = "always" """ The trainer will attempt to load a checkpoint from the save folder or load path (in that order) and raise an error if no checkpoint is found. """ never = "never" """ The trainer will never load a checkpoint even if one exists in the save folder or load path. """
[docs] class ReduceType(StrEnum): """ An enumeration of the allowed ways to reduce a metric across ranks. """ mean = "mean" """ Average across the process group. """ sum = "sum" """ Add across the process group. """ max = "max" """ Take the max across the process group. """ l2_norm = "l2_norm" """ For metrics that are computed as L2 norms on each rank, this will correctly reduce the norm across the process group to produce the global L2 norm. """
[docs] class MetricMergeStrategy(StrEnum): """ Determines how duplicate metrics are merged. """ warn = "warn" """ Warn when a duplicate is logged, keeping the current value. """ latest = "latest" """ The latest is used. """ oldest = "oldest" """ The oldest (first logged) is used. """ mean = "mean" """ When a duplicate is logged we take the average with the last value. """ sum = "sum" """ The sum of the duplicates is used. """ max = "max" """ Take the maximum value of the duplicates. """ min = "min" """ Take the minimum value of the duplicates. """
def reshape_inputs_for_loss( logits: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # shape: (B * S, V) logits_for_loss = logits.view(-1, logits.size(-1)) # shape: (B, S) -> (B * S,) labels_for_loss = labels.view(-1) return logits_for_loss, labels_for_loss def get_inputs_for_loss( batch: Dict[str, Any], logits: torch.Tensor, label_ignore_index: int = -100 ) -> Tuple[torch.Tensor, torch.Tensor]: return reshape_inputs_for_loss( logits, batch.get("labels", get_labels(batch, label_ignore_index=label_ignore_index)) ) @dataclass class TrainingProgress: current_step: int """ The current training step. """ current_tokens: Optional[int] = None """ The current number of tokens processed during training. """ total_steps: Optional[int] = None """ The step that training will stop at. """ time_remaining: Optional[timedelta] = None """ Estimated time remaining. """ bps: Optional[float] = None """ The average training speed in batches per second. """ tps: Optional[float] = None """ The average training speed in tokens per second per device. """ mfu: Optional[float] = None """ The average model flops utilization (MFU) percentage. """ def __str__(self) -> str: if self.total_steps is not None: progress_perc = min(100, int(100 * self.current_step / self.total_steps)) progress_str = ( f"{progress_perc}% complete, step {self.current_step:,d}/{self.total_steps:,d}" ) else: progress_str = f"step {self.current_step:,d}/???" if self.current_tokens is not None: progress_str += f", {format_int(self.current_tokens)} tokens" if self.time_remaining is not None: progress_str += f", eta {format_timedelta(self.time_remaining)}" if self.tps is not None: progress_str += f", {format_float(self.tps)} TPS" elif self.bps is not None: progress_str += f", {format_float(self.bps)} BPS" if self.mfu is not None: progress_str += f", {format_float(self.mfu)}% MFU" return progress_str
[docs] @dataclass class StepSkipRange: """Defines a range of steps to skip during training.""" start: int """The first step to skip (steps start at 1, not 0).""" stop: int """The endpoint of the range (exclusive)."""