import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from torch import nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import Placement
from torch.nn import functional as F
from olmo_core.config import DType
from olmo_core.distributed.parallel.context_parallel import (
all_to_all_cp2hp,
all_to_all_single_cp2hp,
all_to_all_single_hp2cp,
)
from olmo_core.nn.attention.base import SequenceMixer, SequenceMixerConfig
from olmo_core.nn.attention.flash_linear_attn_api import (
dispatch_chunk_gated_delta_rule,
has_fla,
)
from olmo_core.nn.attention.ring import (
RingContextParallelStyle,
UlyssesContextParallelStyle,
)
from olmo_core.nn.buffer_cache import BufferCache
from olmo_core.nn.convolution import CausalConv1d
from olmo_core.nn.feed_forward import ActivationFunction
if TYPE_CHECKING:
from olmo_core.nn.transformer.init import InitMethod
[docs]
class GatedDeltaNet(SequenceMixer):
"""
The layer implementation for `Gated Delta Networks <https://arxiv.org/abs/2412.06464>`_.
Modified from: https://github.com/fla-org/flash-linear-attention/blob/3cf180339b8a1cbad823f553541cd531d18670ea/fla/layers/gated_deltanet.py#L34
This is a linear attention variant that uses a gated delta rule for recurrent
state updates, providing efficient O(n) sequence modeling.
:param d_model: The model hidden size.
:param n_heads: The number of attention heads.
:param n_v_heads: The number of value heads. If ``None``, defaults to ``n_heads``.
GVA is applied if ``n_v_heads`` > ``n_heads``.
:param head_dim: The dimension of each head. If ``None``, defaults to ``d_model // n_heads``.
:param expand_v: The expansion ratio for the value dim. Default: 2.0.
:param allow_neg_eigval: Allow negative eigenvalues. Default: ``True``. If set to ``True``, the beta
will be multiplied by 2. See reference: `Unlocking State-Tracking in Linear RNNs Through Negative
Eigenvalues <https://arxiv.org/abs/2411.12537>`_.
:param conv_size: The kernel size of the short convolution. Default: 4.
:param conv_bias: Whether to use bias in the short convolution. Default: ``False``.
:param norm_eps: The epsilon value for the normalization layer. Default: 1e-5.
:param dtype: The default data type to use for parameters.
:param init_device: The device to initialize weights on.
"""
def __init__(
self,
*,
d_model: int,
n_heads: int,
n_v_heads: int | None = None,
head_dim: int | None = None,
expand_v: float = 2.0,
allow_neg_eigval: bool = True,
conv_size: int = 4,
conv_bias: bool = False,
norm_eps: float = 1e-5,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
):
super().__init__()
assert has_fla()
from fla.modules import FusedRMSNormGated
self.d_model = d_model
self.n_heads = n_heads
self.n_v_heads = n_v_heads if n_v_heads is not None else n_heads
self.head_dim = head_dim if head_dim is not None else d_model // n_heads
self.expand_v = expand_v
self.allow_neg_eigval = allow_neg_eigval
self.conv_size = conv_size
self.head_k_dim = self.head_dim
self.head_v_dim = int(self.head_dim * self.expand_v)
self.key_dim = int(self.n_heads * self.head_k_dim)
self.value_dim = int(self.n_v_heads * self.head_v_dim)
# Consistency checks: ensure expand_v produces integer dimensions
assert math.isclose(self.n_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5)
assert math.isclose(self.head_dim * expand_v, self.head_v_dim, rel_tol=1e-5)
assert self.n_v_heads >= self.n_heads and self.n_v_heads % self.n_heads == 0
self.w_q = nn.Linear(d_model, self.key_dim, bias=False, dtype=dtype, device=init_device)
self.w_k = nn.Linear(d_model, self.key_dim, bias=False, dtype=dtype, device=init_device)
self.w_v = nn.Linear(d_model, self.value_dim, bias=False, dtype=dtype, device=init_device)
self.w_a = nn.Linear(d_model, self.n_v_heads, bias=False, dtype=dtype, device=init_device)
self.w_b = nn.Linear(d_model, self.n_v_heads, bias=False, dtype=dtype, device=init_device)
self.A_log = nn.Parameter(torch.empty(self.n_v_heads, dtype=dtype, device=init_device))
self.dt_bias = nn.Parameter(torch.empty(self.n_v_heads, dtype=dtype, device=init_device))
self.q_conv1d = CausalConv1d(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation=ActivationFunction.silu.value,
dtype=dtype,
init_device=init_device,
)
self.k_conv1d = CausalConv1d(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation=ActivationFunction.silu.value,
dtype=dtype,
init_device=init_device,
)
self.v_conv1d = CausalConv1d(
hidden_size=self.value_dim,
kernel_size=conv_size,
bias=conv_bias,
activation=ActivationFunction.silu.value,
dtype=dtype,
init_device=init_device,
)
self.w_g = nn.Linear(d_model, self.value_dim, bias=False, dtype=dtype, device=init_device)
self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps, device=init_device) # type: ignore
self.w_out = nn.Linear(self.value_dim, d_model, bias=False, dtype=dtype, device=init_device)
self.cp_enabled = False
[docs]
def forward(
self,
x: torch.Tensor,
cu_doc_lens: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Apply gated delta network sequence mixing to the input.
:param x: The input of shape ``(batch_size, seq_len, d_model)``.
:param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D
:class:`torch.int32` tensor that should always have one more element than there
are documents (the first element in the tensor should always be ``0``).
:returns: The output with shape ``(batch_size, seq_len, d_model)``.
"""
del kwargs # Ignore any extra kwargs passed from attention interface
B, T_og, _ = x.shape
# shape: (batch_size, seq_len, n_heads * head_k_dim),
# (batch_size, seq_len, n_heads * head_k_dim),
# (batch_size, seq_len, n_v_heads * head_v_dim)
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
beta = self.w_b(x).sigmoid()
if self.allow_neg_eigval:
beta = beta * 2.0
g = -self.A_log.float().exp() * F.softplus(self.w_a(x).float() + self.dt_bias)
if self.cp_enabled and self.uly is not None:
assert self._cp_group is not None
# [B, T_local, C] -> [B, T_total, C/CP]
q, k = all_to_all_cp2hp([q, k], self._cp_group)
v = all_to_all_single_cp2hp(v, self._cp_group)
g, beta = all_to_all_cp2hp([g, beta], self._cp_group)
q = self.q_conv1d(x=q, cu_seqlens=cu_doc_lens)
k = self.k_conv1d(x=k, cu_seqlens=cu_doc_lens)
v = self.v_conv1d(x=v, cu_seqlens=cu_doc_lens)
T = q.size(1)
q = q.view(B, T, -1, self.head_k_dim)
k = k.view(B, T, -1, self.head_k_dim)
v = v.view(B, T, -1, self.head_v_dim)
if self.n_v_heads > self.n_heads:
repeat_factor = self.n_v_heads // self.n_heads
q = q.repeat_interleave(repeat_factor, dim=-2)
k = k.repeat_interleave(repeat_factor, dim=-2)
o, _ = dispatch_chunk_gated_delta_rule(
q=q, k=k, v=v, g=g, beta=beta, cu_seqlens=cu_doc_lens, use_qk_l2norm_in_kernel=True
)
if self.cp_enabled and self.uly is not None:
assert self._cp_group is not None
# [B, T, H/CP, D] -> [B, T/CP, H, D]
o = all_to_all_single_hp2cp(o, self._cp_group)
g = self.w_g(x).view(B, T, -1, self.head_v_dim)
# shape: (batch_size, seq_len, d_model)
return self.w_out(self.o_norm(o, g).view(B, T_og, -1))
def apply_tp(
self,
tp_mesh: DeviceMesh,
input_layout: Optional[Placement] = None,
output_layout: Optional[Placement] = None,
use_local_output: bool = True,
float8_enabled: bool = False,
):
del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled
raise NotImplementedError("Tensor parallelism is not yet implemented for GatedDeltaNet")
def apply_cp(
self,
cp_mesh: DeviceMesh,
ring: Optional[RingContextParallelStyle] = None,
uly: Optional[UlyssesContextParallelStyle] = None,
):
if ring is not None:
raise NotImplementedError("Ring context parallelism is not supported for GatedDeltaNet")
assert uly is not None
cp_world_size = cp_mesh.size()
if cp_world_size == 1:
return
# Ulysses CP requires divisibility by CP world size for:
# 1. n_v_heads - for head partitioning in the recurrent kernel
# 2. key_dim and value_dim - for channel partitioning in the conv layers
assert self.n_v_heads % cp_world_size == 0
assert self.key_dim % cp_world_size == 0
assert self.value_dim % cp_world_size == 0
self.uly = uly
self._cp_mesh = cp_mesh
self._cp_group = cp_mesh.get_group()
self.cp_enabled = True
self.q_conv1d.apply_cp(cp_mesh)
self.k_conv1d.apply_cp(cp_mesh)
self.v_conv1d.apply_cp(cp_mesh)
@torch.no_grad()
def init_weights(
self,
*,
init_method: "InitMethod",
d_model: int,
block_idx: int,
num_blocks: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
) -> None:
from olmo_core.nn.transformer.init import InitMethod, init_linear
if init_method == InitMethod.fan_in:
raise NotImplementedError(
f"init method '{init_method}' is not supported for GatedDeltaNet"
)
if init_method == InitMethod.normalized:
std = d_model**-0.5
for w in (self.w_q, self.w_k, self.w_v, self.w_a, self.w_b, self.w_g):
init_linear(w, std=std, generator=generator)
for w in (self.q_conv1d, self.k_conv1d, self.v_conv1d):
init_linear(w, std=std, generator=generator)
self.A_log.copy_(nn.init.uniform_(self.A_log, a=0, b=16, generator=generator).log())
dt_min, dt_max, dt_init_floor = 0.001, 0.1, 1e-4
dt = torch.exp(
nn.init.uniform_(self.dt_bias, generator=generator)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min),
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias.copy_(inv_dt)
if init_method == InitMethod.llama:
std = std / (2 * num_blocks) ** 0.5
elif init_method == InitMethod.llama_depth:
std = std / (2 * (block_idx + 1)) ** 0.5
elif init_method == InitMethod.normalized:
std = std / (2 * num_blocks) ** 0.5
init_linear(self.w_out, std=std, generator=generator)
[docs]
def num_flops_per_token(self, seq_len: int) -> int:
"""
Compute FLOPs per token for Gated Delta Net.
This accounts for:
- Linear projections (w_q, w_k, w_v, w_a, w_b, w_g, w_out)
- Short convolutions (q, k, v)
- Gated delta rule recurrent computation
- Gated RMS normalization
"""
del seq_len
# Linear projection FLOPs (2 ops per multiply-add)
linear_flops = 2 * sum(
m.weight.numel()
for m in (self.w_q, self.w_k, self.w_v, self.w_a, self.w_b, self.w_g, self.w_out)
)
# Short convolution FLOPs (2 ops per multiply-add, kernel_size taps per output)
conv_flops = (
2
* self.conv_size
* (self.key_dim + self.key_dim + self.value_dim) # q_conv1d # k_conv1d # v_conv1d
)
# Gated delta rule recurrent computation per token:
# - Outer product k ⊗ v: n_v_heads * head_k_dim * head_v_dim
# - State decay: n_v_heads * head_k_dim * head_v_dim
# - Beta scaling: n_v_heads * head_k_dim * head_v_dim
# - Query-state matmul: n_v_heads * head_k_dim * head_v_dim
# Each is 2 FLOPs per element (multiply-add or similar)
state_size = self.n_v_heads * self.head_k_dim * self.head_v_dim
recurrent_flops = 2 * 4 * state_size
return int(linear_flops + conv_flops + recurrent_flops)
[docs]
@SequenceMixerConfig.register("gated_delta_net")
@dataclass
class GatedDeltaNetConfig(SequenceMixerConfig[GatedDeltaNet]):
"""
Configuration for :class:`GatedDeltaNet`.
See :class:`GatedDeltaNet` for a description of the configuration options.
"""
n_heads: int = 16
"""
The number of attention heads.
"""
n_v_heads: Optional[int] = None
"""
The number of value heads. If ``None``, defaults to ``n_heads``.
If ``n_v_heads`` > ``n_heads``, GVA (Grouped Value Attention) is applied.
GVA is preferred over GQA for linear RNNs like GDN because the recurrent state
has shape ``(n_v_heads, head_k_dim, head_v_dim)``. Unlike softmax attention where
the KV cache grows with sequence length (motivating GQA to reduce it), the linear
RNN state is constant size regardless of sequence length. Since there's no memory
scaling issue to solve, we instead can opt to increase the state size to improve the model's
capacity to compress long-range context. Increasing ``n_v_heads`` directly
increases this fixed state size.
"""
head_dim: Optional[int] = None
"""
The dimension of each head. If ``None``, defaults to ``d_model // n_heads``.
"""
expand_v: float = 2.0
"""
The expansion ratio for the value dimension (``head_v_dim = head_dim * expand_v``).
Like ``n_v_heads``, this increases the constant-size recurrent state, improving
capacity without memory scaling concerns.
"""
allow_neg_eigval: bool = True
"""
Allow negative eigenvalues in the recurrent dynamics.
"""
conv_size: int = 4
"""
The kernel size of the short convolution.
"""
conv_bias: bool = False
"""
Whether to use bias in the short convolution.
"""
norm_eps: float = 1e-5
"""
The epsilon value for the normalization layer.
"""
dtype: DType = DType.float32
"""
The default data type to use for parameters.
"""
[docs]
def num_params(self, d_model: int) -> int:
"""
The number of params that the GatedDeltaNet will have once built.
:param d_model: The model dimensionality.
"""
n_heads = self.n_heads
n_v_heads = self.n_v_heads or n_heads
head_dim = self.head_dim or d_model // n_heads
head_v_dim = int(head_dim * self.expand_v)
key_dim = n_heads * head_dim
value_dim = n_v_heads * head_v_dim
params = 0
# Linear projections: w_q, w_k, w_v, w_a, w_b, w_g, w_out
params += d_model * key_dim # w_q
params += d_model * key_dim # w_k
params += d_model * value_dim # w_v
params += d_model * n_v_heads # w_a
params += d_model * n_v_heads # w_b
params += d_model * value_dim # w_g
params += value_dim * d_model # w_out
# A_log and dt_bias parameters
params += n_v_heads # A_log
params += n_v_heads # dt_bias
# Short convolutions (kernel_size * hidden_size for each)
params += self.conv_size * key_dim # q_conv1d
params += self.conv_size * key_dim # k_conv1d
params += self.conv_size * value_dim # v_conv1d
if self.conv_bias:
params += key_dim # q_conv1d bias
params += key_dim # k_conv1d bias
params += value_dim # v_conv1d bias
# FusedRMSNormGated (weight only, no bias)
params += head_v_dim # o_norm
return params
[docs]
def build(
self,
d_model: int,
*,
layer_idx: int,
n_layers: int,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
) -> GatedDeltaNet:
"""
Build the GatedDeltaNet module.
:param d_model: The model dimensionality.
:param layer_idx: The layer index (unused).
:param n_layers: The total number of layers (unused).
:param init_device: The device to initialize the parameters on, e.g. "cpu", "meta".
:param cache: Optional buffer cache (unused).
"""
del layer_idx, n_layers, cache # Unused
return GatedDeltaNet(
d_model=d_model,
n_heads=self.n_heads,
n_v_heads=self.n_v_heads,
head_dim=self.head_dim,
expand_v=self.expand_v,
allow_neg_eigval=self.allow_neg_eigval,
conv_size=self.conv_size,
conv_bias=self.conv_bias,
norm_eps=self.norm_eps,
dtype=self.dtype.as_pt(),
init_device=init_device,
)