Source code for olmo_core.nn.attention

import logging
import math
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Placement, Replicate, Shard
from torch.distributed.tensor.parallel import parallelize_module

from olmo_core.config import Config, DType, StrEnum
from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel
from olmo_core.doc_utils import beta_feature
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.nn.attention.base import SequenceMixer, SequenceMixerConfig
from olmo_core.nn.attention.kv_cache import KVCacheManager
from olmo_core.nn.attention.recurrent import GatedDeltaNet, GatedDeltaNetConfig

from ..buffer_cache import BufferCache
from ..config import ModuleConfig
from ..functional import l2_normalize
from ..layer_norm import LayerNorm, LayerNormConfig
from ..rope import (
    ComplexRotaryEmbedding,
    FusedRotaryEmbedding,
    RoPEConfig,
    RotaryEmbedding,
)
from ..utils import get_tp_wrappers
from . import flash_attn_api
from .backend import (
    AttentionBackend,
    AttentionBackendName,
    FlashAttention2Backend,
    FlashAttention3Backend,
    FlashAttention4Backend,
    TEAttentionBackend,
    TorchAttentionBackend,
)
from .ring import (
    RingAttentionLlama3LoadBalancer,
    RingAttentionLoadBalancer,
    RingAttentionLoadBalancerType,
    RingAttentionZigZagLoadBalancer,
    RingContextParallelStyle,
    UlyssesContextParallelStyle,
    UlyssesLoadBalancer,
)

if TYPE_CHECKING:
    from olmo_core.nn.transformer.init import InitMethod

__all__ = [
    "SlidingWindowAttentionConfig",
    "GateGranularity",
    "GateConfig",
    "AttentionType",
    "AttentionBackendName",
    "AttentionBackend",
    "TorchAttentionBackend",
    "FlashAttention2Backend",
    "FlashAttention3Backend",
    "FlashAttention4Backend",
    "TEAttentionBackend",
    "AttentionConfig",
    "Attention",
    "FusedAttention",
    "NormalizedAttention",
    "RingAttentionLoadBalancerType",
    "RingAttentionLoadBalancer",
    "RingAttentionZigZagLoadBalancer",
    "RingAttentionLlama3LoadBalancer",
    "UlyssesLoadBalancer",
    "RingContextParallelStyle",
    "UlyssesContextParallelStyle",
    "GatedDeltaNetConfig",
    "GatedDeltaNet",
]

log = logging.getLogger(__name__)


