Source code for olmo_core.nn.transformer.block

import math
from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import Placement, Shard
from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module

from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel
from olmo_core.distributed.utils import get_local_tensor
from olmo_core.doc_utils import beta_feature
from olmo_core.ops import attach_auxiliary_loss

from ..attention.base import SequenceMixerConfig
from ..attention.ring import RingContextParallelStyle, UlyssesContextParallelStyle
from ..buffer_cache import BufferCache
from ..feed_forward import FeedForward, FeedForwardConfig
from ..functional import l2_normalize
from ..layer_norm import LayerNormConfig
from ..moe import MoEConfig, MoERouter
from ..moe.parallel_mlp import ParallelMLPBase
from ..residual_stream import ResidualStream
from .config import TransformerDataParallelWrappingStrategy

if TYPE_CHECKING:
    from olmo_core.train.common import ReduceType


[docs] class TransformerBlockBase(nn.Module): """ Base class for transformer block implementations. """ def __init__(self, *, n_layers: int): super().__init__() self.n_layers = n_layers @property def is_moe(self) -> bool: return False
[docs] @abstractmethod def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: """ Run the block on the input ``x``. :param x: The input of shape ``(batch_size, seq_len, d_model)``. """ raise NotImplementedError
def apply_pp(self, pp_mesh: DeviceMesh): del pp_mesh @abstractmethod def apply_tp( self, tp_mesh: DeviceMesh, *, input_layout: Placement, float8_enabled: bool = False ): raise NotImplementedError @abstractmethod def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): raise NotImplementedError def apply_compile(self): self.compile(fullgraph=False) @abstractmethod def apply_fsdp( self, dp_mesh: Optional[DeviceMesh] = None, prefetch_factor: int = 0, wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, **fsdp_kwargs, ): raise NotImplementedError @abstractmethod def num_flops_per_token(self, seq_len: int) -> int: raise NotImplementedError
[docs] class TransformerBlock(TransformerBlockBase): """ A typical "Llama-style" transformer block implementation. :param d_model: The model dimensionality. :param block_idx: The index/position of the block within the model. Ranges from 0 to ``n_layers - 1``. :param sequence_mixer: The sequence mixer module config (e.g. attention, recurrent, convolution, etc.). :param feed_forward: The feed forward module config. :param layer_norm: The layer norm config for both the attention LN and the feed forward LN. :param dropout: Dropout probability. :param init_device: The device used when initializing parameters. """ def __init__( self, *, d_model: int, block_idx: int, n_layers: int, sequence_mixer: SequenceMixerConfig, feed_forward: FeedForwardConfig, layer_norm: LayerNormConfig, dropout: float = 0.0, attention_residual_alpha: float = 1.0, feed_forward_residual_alpha: float = 1.0, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__(n_layers=n_layers) self.d_model = d_model self.block_idx = block_idx # NOTE: The `self.attention` naming is kept for backwards compatibility with old checkpoints. # `self.attention` could contain any `SequenceMixer` implementation, such as a `GatedDeltaNet`. # Generally it's ok to think of these as "attention" modules at the block level. self.attention = sequence_mixer.build( d_model, layer_idx=block_idx, n_layers=n_layers, init_device=init_device, cache=cache ) self.attention_norm = layer_norm.build(d_model, init_device=init_device) self.attention_residual_stream = ResidualStream( alpha=attention_residual_alpha, dropout=dropout ) self.feed_forward = feed_forward.build(d_model=d_model, init_device=init_device) self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) self.feed_forward_residual_stream = ResidualStream( alpha=feed_forward_residual_alpha, dropout=dropout )
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: del loss_div_factor h = self.attention_residual_stream(x, self.attention(self.attention_norm(x), **kwargs)) return self.feed_forward_residual_stream(h, self.feed_forward(self.feed_forward_norm(h)))
def apply_tp( self, tp_mesh: DeviceMesh, *, input_layout: Placement, float8_enabled: bool = False ): parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=(input_layout,), desired_input_layouts=(Shard(1),), ), ) parallelize_module( self.attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) parallelize_module( self.attention_residual_stream.dropout, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(), ) self.attention.apply_tp( tp_mesh, input_layout=Shard(1), output_layout=Shard(1), use_local_output=False, float8_enabled=float8_enabled, ) parallelize_module( self.feed_forward_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) parallelize_module( self.feed_forward_residual_stream.dropout, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(), ) self.feed_forward.apply_tp( tp_mesh, input_layout=Shard(1), output_layout=Shard(1), use_local_output=False, float8_enabled=float8_enabled, ) def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): self.attention.apply_cp(cp_mesh, ring=ring, uly=uly) def apply_fsdp( self, dp_mesh: Optional[DeviceMesh] = None, prefetch_factor: int = 0, wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, **fsdp_kwargs, ): if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: fsdp_att = cast(FSDPModule, fully_shard(self.attention, mesh=dp_mesh, **fsdp_kwargs)) fsdp_mlp = cast(FSDPModule, fully_shard(self.feed_forward, mesh=dp_mesh, **fsdp_kwargs)) fsdp_root = cast(FSDPModule, fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)) if prefetch_factor > 0: fsdp_root.set_modules_to_forward_prefetch([fsdp_att]) fsdp_att.set_modules_to_forward_prefetch([fsdp_mlp]) else: fully_shard(self, mesh=dp_mesh, **fsdp_kwargs) def num_flops_per_token(self, seq_len: int) -> int: attn_flops = self.attention.num_flops_per_token(seq_len) ff_flops = self.feed_forward.num_flops_per_token(seq_len) return attn_flops + ff_flops
[docs] class LayerNormScaledTransformerBlock(TransformerBlock): """ A variant of ``TransformerBlock`` that applies `LayerNorm Scaling (LNS) <https://github.com/lmsdss/LayerNorm-Scaling>`_. Each LayerNorm output is multiplied by ``1 / sqrt(layer_id)`` where ``layer_id`` is the 1-based position of the block inside the transformer. Keeping this logic in a dedicated subclass ensures that the vanilla ``TransformerBlock`` remains simple and easy to reason about. """ def __init__( self, *, d_model: int, block_idx: int, n_layers: int, sequence_mixer: SequenceMixerConfig, feed_forward: FeedForwardConfig, layer_norm: LayerNormConfig, dropout: float = 0.0, attention_residual_alpha: float = 1.0, feed_forward_residual_alpha: float = 1.0, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__( d_model=d_model, block_idx=block_idx, n_layers=n_layers, sequence_mixer=sequence_mixer, feed_forward=feed_forward, layer_norm=layer_norm, dropout=dropout, attention_residual_alpha=attention_residual_alpha, feed_forward_residual_alpha=feed_forward_residual_alpha, init_device=init_device, cache=cache, ) # LayerNorm scaling factor 1/sqrt(layer_id), where layer_id is 1-based. ln_scale_value = 1.0 / math.sqrt(block_idx + 1) self.register_buffer("ln_scale", torch.tensor(ln_scale_value, dtype=torch.float32))
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: del loss_div_factor scale = self.ln_scale.to(dtype=x.dtype, device=x.device) h = self.attention_residual_stream( x, self.attention(self.attention_norm(x) * scale, **kwargs) ) return self.feed_forward_residual_stream( h, self.feed_forward(self.feed_forward_norm(h) * scale) )
[docs] class ReorderedNormTransformerBlock(TransformerBlock): """ Like :class:`TransformerBlock` except that the attention norm is applied on the output of attention instead of the input, and likewise the feed-forward norm is applied on the output of the feed-forward instead of the input. """
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: del loss_div_factor h = self.attention_residual_stream(x, self.attention_norm(self.attention(x, **kwargs))) return self.feed_forward_residual_stream(h, self.feed_forward_norm(self.feed_forward(h)))
[docs] class PeriNormTransformerBlock(TransformerBlock): """ A transformer block in the style of `Peri-LN <https://arxiv.org/pdf/2502.02732>`_. """ def __init__( self, *, d_model: int, block_idx: int, n_layers: int, sequence_mixer: SequenceMixerConfig, feed_forward: FeedForwardConfig, layer_norm: LayerNormConfig, dropout: float = 0.0, attention_residual_alpha: float = 1.0, feed_forward_residual_alpha: float = 1.0, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__( d_model=d_model, block_idx=block_idx, n_layers=n_layers, sequence_mixer=sequence_mixer, feed_forward=feed_forward, layer_norm=layer_norm, dropout=dropout, attention_residual_alpha=attention_residual_alpha, feed_forward_residual_alpha=feed_forward_residual_alpha, init_device=init_device, cache=cache, ) self.post_attention_norm = layer_norm.build(d_model, init_device=init_device) self.post_feed_forward_norm = layer_norm.build(d_model, init_device=init_device)
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: del loss_div_factor h = self.attention_residual_stream( x, self.post_attention_norm(self.attention(self.attention_norm(x), **kwargs)) ) return self.feed_forward_residual_stream( h, self.post_feed_forward_norm(self.feed_forward(self.feed_forward_norm(h))) )
def apply_tp( self, tp_mesh: DeviceMesh, *, input_layout: Placement, float8_enabled: bool = False ): super().apply_tp(tp_mesh, input_layout=input_layout, float8_enabled=float8_enabled) parallelize_module( self.post_feed_forward_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) parallelize_module( self.post_attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() )
[docs] @beta_feature class NormalizedTransformerBlock(TransformerBlockBase): """ An nGPT block implementation to be used with the :class:`~olmo_core.nn.attention.NormalizedAttention` attention type and :class:`~olmo_core.nn.feed_forward.NormalizedFeedForward` feed-forward type. """ def __init__( self, *, d_model: int, block_idx: int, n_layers: int, sequence_mixer: SequenceMixerConfig, feed_forward: FeedForwardConfig, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__(n_layers=n_layers) self.d_model = d_model self.block_idx = block_idx # NOTE: The `self.attention` naming is kept for backwards compatibility with old checkpoints. # `self.attention` could contain any `SequenceMixer` implementation, such as a `GatedDeltaNet`. # Generally it's ok to think of these as "attention" modules at the block level. self.attention = sequence_mixer.build( d_model, layer_idx=block_idx, n_layers=n_layers, init_device=init_device, cache=cache ) self.feed_forward = feed_forward.build(d_model=d_model, init_device=init_device) self.attn_alpha_init_value = 0.05 self.attn_alpha_init_scaling = 1.0 / math.sqrt(d_model) self.attn_alpha = nn.Parameter( torch.empty(d_model, dtype=torch.float32, device=init_device) ) self.mlp_alpha_init_value = 0.05 self.mlp_alpha_init_scaling = 1.0 / math.sqrt(d_model) self.mlp_alpha = nn.Parameter(torch.empty(d_model, dtype=torch.float32, device=init_device)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.attn_alpha) nn.init.ones_(self.mlp_alpha) with torch.no_grad(): self.attn_alpha.mul_(self.attn_alpha_init_scaling) self.mlp_alpha.mul_(self.mlp_alpha_init_scaling)
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: del loss_div_factor h = l2_normalize( torch.lerp( x, l2_normalize(self.attention(x, **kwargs)), ( self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling) ).abs(), ) ) return l2_normalize( torch.lerp( h, l2_normalize(self.feed_forward(h)), (self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling)).abs(), ) )
def apply_tp( self, tp_mesh: DeviceMesh, *, input_layout: Placement, float8_enabled: bool = False ): del tp_mesh, input_layout, float8_enabled raise NotImplementedError( "TP is not implemented yet for the normalized transformer block variant" ) def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): self.attention.apply_cp(cp_mesh, ring=ring, uly=uly) def apply_fsdp( self, dp_mesh: Optional[DeviceMesh] = None, prefetch_factor: int = 0, wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, **fsdp_kwargs, ): if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: fully_shard(self.attention, mesh=dp_mesh, **fsdp_kwargs) fully_shard(self.feed_forward, mesh=dp_mesh, **fsdp_kwargs) fully_shard(self, mesh=dp_mesh, **fsdp_kwargs) if ( wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained and prefetch_factor > 0 ): cast(FSDPModule, self).set_modules_to_forward_prefetch( [cast(FSDPModule, self.attention)] ) cast(FSDPModule, self.attention).set_modules_to_forward_prefetch( [cast(FSDPModule, self.feed_forward)] )
[docs] @torch.no_grad() def normalize_matrices(self): """ Normalize the weights in all matrices. This should be called after each optimizer step, which the :class:`~olmo_core.train.train_module.TransformerTrainModule` will handle for you. """ if hasattr(self.attention, "normalize_matrices"): self.attention.normalize_matrices() # type: ignore if hasattr(self.feed_forward, "normalize_matrices"): self.feed_forward.normalize_matrices() # type: ignore
def _normalize_matrix(self, w: torch.Tensor, dim: int = -1): w.copy_(l2_normalize(w, dim=dim)) def num_flops_per_token(self, seq_len: int) -> int: attn_flops = self.attention.num_flops_per_token(seq_len) ff_flops = self.feed_forward.num_flops_per_token(seq_len) return attn_flops + ff_flops
[docs] @beta_feature class MoETransformerBlock(TransformerBlockBase): """ Like :class:`TransformerBlock` except that the dense :class:`~olmo_core.nn.feed_forward.FeedForward` module is replaced with a mixture-of-experts (MoE). """ def __init__( self, *, d_model: int, block_idx: int, n_layers: int, sequence_mixer: SequenceMixerConfig, feed_forward_moe: MoEConfig, layer_norm: LayerNormConfig, dropout: float = 0.0, init_device: str = "cpu", cache: Optional[BufferCache] = None, ): super().__init__(n_layers=n_layers) self.d_model = d_model self.block_idx = block_idx # NOTE: The `self.attention` naming is kept for backwards compatibility with old checkpoints. # `self.attention` could contain any `SequenceMixer` implementation, such as a `GatedDeltaNet`. # Generally it's ok to think of these as "attention" modules at the block level. self.attention = sequence_mixer.build( d_model, layer_idx=block_idx, n_layers=n_layers, init_device=init_device, cache=cache ) self.attention_norm = layer_norm.build(d_model, init_device=init_device) self.feed_forward_moe = feed_forward_moe.build( d_model=d_model, n_layers=n_layers, init_device=init_device, cache=cache ) self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() self._ep_enabled = False self._tp_enabled = False @property def is_moe(self) -> bool: return True @property def router(self) -> MoERouter: return self.feed_forward_moe.router @property def shared_mlp(self) -> Optional[FeedForward]: return self.feed_forward_moe.shared_mlp @property def experts(self) -> ParallelMLPBase: return self.feed_forward_moe.experts @property def top_k(self) -> int: return self.feed_forward_moe.top_k @property def ep_enabled(self) -> bool: return self._ep_enabled @property def tp_enabled(self) -> bool: return self._tp_enabled def compute_metrics( self, reset: bool = True ) -> Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]]: return self.feed_forward_moe.compute_metrics(reset=reset) def reset_metrics(self): self.feed_forward_moe.reset_metrics()
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: h = x + self.dropout(self.attention(self.attention_norm(x), **kwargs)) return h + self.dropout( self.feed_forward_moe(self.feed_forward_norm(h), loss_div_factor=loss_div_factor) )
def apply_pp(self, pp_mesh: DeviceMesh): self.feed_forward_moe.apply_pp(pp_mesh) def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): self.feed_forward_moe.apply_ep(ep_mesh, **kwargs) self._ep_enabled = True def apply_tp( self, tp_mesh: DeviceMesh, *, input_layout: Placement, float8_enabled: bool = False ): parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=(input_layout,), desired_input_layouts=(Shard(1),), ), ) parallelize_module( self.attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) self.attention.apply_tp( tp_mesh, input_layout=Shard(1), output_layout=Shard(1), use_local_output=False, float8_enabled=float8_enabled, ) parallelize_module( self.feed_forward_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) self.feed_forward_moe.apply_tp( tp_mesh, input_layout=Shard(1), output_layout=Shard(1), use_local_output=False, float8_enabled=float8_enabled, ) parallelize_module(self.dropout, device_mesh=tp_mesh, parallelize_plan=SequenceParallel()) self._tp_enabled = True def apply_cp( self, cp_mesh: DeviceMesh, ring: Optional[RingContextParallelStyle] = None, uly: Optional[UlyssesContextParallelStyle] = None, ): self.attention.apply_cp(cp_mesh, ring=ring, uly=uly) self.feed_forward_moe.apply_cp(cp_mesh) def apply_fsdp( self, dp_mesh: Optional[DeviceMesh] = None, prefetch_factor: int = 0, wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, **fsdp_kwargs, ): if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: fsdp_att = cast(FSDPModule, fully_shard(self.attention, mesh=dp_mesh, **fsdp_kwargs)) fsdp_moe = cast( FSDPModule, fully_shard(self.feed_forward_moe, mesh=dp_mesh, **fsdp_kwargs) ) fsdp_root = cast(FSDPModule, fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)) if prefetch_factor > 0: fsdp_root.set_modules_to_forward_prefetch([fsdp_att]) fsdp_att.set_modules_to_forward_prefetch([fsdp_moe]) else: fully_shard(self, mesh=dp_mesh, **fsdp_kwargs) def num_flops_per_token(self, seq_len: int) -> int: attn_flops = self.attention.num_flops_per_token(seq_len) moe_flops = self.feed_forward_moe.num_flops_per_token(seq_len) return attn_flops + moe_flops
[docs] @beta_feature class MoEReorderedNormTransformerBlock(MoETransformerBlock): """ Like :class:`MoETransformerBlock` except that the attention norm is applied on the output of attention instead of the input, and likewise the feed-forward norm is applied on the output of the feed-forward MoE instead of the input. """
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: h = x + self.dropout(self.attention_norm(self.attention(x, **kwargs))) return h + self.dropout( self.feed_forward_norm(self.feed_forward_moe(h, loss_div_factor=loss_div_factor)) )
def apply_fsdp( self, dp_mesh: Optional[DeviceMesh] = None, prefetch_factor: int = 0, wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, **fsdp_kwargs, ): if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: fsdp_att = cast(FSDPModule, fully_shard(self.attention, mesh=dp_mesh, **fsdp_kwargs)) fsdp_moe = cast( FSDPModule, fully_shard(self.feed_forward_moe, mesh=dp_mesh, **fsdp_kwargs) ) fsdp_root = cast(FSDPModule, fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)) if prefetch_factor > 0: fsdp_root.set_modules_to_forward_prefetch([fsdp_att]) fsdp_att.set_modules_to_forward_prefetch([fsdp_moe]) else: fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
[docs] @beta_feature class MoEHybridTransformerBlockBase(MoETransformerBlock): def __init__( self, *, d_model: int, n_layers: int, sequence_mixer: SequenceMixerConfig, layer_norm: LayerNormConfig, feed_forward: FeedForwardConfig, init_device: str = "cpu", **kwargs, ): super().__init__( d_model=d_model, n_layers=n_layers, sequence_mixer=sequence_mixer, layer_norm=layer_norm, init_device=init_device, **kwargs, ) self.feed_forward = feed_forward.build(d_model=d_model, init_device=init_device) self.feed_forward_moe_norm = layer_norm.build(d_model, init_device=init_device) self._use_combined_forward: Optional[bool] = None @property def use_combined_forward(self) -> bool: if self._use_combined_forward is not None: return self._use_combined_forward elif not self.ep_enabled and not self.tp_enabled: return False else: return True @use_combined_forward.setter def use_combined_forward(self, should_use: bool): if should_use and not (self.tp_enabled or self.ep_enabled): raise RuntimeError( "combined forward can only be used when expert parallelism is enabled" ) self._use_combined_forward = should_use @abstractmethod def dense_forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError @abstractmethod def sparse_forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None ) -> torch.Tensor: raise NotImplementedError @abstractmethod def combined_forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: raise NotImplementedError
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: if not self.use_combined_forward: return self.sparse_forward(x, loss_div_factor=loss_div_factor) + self.dense_forward( x, **kwargs ) else: # NOTE: alternatively could do something like this, but even with an extra stream it's # not as fast as the hand-crafted 'combined_forward()'. # stream = get_or_init_stream() # stream.wait_stream(torch.cuda.default_stream()) # h_sparse = self._fwd_sparse(x) # with torch.cuda.stream(stream): # h_dense = self._fwd_dense(x, **kwargs) # torch.cuda.default_stream().wait_stream(stream) # return h_sparse + h_dense return self.combined_forward(x, loss_div_factor=loss_div_factor, **kwargs)
def apply_tp( self, tp_mesh: DeviceMesh, *, input_layout: Placement, float8_enabled: bool = False ): super().apply_tp(tp_mesh, input_layout=input_layout, float8_enabled=float8_enabled) self.feed_forward.apply_tp( tp_mesh, output_layout=Shard(1), use_local_output=False, float8_enabled=float8_enabled, ) parallelize_module( self.feed_forward_moe_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) def apply_fsdp( self, dp_mesh: Optional[DeviceMesh] = None, prefetch_factor: int = 0, wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, **fsdp_kwargs, ): from torch.distributed.fsdp import MixedPrecisionPolicy # Force router to be full-precision. fsdp_router = cast( FSDPModule, fully_shard( self.feed_forward_moe.router, mesh=dp_mesh, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float32), ), ) if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: if not self.use_combined_forward: fsdp_att = cast( FSDPModule, fully_shard(self.attention, mesh=dp_mesh, **fsdp_kwargs) ) fsdp_mlp = cast( FSDPModule, fully_shard(self.feed_forward, mesh=dp_mesh, **fsdp_kwargs) ) fsdp_moe = cast( FSDPModule, fully_shard(self.feed_forward_moe, mesh=dp_mesh, **fsdp_kwargs) ) fsdp_root = cast(FSDPModule, fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)) if prefetch_factor > 0: fsdp_root.set_modules_to_forward_prefetch([fsdp_router, fsdp_moe, fsdp_att]) fsdp_att.set_modules_to_forward_prefetch([fsdp_mlp]) else: fsdp_att = cast( FSDPModule, fully_shard(self.attention, mesh=dp_mesh, **fsdp_kwargs) ) fsdp_mlp = cast( FSDPModule, fully_shard(self.feed_forward, mesh=dp_mesh, **fsdp_kwargs) ) # fsdp_moe = cast( # FSDPModule, # fully_shard(self.feed_forward_moe.experts.mlp, mesh=dp_mesh, **fsdp_kwargs), # ) fsdp_shared_mlp = ( None if self.feed_forward_moe.shared_mlp is None else cast( FSDPModule, fully_shard(self.feed_forward_moe.shared_mlp, mesh=dp_mesh, **fsdp_kwargs), ) ) fsdp_root = cast(FSDPModule, fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)) if prefetch_factor > 0: # fsdp_root.set_modules_to_forward_prefetch([fsdp_att, fsdp_moe]) fsdp_root.set_modules_to_forward_prefetch([fsdp_att, fsdp_router]) if fsdp_shared_mlp is not None: fsdp_att.set_modules_to_forward_prefetch([fsdp_mlp, fsdp_shared_mlp]) else: fsdp_att.set_modules_to_forward_prefetch([fsdp_mlp]) else: fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
[docs] @beta_feature class MoEHybridTransformerBlock(MoEHybridTransformerBlockBase): def dense_forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: h = x + self.dropout(self.attention(self.attention_norm(x), **kwargs)) return h + self.dropout(self.feed_forward(self.feed_forward_norm(h))) def sparse_forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None ) -> torch.Tensor: return self.dropout( self.feed_forward_moe(self.feed_forward_moe_norm(x), loss_div_factor=loss_div_factor) ) def combined_forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: # NOTE: this follows the same code path as the MoE's forward pass, except that we run # dense operations while we wait on expert parallel all-to-all comms. B, _, D = x.shape x_moe = get_local_tensor(self.feed_forward_moe_norm(x)) expert_weights, expert_indices, batch_size_per_expert, router_aux_loss = self.router( x_moe, loss_div_factor=loss_div_factor ) if router_aux_loss is not None: x_moe = attach_auxiliary_loss(x_moe, router_aux_loss) # shape: (batch_size * seq_len, d_model) x_moe = x_moe.view(-1, D) # shape: (batch_size * top_k,) expert_weights = expert_weights.flatten() # shape: (batch_size * top_k,) expert_indices = expert_indices.flatten() with torch.no_grad(): indices, bin_ids, bins = self.experts.indices_and_bins( expert_indices, batch_size_per_expert ) ( parallel_x, parallel_indices, parallel_bin_ids, parallel_bins, parallel_batch_size_per_expert, recv_counts, send_counts, expert_capacity, handle, ) = self.experts.permute_and_all_to_all( x_moe, indices=indices, bin_ids=bin_ids, bins=bins, batch_size_per_expert=batch_size_per_expert, ) # Compute attention while all-to-all is in progress. h = x + self.dropout(self.attention(self.attention_norm(x), **kwargs)) # Maybe compute MoE shared out while all-to-all is in progress. moe_shared_out: Optional[torch.Tensor] = None if self.shared_mlp is not None: # NOTE: -1 on seq dim in case of TP moe_shared_out = self.shared_mlp(x_moe.view(B, -1, D)) handle.wait() parallel_x = self.experts.compute_local_experts( parallel_x, parallel_indices=parallel_indices, parallel_bin_ids=parallel_bin_ids, parallel_bins=parallel_bins, parallel_batch_size_per_expert=parallel_batch_size_per_expert, expert_capacity=expert_capacity, ) x_moe, handle = self.experts.reverse_all_to_all( parallel_x, send_counts=send_counts, recv_counts=recv_counts ) # Compute feed-forward while all-to-all is in progress. h = h + self.dropout(self.feed_forward(self.feed_forward_norm(h))) handle.wait() x_moe = self.experts.unpermute( x_moe, expert_weights=expert_weights, expert_indices=expert_indices, indices=indices, bin_ids=bin_ids, bins=bins, ).view(B, -1, D) if moe_shared_out is not None: moe_shared_out = moe_shared_out / (self.top_k + 1) x_moe = moe_shared_out.add(x_moe, alpha=self.top_k / (self.top_k + 1)) return h + self.dropout(x_moe)
[docs] @beta_feature class MoEHybridReorderedNormTransformerBlock(MoEHybridTransformerBlockBase): def dense_forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: h = x + self.dropout(self.attention_norm(self.attention(x, **kwargs))) return h + self.dropout(self.feed_forward_norm(self.feed_forward(h))) def sparse_forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None ) -> torch.Tensor: return self.dropout( self.feed_forward_moe_norm(self.feed_forward_moe(x, loss_div_factor=loss_div_factor)) ) def combined_forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, **kwargs, ) -> torch.Tensor: # NOTE: this follows the same code path as the MoE's forward pass, except that we run # dense operations while we wait on expert parallel all-to-all comms. B, _, D = x.shape x_moe = get_local_tensor(x) expert_weights, expert_indices, batch_size_per_expert, router_aux_loss = self.router( x_moe, loss_div_factor=loss_div_factor ) if router_aux_loss is not None: x_moe = attach_auxiliary_loss(x_moe, router_aux_loss) # shape: (batch_size * seq_len, d_model) x_moe = x_moe.view(-1, D) # shape: (batch_size * seq_len * top_k,) expert_weights = get_local_tensor(expert_weights).flatten() # shape: (batch_size * seq_len * top_k,) expert_indices = get_local_tensor(expert_indices).flatten() with torch.no_grad(): indices, bin_ids, bins = self.experts.indices_and_bins( expert_indices, batch_size_per_expert ) ( parallel_x, parallel_indices, parallel_bin_ids, parallel_bins, parallel_batch_size_per_expert, recv_counts, send_counts, expert_capacity, handle, ) = self.experts.permute_and_all_to_all( x_moe, indices=indices, bin_ids=bin_ids, bins=bins, batch_size_per_expert=batch_size_per_expert, ) # Compute attention while all-to-all is in progress. h = x + self.dropout(self.attention_norm(self.attention(x, **kwargs))) # Maybe compute MoE shared out while all-to-all is in progress. moe_shared_out: Optional[torch.Tensor] = None if self.shared_mlp is not None: # NOTE: -1 on seq dim in case of TP moe_shared_out = self.shared_mlp(x_moe.view(B, -1, D)) handle.wait() parallel_x = self.experts.compute_local_experts( parallel_x, parallel_indices=parallel_indices, parallel_bin_ids=parallel_bin_ids, parallel_bins=parallel_bins, parallel_batch_size_per_expert=parallel_batch_size_per_expert, expert_capacity=expert_capacity, ) x_moe, handle = self.experts.reverse_all_to_all( parallel_x, send_counts=send_counts, recv_counts=recv_counts ) # Compute feed-forward while all-to-all is in progress. h = h + self.dropout(self.feed_forward_norm(self.feed_forward(h))) handle.wait() x_moe = self.experts.unpermute( x_moe, expert_weights=expert_weights, expert_indices=expert_indices, indices=indices, bin_ids=bin_ids, bins=bins, ).view(B, -1, D) if moe_shared_out is not None: moe_shared_out = moe_shared_out / (self.top_k + 1) x_moe = moe_shared_out.add(x_moe, alpha=self.top_k / (self.top_k + 1)) return h + self.dropout(self.feed_forward_moe_norm(x_moe))