nn.attention

class olmo_core.nn.attention.SlidingWindowAttentionConfig(pattern, force_full_attention_on_first_layer=True, force_full_attention_on_last_layer=True)[source]

Bases: Config

pattern: List[int]

The pattern of window sizes to use for attention, repeated to cover all layers. A value of -1 indicates full attention. For example, a pattern of [4096, 4096, 4096, -1] means that for each set of 4 layers, the first 3 will use a window size of 4096, and the last layer will use full attention.

force_full_attention_on_first_layer: bool = True

If True, the first transformer layer will always use full attention, regardless of the pattern.

force_full_attention_on_last_layer: bool = True

If True, the last transformer layer will always use full attention, regardless of the pattern.

should_use_swa(layer_idx, n_layers)[source]

Returns True if the given layer uses sliding window attention.

Return type:

bool

get_window_size(layer_idx, n_layers)[source]

Get the sliding window size for a given layer.

Return type:

int

class olmo_core.nn.attention.GateGranularity(value)[source]

Bases: StrEnum

An enumeration.

headwise = 'headwise'

one gate value per attention head, broadcast across head dimension.

Type:

Head-wise gating

elementwise = 'elementwise'

one gate value per output element.

Type:

Element-wise gating

class olmo_core.nn.attention.GateConfig(granularity='headwise', full_precision=True)[source]

Bases: Config

granularity: GateGranularity = 'headwise'

The granularity of gating to use.

full_precision: bool = True

Whether to always apply gating in full precision regardless of the input data type.

class olmo_core.nn.attention.AttentionType(value)[source]

Bases: StrEnum

An enumeration of the different attention implementations.

default = 'default'

➡️ Attention

fused = 'fused'

➡️ FusedAttention

normalized = 'normalized'

➡️ NormalizedAttention

class olmo_core.nn.attention.AttentionBackendName(value)[source]

Bases: StrEnum

An enumeration of the different attention backends.

torch = 'torch'

PyTorch’s built-in SDPA. Works on all devices. ➡️ TorchAttentionBackend

flash_2 = 'flash_2'

Flash attention 2 from the flash-attn library. Supports Ampere (SM 8.0+) and newer NVIDIA GPUs. To use this with context-parallelism, ring-flash-attn is also required. ➡️ FlashAttention2Backend

flash_3 = 'flash_3'

Flash attention 3 (beta) from the flash-attn library hopper/ subdirectory. Supports Hopper (SM 9.0) GPUs only (H100/H800). ➡️ FlashAttention3Backend

flash_4 = 'flash_4'

Flash attention 4, the CUTE implementation from flash-attn in the flash_attn/cute subdirectory. Supports Blackwell (SM 10.0, e.g. B200) GPUs only. ➡️ FlashAttention4Backend

te = 'te'

Transformer Engine attention. Supports Hopper (SM 9.0+) and newer NVIDIA GPUs. ➡️ TEAttentionBackend.

class olmo_core.nn.attention.AttentionBackend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]

Bases: Module

Encapsulates a backend for the scaled dot-product attention (SDPA) operation.

abstract classmethod assert_supported()[source]

Validates that this backend is supported on the current system. Raises an error if not supported.

abstract classmethod assert_supports_swa()[source]

Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.

abstract classmethod assert_supports_ring_cp()[source]

Validates that this backend supports ring context parallelism. Raises an error if not supported.

abstract classmethod assert_supports_ulysses_cp()[source]

Validates that this backend supports ulysses context parallelism. Raises an error if not supported.

abstract classmethod assert_supports_packed_qkv()[source]

Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.

abstract classmethod assert_supports_kv_cache()[source]

Validates that this backend supports KV caching. Raises an error if not supported.

warmup_cache(max_seq_len, device)[source]

Warmup the buffer cache.

abstract forward(qkv, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, kv_cache_manager=None)[source]

Run the attention operation.

Return type:

Tensor

apply_cp(cp_mesh, ring=None, uly=None)[source]

Apply context parallelism if supported by the backend.

class olmo_core.nn.attention.TorchAttentionBackend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]

Bases: AttentionBackend

PyTorch’s built-in scaled dot-product attention (SDPA) backend.

classmethod assert_supported()[source]