[docs] class GateGranularity(StrEnum): headwise = "headwise" """Head-wise gating: one gate value per attention head, broadcast across head dimension.""" elementwise = "elementwise" """Element-wise gating: one gate value per output element."""
[docs] @dataclass class GateConfig(Config): granularity: GateGranularity = GateGranularity.headwise """The granularity of gating to use.""" full_precision: bool = True """Whether to always apply gating in full precision regardless of the input data type."""
[docs] @dataclass class SlidingWindowAttentionConfig(Config): pattern: List[int] """ The pattern of window sizes to use for attention, repeated to cover all layers. A value of -1 indicates full attention. For example, a pattern of ``[4096, 4096, 4096, -1]`` means that for each set of 4 layers, the first 3 will use a window size of 4096, and the last layer will use full attention. """ force_full_attention_on_first_layer: bool = True """ If `True`, the first transformer layer will always use full attention, regardless of the pattern. """ force_full_attention_on_last_layer: bool = True """ If `True`, the last transformer layer will always use full attention, regardless of the pattern. """ def _get_window_size(self, layer_idx: int, n_layers: int) -> int: """ Get the window size for a given layer, returning -1 for full attention. """ if self.force_full_attention_on_first_layer and layer_idx == 0: return -1 if self.force_full_attention_on_last_layer and layer_idx == (n_layers - 1): return -1 # Adjust the layer index if the first layer is special-cased to full attention # (in which case the pattern is applied starting from the second layer) effective_layer_idx = layer_idx if self.force_full_attention_on_first_layer: effective_layer_idx -= 1 window_size = self.pattern[effective_layer_idx % len(self.pattern)] if window_size <= 0 and window_size != -1: raise OLMoConfigurationError( f"Sliding window size must be positive or -1 (got {window_size})" ) return window_size
[docs] def should_use_swa(self, layer_idx: int, n_layers: int) -> bool: """ Returns `True` if the given layer uses sliding window attention. """ return self._get_window_size(layer_idx, n_layers) != -1
[docs] def get_window_size(self, layer_idx: int, n_layers: int) -> int: """ Get the sliding window size for a given layer. """ window_size = self._get_window_size(layer_idx, n_layers) if window_size == -1: raise ValueError(f"Layer {layer_idx} is not configured for sliding window attention.") return window_size
[docs] class AttentionType(StrEnum): """ An enumeration of the different attention implementations. """ default = "default" """ ➡️ :class:`Attention` """ fused = "fused" """ ➡️ :class:`FusedAttention` """ normalized = "normalized" """ ➡️ :class:`NormalizedAttention` """
[docs] @SequenceMixerConfig.register("attention") @dataclass class AttentionConfig(SequenceMixerConfig["SequenceMixer"]): """ A configuration class for easily building any of the different attention modules. See the individual :class:`Attention` subclasses for a description of the configuration options. """ name: AttentionType = AttentionType.default """ The name of the implementation. """ n_heads: int = 16 n_kv_heads: Optional[int] = None head_dim: Optional[int] = None bias: Optional[bool] = None gate: Optional[GateConfig] = None rope: Optional[RoPEConfig] = None clip_qkv: Optional[float] = None qk_norm: Optional[LayerNormConfig] = None dropout: Optional[float] = None use_flash: Optional[bool] = None backend: Optional[AttentionBackendName] = None dtype: DType = DType.float32 sliding_window: Optional[SlidingWindowAttentionConfig] = None use_head_qk_norm: Optional[bool] = None
[docs] def num_params(self, d_model: int) -> int: """ The number of params that the attention implementation will have once built. :param d_model: The model dimensionality. """ n_heads = self.n_heads n_kv_heads = self.n_kv_heads or n_heads head_dim = self.head_dim or d_model // n_heads bias = self.bias if self.bias is not None else self.name != AttentionType.normalized params = 0 # Block attention Q projection. params += d_model * n_heads * head_dim if bias: params += n_heads * head_dim # Block attention KV projections. params += 2 * d_model * n_kv_heads * head_dim if bias: params += 2 * n_kv_heads * head_dim # Block attention QK norm. if self.qk_norm is not None: if self.use_head_qk_norm: params += 2 * self.qk_norm.num_params(head_dim) else: params += self.qk_norm.num_params(n_heads * head_dim) # q_norm params += self.qk_norm.num_params(n_kv_heads * head_dim) # k_norm # Block attention out. params += n_heads * head_dim * d_model if bias: params += d_model # Block attention gate projection. if self.gate is not None: if self.gate.granularity == GateGranularity.headwise: params += d_model * n_heads if bias: params += n_heads elif self.gate.granularity == GateGranularity.elementwise: params += d_model * (n_heads * head_dim) if bias: params += n_heads * head_dim # Block QK scaling factors. if self.name == AttentionType.normalized: params += n_heads * head_dim params += n_kv_heads * head_dim return params
[docs] def build( self, d_model: int, *, layer_idx: int, n_layers: int, init_device: str = "cpu", cache: Optional[BufferCache] = None, ) -> "SequenceMixer": """ Build the corresponding attention module. :param d_model: The model dimensionality. :param init_device: The device to initialize the parameters on, e.g. "cpu", "meta". """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") sliding_window_config: Optional[SlidingWindowAttentionConfig] = kwargs.pop( "sliding_window", None ) if sliding_window_config is not None and sliding_window_config.should_use_swa( layer_idx, n_layers ): kwargs["window_size"] = sliding_window_config.get_window_size(layer_idx, n_layers) else: # global (non-SWA) layer rope_config: Optional[RoPEConfig] = kwargs.get("rope") if rope_config is not None and rope_config.no_global_rope: kwargs["rope"] = None kwargs.update( dtype=kwargs.pop("dtype").as_pt(), d_model=d_model, init_device=init_device, cache=cache, ) try: if self.name == "default": return Attention(**kwargs) elif self.name == "fused": kwargs.pop("use_flash", None) if "window_size" in kwargs: raise OLMoConfigurationError( "'window_size' is not supported with fused attention" ) return FusedAttention(**kwargs) elif self.name == "normalized": if "window_size" in kwargs: raise OLMoConfigurationError( "'window_size' is not supported with normalized attention" ) return NormalizedAttention(**kwargs) else: raise NotImplementedError(self.name) except TypeError as e: raise OLMoConfigurationError( f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" ) from e
[docs] class Attention(SequenceMixer): """ An implementation of multi-head self-attention with support for multi-query (MQA) and grouped-query (GQA) attention. Intra-document masking is also supported by passing in the ``max_doc_len`` and ``cu_doc_lens`` parameters to :meth:`forward()`. This requires a backend that supports it, like the flash backend. .. seealso:: :class:`FusedAttention` if you have flash-attn installed and you're not using MQA or GQA. :param d_model: The model hidden size. :param n_heads: The number of attention heads. :param n_kv_heads: The number of key and value heads, if different. :param bias: Include biases with linear layers. :param gate: Configuration for attention gating. If None, no gating is applied. :param rope: The config for RoPE, if RoPE should be used. :param clip_qkv: Clip QKV to this value, if set. :param qk_norm: Configuration a layer norm for queries and keys. :param dropout: Dropout probability. :param use_flash: Deprecated, use ``backend="flash_2"`` instead. :param backend: The attention backend to use. If not set, it will be chosen automatically. :param dtype: The default data type to use for parameters. :param init_device: The device to initialize weights on. """ def __init__( self, *, d_model: int, n_heads: int, n_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, bias: bool = True, gate: Optional[GateConfig] = None, rope: Optional[RoPEConfig] = None, clip_qkv: Optional[float] = None, qk_norm: Optional[LayerNormConfig] = None, dropout: float = 0.0, softmax_scale: Optional[float] = None, use_flash: Optional[bool] = None, backend: Optional[AttentionBackendName] = None, window_size: Optional[int] = None, dtype: torch.dtype = torch.float32, init_device: str = "cpu", cache: Optional[BufferCache] = None, use_head_qk_norm: bool = False, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads or n_heads self.d_model = d_model # Some models (e.g. Qwen3) use explicit head_dim that differs from d_model // n_heads. if head_dim is not None: self.head_dim = head_dim else: self.head_dim = d_model // n_heads self.w_q = nn.Linear( d_model, n_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device ) self.w_k = nn.Linear( d_model, self.n_kv_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device ) self.w_v = nn.Linear( d_model, self.n_kv_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device ) self.w_out = nn.Linear( n_heads * self.head_dim, d_model, bias=bias, dtype=dtype, device=init_device ) self.gate = gate self.w_g: Optional[nn.Linear] = None if gate is not None: if gate.granularity == GateGranularity.headwise: self.w_g = nn.Linear( d_model, self.n_heads, bias=bias, dtype=dtype, device=init_device ) elif gate.granularity == GateGranularity.elementwise: self.w_g = nn.Linear( d_model, self.n_heads * self.head_dim, bias=bias, dtype=dtype, device=init_device, ) self.clip_qkv = clip_qkv self.use_head_qk_norm = use_head_qk_norm self.q_norm: Optional[LayerNorm] = None self.k_norm: Optional[LayerNorm] = None if qk_norm is not None: if use_head_qk_norm: self.q_norm = qk_norm.build(size=self.head_dim, init_device=init_device) self.k_norm = qk_norm.build(size=self.head_dim, init_device=init_device) else: self.q_norm = qk_norm.build(size=n_heads * self.head_dim, init_device=init_device) self.k_norm = qk_norm.build( size=self.n_kv_heads * self.head_dim, init_device=init_device ) self.rope: Optional[Union[RotaryEmbedding, ComplexRotaryEmbedding]] = None if rope is not None: if rope.name == "fused": raise OLMoConfigurationError( f"fused RoPE is not compatible with {self.__class__.__name__}" ) rope_class = rope.build(self.head_dim, cache=cache) assert isinstance(rope_class, (RotaryEmbedding, ComplexRotaryEmbedding)) self.rope = rope_class if backend is not None: backend = AttentionBackendName(backend) if use_flash: if backend is not None and backend != AttentionBackendName.flash_2: raise OLMoConfigurationError( f"'use_flash' is only compatible with 'flash_2' backend (got '{backend}')" ) elif backend is None: warnings.warn( "'use_flash' is deprecated, use 'backend=flash_2' instead", DeprecationWarning ) backend = AttentionBackendName.flash_2 # Translate window size so that we only look left, not right. self.window_size = window_size window_size_tuple: Tuple[int, int] = (-1, -1) if window_size is not None: if window_size <= 0: raise OLMoConfigurationError(f"'window_size' must be positive (got {window_size})") if backend is None and flash_attn_api.has_flash_attn_2(): # note: flash_3, flash_4, and te backends are faster than flash_2 and also support SWA backend = AttentionBackendName.flash_2 # Window size is [i - window_size[0], i + window_size[1]] inclusive window_size_tuple = (window_size - 1, 0) if backend is None: backend = AttentionBackendName.torch if not torch.cuda.is_available() and backend != AttentionBackendName.torch: warnings.warn( f"Backend is set to {backend}, but GPUs are not available. Defaulting to torch." ) backend = AttentionBackendName.torch backend.assert_supported() log.info(f"Using attention backend '{backend}'") self.backend = backend.build( head_dim=self.head_dim, n_heads=n_heads, n_kv_heads=self.n_kv_heads, scale=softmax_scale, dropout_p=dropout, window_size=window_size_tuple, cache=cache, ) self.kv_cache_manager: Optional[KVCacheManager] = None @property def cp_enabled(self) -> bool: return self.backend.cp_enabled def sdpa( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, cache_leftpad: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.kv_cache_manager is not None: self.kv_cache_manager.record_leftpad(cache_leftpad) # shape: (batch_size, seq_len, n_heads, head_dim) att = self.backend( (q, k, v), cu_doc_lens=cu_doc_lens, cu_doc_lens_q=cu_doc_lens_q, cu_doc_lens_k=cu_doc_lens_k, max_doc_len=max_doc_len, max_doc_len_q=max_doc_len_q, max_doc_len_k=max_doc_len_k, local_k_slice=local_k_slice, kv_cache_manager=self.kv_cache_manager, ) if self.kv_cache_manager is not None: self.kv_cache_manager.update_seqlen(q.shape[1]) return att def _apply_rope( self, q: torch.Tensor, k: torch.Tensor, start_pos: Optional[int], pos_sin: Optional[torch.Tensor], pos_cos: Optional[torch.Tensor], freqs_cis: Optional[torch.Tensor], cu_doc_lens: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: assert self.rope is not None rope_kwargs = {} if cu_doc_lens is not None: if not isinstance(self.rope, RotaryEmbedding): raise NotImplementedError( "Intra-document RoPE (cu_doc_lens) is only supported by RotaryEmbedding; " f"got {type(self.rope).__name__}" ) rope_kwargs["cu_doc_lens"] = cu_doc_lens return self.rope( q, k, head_first=False, start_pos=start_pos, pos_sin=pos_sin, pos_cos=pos_cos, freqs_cis=freqs_cis, **rope_kwargs, )
[docs] def forward( self, x: torch.Tensor, cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, pos_sin: Optional[torch.Tensor] = None, pos_cos: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Apply attention to the input. :param x: The input of shape ``(batch_size, seq_len, d_model)``. :param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D :class:`torch.int32` tensor that should always have one more element than there are documents (the first element in the tensor should always be ``0``). Required together with ``max_doc_len`` when using intra-document masking. :param max_doc_len: The maximum document length in the input ``x``. Required together with ``cu_doc_lens`` when using intra-document masking. :returns: The output of attention with shape ``(batch_size, seq_len, d_model)``. """ B, T, _ = x.shape # shape: (batch_size, seq_len, n_heads * head_dim), # (batch_size, seq_len, n_kv_heads * head_dim), # (batch_size, seq_len, n_kv_heads * head_dim) q, k, v = self.w_q(x), self.w_k(x), self.w_v(x) if self.clip_qkv is not None: q.clamp_(min=-self.clip_qkv, max=self.clip_qkv) k.clamp_(min=-self.clip_qkv, max=self.clip_qkv) v.clamp_(min=-self.clip_qkv, max=self.clip_qkv) if not self.use_head_qk_norm: if self.q_norm is not None: q = self.q_norm(q) if self.k_norm is not None: k = self.k_norm(k) # NOTE: use -1 instead of `n_heads` / `n_kv_heads` to infer actual local size when # using tensor parallelism. # shape: (batch_size, seq_len, n_heads (local), head_dim) q = q.view(B, T, -1, self.head_dim) # shape: (batch_size, seq_len, n_kv_heads (local), head_dim) k = k.view(B, T, -1, self.head_dim) # shape: (batch_size, seq_len, n_kv_heads (local), head_dim) v = v.view(B, T, -1, self.head_dim) if self.use_head_qk_norm: if self.q_norm is not None: q = self.q_norm(q) if self.k_norm is not None: k = self.k_norm(k) if self.rope is not None: # In context-parallel mode we must be given pre-sharded buffers if self.cp_enabled and pos_sin is None and pos_cos is None and freqs_cis is None: raise RuntimeError( "RoPE buffers must be passed through to attention after being properly " "sharded by the context parallel load balancer" ) start_pos = self.kv_cache_manager.current_position() if self.kv_cache_manager else None q, k = self._apply_rope(q, k, start_pos, pos_sin, pos_cos, freqs_cis, cu_doc_lens) # shape: (batch_size, seq_len, n_heads, head_dim) att = self.sdpa( q, k, v, cu_doc_lens=cu_doc_lens, cu_doc_lens_q=cu_doc_lens_q, cu_doc_lens_k=cu_doc_lens_k, max_doc_len=max_doc_len, max_doc_len_q=max_doc_len_q, max_doc_len_k=max_doc_len_k, local_k_slice=local_k_slice, cache_leftpad=cache_leftpad, ) if self.gate is not None: assert self.w_g is not None g = self.w_g(x) if self.gate.full_precision: g = g.float() gate_values = torch.sigmoid(g).to(att.dtype) if self.gate.granularity == GateGranularity.headwise: # head-wise gating is broadcast across head_dim # shape: (batch_size, seq_len, n_heads, head_dim) att = att * gate_values.unsqueeze(-1) elif self.gate.granularity == GateGranularity.elementwise: att = att.view(B, T, -1) * gate_values # the following att.view op is redundant (a no-op) # shape: (batch_size, seq_len, d_model) att = att.view(B, T, -1) # shape: (batch_size, seq_len, d_model) return self.w_out(att)
def apply_tp( self, tp_mesh: DeviceMesh, input_layout: Optional[Placement] = None, output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): rowwise_parallel, colwise_parallel, prepare_module_input = get_tp_wrappers( float8_enabled=float8_enabled ) parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=prepare_module_input( input_layouts=None if input_layout is None else (input_layout,), desired_input_layouts=(Replicate(),), ), ) plan = { "w_q": colwise_parallel( output_layouts=None if self.q_norm is None else Shard(1), use_local_output=self.q_norm is None, ), "w_k": colwise_parallel( output_layouts=None if self.k_norm is None else Shard(1), use_local_output=self.k_norm is None, ), "w_v": colwise_parallel(), "w_out": rowwise_parallel( output_layouts=output_layout, use_local_output=use_local_output ), } if self.w_g is not None: plan["w_g"] = colwise_parallel() if self.q_norm is not None: # if full-dim norm: output is sharded on the embedding dimension (B, T, E [sharded]) # which will be reshaped into (B, T, H [sharded], D) # if head-wise norm: output is sharded on the head dimension (B, T, H [sharded], D) plan["q_norm"] = SequenceParallel(use_local_output=True, output_layouts=Shard(2)) if self.k_norm is not None: plan["k_norm"] = SequenceParallel(use_local_output=True, output_layouts=Shard(2)) parallelize_module( module=self, device_mesh=tp_mesh, parallelize_plan=plan, )
[docs] def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): """ Prepare the module for context-parallelism (ring attention). .. important:: This requires a backend that supports CP, such as "flash_2" or "te". :param cp_mesh: The context parallel device sub-mesh. :param ring: The ring context parallel style. :param uly: The ulysses context parallel style. """ self.backend.apply_cp(cp_mesh, ring=ring, uly=uly)
def init_weights( self, *, init_method: "InitMethod", d_model: int, block_idx: int, num_blocks: int, std: float = 0.02, generator: Optional[torch.Generator] = None, ) -> None: from olmo_core.nn.transformer.init import InitMethod, init_linear # Compute std for Q/K/V initialization if init_method == InitMethod.fan_in: # For fan_in, use 1/√d_in based on actual weight shape (ignores base std parameter) # Each projection may have different output dims (n_heads * head_dim vs n_kv_heads * head_dim) # but they all have the same input dim for w in (self.w_q, self.w_k, self.w_v): w_std = w.in_features**-0.5 init_linear(w, std=w_std, generator=generator) else: if init_method == InitMethod.normalized: std = d_model**-0.5 for w in (self.w_q, self.w_k, self.w_v): init_linear(w, std=std, generator=generator) # Initialize attention gate projection if present if self.w_g is not None: if init_method == InitMethod.fan_in: g_std = self.w_g.in_features**-0.5 else: g_std = std init_linear(self.w_g, std=g_std, generator=generator) # Compute std for w_out initialization if init_method == InitMethod.fan_in: std = self.w_out.in_features**-0.5 elif init_method == InitMethod.llama: std = std / (2 * num_blocks) ** 0.5 elif init_method == InitMethod.llama_depth: std = std / (2 * (block_idx + 1)) ** 0.5 elif init_method == InitMethod.normalized: std = std / (2 * num_blocks) ** 0.5 init_linear(self.w_out, std=std, generator=generator)
[docs] def init_kv_cache_manager(self, batch_size: int, max_seq_len: int): """ Initialize the kv cache manager for attention. When the kv cache manager exists, kv caching will be used during the forward pass. This should only be called during inference. :param batch_size: The batch size for the cache. :param max_seq_len: The maximum sequence length for the cache. """ self.backend.assert_supports_kv_cache() self.kv_cache_manager = KVCacheManager( batch_size=batch_size, max_seq_len=max_seq_len, num_kv_heads=self.n_kv_heads, head_dim=self.head_dim, device=self.w_k.weight.device, )
[docs] def num_flops_per_token(self, seq_len: int) -> int: """ This accounts for: - Linear projections (Q, K, V, output, and gating if enabled) - Attention computation (QK^T and softmax(QK^T) @ V) - Sliding window attention (reduced effective sequence length) """ # 6 FLOPs per parameter (2 ops * 3 for forward+backward) param_flops = 6 * sum(p.numel() for p in self.parameters()) # Attention computation (QK^T and Attn*V) # 12x multiplier: 2 matmuls * 2 ops each * 3 for forward+backward # For sliding window attention, effective sequence length is limited by window size # Note that flash attention technically uses more flops (14x multiplier) due to recomputation, # however, we just compute the idealized flops for SDPA. effective_seq_len = min(self.window_size, seq_len) if self.window_size else seq_len attn_flops = 12 * self.n_heads * self.head_dim * effective_seq_len return param_flops + attn_flops
[docs] @beta_feature class NormalizedAttention(Attention): """ An nGPT attention implementation. """ def __init__( self, *, d_model: int, n_heads: int, n_kv_heads: Optional[int] = None, rope: Optional[RoPEConfig] = None, qk_norm: Optional[LayerNormConfig] = None, use_flash: Optional[bool] = None, backend: Optional[AttentionBackendName] = None, dtype: torch.dtype = torch.float32, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, rope=rope, qk_norm=qk_norm, use_flash=use_flash, backend=backend, softmax_scale=math.sqrt(d_model // n_heads), bias=False, dtype=dtype, init_device=init_device, cache=cache, ) self.sq_init_value = 1.0 self.sq_init_scaling = 1.0 / math.sqrt(d_model) self.sq = nn.Parameter( torch.empty(self.head_dim * self.n_heads, dtype=dtype, device=init_device) ) self.sk_init_value = 1.0 self.sk_init_scaling = 1.0 / math.sqrt(d_model) self.sk = nn.Parameter( torch.empty(self.head_dim * self.n_kv_heads, dtype=dtype, device=init_device) ) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.sq) nn.init.ones_(self.sk) with torch.no_grad(): self.sq.mul_(self.sq_init_scaling) self.sk.mul_(self.sk_init_scaling)
[docs] def forward( self, x: torch.Tensor, cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, pos_sin: Optional[torch.Tensor] = None, pos_cos: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cache_leftpad: raise NotImplementedError( "cache_leftpad is not supported for the normalized attention variant" ) B, T, _ = x.shape # shape: (batch_size, seq_len, n_heads * head_dim), # (batch_size, seq_len, n_kv_heads * head_dim), # (batch_size, seq_len, n_kv_heads * head_dim) q, k, v = self.w_q(x), self.w_k(x), self.w_v(x) if self.q_norm is not None and self.k_norm is not None: q = self.q_norm(q) k = self.k_norm(k) sq = (self.sq * (self.sq_init_value / self.sq_init_scaling)).view(1, 1, -1) q = sq * q sk = (self.sk * (self.sk_init_value / self.sk_init_scaling)).view(1, 1, -1) k = sk * k # shape: (batch_size, seq_len, n_heads, head_dim) q = q.view(B, T, self.n_heads, self.head_dim) # shape: (batch_size, seq_len, n_kv_heads, head_dim) k = k.view(B, T, self.n_kv_heads, self.head_dim) # shape: (batch_size, seq_len, n_kv_heads, head_dim) v = v.view(B, T, self.n_kv_heads, self.head_dim) if self.rope is not None: if self.cp_enabled and pos_sin is None and pos_cos is None and freqs_cis is None: raise RuntimeError( "RoPE buffers must be passed through to attention after being properly " "sharded by the context parallel load balancer" ) start_pos = self.kv_cache_manager.current_position() if self.kv_cache_manager else None q, k = self._apply_rope(q, k, start_pos, pos_sin, pos_cos, freqs_cis, cu_doc_lens) # shape: (batch_size, seq_len, n_heads, head_dim) att = self.sdpa( q, k, v, cu_doc_lens=cu_doc_lens, cu_doc_lens_q=cu_doc_lens_q, cu_doc_lens_k=cu_doc_lens_k, max_doc_len=max_doc_len, max_doc_len_q=max_doc_len_q, max_doc_len_k=max_doc_len_k, local_k_slice=local_k_slice, cache_leftpad=cache_leftpad, ) # shape: (batch_size, seq_len, d_model) att = att.view(B, T, -1) # shape: (batch_size, seq_len, d_model) return self.w_out(att)
def apply_tp( self, tp_mesh: DeviceMesh, input_layout: Optional[Placement] = None, output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled raise NotImplementedError("TP is not implemented yet for the normalized attention variant")
[docs] @torch.no_grad() def normalize_matrices(self): """ Normalize the weights in all matrices. This should be called after each optimizer step, which the :class:`~olmo_core.train.train_module.TransformerTrainModule` will handle for you. """ self._normalize_matrix(self.w_q.weight) self._normalize_matrix(self.w_k.weight) self._normalize_matrix(self.w_v.weight) self._normalize_matrix(self.w_out.weight, dim=0)
def _normalize_matrix(self, w: torch.Tensor, dim: int = -1): w.copy_(l2_normalize(w, dim=dim))
[docs] class FusedAttention(SequenceMixer): """ An "fused" implementation of multi-head self-attention. Intra-document masking is supported by passing in the ``max_doc_len`` and ``cu_doc_lens`` parameters to :meth:`forward()`. .. warning:: Currently this is only supported with the "flash_2" backend. .. warning:: If using RoPE, this requires that you use the "fused" RoPE implementation (:class:`~olmo_core.nn.rope.FusedRotaryEmbedding`). :param d_model: The model hidden size. :param n_heads: The number of attention heads. :param bias: Include biases with linear layers. :param rope: The config for RoPE, if RoPE should be used. :param clip_qkv: Clip QKV to this value, if set. :param dropout: Dropout probability. :param dtype: The default data type to use for parameters. :param init_device: The device to initialize weights on. """ def __init__( self, *, d_model: int, n_heads: int, bias: bool = True, rope: Optional[RoPEConfig] = None, clip_qkv: Optional[float] = None, dropout: float = 0.0, dtype: torch.dtype = torch.float32, backend: Optional[AttentionBackendName] = None, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.w_qkv = nn.Linear(d_model, 3 * d_model, bias=bias, dtype=dtype, device=init_device) self.w_out = nn.Linear(d_model, d_model, bias=bias, dtype=dtype, device=init_device) self.clip_qkv = clip_qkv self.rope: Optional[FusedRotaryEmbedding] = None if rope is not None: if rope.name != "fused": raise OLMoConfigurationError(f"{self.__class__.__name__} requires fused RoPE") rope_class = rope.build(self.head_dim, cache=cache) assert isinstance(rope_class, FusedRotaryEmbedding) self.rope = rope_class if backend is not None: backend = AttentionBackendName(backend) elif backend is None: backend = AttentionBackendName.flash_2 backend.assert_supported() backend.assert_supports_packed_qkv() log.info(f"Using attention backend '{backend}'") self.backend = backend.build( head_dim=self.head_dim, n_heads=self.n_heads, dropout_p=dropout, cache=cache ) @property def cp_enabled(self) -> bool: return self.backend.cp_enabled
[docs] def forward( self, x: torch.Tensor, max_doc_len: Optional[int] = None, cu_doc_lens: Optional[torch.Tensor] = None, pos_sin: Optional[torch.Tensor] = None, pos_cos: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Apply attention to the input. :param x: The input of shape ``(batch_size, seq_len, d_model)``. :param max_doc_len: The maximum document length in the input ``x``. Required together with ``cu_doc_lens`` when using intra-document masking. :param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D :class:`torch.int32` tensor that should always have one more element than there are documents (the first element in the tensor should always be ``0``). Required together with ``max_doc_len`` when using intra-document masking. :returns: The output of attention with shape ``(batch_size, seq_len, d_model)``. """ if cache_leftpad: raise NotImplementedError( "cache_leftpad is not supported for the fused attention variant" ) if cu_doc_lens is not None and self.rope is not None: raise NotImplementedError( "Intra-document RoPE (cu_doc_lens) is not yet supported by FusedAttention" ) B, T, _ = x.shape # shape: (batch_size, seq_len, 3, n_heads, head_dim) qkv = self.w_qkv(x).view(B, T, 3, self.n_heads, self.head_dim) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) if self.rope is not None: if self.cp_enabled and pos_sin is None and pos_cos is None and freqs_cis is None: raise RuntimeError( "RoPE buffers must be passed through to attention after being properly " "sharded by the context parallel load balancer" ) qkv = self.rope(qkv, pos_sin=pos_sin, pos_cos=pos_cos, freqs_cis=freqs_cis) att = self.backend( qkv, cu_doc_lens=cu_doc_lens, max_doc_len=max_doc_len, ) # shape: (batch_size, seq_len, d_model) att = att.view(B, T, -1) # type: ignore # shape: (batch_size, seq_len, d_model) return self.w_out(att)
def apply_tp( self, tp_mesh: DeviceMesh, input_layout: Optional[Placement] = None, output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled raise NotImplementedError("TP is not implemented yet for the fused attention variant") def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): self.backend.apply_cp(cp_mesh, ring=ring, uly=uly) def init_weights( self, *, init_method: "InitMethod", d_model: int, block_idx: int, num_blocks: int, std: float = 0.02, generator: Optional[torch.Generator] = None, ) -> None: from olmo_core.nn.transformer.init import InitMethod, init_linear # Compute std for fused QKV initialization if init_method == InitMethod.fan_in: std = self.w_qkv.in_features**-0.5 elif init_method == InitMethod.normalized: std = d_model**-0.5 init_linear(self.w_qkv, std=std, generator=generator) # Compute std for w_out initialization if init_method == InitMethod.fan_in: std = self.w_out.in_features**-0.5 elif init_method == InitMethod.llama: std = std / (2 * num_blocks) ** 0.5 elif init_method == InitMethod.llama_depth: std = std / (2 * (block_idx + 1)) ** 0.5 elif init_method == InitMethod.normalized: std = std / (2 * num_blocks) ** 0.5 init_linear(self.w_out, std=std, generator=generator) def num_flops_per_token(self, seq_len: int) -> int: # 6 FLOPs per parameter (2 ops * 3 for forward+backward) param_flops = 6 * sum(p.numel() for p in self.parameters()) # Attention computation (QK^T and Attn*V) # 12x multiplier: 2 matmuls * 2 ops each * 3 for forward+backward attn_flops = 12 * self.n_heads * self.head_dim * seq_len return param_flops + attn_flops