import logging
import math
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Placement, Replicate, Shard
from torch.distributed.tensor.parallel import parallelize_module
from olmo_core.config import Config, DType, StrEnum
from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel
from olmo_core.doc_utils import beta_feature
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.nn.attention.base import SequenceMixer, SequenceMixerConfig
from olmo_core.nn.attention.kv_cache import KVCacheManager
from olmo_core.nn.attention.recurrent import GatedDeltaNet, GatedDeltaNetConfig
from ..buffer_cache import BufferCache
from ..config import ModuleConfig
from ..functional import l2_normalize
from ..layer_norm import LayerNorm, LayerNormConfig
from ..rope import (
ComplexRotaryEmbedding,
FusedRotaryEmbedding,
RoPEConfig,
RotaryEmbedding,
)
from ..utils import get_tp_wrappers
from . import flash_attn_api
from .backend import (
AttentionBackend,
AttentionBackendName,
FlashAttention2Backend,
FlashAttention3Backend,
FlashAttention4Backend,
TEAttentionBackend,
TorchAttentionBackend,
)
from .ring import (
RingAttentionLlama3LoadBalancer,
RingAttentionLoadBalancer,
RingAttentionLoadBalancerType,
RingAttentionZigZagLoadBalancer,
RingContextParallelStyle,
UlyssesContextParallelStyle,
UlyssesLoadBalancer,
)
if TYPE_CHECKING:
from olmo_core.nn.transformer.init import InitMethod
__all__ = [
"SlidingWindowAttentionConfig",
"GateGranularity",
"GateConfig",
"AttentionType",
"AttentionBackendName",
"AttentionBackend",
"TorchAttentionBackend",
"FlashAttention2Backend",
"FlashAttention3Backend",
"FlashAttention4Backend",
"TEAttentionBackend",
"AttentionConfig",
"Attention",
"FusedAttention",
"NormalizedAttention",
"RingAttentionLoadBalancerType",
"RingAttentionLoadBalancer",
"RingAttentionZigZagLoadBalancer",
"RingAttentionLlama3LoadBalancer",
"UlyssesLoadBalancer",
"RingContextParallelStyle",
"UlyssesContextParallelStyle",
"GatedDeltaNetConfig",
"GatedDeltaNet",
]
log = logging.getLogger(__name__)
[docs]
class GateGranularity(StrEnum):
headwise = "headwise"
"""Head-wise gating: one gate value per attention head, broadcast across head dimension."""
elementwise = "elementwise"
"""Element-wise gating: one gate value per output element."""
[docs]
@dataclass
class GateConfig(Config):
granularity: GateGranularity = GateGranularity.headwise
"""The granularity of gating to use."""
full_precision: bool = True
"""Whether to always apply gating in full precision regardless of the input data type."""
[docs]
@dataclass
class SlidingWindowAttentionConfig(Config):
pattern: List[int]
"""
The pattern of window sizes to use for attention, repeated to cover all layers.
A value of -1 indicates full attention. For example, a pattern of ``[4096, 4096, 4096, -1]``
means that for each set of 4 layers, the first 3 will use a window size of 4096,
and the last layer will use full attention.
"""
force_full_attention_on_first_layer: bool = True
"""
If `True`, the first transformer layer will always use full attention, regardless of the pattern.
"""
force_full_attention_on_last_layer: bool = True
"""
If `True`, the last transformer layer will always use full attention, regardless of the pattern.
"""
def _get_window_size(self, layer_idx: int, n_layers: int) -> int:
"""
Get the window size for a given layer, returning -1 for full attention.
"""
if self.force_full_attention_on_first_layer and layer_idx == 0:
return -1
if self.force_full_attention_on_last_layer and layer_idx == (n_layers - 1):
return -1
# Adjust the layer index if the first layer is special-cased to full attention
# (in which case the pattern is applied starting from the second layer)
effective_layer_idx = layer_idx
if self.force_full_attention_on_first_layer:
effective_layer_idx -= 1
window_size = self.pattern[effective_layer_idx % len(self.pattern)]
if window_size <= 0 and window_size != -1:
raise OLMoConfigurationError(
f"Sliding window size must be positive or -1 (got {window_size})"
)
return window_size
[docs]
def should_use_swa(self, layer_idx: int, n_layers: int) -> bool:
"""
Returns `True` if the given layer uses sliding window attention.
"""
return self._get_window_size(layer_idx, n_layers) != -1
[docs]
def get_window_size(self, layer_idx: int, n_layers: int) -> int:
"""
Get the sliding window size for a given layer.
"""
window_size = self._get_window_size(layer_idx, n_layers)
if window_size == -1:
raise ValueError(f"Layer {layer_idx} is not configured for sliding window attention.")
return window_size
[docs]
class AttentionType(StrEnum):
"""
An enumeration of the different attention implementations.
"""
default = "default"
"""
➡️ :class:`Attention`
"""
fused = "fused"
"""
➡️ :class:`FusedAttention`
"""
normalized = "normalized"
"""
➡️ :class:`NormalizedAttention`
"""
[docs]
@SequenceMixerConfig.register("attention")
@dataclass
class AttentionConfig(SequenceMixerConfig["SequenceMixer"]):
"""
A configuration class for easily building any of the different attention modules.
See the individual :class:`Attention` subclasses for a description of the configuration options.
"""
name: AttentionType = AttentionType.default
"""
The name of the implementation.
"""
n_heads: int = 16
n_kv_heads: Optional[int] = None
head_dim: Optional[int] = None
bias: Optional[bool] = None
gate: Optional[GateConfig] = None
rope: Optional[RoPEConfig] = None
clip_qkv: Optional[float] = None
qk_norm: Optional[LayerNormConfig] = None
dropout: Optional[float] = None
use_flash: Optional[bool] = None
backend: Optional[AttentionBackendName] = None
dtype: DType = DType.float32
sliding_window: Optional[SlidingWindowAttentionConfig] = None
use_head_qk_norm: Optional[bool] = None
[docs]
def num_params(self, d_model: int) -> int:
"""
The number of params that the attention implementation will have once built.
:param d_model: The model dimensionality.
"""
n_heads = self.n_heads
n_kv_heads = self.n_kv_heads or n_heads
head_dim = self.head_dim or d_model // n_heads
bias = self.bias if self.bias is not None else self.name != AttentionType.normalized
params = 0
# Block attention Q projection.
params += d_model * n_heads * head_dim
if bias:
params += n_heads * head_dim
# Block attention KV projections.
params += 2 * d_model * n_kv_heads * head_dim
if bias:
params += 2 * n_kv_heads * head_dim
# Block attention QK norm.
if self.qk_norm is not None:
if self.use_head_qk_norm:
params += 2 * self.qk_norm.num_params(head_dim)
else:
params += self.qk_norm.num_params(n_heads * head_dim) # q_norm
params += self.qk_norm.num_params(n_kv_heads * head_dim) # k_norm
# Block attention out.
params += n_heads * head_dim * d_model
if bias:
params += d_model
# Block attention gate projection.
if self.gate is not None:
if self.gate.granularity == GateGranularity.headwise:
params += d_model * n_heads
if bias:
params += n_heads
elif self.gate.granularity == GateGranularity.elementwise:
params += d_model * (n_heads * head_dim)
if bias:
params += n_heads * head_dim
# Block QK scaling factors.
if self.name == AttentionType.normalized:
params += n_heads * head_dim
params += n_kv_heads * head_dim
return params
[docs]
def build(
self,
d_model: int,
*,
layer_idx: int,
n_layers: int,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
) -> "SequenceMixer":
"""
Build the corresponding attention module.
:param d_model: The model dimensionality.
:param init_device: The device to initialize the parameters on, e.g. "cpu", "meta".
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
sliding_window_config: Optional[SlidingWindowAttentionConfig] = kwargs.pop(
"sliding_window", None
)
if sliding_window_config is not None and sliding_window_config.should_use_swa(
layer_idx, n_layers
):
kwargs["window_size"] = sliding_window_config.get_window_size(layer_idx, n_layers)
else: # global (non-SWA) layer
rope_config: Optional[RoPEConfig] = kwargs.get("rope")
if rope_config is not None and rope_config.no_global_rope:
kwargs["rope"] = None
kwargs.update(
dtype=kwargs.pop("dtype").as_pt(),
d_model=d_model,
init_device=init_device,
cache=cache,
)
try:
if self.name == "default":
return Attention(**kwargs)
elif self.name == "fused":
kwargs.pop("use_flash", None)
if "window_size" in kwargs:
raise OLMoConfigurationError(
"'window_size' is not supported with fused attention"
)
return FusedAttention(**kwargs)
elif self.name == "normalized":
if "window_size" in kwargs:
raise OLMoConfigurationError(
"'window_size' is not supported with normalized attention"
)
return NormalizedAttention(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e
[docs]
class Attention(SequenceMixer):
"""
An implementation of multi-head self-attention with support for multi-query (MQA)
and grouped-query (GQA) attention.
Intra-document masking is also supported by passing in the
``max_doc_len`` and ``cu_doc_lens`` parameters to :meth:`forward()`. This requires
a backend that supports it, like the flash backend.
.. seealso::
:class:`FusedAttention` if you have flash-attn installed and you're not using MQA or GQA.
:param d_model: The model hidden size.
:param n_heads: The number of attention heads.
:param n_kv_heads: The number of key and value heads, if different.
:param bias: Include biases with linear layers.
:param gate: Configuration for attention gating. If None, no gating is applied.
:param rope: The config for RoPE, if RoPE should be used.
:param clip_qkv: Clip QKV to this value, if set.
:param qk_norm: Configuration a layer norm for queries and keys.
:param dropout: Dropout probability.
:param use_flash: Deprecated, use ``backend="flash_2"`` instead.
:param backend: The attention backend to use. If not set, it will be chosen automatically.
:param dtype: The default data type to use for parameters.
:param init_device: The device to initialize weights on.
"""
def __init__(
self,
*,
d_model: int,
n_heads: int,
n_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
bias: bool = True,
gate: Optional[GateConfig] = None,
rope: Optional[RoPEConfig] = None,
clip_qkv: Optional[float] = None,
qk_norm: Optional[LayerNormConfig] = None,
dropout: float = 0.0,
softmax_scale: Optional[float] = None,
use_flash: Optional[bool] = None,
backend: Optional[AttentionBackendName] = None,
window_size: Optional[int] = None,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
use_head_qk_norm: bool = False,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads or n_heads
self.d_model = d_model
# Some models (e.g. Qwen3) use explicit head_dim that differs from d_model // n_heads.
if head_dim is not None:
self.head_dim = head_dim
else:
self.head_dim = d_model // n_heads
self.w_q = nn.Linear(
d_model, n_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device
)
self.w_k = nn.Linear(
d_model, self.n_kv_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device
)
self.w_v = nn.Linear(
d_model, self.n_kv_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device
)
self.w_out = nn.Linear(
n_heads * self.head_dim, d_model, bias=bias, dtype=dtype, device=init_device
)
self.gate = gate
self.w_g: Optional[nn.Linear] = None
if gate is not None:
if gate.granularity == GateGranularity.headwise:
self.w_g = nn.Linear(
d_model, self.n_heads, bias=bias, dtype=dtype, device=init_device
)
elif gate.granularity == GateGranularity.elementwise:
self.w_g = nn.Linear(
d_model,
self.n_heads * self.head_dim,
bias=bias,
dtype=dtype,
device=init_device,
)
self.clip_qkv = clip_qkv
self.use_head_qk_norm = use_head_qk_norm
self.q_norm: Optional[LayerNorm] = None
self.k_norm: Optional[LayerNorm] = None
if qk_norm is not None:
if use_head_qk_norm:
self.q_norm = qk_norm.build(size=self.head_dim, init_device=init_device)
self.k_norm = qk_norm.build(size=self.head_dim, init_device=init_device)
else:
self.q_norm = qk_norm.build(size=n_heads * self.head_dim, init_device=init_device)
self.k_norm = qk_norm.build(
size=self.n_kv_heads * self.head_dim, init_device=init_device
)
self.rope: Optional[Union[RotaryEmbedding, ComplexRotaryEmbedding]] = None
if rope is not None:
if rope.name == "fused":
raise OLMoConfigurationError(
f"fused RoPE is not compatible with {self.__class__.__name__}"
)
rope_class = rope.build(self.head_dim, cache=cache)
assert isinstance(rope_class, (RotaryEmbedding, ComplexRotaryEmbedding))
self.rope = rope_class
if backend is not None:
backend = AttentionBackendName(backend)
if use_flash:
if backend is not None and backend != AttentionBackendName.flash_2:
raise OLMoConfigurationError(
f"'use_flash' is only compatible with 'flash_2' backend (got '{backend}')"
)
elif backend is None:
warnings.warn(
"'use_flash' is deprecated, use 'backend=flash_2' instead", DeprecationWarning
)
backend = AttentionBackendName.flash_2
# Translate window size so that we only look left, not right.
self.window_size = window_size
window_size_tuple: Tuple[int, int] = (-1, -1)
if window_size is not None:
if window_size <= 0:
raise OLMoConfigurationError(f"'window_size' must be positive (got {window_size})")
if backend is None and flash_attn_api.has_flash_attn_2():
# note: flash_3, flash_4, and te backends are faster than flash_2 and also support SWA
backend = AttentionBackendName.flash_2
# Window size is [i - window_size[0], i + window_size[1]] inclusive
window_size_tuple = (window_size - 1, 0)
if backend is None:
backend = AttentionBackendName.torch
if not torch.cuda.is_available() and backend != AttentionBackendName.torch:
warnings.warn(
f"Backend is set to {backend}, but GPUs are not available. Defaulting to torch."
)
backend = AttentionBackendName.torch
backend.assert_supported()
log.info(f"Using attention backend '{backend}'")
self.backend = backend.build(
head_dim=self.head_dim,
n_heads=n_heads,
n_kv_heads=self.n_kv_heads,
scale=softmax_scale,
dropout_p=dropout,
window_size=window_size_tuple,
cache=cache,
)
self.kv_cache_manager: Optional[KVCacheManager] = None
@property
def cp_enabled(self) -> bool:
return self.backend.cp_enabled
def sdpa(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_doc_lens: Optional[torch.Tensor] = None,
cu_doc_lens_q: Optional[torch.Tensor] = None,
cu_doc_lens_k: Optional[torch.Tensor] = None,
max_doc_len: Optional[int] = None,
max_doc_len_q: Optional[int] = None,
max_doc_len_k: Optional[int] = None,
local_k_slice: Optional[slice] = None,
cache_leftpad: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.kv_cache_manager is not None:
self.kv_cache_manager.record_leftpad(cache_leftpad)
# shape: (batch_size, seq_len, n_heads, head_dim)
att = self.backend(
(q, k, v),
cu_doc_lens=cu_doc_lens,
cu_doc_lens_q=cu_doc_lens_q,
cu_doc_lens_k=cu_doc_lens_k,
max_doc_len=max_doc_len,
max_doc_len_q=max_doc_len_q,
max_doc_len_k=max_doc_len_k,
local_k_slice=local_k_slice,
kv_cache_manager=self.kv_cache_manager,
)
if self.kv_cache_manager is not None:
self.kv_cache_manager.update_seqlen(q.shape[1])
return att
[docs]
def forward(
self,
x: torch.Tensor,
cu_doc_lens: Optional[torch.Tensor] = None,
cu_doc_lens_q: Optional[torch.Tensor] = None,
cu_doc_lens_k: Optional[torch.Tensor] = None,
max_doc_len: Optional[int] = None,
max_doc_len_q: Optional[int] = None,
max_doc_len_k: Optional[int] = None,
local_k_slice: Optional[slice] = None,
pos_sin: Optional[torch.Tensor] = None,
pos_cos: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply attention to the input.
:param x: The input of shape ``(batch_size, seq_len, d_model)``.
:param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D
:class:`torch.int32` tensor that should always have one more element than there
are documents (the first element in the tensor should always be ``0``).
Required together with ``max_doc_len`` when using intra-document masking.
:param max_doc_len: The maximum document length in the input ``x``.
Required together with ``cu_doc_lens`` when using intra-document masking.
:returns: The output of attention with shape ``(batch_size, seq_len, d_model)``.
"""
B, T, _ = x.shape
# shape: (batch_size, seq_len, n_heads * head_dim),
# (batch_size, seq_len, n_kv_heads * head_dim),
# (batch_size, seq_len, n_kv_heads * head_dim)
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
if self.clip_qkv is not None:
q.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
k.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
v.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
if not self.use_head_qk_norm:
if self.q_norm is not None:
q = self.q_norm(q)
if self.k_norm is not None:
k = self.k_norm(k)
# NOTE: use -1 instead of `n_heads` / `n_kv_heads` to infer actual local size when
# using tensor parallelism.
# shape: (batch_size, seq_len, n_heads (local), head_dim)
q = q.view(B, T, -1, self.head_dim)
# shape: (batch_size, seq_len, n_kv_heads (local), head_dim)
k = k.view(B, T, -1, self.head_dim)
# shape: (batch_size, seq_len, n_kv_heads (local), head_dim)
v = v.view(B, T, -1, self.head_dim)
if self.use_head_qk_norm:
if self.q_norm is not None:
q = self.q_norm(q)
if self.k_norm is not None:
k = self.k_norm(k)
if self.rope is not None:
# In context-parallel mode we must be given pre-sharded buffers
if self.cp_enabled and pos_sin is None and pos_cos is None and freqs_cis is None:
raise RuntimeError(
"RoPE buffers must be passed through to attention after being properly "
"sharded by the context parallel load balancer"
)
start_pos = self.kv_cache_manager.current_position() if self.kv_cache_manager else None
q, k = self.rope(
q,
k,
head_first=False,
start_pos=start_pos,
pos_sin=pos_sin,
pos_cos=pos_cos,
freqs_cis=freqs_cis,
)
# shape: (batch_size, seq_len, n_heads, head_dim)
att = self.sdpa(
q,
k,
v,
cu_doc_lens=cu_doc_lens,
cu_doc_lens_q=cu_doc_lens_q,
cu_doc_lens_k=cu_doc_lens_k,
max_doc_len=max_doc_len,
max_doc_len_q=max_doc_len_q,
max_doc_len_k=max_doc_len_k,
local_k_slice=local_k_slice,
cache_leftpad=cache_leftpad,
)
if self.gate is not None:
assert self.w_g is not None
g = self.w_g(x)
if self.gate.full_precision:
g = g.float()
gate_values = torch.sigmoid(g).to(att.dtype)
if self.gate.granularity == GateGranularity.headwise:
# head-wise gating is broadcast across head_dim
# shape: (batch_size, seq_len, n_heads, head_dim)
att = att * gate_values.unsqueeze(-1)
elif self.gate.granularity == GateGranularity.elementwise:
att = att.view(B, T, -1) * gate_values
# the following att.view op is redundant (a no-op)
# shape: (batch_size, seq_len, d_model)
att = att.view(B, T, -1)
# shape: (batch_size, seq_len, d_model)
return self.w_out(att)
def apply_tp(
self,
tp_mesh: DeviceMesh,
input_layout: Optional[Placement] = None,
output_layout: Optional[Placement] = None,
use_local_output: bool = True,
float8_enabled: bool = False,
):
rowwise_parallel, colwise_parallel, prepare_module_input = get_tp_wrappers(
float8_enabled=float8_enabled
)
parallelize_module(
self,
device_mesh=tp_mesh,
parallelize_plan=prepare_module_input(
input_layouts=None if input_layout is None else (input_layout,),
desired_input_layouts=(Replicate(),),
),
)
plan = {
"w_q": colwise_parallel(
output_layouts=None if self.q_norm is None else Shard(1),
use_local_output=self.q_norm is None,
),
"w_k": colwise_parallel(
output_layouts=None if self.k_norm is None else Shard(1),
use_local_output=self.k_norm is None,
),
"w_v": colwise_parallel(),
"w_out": rowwise_parallel(
output_layouts=output_layout, use_local_output=use_local_output
),
}
if self.w_g is not None:
plan["w_g"] = colwise_parallel()
if self.q_norm is not None:
# if full-dim norm: output is sharded on the embedding dimension (B, T, E [sharded])
# which will be reshaped into (B, T, H [sharded], D)
# if head-wise norm: output is sharded on the head dimension (B, T, H [sharded], D)
plan["q_norm"] = SequenceParallel(use_local_output=True, output_layouts=Shard(2))
if self.k_norm is not None:
plan["k_norm"] = SequenceParallel(use_local_output=True, output_layouts=Shard(2))
parallelize_module(
module=self,
device_mesh=tp_mesh,
parallelize_plan=plan,
)
[docs]
def apply_cp(
self,
cp_mesh: DeviceMesh,
ring: Optional[RingContextParallelStyle] = None,
uly: Optional[UlyssesContextParallelStyle] = None,
):
"""
Prepare the module for context-parallelism (ring attention).
.. important::
This requires a backend that supports CP, such as "flash_2" or "te".
:param cp_mesh: The context parallel device sub-mesh.
:param ring: The ring context parallel style.
:param uly: The ulysses context parallel style.
"""
self.backend.apply_cp(cp_mesh, ring=ring, uly=uly)
def init_weights(
self,
*,
init_method: "InitMethod",
d_model: int,
block_idx: int,
num_blocks: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
) -> None:
from olmo_core.nn.transformer.init import InitMethod, init_linear
# Compute std for Q/K/V initialization
if init_method == InitMethod.fan_in:
# For fan_in, use 1/√d_in based on actual weight shape (ignores base std parameter)
# Each projection may have different output dims (n_heads * head_dim vs n_kv_heads * head_dim)
# but they all have the same input dim
for w in (self.w_q, self.w_k, self.w_v):
w_std = w.in_features**-0.5
init_linear(w, std=w_std, generator=generator)
else:
if init_method == InitMethod.normalized:
std = d_model**-0.5
for w in (self.w_q, self.w_k, self.w_v):
init_linear(w, std=std, generator=generator)
# Initialize attention gate projection if present
if self.w_g is not None:
if init_method == InitMethod.fan_in:
g_std = self.w_g.in_features**-0.5
else:
g_std = std
init_linear(self.w_g, std=g_std, generator=generator)
# Compute std for w_out initialization
if init_method == InitMethod.fan_in:
std = self.w_out.in_features**-0.5
elif init_method == InitMethod.llama:
std = std / (2 * num_blocks) ** 0.5
elif init_method == InitMethod.llama_depth:
std = std / (2 * (block_idx + 1)) ** 0.5
elif init_method == InitMethod.normalized:
std = std / (2 * num_blocks) ** 0.5
init_linear(self.w_out, std=std, generator=generator)
[docs]
def init_kv_cache_manager(self, batch_size: int, max_seq_len: int):
"""
Initialize the kv cache manager for attention. When the kv cache manager exists,
kv caching will be used during the forward pass. This should only be called during inference.
:param batch_size: The batch size for the cache.
:param max_seq_len: The maximum sequence length for the cache.
"""
self.backend.assert_supports_kv_cache()
self.kv_cache_manager = KVCacheManager(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_kv_heads=self.n_kv_heads,
head_dim=self.head_dim,
device=self.w_k.weight.device,
)
[docs]
def num_flops_per_token(self, seq_len: int) -> int:
"""
This accounts for:
- Linear projections (Q, K, V, output, and gating if enabled)
- Attention computation (QK^T and softmax(QK^T) @ V)
- Sliding window attention (reduced effective sequence length)
"""
# 6 FLOPs per parameter (2 ops * 3 for forward+backward)
param_flops = 6 * sum(p.numel() for p in self.parameters())
# Attention computation (QK^T and Attn*V)
# 12x multiplier: 2 matmuls * 2 ops each * 3 for forward+backward
# For sliding window attention, effective sequence length is limited by window size
# Note that flash attention technically uses more flops (14x multiplier) due to recomputation,
# however, we just compute the idealized flops for SDPA.
effective_seq_len = min(self.window_size, seq_len) if self.window_size else seq_len
attn_flops = 12 * self.n_heads * self.head_dim * effective_seq_len
return param_flops + attn_flops
[docs]
@beta_feature
class NormalizedAttention(Attention):
"""
An nGPT attention implementation.
"""
def __init__(
self,
*,
d_model: int,
n_heads: int,
n_kv_heads: Optional[int] = None,
rope: Optional[RoPEConfig] = None,
qk_norm: Optional[LayerNormConfig] = None,
use_flash: Optional[bool] = None,
backend: Optional[AttentionBackendName] = None,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
):
super().__init__(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
rope=rope,
qk_norm=qk_norm,
use_flash=use_flash,
backend=backend,
softmax_scale=math.sqrt(d_model // n_heads),
bias=False,
dtype=dtype,
init_device=init_device,
cache=cache,
)
self.sq_init_value = 1.0
self.sq_init_scaling = 1.0 / math.sqrt(d_model)
self.sq = nn.Parameter(
torch.empty(self.head_dim * self.n_heads, dtype=dtype, device=init_device)
)
self.sk_init_value = 1.0
self.sk_init_scaling = 1.0 / math.sqrt(d_model)
self.sk = nn.Parameter(
torch.empty(self.head_dim * self.n_kv_heads, dtype=dtype, device=init_device)
)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.sq)
nn.init.ones_(self.sk)
with torch.no_grad():
self.sq.mul_(self.sq_init_scaling)
self.sk.mul_(self.sk_init_scaling)
[docs]
def forward(
self,
x: torch.Tensor,
cu_doc_lens: Optional[torch.Tensor] = None,
cu_doc_lens_q: Optional[torch.Tensor] = None,
cu_doc_lens_k: Optional[torch.Tensor] = None,
max_doc_len: Optional[int] = None,
max_doc_len_q: Optional[int] = None,
max_doc_len_k: Optional[int] = None,
local_k_slice: Optional[slice] = None,
pos_sin: Optional[torch.Tensor] = None,
pos_cos: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if cache_leftpad:
raise NotImplementedError(
"cache_leftpad is not supported for the normalized attention variant"
)
B, T, _ = x.shape
# shape: (batch_size, seq_len, n_heads * head_dim),
# (batch_size, seq_len, n_kv_heads * head_dim),
# (batch_size, seq_len, n_kv_heads * head_dim)
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
sq = (self.sq * (self.sq_init_value / self.sq_init_scaling)).view(1, 1, -1)
q = sq * q
sk = (self.sk * (self.sk_init_value / self.sk_init_scaling)).view(1, 1, -1)
k = sk * k
# shape: (batch_size, seq_len, n_heads, head_dim)
q = q.view(B, T, self.n_heads, self.head_dim)
# shape: (batch_size, seq_len, n_kv_heads, head_dim)
k = k.view(B, T, self.n_kv_heads, self.head_dim)
# shape: (batch_size, seq_len, n_kv_heads, head_dim)
v = v.view(B, T, self.n_kv_heads, self.head_dim)
if self.rope is not None:
if self.cp_enabled and pos_sin is None and pos_cos is None and freqs_cis is None:
raise RuntimeError(
"RoPE buffers must be passed through to attention after being properly "
"sharded by the context parallel load balancer"
)
start_pos = self.kv_cache_manager.current_position() if self.kv_cache_manager else None
q, k = self.rope(
q,
k,
head_first=False,
start_pos=start_pos,
pos_sin=pos_sin,
pos_cos=pos_cos,
freqs_cis=freqs_cis,
)
# shape: (batch_size, seq_len, n_heads, head_dim)
att = self.sdpa(
q,
k,
v,
cu_doc_lens=cu_doc_lens,
cu_doc_lens_q=cu_doc_lens_q,
cu_doc_lens_k=cu_doc_lens_k,
max_doc_len=max_doc_len,
max_doc_len_q=max_doc_len_q,
max_doc_len_k=max_doc_len_k,
local_k_slice=local_k_slice,
cache_leftpad=cache_leftpad,
)
# shape: (batch_size, seq_len, d_model)
att = att.view(B, T, -1)
# shape: (batch_size, seq_len, d_model)
return self.w_out(att)
def apply_tp(
self,
tp_mesh: DeviceMesh,
input_layout: Optional[Placement] = None,
output_layout: Optional[Placement] = None,
use_local_output: bool = True,
float8_enabled: bool = False,
):
del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled
raise NotImplementedError("TP is not implemented yet for the normalized attention variant")
[docs]
@torch.no_grad()
def normalize_matrices(self):
"""
Normalize the weights in all matrices. This should be called after each optimizer step, which
the :class:`~olmo_core.train.train_module.TransformerTrainModule` will handle for you.
"""
self._normalize_matrix(self.w_q.weight)
self._normalize_matrix(self.w_k.weight)
self._normalize_matrix(self.w_v.weight)
self._normalize_matrix(self.w_out.weight, dim=0)
def _normalize_matrix(self, w: torch.Tensor, dim: int = -1):
w.copy_(l2_normalize(w, dim=dim))
[docs]
class FusedAttention(SequenceMixer):
"""
An "fused" implementation of multi-head self-attention.
Intra-document masking is supported by passing in the ``max_doc_len`` and ``cu_doc_lens``
parameters to :meth:`forward()`.
.. warning::
Currently this is only supported with the "flash_2" backend.
.. warning::
If using RoPE, this requires that you use the "fused" RoPE implementation
(:class:`~olmo_core.nn.rope.FusedRotaryEmbedding`).
:param d_model: The model hidden size.
:param n_heads: The number of attention heads.
:param bias: Include biases with linear layers.
:param rope: The config for RoPE, if RoPE should be used.
:param clip_qkv: Clip QKV to this value, if set.
:param dropout: Dropout probability.
:param dtype: The default data type to use for parameters.
:param init_device: The device to initialize weights on.
"""
def __init__(
self,
*,
d_model: int,
n_heads: int,
bias: bool = True,
rope: Optional[RoPEConfig] = None,
clip_qkv: Optional[float] = None,
dropout: float = 0.0,
dtype: torch.dtype = torch.float32,
backend: Optional[AttentionBackendName] = None,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.w_qkv = nn.Linear(d_model, 3 * d_model, bias=bias, dtype=dtype, device=init_device)
self.w_out = nn.Linear(d_model, d_model, bias=bias, dtype=dtype, device=init_device)
self.clip_qkv = clip_qkv
self.rope: Optional[FusedRotaryEmbedding] = None
if rope is not None:
if rope.name != "fused":
raise OLMoConfigurationError(f"{self.__class__.__name__} requires fused RoPE")
rope_class = rope.build(self.head_dim, cache=cache)
assert isinstance(rope_class, FusedRotaryEmbedding)
self.rope = rope_class
if backend is not None:
backend = AttentionBackendName(backend)
elif backend is None:
backend = AttentionBackendName.flash_2
backend.assert_supported()
backend.assert_supports_packed_qkv()
log.info(f"Using attention backend '{backend}'")
self.backend = backend.build(
head_dim=self.head_dim, n_heads=self.n_heads, dropout_p=dropout, cache=cache
)
@property
def cp_enabled(self) -> bool:
return self.backend.cp_enabled
[docs]
def forward(
self,
x: torch.Tensor,
max_doc_len: Optional[int] = None,
cu_doc_lens: Optional[torch.Tensor] = None,
pos_sin: Optional[torch.Tensor] = None,
pos_cos: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply attention to the input.
:param x: The input of shape ``(batch_size, seq_len, d_model)``.
:param max_doc_len: The maximum document length in the input ``x``.
Required together with ``cu_doc_lens`` when using intra-document masking.
:param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D
:class:`torch.int32` tensor that should always have one more element than there
are documents (the first element in the tensor should always be ``0``).
Required together with ``max_doc_len`` when using intra-document masking.
:returns: The output of attention with shape ``(batch_size, seq_len, d_model)``.
"""
if cache_leftpad:
raise NotImplementedError(
"cache_leftpad is not supported for the fused attention variant"
)
B, T, _ = x.shape
# shape: (batch_size, seq_len, 3, n_heads, head_dim)
qkv = self.w_qkv(x).view(B, T, 3, self.n_heads, self.head_dim)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
if self.rope is not None:
if self.cp_enabled and pos_sin is None and pos_cos is None and freqs_cis is None:
raise RuntimeError(
"RoPE buffers must be passed through to attention after being properly "
"sharded by the context parallel load balancer"
)
qkv = self.rope(qkv, pos_sin=pos_sin, pos_cos=pos_cos, freqs_cis=freqs_cis)
att = self.backend(
qkv,
cu_doc_lens=cu_doc_lens,
max_doc_len=max_doc_len,
)
# shape: (batch_size, seq_len, d_model)
att = att.view(B, T, -1) # type: ignore
# shape: (batch_size, seq_len, d_model)
return self.w_out(att)
def apply_tp(
self,
tp_mesh: DeviceMesh,
input_layout: Optional[Placement] = None,
output_layout: Optional[Placement] = None,
use_local_output: bool = True,
float8_enabled: bool = False,
):
del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled
raise NotImplementedError("TP is not implemented yet for the fused attention variant")
def apply_cp(
self,
cp_mesh: DeviceMesh,
ring: Optional[RingContextParallelStyle] = None,
uly: Optional[UlyssesContextParallelStyle] = None,
):
self.backend.apply_cp(cp_mesh, ring=ring, uly=uly)
def init_weights(
self,
*,
init_method: "InitMethod",
d_model: int,
block_idx: int,
num_blocks: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
) -> None:
from olmo_core.nn.transformer.init import InitMethod, init_linear
# Compute std for fused QKV initialization
if init_method == InitMethod.fan_in:
std = self.w_qkv.in_features**-0.5
elif init_method == InitMethod.normalized:
std = d_model**-0.5
init_linear(self.w_qkv, std=std, generator=generator)
# Compute std for w_out initialization
if init_method == InitMethod.fan_in:
std = self.w_out.in_features**-0.5
elif init_method == InitMethod.llama:
std = std / (2 * num_blocks) ** 0.5
elif init_method == InitMethod.llama_depth:
std = std / (2 * (block_idx + 1)) ** 0.5
elif init_method == InitMethod.normalized:
std = std / (2 * num_blocks) ** 0.5
init_linear(self.w_out, std=std, generator=generator)
def num_flops_per_token(self, seq_len: int) -> int:
# 6 FLOPs per parameter (2 ops * 3 for forward+backward)
param_flops = 6 * sum(p.numel() for p in self.parameters())
# Attention computation (QK^T and Attn*V)
# 12x multiplier: 2 matmuls * 2 ops each * 3 for forward+backward
attn_flops = 12 * self.n_heads * self.head_dim * seq_len
return param_flops + attn_flops