nn.attention¶
- class olmo_core.nn.attention.SlidingWindowAttentionConfig(pattern, force_full_attention_on_first_layer=True, force_full_attention_on_last_layer=True)[source]¶
Bases:
Config-
pattern:
List[int]¶ The pattern of window sizes to use for attention, repeated to cover all layers. A value of -1 indicates full attention. For example, a pattern of
[4096, 4096, 4096, -1]means that for each set of 4 layers, the first 3 will use a window size of 4096, and the last layer will use full attention.
-
force_full_attention_on_first_layer:
bool= True¶ If True, the first transformer layer will always use full attention, regardless of the pattern.
-
force_full_attention_on_last_layer:
bool= True¶ If True, the last transformer layer will always use full attention, regardless of the pattern.
-
pattern:
- class olmo_core.nn.attention.GateGranularity(value)[source]¶
Bases:
StrEnumAn enumeration.
- headwise = 'headwise'¶
one gate value per attention head, broadcast across head dimension.
- Type:
Head-wise gating
- elementwise = 'elementwise'¶
one gate value per output element.
- Type:
Element-wise gating
- class olmo_core.nn.attention.GateConfig(granularity='headwise', full_precision=True)[source]¶
Bases:
Config-
granularity:
GateGranularity= 'headwise'¶ The granularity of gating to use.
-
granularity:
- class olmo_core.nn.attention.AttentionType(value)[source]¶
Bases:
StrEnumAn enumeration of the different attention implementations.
- fused = 'fused'¶
- normalized = 'normalized'¶
- class olmo_core.nn.attention.AttentionBackendName(value)[source]¶
Bases:
StrEnumAn enumeration of the different attention backends.
- torch = 'torch'¶
PyTorch’s built-in SDPA. Works on all devices. ➡️
TorchAttentionBackend
- flash_2 = 'flash_2'¶
Flash attention 2 from the flash-attn library. Supports Ampere (SM 8.0+) and newer NVIDIA GPUs. To use this with context-parallelism, ring-flash-attn is also required. ➡️
FlashAttention2Backend
- flash_3 = 'flash_3'¶
Flash attention 3 (beta) from the flash-attn library
hopper/subdirectory. Supports Hopper (SM 9.0) GPUs only (H100/H800). ➡️FlashAttention3Backend
- flash_4 = 'flash_4'¶
Flash attention 4, the CUTE implementation from flash-attn in the
flash_attn/cutesubdirectory. Supports Blackwell (SM 10.0, e.g. B200) GPUs only. ➡️FlashAttention4Backend
- te = 'te'¶
Transformer Engine attention. Supports Hopper (SM 9.0+) and newer NVIDIA GPUs. ➡️
TEAttentionBackend.
- class olmo_core.nn.attention.AttentionBackend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]¶
Bases:
ModuleEncapsulates a backend for the scaled dot-product attention (SDPA) operation.
- abstract classmethod assert_supported()[source]¶
Validates that this backend is supported on the current system. Raises an error if not supported.
- abstract classmethod assert_supports_swa()[source]¶
Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.
- abstract classmethod assert_supports_ring_cp()[source]¶
Validates that this backend supports ring context parallelism. Raises an error if not supported.
- abstract classmethod assert_supports_ulysses_cp()[source]¶
Validates that this backend supports ulysses context parallelism. Raises an error if not supported.
- abstract classmethod assert_supports_packed_qkv()[source]¶
Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.
- abstract classmethod assert_supports_kv_cache()[source]¶
Validates that this backend supports KV caching. Raises an error if not supported.
- class olmo_core.nn.attention.TorchAttentionBackend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]¶
Bases:
AttentionBackendPyTorch’s built-in scaled dot-product attention (SDPA) backend.
- classmethod assert_supported()[source]¶
Validates that this backend is supported on the current system. Raises an error if not supported.
- classmethod assert_supports_swa()[source]¶
Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.
- classmethod assert_supports_ring_cp()[source]¶
Validates that this backend supports ring context parallelism. Raises an error if not supported.
- classmethod assert_supports_ulysses_cp()[source]¶
Validates that this backend supports ulysses context parallelism. Raises an error if not supported.
- classmethod assert_supports_packed_qkv()[source]¶
Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.
- class olmo_core.nn.attention.FlashAttention2Backend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]¶
Bases:
AttentionBackendSDPA from the flash-attn package. Additionally, ring-flash-attn is required for context parallelism.
- classmethod assert_supported()[source]¶
Validates that this backend is supported on the current system. Raises an error if not supported.
- classmethod assert_supports_swa()[source]¶
Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.
- classmethod assert_supports_ring_cp()[source]¶
Validates that this backend supports ring context parallelism. Raises an error if not supported.
- classmethod assert_supports_ulysses_cp()[source]¶
Validates that this backend supports ulysses context parallelism. Raises an error if not supported.
- classmethod assert_supports_packed_qkv()[source]¶
Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.
- class olmo_core.nn.attention.FlashAttention3Backend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]¶
Bases:
AttentionBackendSDPA from the flash-attn 3 package. Does not currently support context parallelism.
- classmethod assert_supported()[source]¶
Validates that this backend is supported on the current system. Raises an error if not supported.
- classmethod assert_supports_swa()[source]¶
Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.
- classmethod assert_supports_ring_cp()[source]¶
Validates that this backend supports ring context parallelism. Raises an error if not supported.
- classmethod assert_supports_ulysses_cp()[source]¶
Validates that this backend supports ulysses context parallelism. Raises an error if not supported.
- classmethod assert_supports_packed_qkv()[source]¶
Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.
- class olmo_core.nn.attention.FlashAttention4Backend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]¶
Bases:
AttentionBackendSDPA from flash-attn 4 (CUTE implementation).
- classmethod assert_supported()[source]¶
Validates that this backend is supported on the current system. Raises an error if not supported.
- classmethod assert_supports_swa()[source]¶
Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.
- classmethod assert_supports_ring_cp()[source]¶
Validates that this backend supports ring context parallelism. Raises an error if not supported.
- classmethod assert_supports_ulysses_cp()[source]¶
Validates that this backend supports ulysses context parallelism. Raises an error if not supported.
- classmethod assert_supports_packed_qkv()[source]¶
Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.
- class olmo_core.nn.attention.TEAttentionBackend(*, head_dim, n_heads, n_kv_heads=None, scale=None, dropout_p=0.0, window_size=(-1, -1), cache=None)[source]¶
Bases:
AttentionBackend- classmethod assert_supported()[source]¶
Validates that this backend is supported on the current system. Raises an error if not supported.
- classmethod assert_supports_swa()[source]¶
Validates that this backend supports sliding window attention (SWA). Raises an error if not supported.
- classmethod assert_supports_packed_qkv()[source]¶
Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported.
- classmethod assert_supports_kv_cache()[source]¶
Validates that this backend supports KV caching. Raises an error if not supported.
- class olmo_core.nn.attention.AttentionConfig(name='default', n_heads=16, n_kv_heads=None, head_dim=None, bias=None, gate=None, rope=None, clip_qkv=None, qk_norm=None, dropout=None, use_flash=None, backend=None, dtype='float32', sliding_window=None, use_head_qk_norm=None, *, type=None)[source]¶
Bases:
SequenceMixerConfig[SequenceMixer]A configuration class for easily building any of the different attention modules.
See the individual
Attentionsubclasses for a description of the configuration options.-
name:
AttentionType= 'default'¶ The name of the implementation.
- num_params(d_model)[source]¶
The number of params that the attention implementation will have once built.
- build(d_model, *, layer_idx, n_layers, init_device='cpu', cache=None)[source]¶
Build the corresponding attention module.
- registered_base¶
alias of
SequenceMixerConfig
-
name:
- class olmo_core.nn.attention.Attention(*, d_model, n_heads, n_kv_heads=None, head_dim=None, bias=True, gate=None, rope=None, clip_qkv=None, qk_norm=None, dropout=0.0, softmax_scale=None, use_flash=None, backend=None, window_size=None, dtype=torch.float32, init_device='cpu', cache=None, use_head_qk_norm=False)[source]¶
Bases:
SequenceMixerAn implementation of multi-head self-attention with support for multi-query (MQA) and grouped-query (GQA) attention.
Intra-document masking is also supported by passing in the
max_doc_lenandcu_doc_lensparameters toforward(). This requires a backend that supports it, like the flash backend.See also
FusedAttentionif you have flash-attn installed and you’re not using MQA or GQA.- Parameters:
d_model (
int) – The model hidden size.n_heads (
int) – The number of attention heads.n_kv_heads (
Optional[int], default:None) – The number of key and value heads, if different.bias (
bool, default:True) – Include biases with linear layers.gate (
Optional[GateConfig], default:None) – Configuration for attention gating. If None, no gating is applied.rope (
Optional[RoPEConfig], default:None) – The config for RoPE, if RoPE should be used.clip_qkv (
Optional[float], default:None) – Clip QKV to this value, if set.qk_norm (
Optional[LayerNormConfig], default:None) – Configuration a layer norm for queries and keys.dropout (
float, default:0.0) – Dropout probability.use_flash (
Optional[bool], default:None) – Deprecated, usebackend="flash_2"instead.backend (
Optional[AttentionBackendName], default:None) – The attention backend to use. If not set, it will be chosen automatically.dtype (
dtype, default:torch.float32) – The default data type to use for parameters.init_device (
str, default:'cpu') – The device to initialize weights on.
- forward(x, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, pos_sin=None, pos_cos=None, freqs_cis=None, cache_leftpad=None)[source]¶
Apply attention to the input.
- Parameters:
x (
Tensor) – The input of shape(batch_size, seq_len, d_model).cu_doc_lens (
Optional[Tensor], default:None) – Cumulative document lengths in the inputx, a 1Dtorch.int32tensor that should always have one more element than there are documents (the first element in the tensor should always be0). Required together withmax_doc_lenwhen using intra-document masking.max_doc_len (
Optional[int], default:None) – The maximum document length in the inputx. Required together withcu_doc_lenswhen using intra-document masking.
- Return type:
- Returns:
The output of attention with shape
(batch_size, seq_len, d_model).
- apply_cp(cp_mesh, ring=None, uly=None)[source]¶
Prepare the module for context-parallelism (ring attention).
Important
This requires a backend that supports CP, such as “flash_2” or “te”.
- Parameters:
cp_mesh (
DeviceMesh) – The context parallel device sub-mesh.ring (
Optional[RingContextParallelStyle], default:None) – The ring context parallel style.uly (
Optional[UlyssesContextParallelStyle], default:None) – The ulysses context parallel style.
- class olmo_core.nn.attention.FusedAttention(*, d_model, n_heads, bias=True, rope=None, clip_qkv=None, dropout=0.0, dtype=torch.float32, backend=None, init_device='cpu', cache=None)[source]¶
Bases:
SequenceMixerAn “fused” implementation of multi-head self-attention.
Intra-document masking is supported by passing in the
max_doc_lenandcu_doc_lensparameters toforward().Warning
Currently this is only supported with the “flash_2” backend.
Warning
If using RoPE, this requires that you use the “fused” RoPE implementation (
FusedRotaryEmbedding).- Parameters:
d_model (
int) – The model hidden size.n_heads (
int) – The number of attention heads.bias (
bool, default:True) – Include biases with linear layers.rope (
Optional[RoPEConfig], default:None) – The config for RoPE, if RoPE should be used.clip_qkv (
Optional[float], default:None) – Clip QKV to this value, if set.dropout (
float, default:0.0) – Dropout probability.dtype (
dtype, default:torch.float32) – The default data type to use for parameters.init_device (
str, default:'cpu') – The device to initialize weights on.
- forward(x, max_doc_len=None, cu_doc_lens=None, pos_sin=None, pos_cos=None, freqs_cis=None, cache_leftpad=None)[source]¶
Apply attention to the input.
- Parameters:
x (
Tensor) – The input of shape(batch_size, seq_len, d_model).max_doc_len (
Optional[int], default:None) – The maximum document length in the inputx. Required together withcu_doc_lenswhen using intra-document masking.cu_doc_lens (
Optional[Tensor], default:None) – Cumulative document lengths in the inputx, a 1Dtorch.int32tensor that should always have one more element than there are documents (the first element in the tensor should always be0). Required together withmax_doc_lenwhen using intra-document masking.
- Return type:
- Returns:
The output of attention with shape
(batch_size, seq_len, d_model).
- class olmo_core.nn.attention.NormalizedAttention(*, d_model, n_heads, n_kv_heads=None, rope=None, qk_norm=None, use_flash=None, backend=None, dtype=torch.float32, init_device='cpu', cache=None)[source]¶
Bases:
AttentionAn nGPT attention implementation.
Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
- forward(x, cu_doc_lens=None, cu_doc_lens_q=None, cu_doc_lens_k=None, max_doc_len=None, max_doc_len_q=None, max_doc_len_k=None, local_k_slice=None, pos_sin=None, pos_cos=None, freqs_cis=None, cache_leftpad=None)[source]¶
Apply attention to the input.
- Parameters:
x (
Tensor) – The input of shape(batch_size, seq_len, d_model).cu_doc_lens (
Optional[Tensor], default:None) – Cumulative document lengths in the inputx, a 1Dtorch.int32tensor that should always have one more element than there are documents (the first element in the tensor should always be0). Required together withmax_doc_lenwhen using intra-document masking.max_doc_len (
Optional[int], default:None) – The maximum document length in the inputx. Required together withcu_doc_lenswhen using intra-document masking.
- Return type:
- Returns:
The output of attention with shape
(batch_size, seq_len, d_model).
- normalize_matrices()[source]¶
Normalize the weights in all matrices. This should be called after each optimizer step, which the
TransformerTrainModulewill handle for you.
- class olmo_core.nn.attention.RingAttentionLoadBalancerType(value)[source]¶
Bases:
StrEnumAn enumeration of the different
RingAttentionLoadBalancerimplementations.- zig_zag = 'zig_zag'¶
- llama3 = 'llama3'¶
- ulysses = 'ulysses'¶
- class olmo_core.nn.attention.RingAttentionLoadBalancer(*, cp_rank, cp_world_size)[source]¶
Bases:
objectA class that handles the logic of sharding inputs on the sequence dimension for ring attention (context parallelism).
- abstract batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]¶
Shard inputs on their sequence dimension, optionally adding padding if needed.
Important
If using intra-document masking, use
batch_shard_by_document()instead.
- abstract batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]¶
Same as
batch_shard()but for strategies that support intra-document masking.
- class olmo_core.nn.attention.RingAttentionZigZagLoadBalancer(*, cp_rank, cp_world_size)[source]¶
Bases:
RingAttentionLoadBalancerImplements the zig-zag load-balancing strategy.
- batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]¶
Shard inputs on their sequence dimension, optionally adding padding if needed.
Important
If using intra-document masking, use
batch_shard_by_document()instead.
- batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]¶
Same as
batch_shard()but for strategies that support intra-document masking.
- class olmo_core.nn.attention.RingAttentionLlama3LoadBalancer(*, cp_rank, cp_world_size)[source]¶
Bases:
RingAttentionLoadBalancerImplements Llama3’s load-balancing strategy for context parallelism.
The Llama3 strategy assigns each rank a contiguous slice of the full sequence. Rank
ireceives positions[i * local_len, (i + 1) * local_len)wherelocal_len = total_seq_len // cp_world_size.This strategy is designed specifically for intra-document masking with variable-length documents packed into a single sequence. It computes separate cumulative sequence lengths for queries (
cu_doc_lens_q) and keys (cu_doc_lens_k) per rank, enabling proper causal masking across document boundaries within the ring attention loop. Padding is added as a synthetic document at the end when the total sequence length is not divisible by the context parallel world size.Note
This strategy only supports
batch_shard_by_document()and will raise an error ifbatch_shard()is called directly.- batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]¶
Shard inputs on their sequence dimension, optionally adding padding if needed.
Important
If using intra-document masking, use
batch_shard_by_document()instead.
- batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]¶
Same as
batch_shard()but for strategies that support intra-document masking.
- class olmo_core.nn.attention.UlyssesLoadBalancer(*, cp_rank, cp_world_size)[source]¶
Bases:
RingAttentionLoadBalancerImplements simple contiguous sequence sharding for Ulysses-style context parallelism.
Unlike ring attention which uses zig-zag or other interleaving strategies, Ulysses just needs simple contiguous chunking since the all-to-all communication handles the sequence/head exchange.
- batch_shard(*, inputs, seq_dims, pad_values=None, length_multiple=None)[source]¶
Shard inputs on their sequence dimension, optionally adding padding if needed.
Important
If using intra-document masking, use
batch_shard_by_document()instead.
- batch_shard_by_document(*, inputs, seq_dims, cu_doc_lens, pad_values=None, length_multiple=None)[source]¶
Same as
batch_shard()but for strategies that support intra-document masking.
- class olmo_core.nn.attention.RingContextParallelStyle(load_balancer='zig_zag', head_stride=1)[source]¶
Bases:
ConfigConfiguration for ring attention-style context parallelism.
-
load_balancer:
RingAttentionLoadBalancerType= 'zig_zag'¶ The type of load balancer to use for ring attention.
-
load_balancer:
- class olmo_core.nn.attention.UlyssesContextParallelStyle[source]¶
Bases:
ConfigConfiguration for Ulysses-style context parallelism.
- class olmo_core.nn.attention.GatedDeltaNetConfig(n_heads=16, n_v_heads=None, head_dim=None, expand_v=2.0, allow_neg_eigval=True, conv_size=4, conv_bias=False, norm_eps=1e-05, dtype='float32', *, type=None)[source]¶
Bases:
SequenceMixerConfig[GatedDeltaNet]Configuration for
GatedDeltaNet.See
GatedDeltaNetfor a description of the configuration options.-
n_v_heads:
Optional[int] = None¶ The number of value heads. If
None, defaults ton_heads. Ifn_v_heads>n_heads, GVA (Grouped Value Attention) is applied.GVA is preferred over GQA for linear RNNs like GDN because the recurrent state has shape
(n_v_heads, head_k_dim, head_v_dim). Unlike softmax attention where the KV cache grows with sequence length (motivating GQA to reduce it), the linear RNN state is constant size regardless of sequence length. Since there’s no memory scaling issue to solve, we instead can opt to increase the state size to improve the model’s capacity to compress long-range context. Increasingn_v_headsdirectly increases this fixed state size.
-
head_dim:
Optional[int] = None¶ The dimension of each head. If
None, defaults tod_model // n_heads.
-
expand_v:
float= 2.0¶ The expansion ratio for the value dimension (
head_v_dim = head_dim * expand_v). Liken_v_heads, this increases the constant-size recurrent state, improving capacity without memory scaling concerns.
- registered_base¶
alias of
SequenceMixerConfig
- build(d_model, *, layer_idx, n_layers, init_device='cpu', cache=None)[source]¶
Build the GatedDeltaNet module.
- Parameters:
d_model (
int) – The model dimensionality.layer_idx (
int) – The layer index (unused).n_layers (
int) – The total number of layers (unused).init_device (
str, default:'cpu') – The device to initialize the parameters on, e.g. “cpu”, “meta”.cache (
Optional[BufferCache], default:None) – Optional buffer cache (unused).
- Return type:
-
n_v_heads:
- class olmo_core.nn.attention.GatedDeltaNet(*, d_model, n_heads, n_v_heads=None, head_dim=None, expand_v=2.0, allow_neg_eigval=True, conv_size=4, conv_bias=False, norm_eps=1e-05, dtype=torch.float32, init_device='cpu')[source]¶
Bases:
SequenceMixerThe layer implementation for Gated Delta Networks.
Modified from: https://github.com/fla-org/flash-linear-attention/blob/3cf180339b8a1cbad823f553541cd531d18670ea/fla/layers/gated_deltanet.py#L34
This is a linear attention variant that uses a gated delta rule for recurrent state updates, providing efficient O(n) sequence modeling.
- Parameters:
d_model (
int) – The model hidden size.n_heads (
int) – The number of attention heads.n_v_heads (
Optional[int], default:None) – The number of value heads. IfNone, defaults ton_heads. GVA is applied ifn_v_heads>n_heads.head_dim (
Optional[int], default:None) – The dimension of each head. IfNone, defaults tod_model // n_heads.expand_v (
float, default:2.0) – The expansion ratio for the value dim. Default: 2.0.allow_neg_eigval (
bool, default:True) – Allow negative eigenvalues. Default:True. If set toTrue, the beta will be multiplied by 2. See reference: Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues.conv_size (
int, default:4) – The kernel size of the short convolution. Default: 4.conv_bias (
bool, default:False) – Whether to use bias in the short convolution. Default:False.norm_eps (
float, default:1e-05) – The epsilon value for the normalization layer. Default: 1e-5.dtype (
dtype, default:torch.float32) – The default data type to use for parameters.init_device (
str, default:'cpu') – The device to initialize weights on.