Validates that this backend is supported on the current system. Raises an error if not supported.

classmethod assert_supports_swa()[source]

Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.

classmethod assert_supports_ring_cp()[source]

Validates that this backend supports ring context parallelism. Raises an error if not supported.

classmethod assert_supports_ulysses_cp()[source]

Validates that this backend supports ulysses context parallelism. Raises an error if not supported.

classmethod assert_supports_packed_qkv()[source]

Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.

classmethod assert_supports_kv_cache()[source]

Validates that this backend supports KV caching. Raises an error if not supported.

warmup_cache(max_seq_len, device)[source]

Warmup the buffer cache.

forward(qkv, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, kv_cache_manager=None)[source]

Run the attention operation.

Return type:

Tensor

class olmo_core.nn.attention.FlashAttention2Backend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]

Bases: AttentionBackend

SDPA from the flash-attn package. Additionally, ring-flash-attn is required for context parallelism.

classmethod assert_supported()[source]

Validates that this backend is supported on the current system. Raises an error if not supported.

classmethod assert_supports_swa()[source]

Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.

classmethod assert_supports_ring_cp()[source]

Validates that this backend supports ring context parallelism. Raises an error if not supported.

classmethod assert_supports_ulysses_cp()[source]

Validates that this backend supports ulysses context parallelism. Raises an error if not supported.

classmethod assert_supports_packed_qkv()[source]

Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.

classmethod assert_supports_kv_cache()[source]

Validates that this backend supports KV caching. Raises an error if not supported.

forward(qkv, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, kv_cache_manager=None)[source]

Run the attention operation.

Return type:

Tensor

class olmo_core.nn.attention.FlashAttention3Backend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]

Bases: AttentionBackend

SDPA from the flash-attn 3 package. Does not currently support context parallelism.

classmethod assert_supported()[source]

Validates that this backend is supported on the current system. Raises an error if not supported.

classmethod assert_supports_swa()[source]

Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.

classmethod assert_supports_ring_cp()[source]

Validates that this backend supports ring context parallelism. Raises an error if not supported.

classmethod assert_supports_ulysses_cp()[source]

Validates that this backend supports ulysses context parallelism. Raises an error if not supported.

classmethod assert_supports_packed_qkv()[source]

Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.

classmethod assert_supports_kv_cache()[source]

Validates that this backend supports KV caching. Raises an error if not supported.

forward(qkv, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, kv_cache_manager=None)[source]

Run the attention operation.

Return type:

Tensor

class olmo_core.nn.attention.FlashAttention4Backend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]

Bases: AttentionBackend

SDPA from flash-attn 4 (CUTE implementation).

classmethod assert_supported()[source]

Validates that this backend is supported on the current system. Raises an error if not supported.

classmethod assert_supports_swa()[source]

Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.

classmethod assert_supports_ring_cp()[source]

Validates that this backend supports ring context parallelism. Raises an error if not supported.

classmethod assert_supports_ulysses_cp()[source]

Validates that this backend supports ulysses context parallelism. Raises an error if not supported.

classmethod assert_supports_packed_qkv()[source]

Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.

classmethod assert_supports_kv_cache()[source]

Validates that this backend supports KV caching. Raises an error if not supported.

forward(qkv, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, kv_cache_manager=None)[source]

Run the attention operation.

Return type:

Tensor

class olmo_core.nn.attention.TEAttentionBackend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]

Bases: AttentionBackend

classmethod assert_supported()[source]

Validates that this backend is supported on the current system. Raises an error if not supported.

classmethod assert_supports_swa()[source]

Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.

classmethod assert_supports_packed_qkv()[source]

Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.

classmethod assert_supports_kv_cache()[source]

Validates that this backend supports KV caching. Raises an error if not supported.

apply_cp(cp_mesh, ring=None, uly=None)[source]

Apply context parallelism if supported by the backend.

forward(qkv, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, kv_cache_manager=None)[source]

Run the attention operation.

Return type:

Tensor

class olmo_core.nn.attention.AttentionConfig(name='default', n_heads=16, n_kv_heads=None, head_dim=None, bias=None, gate=None, rope=None, clip_qkv=None, qk_norm=None, dropout=None, use_flash=None, backend=None, dtype='float32', sliding_window=None, use_head_qk_norm=None, *, type=None)[source]

