Source code for olmo_core.nn.functional.cross_entropy_loss

import logging
from typing import Callable, Literal, Optional, Tuple

import torch
import torch.nn.functional as F

__all__ = ["cross_entropy_loss", "fused_linear_cross_entropy_loss"]

log = logging.getLogger(__name__)


[docs] def cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, *, ignore_index: int = -100, reduction: Literal["mean", "sum", "none"] = "mean", compute_z_loss: bool = False, z_loss_multiplier: float = 1e-4, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Cross entropy loss that optionally computes the softmax auxiliary loss (z-loss) as well. :param logits: Predicted unnormalized logits with shape ``(N, vocab_size)``. :param labels: Ground truth class indices with shape ``(N,)``. :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. :param reduction: Specifies the reduction to apply to the output. Can be "none", "mean", or "sum". :param compute_z_loss: Compute the softmax auxiliary loss as well. :param z_loss_multiplier: The multiplier to apply to the z-loss. :returns: The cross entropy loss and optionally the z-loss. """ logits = logits.float() loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction) if not compute_z_loss: return loss, None z_squared = logits.logsumexp(-1).pow(2) mask = labels != ignore_index if reduction == "mean": z_squared = (z_squared * mask).sum() / mask.sum() elif reduction == "sum": z_squared = (z_squared * mask).sum() z_loss = z_loss_multiplier * z_squared return loss, z_loss
_fused_linear_cross_entropy_loss: Optional[Callable] = None try: from liger_kernel.ops.fused_linear_cross_entropy import ( # type: ignore LigerFusedLinearCrossEntropyFunction, ) _fused_linear_cross_entropy_loss = LigerFusedLinearCrossEntropyFunction.apply except ImportError: pass except Exception: log.exception("Error importing liger-kernel")
[docs] @torch._dynamo.disable() def fused_linear_cross_entropy_loss( _input: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, *, bias: Optional[torch.Tensor] = None, ignore_index: int = -100, reduction: Literal["mean", "sum", "none"] = "mean", compute_z_loss: bool = False, z_loss_multiplier: float = 1e-4, ce_weight: Optional[torch.Tensor] = None, label_smoothing: float = 0.0, softcap: Optional[float] = None, accum_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Cross entropy loss fused with the linear layer that computes the logits, which avoids materialization of the large logits tensor. Additionally, this function computes gradients during the forward pass, (valid when CrossEntropyLoss comes last), so _input and labels do not need to be stored for the backwards pass. :param _input: The inputs to pass through the linear layer to produce the logits ``(N, D)``. :param weight: The weight of the linear layer. :param labels: Ground truth class indices with shape ``(N,)``. :param bias: Optional bias for the linear layer. :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. :param reduction: Specifies the reduction to apply to the output. Can be "none", "mean", or "sum". :param compute_z_loss: Compute the softmax auxiliary loss as well. :param z_loss_multiplier: The multiplier to apply to the z-loss. :param accum_dtype: The dtype of intermediate result buffers for weight and bias gradient accumulations. Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default to performing accumulations in original dtype. :returns: The cross entropy loss and optionally the z-loss. """ if _fused_linear_cross_entropy_loss is None: raise RuntimeError("'fused_linear_cross_entropy_loss' requires liger-kernel") ce_loss, z_loss, per_token_acc = _fused_linear_cross_entropy_loss( _input, weight, labels, bias, ce_weight, ignore_index, z_loss_multiplier, label_smoothing, reduction, softcap, compute_z_loss, accum_dtype, ) del per_token_acc if compute_z_loss: return ce_loss, z_loss else: return ce_loss, None