Source code for olmo_core.nn.attention.backend

from abc import abstractmethod
from typing import Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import DeviceMesh

from olmo_core.config import StrEnum
from olmo_core.distributed.parallel.context_parallel import (
    all_to_all_cp2hp,
    all_to_all_single_cp2hp,
    all_to_all_single_cp2hp_qkvpacked,
    all_to_all_single_hp2cp,
)
from olmo_core.nn.attention.kv_cache import KVCacheManager
from olmo_core.nn.buffer_cache import BufferCache

from .flash_attn_api import (
    dispatch_flash_attn,
    dispatch_flash_attn_3,
    dispatch_flash_attn_3_qkvpacked,
    dispatch_flash_attn_3_with_kvcache,
    dispatch_flash_attn_4,
    dispatch_flash_attn_qkvpacked,
    dispatch_flash_attn_with_kvcache,
    dispatch_ring_flash_attn,
    dispatch_ring_flash_attn_qkvpacked,
    has_flash_attn_2,
    has_flash_attn_3,
    has_flash_attn_4,
    has_ring_flash_attn,
)
from .ring import (
    RingAttentionLoadBalancerType,
    RingContextParallelStyle,
    UlyssesContextParallelStyle,
)
from .te_attn_api import TEDotProductAttention, has_te_attn