Bases: SequenceMixerConfig[SequenceMixer]

A configuration class for easily building any of the different attention modules.

See the individual Attention subclasses for a description of the configuration options.

name: AttentionType = 'default'

The name of the implementation.

num_params(d_model)[source]

The number of params that the attention implementation will have once built.

Parameters:

d_model (int) – The model dimensionality.

Return type:

int

build(d_model, *, layer_idx, n_layers, init_device='cpu', cache=None)[source]

Build the corresponding attention module.

Parameters:
  • d_model (int) – The model dimensionality.

  • init_device (str, default: 'cpu') – The device to initialize the parameters on, e.g. “cpu”, “meta”.

Return type:

SequenceMixer

registered_base

alias of SequenceMixerConfig

class olmo_core.nn.attention.Attention(*, d_model, n_heads, n_kv_heads=None, head_dim=None, bias=True, gate=None, rope=None, clip_qkv=None, qk_norm=None, dropout=0.0, softmax_scale=None, use_flash=None, backend=None, window_size=None, dtype=torch.float32, init_device='cpu', cache=None, use_head_qk_norm=False)[source]

Bases: SequenceMixer

An implementation of multi-head self-attention with support for multi-query (MQA) and grouped-query (GQA) attention.

Intra-document masking is also supported by passing in the max_doc_len and cu_doc_lens parameters to forward(). This requires a backend that supports it, like the flash backend.

See also

FusedAttention if you have flash-attn installed and you’re not using MQA or GQA.

Parameters:
  • d_model (int) – The model hidden size.

  • n_heads (int) – The number of attention heads.

  • n_kv_heads (Optional[int], default: None) – The number of key and value heads, if different.

  • bias (bool, default: True) – Include biases with linear layers.

  • gate (Optional[GateConfig], default: None) – Configuration for attention gating. If None, no gating is applied.

  • rope (Optional[RoPEConfig], default: None) – The config for RoPE, if RoPE should be used.

  • clip_qkv (Optional[float], default: None) – Clip QKV to this value, if set.

  • qk_norm (Optional[LayerNormConfig], default: None) – Configuration a layer norm for queries and keys.

  • dropout (float, default: 0.0) – Dropout probability.

  • use_flash (Optional[bool], default: None) – Deprecated, use backend="flash_2" instead.

  • backend (Optional[AttentionBackendName], default: None) – The attention backend to use. If not set, it will be chosen automatically.

  • dtype (dtype, default: torch.float32) – The default data type to use for parameters.

  • init_device (str, default: 'cpu') – The device to initialize weights on.

forward(x, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, pos_sin=None, pos_cos=None, freqs_cis=None, cache_leftpad=None)[source]

Apply attention to the input.

Parameters:
  • x (Tensor) – The input of shape (batch_size, seq_len, d_model).

  • cu_doc_lens (Optional[Tensor], default: None) – Cumulative document lengths in the input x, a 1D torch.int32 tensor that should always have one more element than there are documents (the first element in the tensor should always be 0). Required together with max_doc_len when using intra-document masking.

  • max_doc_len (Optional[int], default: None) – The maximum document length in the input x. Required together with cu_doc_lens when using intra-document masking.

Return type:

Tensor

Returns:

The output of attention with shape (batch_size, seq_len, d_model).

apply_cp(cp_mesh, ring=None, uly=None)[source]

Prepare the module for context-parallelism (ring attention).

Important

This requires a backend that supports CP, such as “flash_2” or “te”.

Parameters:
init_kv_cache_manager(batch_size, max_seq_len)[source]

Initialize the kv cache manager for attention. When the kv cache manager exists, kv caching will be used during the forward pass. This should only be called during inference.

Parameters:
  • batch_size (int) – The batch size for the cache.

  • max_seq_len (int) – The maximum sequence length for the cache.

num_flops_per_token(seq_len)[source]

This accounts for: - Linear projections (Q, K, V, output, and gating if enabled) - Attention computation (QK^T and softmax(QK^T) @ V) - Sliding window attention (reduced effective sequence length)

Return type:

int

class olmo_core.nn.attention.FusedAttention(*, d_model, n_heads, bias=True, rope=None, clip_qkv=None, dropout=0.0, dtype=torch.float32, backend=None, init_device='cpu', cache=None)[source]

