Source code for olmo_core.nn.attention.ring

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch.distributed import DeviceMesh

from olmo_core.config import Config, StrEnum
from olmo_core.distributed.utils import get_rank, get_world_size
from olmo_core.utils import ensure_multiple_of


[docs] class RingAttentionLoadBalancerType(StrEnum): """ An enumeration of the different :class:`RingAttentionLoadBalancer` implementations. """ zig_zag = "zig_zag" """ ➡️ :class:`RingAttentionZigZagLoadBalancer` """ llama3 = "llama3" """ ➡️ :class:`RingAttentionLlama3LoadBalancer` """ ulysses = "ulysses" """ ➡️ :class:`UlyssesLoadBalancer` """
[docs] def build(self, cp_mesh: DeviceMesh) -> "RingAttentionLoadBalancer": """ Build the load balancer. """ pg = cp_mesh.get_group() cp_rank = get_rank(pg) cp_world_size = get_world_size(pg) if self == self.zig_zag: return RingAttentionZigZagLoadBalancer(cp_rank=cp_rank, cp_world_size=cp_world_size) elif self == self.llama3: return RingAttentionLlama3LoadBalancer(cp_rank=cp_rank, cp_world_size=cp_world_size) elif self == self.ulysses: return UlyssesLoadBalancer(cp_rank=cp_rank, cp_world_size=cp_world_size) else: raise NotImplementedError(self)
[docs] class RingAttentionLoadBalancer(metaclass=ABCMeta): """ A class that handles the logic of sharding inputs on the sequence dimension for ring attention (context parallelism). """ def __init__(self, *, cp_rank: int, cp_world_size: int): self.cp_rank = cp_rank self.cp_world_size = cp_world_size
[docs] @abstractmethod def batch_shard( self, *, inputs: List[torch.Tensor], seq_dims: List[int], pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> List[torch.Tensor]: """ Shard inputs on their sequence dimension, optionally adding padding if needed. .. important:: If using intra-document masking, use :meth:`batch_shard_by_document` instead. :returns: The local shards of the inputs. """ raise NotImplementedError
[docs] @abstractmethod def batch_shard_by_document( self, *, inputs: List[torch.Tensor], seq_dims: List[int], cu_doc_lens: torch.Tensor, pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> Tuple[List[torch.Tensor], Dict[str, Any]]: """ Same as :meth:`batch_shard` but for strategies that support intra-document masking. :returns: The local shards of the inputs and any other additional inputs required for the corresponding ring attention implementation. """ raise NotImplementedError
[docs] class RingAttentionZigZagLoadBalancer(RingAttentionLoadBalancer): """ Implements the zig-zag load-balancing strategy. """
[docs] def batch_shard( self, *, inputs: List[torch.Tensor], seq_dims: List[int], pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> List[torch.Tensor]: assert len(inputs) == len(seq_dims) assert len(set(x.shape[seq_dim] for x, seq_dim in zip(inputs, seq_dims))) == 1 if pad_values is not None: assert len(inputs) == len(pad_values) if length_multiple is None: length_multiple = 2 * self.cp_world_size elif length_multiple % (2 * self.cp_world_size) != 0: raise RuntimeError( f"length multiple ({length_multiple}) must be divisible by " f"2 x CP degree ({2 * self.cp_world_size})" ) out = [] for x, seq_dim, pad_value in zip( inputs, seq_dims, pad_values or [None for _ in range(len(inputs))], # type: ignore ): if x.shape[seq_dim] % length_multiple != 0: if pad_value is None: raise RuntimeError( f"sequence dimension size ({x.shape[seq_dim]}) must be divisible by " f"{length_multiple}, otherwise provide a padding value" ) else: x, _ = self.pad(x, seq_dim, pad_value, length_multiple=length_multiple) x_chunks = x.chunk(2 * self.cp_world_size, dim=seq_dim) local_value = torch.cat( [x_chunks[self.cp_rank], x_chunks[2 * self.cp_world_size - self.cp_rank - 1]], dim=seq_dim, ) out.append(local_value.contiguous()) return out
[docs] def batch_shard_by_document( self, *, inputs: List[torch.Tensor], seq_dims: List[int], cu_doc_lens: torch.Tensor, pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> Tuple[List[torch.Tensor], Dict[str, Any]]: assert len(inputs) == len(seq_dims) assert len(set(x.shape[seq_dim] for x, seq_dim in zip(inputs, seq_dims))) == 1 if pad_values is not None: assert len(inputs) == len(pad_values) if cu_doc_lens.device.type != "cpu": raise RuntimeError("expected 'cu_doc_lens' to be on CPU") if cu_doc_lens.ndim != 1: raise RuntimeError("expected 'cu_doc_lens' to be a 1D tensor") if cu_doc_lens[0] != 0: raise RuntimeError("expected 'cu_doc_lens' to start with a 0") out = [] padding_added = [0 for _ in range(len(cu_doc_lens) - 1)] final_padding: Optional[int] = None if length_multiple is None else 0 for x, seq_dim, pad_value in zip( inputs, seq_dims, pad_values or [None for _ in range(len(inputs))], # type: ignore ): local_values = [] for i in range(len(cu_doc_lens) - 1): start, end = cu_doc_lens[i], cu_doc_lens[i + 1] # NOTE: Since 'torch.slice' is not available from the Python API we just call # the JIT op directly. x_doc_slice = torch.ops.aten.slice(x, dim=seq_dim, start=start, end=end) # type: ignore if x_doc_slice.shape[seq_dim] % (2 * self.cp_world_size) != 0: if pad_value is None: raise RuntimeError( f"document length ({x_doc_slice.shape[seq_dim]}) must be divisible by " f"2 x CP degree ({2 * self.cp_world_size}), otherwise provide a padding value" ) else: x_doc_slice, padding = self.pad(x_doc_slice, seq_dim, pad_value) padding_added[i] = padding x_chunks = x_doc_slice.chunk(2 * self.cp_world_size, dim=seq_dim) local_values.extend( [ x_chunks[self.cp_rank], x_chunks[2 * self.cp_world_size - 1 - self.cp_rank], ] ) local_value = torch.cat(local_values, dim=seq_dim).contiguous() if length_multiple is not None and local_value.shape[seq_dim] % length_multiple != 0: if pad_value is None: raise RuntimeError( "You must provide a 'pad_value' when 'length_multiple' is specified!" ) else: local_value, final_padding = self.pad( local_value, seq_dim, pad_value, length_multiple=length_multiple ) out.append(local_value) if pad_values is not None: cumulative_padding = torch.cat( [ torch.tensor([0], dtype=cu_doc_lens.dtype, device=cu_doc_lens.device), torch.tensor(padding_added, device=cu_doc_lens.device).cumsum( 0, dtype=cu_doc_lens.dtype ), ] ) cu_doc_lens = cu_doc_lens + cumulative_padding local_cu_doc_lens = cu_doc_lens // self.cp_world_size if final_padding is not None: local_cu_doc_lens = torch.cat( [local_cu_doc_lens, (local_cu_doc_lens[-1] + final_padding).unsqueeze(0)] ) local_max_doc_len = (local_cu_doc_lens[1:] - local_cu_doc_lens[:-1]).max().item() return out, dict(cu_doc_lens=local_cu_doc_lens, max_doc_len=local_max_doc_len)
def pad( self, x: torch.Tensor, seq_dim: int, value: Union[int, float], length_multiple: Optional[int] = None, ) -> Tuple[torch.Tensor, int]: if length_multiple is None: length_multiple = 2 * self.cp_world_size pad_to = ensure_multiple_of(x.shape[seq_dim], length_multiple) padding_to_add = pad_to - x.shape[seq_dim] padding = (0, 0) * (x.ndim - seq_dim - 1) + (0, padding_to_add) return F.pad(x, padding, value=value), padding_to_add
[docs] class RingAttentionLlama3LoadBalancer(RingAttentionLoadBalancer): """ Implements Llama3's load-balancing strategy for context parallelism. The Llama3 strategy assigns each rank a contiguous slice of the full sequence. Rank ``i`` receives positions ``[i * local_len, (i + 1) * local_len)`` where ``local_len = total_seq_len // cp_world_size``. This strategy is designed specifically for **intra-document masking** with variable-length documents packed into a single sequence. It computes separate cumulative sequence lengths for queries (``cu_doc_lens_q``) and keys (``cu_doc_lens_k``) per rank, enabling proper causal masking across document boundaries within the ring attention loop. Padding is added as a synthetic document at the end when the total sequence length is not divisible by the context parallel world size. .. note:: This strategy only supports :meth:`batch_shard_by_document` and will raise an error if :meth:`batch_shard` is called directly. """
[docs] def batch_shard( self, *, inputs: List[torch.Tensor], seq_dims: List[int], pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> List[torch.Tensor]: del inputs, seq_dims, pad_values, length_multiple raise NotImplementedError( f"{self.__class__.__name__} should only be used with intra-document masking. " "Please use the 'batch_shard_by_document()' instead." )
[docs] def batch_shard_by_document( self, *, inputs: List[torch.Tensor], seq_dims: List[int], cu_doc_lens: torch.Tensor, pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> Tuple[List[torch.Tensor], Dict[str, Any]]: try: from ring_flash_attn import llama3_flash_attn_prepare_cu_seqlens except ImportError as e: raise RuntimeError(f"ring-flash-attn is required for {self.__class__.__name__}") from e assert len(inputs) == len(seq_dims) if pad_values is not None: assert len(inputs) == len(pad_values) if cu_doc_lens.device.type != "cpu": raise RuntimeError("expected 'cu_doc_lens' to be on CPU") if cu_doc_lens.ndim != 1: raise RuntimeError("expected 'cu_doc_lens' to be a 1D tensor") if cu_doc_lens[0] != 0: raise RuntimeError("expected 'cu_doc_lens' to start with a 0") if length_multiple is None: length_multiple = self.cp_world_size else: length_multiple = length_multiple * self.cp_world_size total_length = int(cu_doc_lens[-1]) padding_to_add = total_length - ensure_multiple_of(total_length, length_multiple) local_length = (total_length + padding_to_add) // self.cp_world_size if padding_to_add > 0: if pad_values is None: raise RuntimeError("'pad_values' is required since padding is needed") cu_doc_lens = torch.cat( [ cu_doc_lens, torch.tensor( [total_length + padding_to_add], dtype=cu_doc_lens.dtype, device=cu_doc_lens.device, ), ] ) out = [] for x, seq_dim, pad_value in zip( inputs, seq_dims, pad_values or [None for _ in range(len(inputs))], # type: ignore ): if x.shape[seq_dim] != total_length: raise RuntimeError( f"expected input to be have size {total_length} on the sequence dimension " f"but got {x.shape[seq_dim]}" ) if padding_to_add > 0: assert pad_value is not None x = self.pad(x, seq_dim, padding_to_add, pad_value) # NOTE: Since 'torch.slice' is not available from the Python API we just call # the JIT op directly. local_value = torch.ops.aten.slice( # type: ignore x, dim=seq_dim, start=self.cp_rank * local_length, end=(self.cp_rank + 1) * local_length, ).contiguous() out.append(local_value) ( cu_doc_lens_q, cu_doc_lens_k, max_doc_len_q, max_doc_len_k, local_k_slice, ) = llama3_flash_attn_prepare_cu_seqlens( cu_doc_lens, causal=True, rank=self.cp_rank, world_size=self.cp_world_size, ) return out, dict( cu_doc_lens_q=cu_doc_lens_q, cu_doc_lens_k=cu_doc_lens_k, max_doc_len_q=max_doc_len_q, max_doc_len_k=max_doc_len_k, local_k_slice=local_k_slice, )
def pad( self, x: torch.Tensor, seq_dim: int, padding_to_add: int, value: Union[int, float], ) -> Tuple[torch.Tensor, int]: padding = (0, 0) * (x.ndim - seq_dim - 1) + (0, padding_to_add) return F.pad(x, padding, value=value), padding_to_add
[docs] class UlyssesLoadBalancer(RingAttentionLoadBalancer): """ Implements simple contiguous sequence sharding for Ulysses-style context parallelism. Unlike ring attention which uses zig-zag or other interleaving strategies, Ulysses just needs simple contiguous chunking since the all-to-all communication handles the sequence/head exchange. """
[docs] def batch_shard( self, *, inputs: List[torch.Tensor], seq_dims: List[int], pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> List[torch.Tensor]: assert len(inputs) == len(seq_dims) assert len(set(x.shape[seq_dim] for x, seq_dim in zip(inputs, seq_dims))) == 1 if pad_values is not None: assert len(inputs) == len(pad_values) if length_multiple is None: length_multiple = self.cp_world_size elif length_multiple % self.cp_world_size != 0: raise RuntimeError( f"length multiple ({length_multiple}) must be divisible by " f"CP degree ({self.cp_world_size})" ) out = [] for x, seq_dim, pad_value in zip( inputs, seq_dims, pad_values or [None for _ in range(len(inputs))], # type: ignore ): seq_len = x.shape[seq_dim] # Pad if needed to make divisible by CP world size if seq_len % length_multiple != 0: if pad_value is None: raise RuntimeError( f"sequence dimension size ({seq_len}) must be divisible by " f"{length_multiple}, otherwise provide a padding value" ) else: x, _ = self.pad(x, seq_dim, pad_value, length_multiple=length_multiple) # Simple contiguous chunking - each rank gets seq_len / cp_world_size tokens local_seq_len = x.shape[seq_dim] // self.cp_world_size start = self.cp_rank * local_seq_len end = start + local_seq_len # Use torch.ops.aten.slice for efficiency local_value = torch.ops.aten.slice(x, dim=seq_dim, start=start, end=end).contiguous() out.append(local_value) return out
[docs] def batch_shard_by_document( self, *, inputs: List[torch.Tensor], seq_dims: List[int], cu_doc_lens: torch.Tensor, pad_values: Optional[List[Union[int, float]]] = None, length_multiple: Optional[int] = None, ) -> Tuple[List[torch.Tensor], Dict[str, Any]]: # Ulysses reconstructs full sequences via all-to-all, so we don't shard cu_doc_lens. # We just shard the inputs and pass through the original document boundaries. assert len(inputs) == len(seq_dims) if pad_values is not None: assert len(inputs) == len(pad_values) if cu_doc_lens.device.type != "cpu": raise RuntimeError("expected 'cu_doc_lens' to be on CPU") if cu_doc_lens.ndim != 1: raise RuntimeError("expected 'cu_doc_lens' to be a 1D tensor") if cu_doc_lens[0] != 0: raise RuntimeError("expected 'cu_doc_lens' to start with a 0") if length_multiple is None: length_multiple = self.cp_world_size elif length_multiple % self.cp_world_size != 0: raise RuntimeError( f"length multiple ({length_multiple}) must be divisible by " f"CP degree ({self.cp_world_size})" ) total_length = int(cu_doc_lens[-1]) padded_total_length = ensure_multiple_of(total_length, length_multiple) padding_to_add = padded_total_length - total_length # Shard the inputs (handles padding to length_multiple internally) out = self.batch_shard( inputs=inputs, seq_dims=seq_dims, pad_values=pad_values, length_multiple=length_multiple, ) # Compute max_doc_len from the (potentially padded) cu_doc_lens max_doc_len = (cu_doc_lens[1:] - cu_doc_lens[:-1]).max().item() if padding_to_add > 0: cu_doc_lens = torch.cat( [ cu_doc_lens, (cu_doc_lens[-1] + padding_to_add).unsqueeze(0), ] ) max_doc_len = max(max_doc_len, padding_to_add) # Pass through the (possibly padded) cu_doc_lens and max_doc_len # since Ulysses reconstructs full sequences before attention return out, dict(cu_doc_lens=cu_doc_lens, max_doc_len=max_doc_len)
def pad( self, x: torch.Tensor, seq_dim: int, value: Union[int, float], length_multiple: Optional[int] = None, ) -> Tuple[torch.Tensor, int]: if length_multiple is None: length_multiple = self.cp_world_size pad_to = ensure_multiple_of(x.shape[seq_dim], length_multiple) padding_to_add = pad_to - x.shape[seq_dim] padding = (0, 0) * (x.ndim - seq_dim - 1) + (0, padding_to_add) return F.pad(x, padding, value=value), padding_to_add
[docs] @dataclass class UlyssesContextParallelStyle(Config): """ Configuration for Ulysses-style context parallelism. """ @property def load_balancer(self) -> "RingAttentionLoadBalancerType": return RingAttentionLoadBalancerType.ulysses
[docs] @dataclass class RingContextParallelStyle(Config): """ Configuration for ring attention-style context parallelism. """ load_balancer: RingAttentionLoadBalancerType = RingAttentionLoadBalancerType.zig_zag """ The type of load balancer to use for ring attention. """ head_stride: int = 1 """ The stride of the head dimension to process for each iteration of ring attention. A value of 1 means each iteration will process one k and one v head. A value of 2 will process two k and two v heads, etc. A larger stride will reduce the number of communication ops. """