[docs] class AttentionBackendName(StrEnum): """ An enumeration of the different attention backends. """ torch = "torch" """ PyTorch's built-in SDPA. Works on all devices. ➡️ :class:`TorchAttentionBackend` """ flash_2 = "flash_2" """ Flash attention 2 from the `flash-attn <https://github.com/Dao-AILab/flash-attention>`_ library. Supports Ampere (SM 8.0+) and newer NVIDIA GPUs. To use this with context-parallelism, `ring-flash-attn <https://github.com/zhuzilin/ring-flash-attention>`_ is also required. ➡️ :class:`FlashAttention2Backend` """ flash_3 = "flash_3" """ Flash attention 3 (beta) from the `flash-attn <https://github.com/Dao-AILab/flash-attention>`_ library ``hopper/`` subdirectory. Supports Hopper (SM 9.0) GPUs only (H100/H800). ➡️ :class:`FlashAttention3Backend` """ flash_4 = "flash_4" """ Flash attention 4, the CUTE implementation from `flash-attn <https://github.com/Dao-AILab/flash-attention>`_ in the ``flash_attn/cute`` subdirectory. Supports Blackwell (SM 10.0, e.g. B200) GPUs only. ➡️ :class:`FlashAttention4Backend` """ te = "te" """ Transformer Engine attention. Supports Hopper (SM 9.0+) and newer NVIDIA GPUs. ➡️ :class:`TEAttentionBackend`. """ def get_class(self) -> Type["AttentionBackend"]: if self == self.torch: return TorchAttentionBackend elif self in self.flash_2: return FlashAttention2Backend elif self == self.flash_3: return FlashAttention3Backend elif self == self.flash_4: return FlashAttention4Backend elif self == self.te: return TEAttentionBackend else: raise NotImplementedError(self) def build( self, *, head_dim: int, n_heads: int, n_kv_heads: Optional[int] = None, scale: Optional[float] = None, dropout_p: float = 0.0, window_size: Tuple[int, int] = (-1, -1), cache: Optional[BufferCache] = None, ) -> "AttentionBackend": return self.get_class()( head_dim=head_dim, n_heads=n_heads, n_kv_heads=n_kv_heads, scale=scale, dropout_p=dropout_p, window_size=window_size, cache=(cache.with_namespace(f"attn_backend={self}") if cache else None), ) def assert_supported(self): self.get_class().assert_supported() def assert_supports_swa(self): self.get_class().assert_supports_swa() def assert_supports_ring_cp(self): self.get_class().assert_supports_ring_cp() def assert_supports_ulysses_cp(self): self.get_class().assert_supports_ulysses_cp() def assert_supports_packed_qkv(self): self.get_class().assert_supports_packed_qkv() def assert_supports_kv_cache(self): self.get_class().assert_supports_kv_cache()
[docs] class AttentionBackend(nn.Module): """ Encapsulates a backend for the scaled dot-product attention (SDPA) operation. """ def __init__( self, *, head_dim: int, n_heads: int, n_kv_heads: Optional[int] = None, scale: Optional[float] = None, dropout_p: float = 0.0, window_size: Tuple[int, int] = (-1, -1), cache: Optional[BufferCache] = None, ): self.assert_supported() if window_size != (-1, -1): self.assert_supports_swa() super().__init__() self.head_dim = head_dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads or n_heads self.scale = scale self.dropout_p = dropout_p self.window_size = window_size self.cache = cache self.cp_pg: Optional[dist.ProcessGroup] = None self.cp_enabled = False self.head_stride: int = 1
[docs] @classmethod @abstractmethod def assert_supported(cls): """ Validates that this backend is supported on the current system. Raises an error if not supported. """ pass
[docs] @classmethod @abstractmethod def assert_supports_swa(cls): """ Validates that this backend supports sliding window attention (SWA). Raises an error if not supported. """ pass
[docs] @classmethod @abstractmethod def assert_supports_ring_cp(cls): """ Validates that this backend supports ring context parallelism. Raises an error if not supported. """ pass
[docs] @classmethod @abstractmethod def assert_supports_ulysses_cp(cls): """ Validates that this backend supports ulysses context parallelism. Raises an error if not supported. """ pass
[docs] @classmethod @abstractmethod def assert_supports_packed_qkv(cls): """ Validates that this backend supports taking QKV as a single packed tensor. Raises an error if not supported. """ pass
[docs] @classmethod @abstractmethod def assert_supports_kv_cache(cls): """ Validates that this backend supports KV caching. Raises an error if not supported. """ pass
[docs] def warmup_cache(self, max_seq_len: int, device: torch.device): """ Warmup the buffer cache. """ del max_seq_len, device
[docs] @abstractmethod def forward( self, qkv: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, kv_cache_manager: Optional[KVCacheManager] = None, ) -> torch.Tensor: """ Run the attention operation. """ raise NotImplementedError
[docs] def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): """ Apply context parallelism if supported by the backend. """ if ring is not None: self.assert_supports_ring_cp() elif uly is not None: self.assert_supports_ulysses_cp() else: raise ValueError("One of ring or uly must be specified") self.cp_pg = cp_mesh.get_group() self.ring = ring self.uly = uly self.cp_enabled = True
[docs] class TorchAttentionBackend(AttentionBackend): """ PyTorch's built-in scaled dot-product attention (SDPA) backend. """
[docs] @classmethod def assert_supported(cls): pass
[docs] @classmethod def assert_supports_swa(cls): pass
[docs] @classmethod def assert_supports_ring_cp(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support ring context parallelism")
[docs] @classmethod def assert_supports_ulysses_cp(cls): pass
[docs] @classmethod def assert_supports_packed_qkv(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support packed QKV")
[docs] @classmethod def assert_supports_kv_cache(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support KV caching")
[docs] def warmup_cache(self, max_seq_len: int, device: torch.device): self._get_sliding_window_mask( seq_len_q=max_seq_len, seq_len_kv=max_seq_len, device=device, window_size=self.window_size, )
[docs] def forward( self, qkv: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, kv_cache_manager: Optional[KVCacheManager] = None, ) -> torch.Tensor: del local_k_slice if isinstance(qkv, torch.Tensor): raise RuntimeError(f"'{self.__class__.__name__}' doesn't support packed QKV") q, k, v = qkv if kv_cache_manager is not None: raise RuntimeError(f"'{self.__class__.__name__}' doesn't support KV caching") attn_mask: Optional[torch.Tensor] = None if self.window_size != (-1, -1): attn_mask = self._get_sliding_window_mask( seq_len_q=q.shape[1], seq_len_kv=k.shape[1], device=q.device, window_size=self.window_size, ) if any( opt is not None for opt in ( cu_doc_lens, cu_doc_lens_q, cu_doc_lens_k, max_doc_len, max_doc_len_q, max_doc_len_k, ) ): raise RuntimeError( f"'{self.__class__.__name__}' doesn't support intra-document masking" ) if self.cp_enabled and self.uly is not None: assert self.cp_pg is not None # Transform from context-parallel to head-parallel partitioning # [B, T/CP, H, D] -> [B, T, H/CP, D] q = all_to_all_single_cp2hp(q, self.cp_pg) k, v = all_to_all_cp2hp([k, v], self.cp_pg) # NOTE: PyTorch's SDPA doesn't support GQA, so we have to do this. n_rep = self.n_heads // self.n_kv_heads # shape: (batch_size, seq_len, n_heads, head_dim) k = _repeat_kv(k, n_rep) v = _repeat_kv(v, n_rep) # PyTorch's SDPA expects the head dimension to come before the sequence dimension. # shape: (batch_size, n_heads, seq_len, head_dim), # (batch_size, n_kv_heads, seq_len, head_dim), # (batch_size, n_kv_heads, seq_len, head_dim) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # shape: (batch_size, n_heads, seq_len, head_dim) att = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout_p, is_causal=attn_mask is None, scale=self.scale, ) # shape: (batch_size, seq_len, n_heads, head_dim) att = att.transpose(1, 2) if self.cp_enabled and self.uly is not None: assert self.cp_pg is not None # Transform back from head-parallel to context-parallel partitioning # [B, T, H/CP, D] -> [B, T/CP, H, D] att = all_to_all_single_hp2cp(att, self.cp_pg) return att.contiguous()
def _get_sliding_window_mask( self, seq_len_q: int, seq_len_kv: int, device: torch.device, window_size: Tuple[int, int], ) -> torch.Tensor: key = f"seq_len_q={seq_len_q},seq_len_kv={seq_len_kv},window_size={window_size}" if self.cache is not None: if (mask := self.cache.get_for_device(key, device)) is not None: return mask attn_mask = self._build_sliding_window_mask( seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, device=device, window_size=window_size, ) self.cache[key] = attn_mask return attn_mask return self._build_sliding_window_mask( seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, device=device, window_size=window_size, ) @classmethod def _build_sliding_window_mask( cls, seq_len_q: int, seq_len_kv: int, device: torch.device, window_size: Tuple[int, int], ) -> torch.Tensor: causal_mask = torch.tril(torch.ones(seq_len_q, seq_len_kv, device=device, dtype=torch.bool)) if window_size != (-1, -1): sliding_window_left_mask = torch.ones_like( causal_mask, dtype=torch.bool, device=device ).triu(diagonal=-window_size[0]) sliding_window_right_mask = torch.ones_like( causal_mask, dtype=torch.bool, device=device ).tril(diagonal=window_size[1]) sliding_window_mask = torch.logical_and( sliding_window_left_mask, sliding_window_right_mask, ) attn_mask = torch.logical_and( causal_mask, sliding_window_mask, ) else: attn_mask = causal_mask return attn_mask
[docs] class FlashAttention2Backend(AttentionBackend): """ SDPA from the flash-attn package. Additionally, ring-flash-attn is required for context parallelism. """
[docs] @classmethod def assert_supported(cls): if not has_flash_attn_2(): raise RuntimeError( f"'{cls.__name__}' is missing the flash-attn package or is not supported on this platform." )
[docs] @classmethod def assert_supports_swa(cls): pass
[docs] @classmethod def assert_supports_ring_cp(cls): if not has_ring_flash_attn(): raise RuntimeError( f"'{cls.__name__}' requires the ring-flash-attn package for context parallelism." )
[docs] @classmethod def assert_supports_ulysses_cp(cls): pass
[docs] @classmethod def assert_supports_packed_qkv(cls): pass
[docs] @classmethod def assert_supports_kv_cache(cls): pass
[docs] def forward( self, qkv: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, kv_cache_manager: Optional[KVCacheManager] = None, ) -> torch.Tensor: if isinstance(qkv, torch.Tensor): if kv_cache_manager is not None: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support packed QKV with KV caching" ) if self.window_size != (-1, -1): raise RuntimeError( f"'{self.__class__.__name__}' doesn't support packed QKV with sliding window attention" ) if self.cp_enabled: assert self.cp_pg is not None if self.ring is not None: return dispatch_ring_flash_attn_qkvpacked( qkv, group=self.cp_pg, strategy=self.ring.load_balancer, cu_seqlens=cu_doc_lens, max_seqlen=max_doc_len, dropout_p=self.dropout_p, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) elif self.uly is not None: # Transform packed qkv from context-parallel to head-parallel partitioning # [B, T/CP, 3, H, D] -> [B, T, 3, H/CP, D] qkv = all_to_all_single_cp2hp_qkvpacked(qkv, self.cp_pg) B, T, _, H_local, D = qkv.shape # NOTE: cu_doc_lens and max_doc_len are assumed to describe the FULL sequence # (same on all CP ranks), so we use them directly after gathering the full sequence. # Run attention with full sequence, partitioned heads out = dispatch_flash_attn_qkvpacked( qkv, cu_seqlens=cu_doc_lens, max_seqlen=max_doc_len, dropout_p=self.dropout_p, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) # Transform back from head-parallel to context-parallel partitioning # [B, T, H/CP, D] -> [B, T/CP, H, D] return all_to_all_single_hp2cp(out.view(B, T, H_local, D), self.cp_pg) else: raise RuntimeError("One of ring or uly must be specified") else: return dispatch_flash_attn_qkvpacked( qkv, cu_seqlens=cu_doc_lens, max_seqlen=max_doc_len, dropout_p=self.dropout_p, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) q, k, v = qkv if kv_cache_manager: if self.cp_enabled: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support KV caching with context parallelism" ) return dispatch_flash_attn_with_kvcache( q, k=k, v=v, softmax_scale=self.scale, causal=True, window_size=self.window_size, k_cache=kv_cache_manager.k_cache, # updated in-place v_cache=kv_cache_manager.v_cache, # updated in-place cache_leftpad=kv_cache_manager.cache_leftpad, cache_seqlens=kv_cache_manager.cache_seqlens.expand( kv_cache_manager.cache_leftpad.shape[0] ).contiguous(), ) if self.cp_enabled: assert self.cp_pg is not None if self.ring is not None: return dispatch_ring_flash_attn( q, k, v, group=self.cp_pg, strategy=self.ring.load_balancer, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, heads_k_stride=self.ring.head_stride, local_k_slice=local_k_slice, dropout_p=self.dropout_p, causal=True, softmax_scale=self.scale, window_size=self.window_size, ) elif self.uly is not None: # Transform from context-parallel to head-parallel partitioning # [B, T/CP, H, D] -> [B, T, H/CP, D] q = all_to_all_single_cp2hp(q, self.cp_pg) k, v = all_to_all_cp2hp([k, v], self.cp_pg) B, T, H_local, D = q.shape # NOTE: cu_doc_lens and max_doc_len are assumed to describe the FULL sequence # (same on all CP ranks), so we use them directly after gathering the full sequence. # This is the default state of cu_doc_lens and max_doc_len before a load balancer is applied. # Run attention with full sequence, partitioned heads out = dispatch_flash_attn( q, k, v, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, dropout_p=self.dropout_p, causal=True, softmax_scale=self.scale, window_size=self.window_size, ) # Transform back from head-parallel to context-parallel partitioning # [B, T, H/CP, D] -> [B, T/CP, H, D] return all_to_all_single_hp2cp(out.view(B, T, H_local, D), self.cp_pg) else: raise RuntimeError("One of ring or uly must be specified") return dispatch_flash_attn( q, k, v, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, dropout_p=self.dropout_p, softmax_scale=self.scale, causal=True, window_size=self.window_size, )
[docs] class FlashAttention3Backend(AttentionBackend): """ SDPA from the flash-attn 3 package. Does not currently support context parallelism. """ def __init__( self, *, head_dim: int, n_heads: int, n_kv_heads: Optional[int] = None, scale: Optional[float] = None, dropout_p: float = 0.0, window_size: Tuple[int, int] = (-1, -1), cache: Optional[BufferCache] = None, ): if dropout_p > 0.0: raise RuntimeError("dropout_p > 0.0 is not supported for flash-attn 3") super().__init__( head_dim=head_dim, n_heads=n_heads, n_kv_heads=n_kv_heads, scale=scale, dropout_p=dropout_p, window_size=window_size, cache=cache, )
[docs] @classmethod def assert_supported(cls): if not has_flash_attn_3(): raise RuntimeError( f"'{cls.__name__}' is missing the flash-attn 3 package or is not supported on this platform." )
[docs] @classmethod def assert_supports_swa(cls): pass
[docs] @classmethod def assert_supports_ring_cp(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support ring context parallelism")
[docs] @classmethod def assert_supports_ulysses_cp(cls): pass
[docs] @classmethod def assert_supports_packed_qkv(cls): pass
[docs] @classmethod def assert_supports_kv_cache(cls): pass
[docs] def forward( self, qkv: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, kv_cache_manager: Optional[KVCacheManager] = None, ) -> torch.Tensor: if isinstance(qkv, torch.Tensor): if kv_cache_manager is not None: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support packed QKV with KV caching" ) if self.window_size != (-1, -1): raise RuntimeError( f"'{self.__class__.__name__}' doesn't support packed QKV with sliding window attention" ) if self.cp_enabled: assert self.cp_pg is not None if self.ring is not None: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support ring context parallelism" ) elif self.uly is not None: # Transform packed qkv from context-parallel to head-parallel partitioning # [B, T/CP, 3, H, D] -> [B, T, 3, H/CP, D] qkv = all_to_all_single_cp2hp_qkvpacked(qkv, self.cp_pg) B, T, _, H_local, D = qkv.shape # NOTE: cu_doc_lens and max_doc_len are assumed to describe the FULL sequence # (same on all CP ranks), so we use them directly after gathering the full sequence. # Run attention with full sequence, partitioned heads out = dispatch_flash_attn_3_qkvpacked( qkv, cu_seqlens=cu_doc_lens, max_seqlen=max_doc_len, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) # Transform back from head-parallel to context-parallel partitioning # [B, T, H/CP, D] -> [B, T/CP, H, D] return all_to_all_single_hp2cp(out.view(B, T, H_local, D), self.cp_pg) else: raise RuntimeError("One of ring or uly must be specified") return dispatch_flash_attn_3_qkvpacked( qkv, cu_seqlens=cu_doc_lens, max_seqlen=max_doc_len, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) q, k, v = qkv if kv_cache_manager: if self.cp_enabled: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support KV caching with context parallelism" ) return dispatch_flash_attn_3_with_kvcache( q, k=k, v=v, softmax_scale=self.scale, causal=True, window_size=self.window_size, k_cache=kv_cache_manager.k_cache, # updated in-place v_cache=kv_cache_manager.v_cache, # updated in-place cache_leftpad=kv_cache_manager.cache_leftpad, cache_seqlens=kv_cache_manager.cache_seqlens.expand( kv_cache_manager.cache_leftpad.shape[0] ).contiguous(), ) if self.cp_enabled: if self.ring is not None: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support ring context parallelism" ) elif self.uly is not None: assert self.cp_pg is not None # Transform from context-parallel to head-parallel partitioning # [B, T/CP, H, D] -> [B, T, H/CP, D] q = all_to_all_single_cp2hp(q, self.cp_pg) k, v = all_to_all_cp2hp([k, v], self.cp_pg) B, T, H_local, D = q.shape # NOTE: cu_doc_lens and max_doc_len are assumed to describe the FULL sequence # (same on all CP ranks), so we use them directly after gathering the full sequence. # This is the default state of cu_doc_lens and max_doc_len before a load balancer is applied. # Run attention with full sequence, partitioned heads out = dispatch_flash_attn_3( q, k, v, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) # Transform back from head-parallel to context-parallel partitioning # [B, T, H/CP, D] -> [B, T/CP, H, D] return all_to_all_single_hp2cp(out.view(B, T, H_local, D), self.cp_pg) else: raise RuntimeError("One of ring or uly must be specified") return dispatch_flash_attn_3( q, k, v, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, softmax_scale=self.scale, causal=True, window_size=self.window_size, )
[docs] class FlashAttention4Backend(AttentionBackend): """ SDPA from flash-attn 4 (CUTE implementation). """ def __init__( self, *, head_dim: int, n_heads: int, n_kv_heads: Optional[int] = None, scale: Optional[float] = None, dropout_p: float = 0.0, window_size: Tuple[int, int] = (-1, -1), cache: Optional[BufferCache] = None, ): if dropout_p > 0.0: raise RuntimeError("dropout_p > 0.0 is not supported for flash-attn 4") super().__init__( head_dim=head_dim, n_heads=n_heads, n_kv_heads=n_kv_heads, scale=scale, dropout_p=dropout_p, window_size=window_size, cache=cache, )
[docs] @classmethod def assert_supported(cls): if not has_flash_attn_4(): raise RuntimeError( f"'{cls.__name__}' is missing the flash-attn CUTE implementation or is not supported on this platform." )
[docs] @classmethod def assert_supports_swa(cls): pass
[docs] @classmethod def assert_supports_ring_cp(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support ring context parallelism")
[docs] @classmethod def assert_supports_ulysses_cp(cls): pass
[docs] @classmethod def assert_supports_packed_qkv(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support packed QKV")
[docs] @classmethod def assert_supports_kv_cache(cls): pass
[docs] def forward( self, qkv: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, kv_cache_manager: Optional[KVCacheManager] = None, ) -> torch.Tensor: assert isinstance(qkv, tuple), f"'{self.__class__.__name__}' requires unpacked QKV" assert local_k_slice is None, f"'{self.__class__.__name__}' doesn't support local_k_slice" q, k, v = qkv if kv_cache_manager is not None: if self.cp_enabled: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support KV caching with context parallelism" ) pos = int(kv_cache_manager.cache_seqlens.item()) T_new = k.shape[1] kv_cache_manager.k_cache[:, pos : pos + T_new] = k kv_cache_manager.v_cache[:, pos : pos + T_new] = v seqused_k = torch.full((q.shape[0],), pos + T_new, dtype=torch.int32, device=q.device) return dispatch_flash_attn_4( q, kv_cache_manager.k_cache, kv_cache_manager.v_cache, seqused_k=seqused_k, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) if self.cp_enabled: if self.ring is not None: raise RuntimeError( f"'{self.__class__.__name__}' doesn't support ring context parallelism" ) elif self.uly is not None: assert self.cp_pg is not None # Transform from context-parallel to head-parallel partitioning # [B, T/CP, H, D] -> [B, T, H/CP, D] q = all_to_all_single_cp2hp(q, self.cp_pg) k, v = all_to_all_cp2hp([k, v], self.cp_pg) B, T, H_local, D = q.shape # NOTE: cu_doc_lens and max_doc_len are assumed to describe the FULL sequence # (same on all CP ranks), so we use them directly after gathering the full sequence. # This is the default state of cu_doc_lens and max_doc_len before a load balancer is applied. # Run attention with full sequence, partitioned heads out = dispatch_flash_attn_4( q, k, v, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, softmax_scale=self.scale, causal=True, window_size=self.window_size, ) # Transform back from head-parallel to context-parallel partitioning # [B, T, H/CP, D] -> [B, T/CP, H, D] return all_to_all_single_hp2cp(out.view(B, T, H_local, D), self.cp_pg) else: raise RuntimeError("One of ring or uly must be specified") return dispatch_flash_attn_4( q, k, v, cu_seqlens=cu_doc_lens, cu_seqlens_q=cu_doc_lens_q, cu_seqlens_k=cu_doc_lens_k, max_seqlen=max_doc_len, max_seqlen_q=max_doc_len_q, max_seqlen_k=max_doc_len_k, softmax_scale=self.scale, causal=True, window_size=self.window_size, )
[docs] class TEAttentionBackend(AttentionBackend): def __init__( self, *, head_dim: int, n_heads: int, n_kv_heads: Optional[int] = None, scale: Optional[float] = None, dropout_p: float = 0.0, window_size: Tuple[int, int] = (-1, -1), cache: Optional[BufferCache] = None, ): super().__init__( head_dim=head_dim, n_heads=n_heads, n_kv_heads=n_kv_heads, scale=scale, dropout_p=dropout_p, window_size=window_size, cache=cache, ) if not has_te_attn(): raise RuntimeError("TransformerEngine attention is not available") assert TEDotProductAttention is not None self.te_attn = TEDotProductAttention( self.n_heads, self.head_dim, num_gqa_groups=self.n_kv_heads, attention_dropout=self.dropout_p, attn_mask_type="causal", window_size=(self.window_size[0], 0), # be explicit about causal mask qkv_format="bshd", softmax_scale=self.scale, )
[docs] @classmethod def assert_supported(cls): if not has_te_attn(): raise RuntimeError( f"'{cls.__name__}' is missing the TransformerEngine package or is not supported on this platform." )
[docs] @classmethod def assert_supports_swa(cls): pass
@classmethod def assert_supports_cp(cls): pass
[docs] @classmethod def assert_supports_packed_qkv(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support packed QKV")
[docs] @classmethod def assert_supports_kv_cache(cls): raise RuntimeError(f"'{cls.__name__}' doesn't support KV caching")
[docs] def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): super().apply_cp(cp_mesh, ring=ring, uly=uly) if self.ring is not None: if self.ring.load_balancer == RingAttentionLoadBalancerType.zig_zag: cp_comm_type = "p2p" # Note: zig-zag/p2p is preferred bc it overlaps with the attention computation elif self.ring.load_balancer == RingAttentionLoadBalancerType.llama3: cp_comm_type = "all_gather" else: raise ValueError(self.ring.load_balancer) self.te_attn.set_context_parallel_group( cp_group=cp_mesh.get_group(), cp_global_ranks=dist.get_process_group_ranks(cp_mesh.get_group()), cp_stream=torch.cuda.default_stream(), # cp_stream=get_or_init_stream("cp"), # this doesn't seem to help cp_comm_type=cp_comm_type, ) elif self.uly is not None: self.te_attn.set_context_parallel_group( cp_group=cp_mesh.get_group(), cp_global_ranks=dist.get_process_group_ranks(cp_mesh.get_group()), cp_stream=torch.cuda.default_stream(), # cp_stream=get_or_init_stream("cp"), # this doesn't seem to help cp_comm_type="a2a", ) else: raise ValueError("One of ring or uly must be specified")
[docs] def forward( self, qkv: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cu_doc_lens: Optional[torch.Tensor] = None, cu_doc_lens_q: Optional[torch.Tensor] = None, cu_doc_lens_k: Optional[torch.Tensor] = None, max_doc_len: Optional[int] = None, max_doc_len_q: Optional[int] = None, max_doc_len_k: Optional[int] = None, local_k_slice: Optional[slice] = None, kv_cache_manager: Optional[KVCacheManager] = None, ) -> torch.Tensor: del local_k_slice if kv_cache_manager is not None: raise RuntimeError(f"'{self.__class__.__name__}' doesn't support KV caching") if isinstance(qkv, torch.Tensor): raise RuntimeError(f"'{self.__class__.__name__}' doesn't support packed QKV") if any( opt is not None for opt in ( cu_doc_lens, cu_doc_lens_q, cu_doc_lens_k, max_doc_len, max_doc_len_q, max_doc_len_k, ) ): raise RuntimeError( f"'{self.__class__.__name__}' doesn't currently support intra-document masking" ) q, k, v = qkv return self.te_attn( q, k, v, cu_seqlens_q=cu_doc_lens if cu_doc_lens is not None else cu_doc_lens_q, cu_seqlens_kv=cu_doc_lens if cu_doc_lens is not None else cu_doc_lens_k, max_seqlen_q=max_doc_len if max_doc_len is not None else max_doc_len_q, max_seqlen_kv=max_doc_len if max_doc_len is not None else max_doc_len_k, )
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( torch.unsqueeze(x, dim=3) .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) )