Bases: SequenceMixer

An “fused” implementation of multi-head self-attention.

Intra-document masking is supported by passing in the max_doc_len and cu_doc_lens parameters to forward().

Warning

Currently this is only supported with the “flash_2” backend.

Warning

If using RoPE, this requires that you use the “fused” RoPE implementation (FusedRotaryEmbedding).

Parameters:
  • d_model (int) – The model hidden size.

  • n_heads (int) – The number of attention heads.

  • bias (bool, default: True) – Include biases with linear layers.

  • rope (Optional[RoPEConfig], default: None) – The config for RoPE, if RoPE should be used.

  • clip_qkv (Optional[float], default: None) – Clip QKV to this value, if set.

  • dropout (float, default: 0.0) – Dropout probability.

  • dtype (dtype, default: torch.float32) – The default data type to use for parameters.

  • init_device (str, default: 'cpu') – The device to initialize weights on.

forward(x, max_doc_len=None, cu_doc_lens=None, pos_sin=None, pos_cos=None, freqs_cis=None, cache_leftpad=None)[source]

Apply attention to the input.

Parameters:
  • x (Tensor) – The input of shape (batch_size, seq_len, d_model).

  • max_doc_len (Optional[int], default: None) – The maximum document length in the input x. Required together with cu_doc_lens when using intra-document masking.

  • cu_doc_lens (Optional[Tensor], default: None) – Cumulative document lengths in the input x, a 1D torch.int32 tensor that should always have one more element than there are documents (the first element in the tensor should always be 0). Required together with max_doc_len when using intra-document masking.

Return type:

Tensor

Returns:

The output of attention with shape (batch_size, seq_len, d_model).

class olmo_core.nn.attention.NormalizedAttention(*, d_model, n_heads, n_kv_heads=None, rope=None, qk_norm=None, use_flash=None, backend=None, dtype=torch.float32, init_device='cpu', cache=None)[source]

Bases: Attention

An nGPT attention implementation.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

forward(x, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, pos_sin=None, pos_cos=None, freqs_cis=None, cache_leftpad=None)[source]

Apply attention to the input.

Parameters:
  • x (Tensor) – The input of shape (batch_size, seq_len, d_model).

  • cu_doc_lens (Optional[Tensor], default: None) – Cumulative document lengths in the input x, a 1D torch.int32 tensor that should always have one more element than there are documents (the first element in the tensor should always be 0). Required together with max_doc_len when using intra-document masking.

  • max_doc_len (Optional[int], default: None) – The maximum document length in the input x. Required together with cu_doc_lens when using intra-document masking.

Return type:

Tensor

Returns:

The output of attention with shape (batch_size, seq_len, d_model).

normalize_matrices()[source]

Normalize the weights in all matrices. This should be called after each optimizer step, which the TransformerTrainModule will handle for you.

class olmo_core.nn.attention.RingAttentionLoadBalancerType(value)[source]

Bases: StrEnum

An enumeration of the different RingAttentionLoadBalancer implementations.

zig_zag = 'zig_zag'

➡️ RingAttentionZigZagLoadBalancer

llama3 = 'llama3'

➡️ RingAttentionLlama3LoadBalancer

ulysses = 'ulysses'

➡️ UlyssesLoadBalancer

build(cp_mesh)[source]

Build the load balancer.

Return type:

RingAttentionLoadBalancer

class olmo_core.nn.attention.RingAttentionLoadBalancer(*, cp_rank, cp_world_size)[source]

Bases: object

A class that handles the logic of sharding inputs on the sequence dimension for ring attention (context parallelism).

abstract batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]

Shard inputs on their sequence dimension, optionally adding padding if needed.

Important

If using intra-document masking, use batch_shard_by_document() instead.

Return type:

List[Tensor]

Returns:

The local shards of the inputs.

abstract batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]

Same as batch_shard() but for strategies that support intra-document masking.

Return type:

Tuple[List[Tensor], Dict[str, Any]]

Returns:

The local shards of the inputs and any other additional inputs required for the corresponding ring attention implementation.

class olmo_core.nn.attention.RingAttentionZigZagLoadBalancer(*, cp_rank, cp_world_size)[source]

