Source code for olmo_core.nn.lm_head

import logging
import math
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    PrepareModuleInput,
    SequenceParallel,
    parallelize_module,
)
from torch.distributed.tensor.placement_types import Placement

from olmo_core.config import DType, StrEnum
from olmo_core.distributed.utils import get_local_tensor
from olmo_core.doc_utils import beta_feature
from olmo_core.exceptions import OLMoConfigurationError

from .config import ModuleConfig
from .functional import (
    cross_entropy_loss,
    fused_linear_cross_entropy_loss,
    l2_normalize,
)
from .layer_norm import LayerNormConfig

__all__ = [
    "LMHeadType",
    "LMLossImplementation",
    "LMHeadConfig",
    "LMHead",
    "NormalizedLMHead",
    "LMOutputWithLoss",
]


log = logging.getLogger(__name__)


[docs] class LMHeadType(StrEnum): """ An enumeration of the different LM head types. """ default = "default" """ ➡️ :class:`LMHead` """ normalized = "normalized" """ ➡️ :class:`NormalizedLMHead` """
[docs] class LMLossImplementation(StrEnum): """ An enumeration of the different loss implementations. """ default = "default" """ Uses native PyTorch's operations. """ fused_linear = "fused_linear" """ A low-memory triton implementation from Liger-Kernel that fused the linear logits projection with the loss computation. """
[docs] @dataclass class LMHeadConfig(ModuleConfig): """ A configuration class for building any of the :class:`LMHead` implementations. See the :class:`LMHead` subclasses to learn which fields are valid for each implementation. """ name: LMHeadType = LMHeadType.default """ The name of the implementation. """ layer_norm: Optional[LayerNormConfig] = None bias: Optional[bool] = None dtype: DType = DType.float32 loss_implementation: LMLossImplementation = LMLossImplementation.default
[docs] def num_params(self, d_model: int, vocab_size: int) -> int: """ The number of parameters in the module once built. """ bias = self.bias if self.bias is not None else self.name != LMHeadType.normalized params = 0 if self.layer_norm is not None: params += self.layer_norm.num_params(d_model) params += d_model * vocab_size if bias: params += vocab_size # Final scaling factor. if self.name == LMHeadType.normalized: params += vocab_size return params
[docs] def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "LMHead": """ Construct the corresponding LM head implementation. :param d_model: The model dimensionality. :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( d_model=d_model, vocab_size=vocab_size, init_device=init_device, dtype=kwargs.pop("dtype").as_pt(), ) try: if self.name == LMHeadType.default: return LMHead(**kwargs) elif self.name == LMHeadType.normalized: return NormalizedLMHead(**kwargs) else: raise NotImplementedError(self.name) except TypeError as e: raise OLMoConfigurationError( f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" ) from e
[docs] class LMOutputWithLoss(NamedTuple): logits: Optional[torch.Tensor] """The LM logits.""" loss: torch.Tensor """The loss to optimize for.""" ce_loss: torch.Tensor """The CE loss (for logging only).""" z_loss: Optional[torch.Tensor] """The Z loss (for logging only)."""
[docs] class LMHead(nn.Module): """ The default language modeling head implementation. """ def __init__( self, *, d_model: int, vocab_size: int, layer_norm: Optional[LayerNormConfig] = None, dtype: torch.dtype = torch.float32, bias: bool = True, init_device: str = "cpu", loss_implementation: LMLossImplementation = LMLossImplementation.default, ): super().__init__() self.norm = ( None if layer_norm is None else layer_norm.build(d_model, init_device=init_device) ) self.w_out = nn.Linear(d_model, vocab_size, bias=bias, dtype=dtype, device=init_device) self._d_model = d_model self._vocab_size = vocab_size self._loss_implementation = loss_implementation self._tp_mesh: Optional[DeviceMesh] = None self._cp_mesh: Optional[DeviceMesh] = None @property def d_model(self) -> int: return self._d_model @property def vocab_size(self) -> int: return self._vocab_size @property def loss_implementation(self) -> LMLossImplementation: return self._loss_implementation @property def tp_enabled(self) -> bool: return self._tp_mesh is not None @property def cp_enabled(self) -> bool: return self._cp_mesh is not None
[docs] def forward( self, x: torch.Tensor, *, labels: Optional[torch.Tensor] = None, ignore_index: int = -100, loss_reduction: Literal["mean", "sum", "none"] = "mean", z_loss_multiplier: Optional[float] = None, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, return_logits: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[torch.Tensor, LMOutputWithLoss]: """ Applies the language modeling (LM) head to the input hidden states. :param x: The input hidden states of shape ``(batch_size, seq_len, d_model)``. :param labels: (Optional) Target token IDs of shape ``(batch_size, seq_len)``. If provided, the method computes and returns the loss. :param ignore_index: Specifies a target value that is ignored and does not contribute to the loss. :param loss_reduction: Specifies the reduction to apply to the output loss: "mean", "sum", or "none". :param z_loss_multiplier: (Optional) Multiplier for the z-loss regularization term. :param loss_div_factor: (Optional) Divisor for the loss, can be a scalar or tensor. :param return_logits: If True, returns logits along with the loss when labels are provided. :param logits_to_keep: If nonzero, restricts computation to the last N positions (if int) or to specific positions (if tensor). :returns: If ``labels`` is ``None``, returns the logits tensor of shape ``(batch_size, seq_len, vocab_size)``. If ``labels`` is provided, returns an ``LMOutputWithLoss`` named tuple containing the loss and optionally the logits. """ B = x.shape[0] h = self.norm(x) if self.norm is not None else x if isinstance(logits_to_keep, int): if logits_to_keep != 0: # Keep only the last logits_to_keep positions h = h[:, -logits_to_keep:, :] if labels is not None: labels = labels[:, -logits_to_keep:] else: # logits_to_keep is a tensor specifying positions to keep h = h.gather(1, logits_to_keep.unsqueeze(-1).expand(-1, -1, h.size(-1))) if labels is not None: labels = labels.gather(1, logits_to_keep) if labels is None: if return_logits is False: raise RuntimeError("'return_logits=False' is only valid when 'labels' is provided") return self.w_out(h) logits: Optional[torch.Tensor] loss: torch.Tensor ce_loss: torch.Tensor z_loss: Optional[torch.Tensor] if self.loss_implementation == LMLossImplementation.default: logits = self.w_out(h) assert logits is not None ce_loss, z_loss = cross_entropy_loss( get_local_tensor(logits).view(-1, self.vocab_size), get_local_tensor(labels).contiguous().view(-1), ignore_index=ignore_index, reduction=loss_reduction, compute_z_loss=z_loss_multiplier is not None, z_loss_multiplier=z_loss_multiplier or 1e-4, ) if z_loss is not None: loss = ce_loss + z_loss else: loss = ce_loss elif self.loss_implementation == LMLossImplementation.fused_linear: logits = None loss, z_loss = fused_linear_cross_entropy_loss( get_local_tensor(h).contiguous().view(-1, self.d_model), weight=get_local_tensor(self.w_out.weight), labels=get_local_tensor(labels).contiguous().view(-1), bias=get_local_tensor(self.w_out.bias) if self.w_out.bias is not None else None, ignore_index=ignore_index, reduction=loss_reduction, compute_z_loss=z_loss_multiplier is not None, z_loss_multiplier=z_loss_multiplier or 1e-4, accum_dtype=torch.float32, # https://github.com/linkedin/Liger-Kernel/issues/512 ) if z_loss is not None: ce_loss = loss - z_loss else: ce_loss = loss else: raise NotImplementedError( f"'{self.loss_implementation}' loss implementation is not supported by {self.__class__.__name__}" ) if return_logits is False: logits = None elif return_logits is True and logits is None: raise RuntimeError( f"'return_logits=True' is not compatible '{self.loss_implementation}' loss implementation" ) return LMOutputWithLoss( logits=logits, loss=self._finalize_loss( loss, B, loss_reduction=loss_reduction, loss_div_factor=loss_div_factor ), ce_loss=self._finalize_loss( ce_loss.detach(), B, loss_reduction=loss_reduction, loss_div_factor=loss_div_factor, reduce_across_tp_group=False, ), z_loss=None if z_loss is None else self._finalize_loss( z_loss.detach(), B, loss_reduction=loss_reduction, loss_div_factor=loss_div_factor, reduce_across_tp_group=False, ), )
def _finalize_loss( self, loss: torch.Tensor, B: int, *, loss_reduction: str, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, reduce_across_tp_group: Optional[bool] = None, ) -> torch.Tensor: if reduce_across_tp_group is None: reduce_across_tp_group = self.tp_enabled if loss_reduction == "none": # Reshape to `(B, S)` loss = loss.view(B, -1) # If TP, wrap with DTensor and mark as sharded on the sequence dimension. if self.tp_enabled: assert self._tp_mesh is not None loss = DTensor.from_local(loss, self._tp_mesh, (Shard(1),)) elif reduce_across_tp_group: # Wrap with DTensor and finish the reduction. assert self._tp_mesh is not None loss = DTensor.from_local(loss.unsqueeze(0), self._tp_mesh, (Shard(0),)) loss = loss.redistribute(placements=(Replicate(),)) if loss_reduction == "sum": loss = loss.sum() elif loss_reduction == "mean": loss = loss.mean() else: raise NotImplementedError(loss_reduction) if loss_div_factor is not None: # Adjust divide factor to account for parallel strategy. if self.tp_enabled and not reduce_across_tp_group: assert self._tp_mesh is not None loss_div_factor = loss_div_factor / self._tp_mesh.size() if self.cp_enabled: assert self._cp_mesh is not None loss_div_factor = loss_div_factor / self._cp_mesh.size() # Apply divide factor. loss = loss / loss_div_factor return loss def apply_tp( self, tp_mesh: DeviceMesh, input_layouts: Optional[Tuple[Placement, Placement]] = None, ): # NOTE: there's a few cases to consider... # 1. If we're not using 'fused_linear' loss and we have a norm, then we do sequence-parallel through # the norm, colwise-parallel through 'w_out', then back to sequence-parallel for the loss. # 2. If we're not using 'fused_linear' loss and we don't have a norm, then we start with # the input replicated and proceed the same way. # 3. If we're using 'fused_linear' loss we do sequence-parallel all the way through. parallelize_module( module=self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=None if input_layouts is None else input_layouts[0], desired_input_layouts=Shard(1) if ( self.loss_implementation == LMLossImplementation.fused_linear or self.norm is not None ) else Replicate(), input_kwarg_layouts=None if input_layouts is None else {"labels": input_layouts[1]}, desired_input_kwarg_layouts={"labels": Shard(1)}, ), ) if self.norm is not None: parallelize_module( module=self.norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(), ) if self.loss_implementation == LMLossImplementation.fused_linear: parallelize_module( module=self.w_out, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(), ) else: parallelize_module( module=self.w_out, device_mesh=tp_mesh, parallelize_plan=ColwiseParallel( input_layouts=Shard(1) if self.norm is not None else Replicate(), output_layouts=Shard(1), use_local_output=False, ), ) self._tp_mesh = tp_mesh def apply_cp(self, cp_mesh: DeviceMesh): self._cp_mesh = cp_mesh def num_flops_per_token(self, seq_len: int) -> int: del seq_len # 6 FLOPs per parameter (2 ops * 3 for forward+backward) return 6 * sum(p.numel() for p in self.parameters())
[docs] @beta_feature class NormalizedLMHead(LMHead): """ An nGPT LM head implementation. """ def __init__( self, *, d_model: int, vocab_size: int, dtype: torch.dtype = torch.float32, init_device: str = "cpu", loss_implementation: LMLossImplementation = LMLossImplementation.default, ): super().__init__( d_model=d_model, vocab_size=vocab_size, layer_norm=None, bias=False, dtype=dtype, init_device=init_device, loss_implementation=loss_implementation, ) self.sz_init_value = 1.0 self.sz_init_scaling = 1.0 / math.sqrt(d_model) self.sz = nn.Parameter(torch.empty(vocab_size, dtype=dtype, device=init_device)) self.reset_parameters()
[docs] def reset_parameters(self): """ Reset the scaling parameter. """ nn.init.ones_(self.sz) with torch.no_grad(): self.sz.mul_(self.sz_init_scaling)
[docs] def forward( self, x: torch.Tensor, *, labels: Optional[torch.Tensor] = None, ignore_index: int = -100, loss_reduction: Literal["mean", "sum", "none"] = "mean", z_loss_multiplier: Optional[float] = None, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, return_logits: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[torch.Tensor, LMOutputWithLoss]: B = x.shape[0] if isinstance(logits_to_keep, int): if logits_to_keep != 0: # Keep only the last logits_to_keep positions x = x[:, -logits_to_keep:, :] if labels is not None: labels = labels[:, -logits_to_keep:] else: # logits_to_keep is a tensor specifying positions to keep x = x.gather(1, logits_to_keep.unsqueeze(-1).expand(-1, -1, x.size(-1))) if labels is not None: labels = labels.gather(1, logits_to_keep) sz = self.sz * (self.sz_init_value / self.sz_init_scaling) logits = sz * self.w_out(x) if labels is None: if return_logits is False: raise RuntimeError("'return_logits=False' is only valid when 'labels' is provided") return logits loss: torch.Tensor ce_loss: torch.Tensor z_loss: Optional[torch.Tensor] if self.loss_implementation == LMLossImplementation.default: ce_loss, z_loss = cross_entropy_loss( get_local_tensor(logits).view(-1, self.vocab_size), get_local_tensor(labels).contiguous().view(-1), ignore_index=ignore_index, reduction=loss_reduction, compute_z_loss=z_loss_multiplier is not None, z_loss_multiplier=z_loss_multiplier or 1e-4, ) if z_loss is not None: loss = ce_loss + z_loss else: loss = ce_loss else: raise NotImplementedError( f"'{self.loss_implementation}' loss implementation is not supported by '{self.__class__.__name__}'" ) if return_logits is False: logits = None elif return_logits is True and logits is None: raise RuntimeError( f"'return_logits=True' is not compatible '{self.loss_implementation}' loss implementation" ) return LMOutputWithLoss( logits=logits, loss=self._finalize_loss( loss, B, loss_reduction=loss_reduction, loss_div_factor=loss_div_factor ), ce_loss=self._finalize_loss( ce_loss.detach(), B, loss_reduction=loss_reduction, loss_div_factor=loss_div_factor, reduce_across_tp_group=False, ), z_loss=None if z_loss is None else self._finalize_loss( z_loss.detach(), B, loss_reduction=loss_reduction, loss_div_factor=loss_div_factor, reduce_across_tp_group=False, ), )
def apply_tp( self, tp_mesh: DeviceMesh, input_layouts: Optional[Tuple[Placement, Placement]] = None, ): del tp_mesh, input_layouts raise NotImplementedError("TP is not implemented yet for the normalized LM head variant") @torch.no_grad() def normalize_matrices(self): self._normalize_matrix(self.w_out.weight) def _normalize_matrix(self, w: torch.Tensor, dim: int = -1): w.copy_(l2_normalize(w, dim=dim))