nn.lm_head

class olmo_core.nn.lm_head.LMHeadType(value)[source]

Bases: StrEnum

An enumeration of the different LM head types.

default = 'default'

➡️ LMHead

normalized = 'normalized'

➡️ NormalizedLMHead

class olmo_core.nn.lm_head.LMLossImplementation(value)[source]

Bases: StrEnum

An enumeration of the different loss implementations.

default = 'default'

Uses native PyTorch’s operations.

fused_linear = 'fused_linear'

A low-memory triton implementation from Liger-Kernel that fused the linear logits projection with the loss computation.

class olmo_core.nn.lm_head.LMHeadConfig(name='default', layer_norm=None, bias=None, dtype='float32', loss_implementation='default')[source]

Bases: ModuleConfig

A configuration class for building any of the LMHead implementations.

See the LMHead subclasses to learn which fields are valid for each implementation.

name: LMHeadType = 'default'

The name of the implementation.

num_params(d_model, vocab_size)[source]

The number of parameters in the module once built.

Return type:

int

build(*, d_model, vocab_size, init_device='cpu')[source]

Construct the corresponding LM head implementation.

Parameters:
  • d_model (int) – The model dimensionality.

  • init_device (str, default: 'cpu') – The device initialize the parameters on, e.g. “cpu”, “meta”.

Return type:

LMHead

class olmo_core.nn.lm_head.LMHead(*, d_model, vocab_size, layer_norm=None, dtype=torch.float32, bias=True, init_device='cpu', loss_implementation='default')[source]

Bases: Module

The default language modeling head implementation.

forward(x, *, labels=None, ignore_index=-100, loss_reduction='mean', z_loss_multiplier=None, loss_div_factor=None, return_logits=None, logits_to_keep=0)[source]

Applies the language modeling (LM) head to the input hidden states.

Parameters:
  • x (Tensor) – The input hidden states of shape (batch_size, seq_len, d_model).

  • labels (Optional[Tensor], default: None) – (Optional) Target token IDs of shape (batch_size, seq_len). If provided, the method computes and returns the loss.

  • ignore_index (int, default: -100) – Specifies a target value that is ignored and does not contribute to the loss.

  • loss_reduction (Literal['mean', 'sum', 'none'], default: 'mean') – Specifies the reduction to apply to the output loss: “mean”, “sum”, or “none”.

  • z_loss_multiplier (Optional[float], default: None) – (Optional) Multiplier for the z-loss regularization term.

  • loss_div_factor (Union[Tensor, float, None], default: None) – (Optional) Divisor for the loss, can be a scalar or tensor.

  • return_logits (Optional[bool], default: None) – If True, returns logits along with the loss when labels are provided.

  • logits_to_keep (Union[int, Tensor], default: 0) – If nonzero, restricts computation to the last N positions (if int) or to specific positions (if tensor).

Return type:

Union[Tensor, LMOutputWithLoss]

Returns:

If labels is None, returns the logits tensor of shape (batch_size, seq_len, vocab_size). If labels is provided, returns an LMOutputWithLoss named tuple containing the loss and optionally the logits.

class olmo_core.nn.lm_head.NormalizedLMHead(*, d_model, vocab_size, dtype=torch.float32, init_device='cpu', loss_implementation='default')[source]

Bases: LMHead

An nGPT LM head implementation.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

reset_parameters()[source]

Reset the scaling parameter.

forward(x, *, labels=None, ignore_index=-100, loss_reduction='mean', z_loss_multiplier=None, loss_div_factor=None, return_logits=None, logits_to_keep=0)[source]

Applies the language modeling (LM) head to the input hidden states.

Parameters:
  • x (Tensor) – The input hidden states of shape (batch_size, seq_len, d_model).

  • labels (Optional[Tensor], default: None) – (Optional) Target token IDs of shape (batch_size, seq_len). If provided, the method computes and returns the loss.

  • ignore_index (int, default: -100) – Specifies a target value that is ignored and does not contribute to the loss.

  • loss_reduction (Literal['mean', 'sum', 'none'], default: 'mean') – Specifies the reduction to apply to the output loss: “mean”, “sum”, or “none”.

  • z_loss_multiplier (Optional[float], default: None) – (Optional) Multiplier for the z-loss regularization term.

  • loss_div_factor (Union[Tensor, float, None], default: None) – (Optional) Divisor for the loss, can be a scalar or tensor.

  • return_logits (Optional[bool], default: None) – If True, returns logits along with the loss when labels are provided.

  • logits_to_keep (Union[int, Tensor], default: 0) – If nonzero, restricts computation to the last N positions (if int) or to specific positions (if tensor).

Return type:

Union[Tensor, LMOutputWithLoss]

Returns:

If labels is None, returns the logits tensor of shape (batch_size, seq_len, vocab_size). If labels is provided, returns an LMOutputWithLoss named tuple containing the loss and optionally the logits.

class olmo_core.nn.lm_head.LMOutputWithLoss(logits, loss, ce_loss, z_loss)[source]

Bases: NamedTuple

logits: Optional[Tensor]

The LM logits.

loss: Tensor

The loss to optimize for.

ce_loss: Tensor

The CE loss (for logging only).

z_loss: Optional[Tensor]

The Z loss (for logging only).