Bases: RingAttentionLoadBalancer

Implements the zig-zag load-balancing strategy.

batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]

Shard inputs on their sequence dimension, optionally adding padding if needed.

Important

If using intra-document masking, use batch_shard_by_document() instead.

Return type:

List[Tensor]

Returns:

The local shards of the inputs.

batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]

Same as batch_shard() but for strategies that support intra-document masking.

Return type:

Tuple[List[Tensor], Dict[str, Any]]

Returns:

The local shards of the inputs and any other additional inputs required for the corresponding ring attention implementation.

class olmo_core.nn.attention.RingAttentionLlama3LoadBalancer(*, cp_rank, cp_world_size)[source]

Bases: RingAttentionLoadBalancer

Implements Llama3’s load-balancing strategy for context parallelism.

The Llama3 strategy assigns each rank a contiguous slice of the full sequence. Rank i receives positions [i * local_len, (i + 1) * local_len) where local_len = total_seq_len // cp_world_size.

This strategy is designed specifically for intra-document masking with variable-length documents packed into a single sequence. It computes separate cumulative sequence lengths for queries (cu_doc_lens_q) and keys (cu_doc_lens_k) per rank, enabling proper causal masking across document boundaries within the ring attention loop. Padding is added as a synthetic document at the end when the total sequence length is not divisible by the context parallel world size.

Note

This strategy only supports batch_shard_by_document() and will raise an error if batch_shard() is called directly.

batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]

Shard inputs on their sequence dimension, optionally adding padding if needed.

Important

If using intra-document masking, use batch_shard_by_document() instead.

Return type:

List[Tensor]

Returns:

The local shards of the inputs.

batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]

Same as batch_shard() but for strategies that support intra-document masking.

Return type:

Tuple[List[Tensor], Dict[str, Any]]

Returns:

The local shards of the inputs and any other additional inputs required for the corresponding ring attention implementation.

class olmo_core.nn.attention.UlyssesLoadBalancer(*, cp_rank, cp_world_size)[source]

Bases: RingAttentionLoadBalancer

Implements simple contiguous sequence sharding for Ulysses-style context parallelism.

Unlike ring attention which uses zig-zag or other interleaving strategies, Ulysses just needs simple contiguous chunking since the all-to-all communication handles the sequence/head exchange.

batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]

Shard inputs on their sequence dimension, optionally adding padding if needed.

Important

If using intra-document masking, use batch_shard_by_document() instead.

Return type:

List[Tensor]

Returns:

The local shards of the inputs.

batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]

Same as batch_shard() but for strategies that support intra-document masking.

Return type:

Tuple[List[Tensor], Dict[str, Any]]

Returns:

The local shards of the inputs and any other additional inputs required for the corresponding ring attention implementation.

class olmo_core.nn.attention.RingContextParallelStyle(load_balancer='zig_zag', head_stride=1)[source]

Bases: Config

Configuration for ring attention-style context parallelism.

load_balancer: RingAttentionLoadBalancerType = 'zig_zag'

The type of load balancer to use for ring attention.

head_stride: int = 1

The stride of the head dimension to process for each iteration of ring attention. A value of 1 means each iteration will process one k and one v head. A value of 2 will process two k and two v heads, etc. A larger stride will reduce the number of communication ops.

class olmo_core.nn.attention.UlyssesContextParallelStyle[source]

Bases: Config

Configuration for Ulysses-style context parallelism.

class olmo_core.nn.attention.GatedDeltaNetConfig(n_heads=16, n_v_heads=None, head_dim=None, expand_v=2.0, allow_neg_eigval=True, conv_size=4, conv_bias=False, norm_eps=1e-05, dtype='float32', *, type=None)[source]

Bases: SequenceMixerConfig[GatedDeltaNet]

Configuration for GatedDeltaNet.

See GatedDeltaNet for a description of the configuration options.

n_heads: int = 16

The number of attention heads.

n_v_heads: Optional[int] = None

The number of value heads. If None, defaults to n_heads. If n_v_heads > n_heads, GVA (Grouped Value Attention) is applied.

GVA is preferred over GQA for linear RNNs like GDN because the recurrent state has shape (n_v_heads, head_k_dim, head_v_dim). Unlike softmax attention where the KV cache grows with sequence length (motivating GQA to reduce it), the linear RNN state is constant size regardless of sequence length. Since there’s no memory scaling issue to solve, we instead can opt to increase the state size to improve the model’s capacity to compress long-range context. Increasing n_v_heads directly increases this fixed state size.

head_dim: Optional[int] = None

The dimension of each head. If None, defaults to d_model // n_heads.

expand_v: float = 2.0

The expansion ratio for the value dimension (head_v_dim = head_dim * expand_v). Like n_v_heads, this increases the constant-size recurrent state, improving capacity without memory scaling concerns.

allow_neg_eigval: bool = True

Allow negative eigenvalues in the recurrent dynamics.

conv_size: int = 4

The kernel size of the short convolution.

registered_base

alias of SequenceMixerConfig

conv_bias: bool = False

Whether to use bias in the short convolution.

norm_eps: float = 1e-05

The epsilon value for the normalization layer.

dtype: DType = 'float32'

The default data type to use for parameters.

num_params(d_model)[source]

The number of params that the GatedDeltaNet will have once built.

Parameters:

d_model (int) – The model dimensionality.

Return type:

int

build(d_model, *, layer_idx, n_layers, init_device='cpu', cache=None)[source]

Build the GatedDeltaNet module.

Parameters:
  • d_model (int) – The model dimensionality.

  • layer_idx (int) – The layer index (unused).

  • n_layers (int) – The total number of layers (unused).

  • init_device (str, default: 'cpu') – The device to initialize the parameters on, e.g. “cpu”, “meta”.

  • cache (Optional[BufferCache], default: None) – Optional buffer cache (unused).

Return type:

GatedDeltaNet

class olmo_core.nn.attention.GatedDeltaNet(*, d_model, n_heads, n_v_heads=None, head_dim=None, expand_v=2.0, allow_neg_eigval=True, conv_size=4, conv_bias=False, norm_eps=1e-05, dtype=torch.float32, init_device='cpu')[source]

Bases: SequenceMixer

The layer implementation for Gated Delta Networks.

Modified from: https://github.com/fla-org/flash-linear-attention/blob/3cf180339b8a1cbad823f553541cd531d18670ea/fla/layers/gated_deltanet.py#L34

This is a linear attention variant that uses a gated delta rule for recurrent state updates, providing efficient O(n) sequence modeling.

Parameters:
  • d_model (int) – The model hidden size.

  • n_heads (int) – The number of attention heads.

  • n_v_heads (Optional[int], default: None) – The number of value heads. If None, defaults to n_heads. GVA is applied if n_v_heads > n_heads.

  • head_dim (Optional[int], default: None) – The dimension of each head. If None, defaults to d_model // n_heads.

  • expand_v (float, default: 2.0) – The expansion ratio for the value dim. Default: 2.0.

  • allow_neg_eigval (bool, default: True) – Allow negative eigenvalues. Default: True. If set to True, the beta will be multiplied by 2. See reference: Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues.

  • conv_size (int, default: 4) – The kernel size of the short convolution. Default: 4.

  • conv_bias (bool, default: False) – Whether to use bias in the short convolution. Default: False.

  • norm_eps (float, default: 1e-05) – The epsilon value for the normalization layer. Default: 1e-5.

  • dtype (dtype, default: torch.float32) – The default data type to use for parameters.

  • init_device (str, default: 'cpu') – The device to initialize weights on.

forward(x, cu_doc_lens=None, **kwargs)[source]

Apply gated delta network sequence mixing to the input.

Parameters:
  • x (Tensor) – The input of shape (batch_size, seq_len, d_model).

  • cu_doc_lens (Optional[Tensor], default: None) – Cumulative document lengths in the input x, a 1D torch.int32 tensor that should always have one more element than there are documents (the first element in the tensor should always be 0).

Return type:

Tensor

Returns:

The output with shape (batch_size, seq_len, d_model).

num_flops_per_token(seq_len)[source]

Compute FLOPs per token for Gated Delta Net.

This accounts for: - Linear projections (w_q, w_k, w_v, w_a, w_b, w_g, w_out) - Short convolutions (q, k, v) - Gated delta rule recurrent computation - Gated RMS normalization

Return type:

int