Source code for olmo_core.nn.moe.router

import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union, cast

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Replicate, Shard, distribute_tensor
from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module

import olmo_core.ops.moe as ops
from olmo_core.config import DType, StrEnum
from olmo_core.distributed.utils import (
    _HiddenTensor,
    distribute_like,
    get_local_tensor,
    hide_from_torch,
    is_distributed,
    unhide_from_torch,
)
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.utils import get_default_device

from ..config import ModuleConfig
from .loss import MoELoadBalancingLossGranularity, load_balancing_loss, router_z_loss

if TYPE_CHECKING:
    from olmo_core.train.common import ReduceType

__all__ = [
    "MoERouter",
    "MoELinearRouter",
    "MoERouterConfig",
    "MoERouterType",
    "MoERouterGatingFunction",
]


log = logging.getLogger(__name__)


# NOTE: To enable end-to-end benchmarking without convergence we
# support a flag to force the router to assign items/tokens uniformly
# across the experts. We do this with a custom autograd operation
# so that PyTorch still executes the full set of router operation.
class _UniformExpertAssignment(torch.autograd.Function):
    @staticmethod
    def forward(ctx: Any, x: torch.Tensor, num_experts: int):
        del ctx
        out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
        out = torch.remainder(out, num_experts)
        return out.view(x.shape)


_uniform_expert_assignment: Callable[
    [torch.Tensor, int], torch.Tensor
] = _UniformExpertAssignment.apply  # type: ignore


[docs] class MoERouterType(StrEnum): """ An enumeration of the different MoE router implementations. """ default = "default" """ ➡️ :class:`MoELinearRouter` """
[docs] class MoERouterGatingFunction(StrEnum): softmax = "softmax" sigmoid = "sigmoid"
[docs] @dataclass class MoERouterConfig(ModuleConfig): """ A configuration class for easily building any of the different MoE router modules. """ name: MoERouterType = MoERouterType.default """ The name of the implementation. """ top_k: int = 1 jitter_eps: Optional[float] = None normalize_expert_weights: Optional[float] = None uniform_expert_assignment: bool = False bias_gamma: Optional[float] = None gating_function: MoERouterGatingFunction = MoERouterGatingFunction.softmax dtype: Optional[DType] = None
[docs] def num_params(self, d_model: int, num_experts: int) -> int: """ The number of params that the module will have once built. :param d_model: The model dimensionality. """ num_params = 0 if self.name == MoERouterType.default: num_params += d_model * num_experts else: raise NotImplementedError return num_params
[docs] def build( self, d_model: int, num_experts, *, lb_loss_weight: Optional[float] = None, lb_loss_granularity: MoELoadBalancingLossGranularity = MoELoadBalancingLossGranularity.local_batch, z_loss_weight: Optional[float] = None, dtype: Optional[torch.dtype] = None, init_device: str = "cpu", ) -> "MoERouter": """ Build the corresponding MoE router module. :param d_model: The model dimensionality. :param num_experts: The number of experts. :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( d_model=d_model, num_experts=num_experts, init_device=init_device, lb_loss_weight=lb_loss_weight, lb_loss_granularity=lb_loss_granularity, z_loss_weight=z_loss_weight, ) if self.dtype is not None: kwargs["dtype"] = self.dtype.as_pt() elif dtype is not None: kwargs["dtype"] = dtype try: if self.name == MoERouterType.default: return MoELinearRouter(**kwargs) else: raise NotImplementedError(self.name) except TypeError as e: raise OLMoConfigurationError( f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" ) from e
[docs] class MoERouter(nn.Module): """ A base class for MoE router modules. :param d_model: The model dimensionality (hidden size). :param num_experts: The total number of experts. :param top_k: The number of experts to assign to each item/token. :param jitter_eps: Controls the amount of noise added to the input during training. :param normalize_expert_weights: The type of norm (e.g. ``2.0`` for L2 norm) to use to normalize the expert weights. :param uniform_expert_assignment: Force uniform assignment. Useful for benchmarking. :param bias_gamma: If set to a positive float, experts scores for top-k routing will be adjusted by a bias following the "auxiliary-loss-free load balancing" strategy from DeepSeek-v3. A reasonable value is on the order of 0.0001. """ def __init__( self, *, d_model: int, num_experts: int, top_k: int = 1, jitter_eps: Optional[float] = None, normalize_expert_weights: Optional[float] = None, uniform_expert_assignment: bool = False, bias_gamma: Optional[float] = None, gating_function: MoERouterGatingFunction = MoERouterGatingFunction.softmax, lb_loss_weight: Optional[float] = None, lb_loss_granularity: MoELoadBalancingLossGranularity = MoELoadBalancingLossGranularity.local_batch, z_loss_weight: Optional[float] = None, init_device: str = "cpu", ): super().__init__() self.d_model = d_model self.num_experts = num_experts self.top_k = top_k self.jitter_eps = jitter_eps self.normalize_expert_weights = normalize_expert_weights self.uniform_expert_assignment = uniform_expert_assignment self.bias_gamma = bias_gamma self.gating_function = gating_function self.lb_loss_weight = lb_loss_weight self.lb_loss_granularity = lb_loss_granularity self.z_loss_weight = z_loss_weight self.group: Optional[dist.ProcessGroup] = None self.cp_mesh: Optional[dist.DeviceMesh] = None self.tp_mesh: Optional[dist.DeviceMesh] = None if self.bias_gamma is not None: assert self.bias_gamma > 0 self.register_buffer("score_bias", torch.zeros(self.num_experts, device=init_device)) else: self.register_buffer("score_bias", None) # NOTE: we don't use buffers for t hese because we don't want FSDP to manage them, and we # don't use a BufferCache because `torch.compile()` doesn't handle that well when we're modifying # values in the cache. self._batch_size_per_expert = hide_from_torch( torch.zeros(self.num_experts, device=init_device) ) self._score_bias_batch_size_per_expert: Optional[_HiddenTensor] = None self._load_balancing_loss: Optional[_HiddenTensor] = None self._z_loss: Optional[_HiddenTensor] = None def reset_parameters(self): self._batch_size_per_expert = hide_from_torch( torch.zeros(self.num_experts, device=self.device) ) if self.bias_gamma is not None: assert self.score_bias is not None score_bias = cast(torch.Tensor, self.score_bias) score_bias.zero_() self._score_bias_batch_size_per_expert = hide_from_torch( torch.zeros(self.num_experts, device=self.device) ) if self.lb_loss_weight is not None: self._load_balancing_loss = hide_from_torch(torch.zeros([], device=self.device)) if self.z_loss_weight is not None: self._z_loss = hide_from_torch(torch.zeros([], device=self.device)) @property def device(self) -> torch.device: return get_default_device() @property def score_bias_batch_size_per_expert(self) -> Optional[torch.Tensor]: if self.bias_gamma is not None: if self._score_bias_batch_size_per_expert is None: self._score_bias_batch_size_per_expert = hide_from_torch( torch.zeros(self.num_experts, device=self.device) ) elif self._score_bias_batch_size_per_expert.device != self.device: self._score_bias_batch_size_per_expert = self._score_bias_batch_size_per_expert.to( self.device ) return ( None if self._score_bias_batch_size_per_expert is None else unhide_from_torch(self._score_bias_batch_size_per_expert) ) @score_bias_batch_size_per_expert.setter def score_bias_batch_size_per_expert(self, value: torch.Tensor): self._score_bias_batch_size_per_expert = hide_from_torch(value) @property def batch_size_per_expert(self) -> torch.Tensor: if self._batch_size_per_expert.device != self.device: self._batch_size_per_expert = self._batch_size_per_expert.to(self.device) return unhide_from_torch(self._batch_size_per_expert) @batch_size_per_expert.setter def batch_size_per_expert(self, value: torch.Tensor): self._batch_size_per_expert = hide_from_torch(value) @property def load_balancing_loss(self) -> Optional[torch.Tensor]: if self.lb_loss_weight is not None: if self._load_balancing_loss is None: self._load_balancing_loss = hide_from_torch(torch.zeros([], device=self.device)) elif self._load_balancing_loss.device != self.device: self._load_balancing_loss = self._load_balancing_loss.to(self.device) return ( None if self._load_balancing_loss is None else unhide_from_torch(self._load_balancing_loss) ) @load_balancing_loss.setter def load_balancing_loss(self, value: torch.Tensor): self._load_balancing_loss = hide_from_torch(value) @property def z_loss(self) -> Optional[torch.Tensor]: if self.z_loss_weight is not None: if self._z_loss is None: self._z_loss = hide_from_torch(torch.zeros([], device=self.device)) elif self._z_loss.device != self.device: self._z_loss = self._z_loss.to(self.device) return None if self._z_loss is None else unhide_from_torch(self._z_loss) @z_loss.setter def z_loss(self, value: torch.Tensor): self._z_loss = hide_from_torch(value) @torch.no_grad() def post_batch(self, dry_run: bool = False): if self.bias_gamma is None or not self.training: return assert self.score_bias is not None assert self.score_bias_batch_size_per_expert is not None score_bias = cast(torch.Tensor, self.score_bias) batch_size_per_expert = self.score_bias_batch_size_per_expert # Maybe reduce across the process group. if is_distributed(): dist.all_reduce(batch_size_per_expert, group=self.group) ideal_batch_size_per_expert = batch_size_per_expert.mean( dim=0, keepdim=True, dtype=torch.float32 ) bias_delta = self.bias_gamma * (ideal_batch_size_per_expert - batch_size_per_expert).sign() # NOTE: have to be careful here to manage the case where `score_bias` is a DTensor. bias_delta = distribute_like(score_bias, bias_delta) if not dry_run: get_local_tensor(score_bias).add_(get_local_tensor(bias_delta)) # Reset the accumulator. batch_size_per_expert.zero_() def jitter(self, x: torch.Tensor) -> torch.Tensor: if self.jitter_eps is None or not self.training: return x else: low = 1.0 - self.jitter_eps high = 1.0 + self.jitter_eps noise = torch.rand_like(x) return x * (low + noise * (high - low)) def get_top_k(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: expert_weights: torch.Tensor expert_indices: torch.Tensor if self.bias_gamma is None: if self.top_k == 1: expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) else: expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1) else: assert self.score_bias is not None with torch.no_grad(): _, expert_indices = torch.topk( scores + self.score_bias.unsqueeze(0), self.top_k, dim=-1 # type: ignore ) expert_weights = scores.gather(-1, expert_indices) if self.uniform_expert_assignment: expert_indices = _uniform_expert_assignment(expert_indices, self.num_experts) expert_weights = scores.gather(-1, expert_indices) return expert_weights, expert_indices
[docs] @abstractmethod def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: """ Given the input ``x`` of shape ``(*, d_model)``, compute the un-normalized expert scores. :returns: The expert logits, shape ``(*, num_experts)``. """ raise NotImplementedError
@torch.no_grad() def compute_metrics( self, reset: bool = True ) -> Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]]: from olmo_core.train.common import ReduceType out: Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]] = {} # Load imbalance. batch_size_per_expert = self.batch_size_per_expert out["load imbalance"] = ( batch_size_per_expert.max() / batch_size_per_expert.mean(dtype=torch.float), ReduceType.max, ) # Load balancing loss. if self.lb_loss_weight is not None: assert self.load_balancing_loss is not None out["load balancing loss"] = ( self.lb_loss_weight * self.load_balancing_loss, ReduceType.mean, ) out["load balancing loss unscaled"] = ( self.load_balancing_loss.clone(), ReduceType.mean, ) # Router Z loss. if self.z_loss_weight is not None: assert self.z_loss is not None out["router Z loss"] = (self.z_loss_weight * self.z_loss, ReduceType.mean) out["router Z loss unscaled"] = (self.z_loss.clone(), ReduceType.mean) if reset: self.reset_metrics() return out def reset_metrics(self): if (bz_per_expert := self.batch_size_per_expert) is not None: bz_per_expert.zero_() if (lb_loss := self.load_balancing_loss) is not None: lb_loss.zero_() if (z_loss := self.z_loss) is not None: z_loss.zero_()
[docs] def forward( self, x: torch.Tensor, *, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Given the input ``x`` of shape ``(B, S, d_model)``, compute the experts assignment. :returns: The expert weights of shape ``(B, S, top_k)``, the expert indices of shape ``(B, S, top_k)``, the total number of items routed to each expert, with shape ``(num_experts,)``, and optionally the auxiliary losses. """ # shape: (batch_size, seq_len, d_model) x = self.jitter(x) # shape: (batch_size, seq_len, num_experts) logits = self.get_expert_logits(x).float() # shape: (batch_size, seq_len, num_experts) if self.gating_function == MoERouterGatingFunction.softmax: scores = logits.softmax(dim=-1) elif self.gating_function == MoERouterGatingFunction.sigmoid: scores = F.sigmoid(logits) + 1e-7 else: raise NotImplementedError(self.gating_function) # shape: (batch_size, seq_len, top_k) expert_weights, expert_indices = self.get_top_k(scores) if self.normalize_expert_weights is not None: expert_weights = expert_weights.div( torch.norm( expert_weights, p=self.normalize_expert_weights, dim=-1, keepdim=True, ) ) with torch.no_grad(): # Histogram the expert ids to identify the number of items/tokens routed to each expert. # shape: (batch_size, seq_len, num_experts) batched_batch_size_per_expert = ops.batched_histc(expert_indices, self.num_experts) # shape: (batch_size, num_experts) batched_batch_size_per_expert = batched_batch_size_per_expert.sum(dim=1) # shape: (num_experts,) batch_size_per_expert = batched_batch_size_per_expert.sum(dim=0) # Maybe compute auxiliary losses and accumulate metrics. aux_loss: Optional[torch.Tensor] = None if self.training and torch.is_grad_enabled(): with torch.autocast(enabled=False, device_type=x.device.type): if self.lb_loss_weight is not None: assert self.load_balancing_loss is not None # Make sure scores are normalized, otherwise load balancing loss doesn't work well. if self.gating_function == MoERouterGatingFunction.sigmoid: scores = scores / scores.sum(dim=-1, keepdim=True) lb_loss = load_balancing_loss( num_experts=self.num_experts, top_k=self.top_k, expert_scores=scores, batch_size_per_expert=batch_size_per_expert, batched_batch_size_per_expert=batched_batch_size_per_expert, granularity=self.lb_loss_granularity, loss_div_factor=loss_div_factor, tp_mesh=self.tp_mesh, cp_mesh=self.cp_mesh, ) self.load_balancing_loss += lb_loss.detach() scaled_lb_loss = self.lb_loss_weight * lb_loss aux_loss = scaled_lb_loss if self.z_loss_weight is not None: assert self.z_loss is not None z_loss = router_z_loss( expert_logits=logits, loss_div_factor=loss_div_factor, tp_mesh=self.tp_mesh, cp_mesh=self.cp_mesh, ) self.z_loss += z_loss.detach() scaled_z_loss = self.z_loss_weight * z_loss aux_loss = scaled_z_loss if aux_loss is None else aux_loss + scaled_z_loss self.batch_size_per_expert += batch_size_per_expert if self.bias_gamma is not None: assert self.score_bias_batch_size_per_expert is not None self.score_bias_batch_size_per_expert += batch_size_per_expert return expert_weights, expert_indices, batch_size_per_expert, aux_loss
def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): del float8_enabled parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Shard(1),), use_local_output=True, ), ) self.tp_mesh = tp_mesh def apply_cp(self, cp_mesh: DeviceMesh): self.cp_mesh = cp_mesh
[docs] class MoELinearRouter(MoERouter): """ A simple, learned, linear router. """ def __init__( self, *, dtype: torch.dtype = torch.float32, init_device: str = "cpu", **kwargs, ): super().__init__(init_device=init_device, **kwargs) # NOTE: this parameter needs to have a large enough first dimension (which would be num experts) # in order to be sharded over big world sizes with FSDP. So we flatten it to a single dimension tensor. # And for that reason we don't support a 'bias' option. self.weight = nn.Parameter( torch.empty(self.num_experts * self.d_model, device=init_device, dtype=dtype) ) self.reset_parameters() @property def device(self) -> torch.device: return self.weight.device if self.weight.device.type != "meta" else torch.device("cpu") def reset_parameters(self) -> None: super().reset_parameters() nn.init.trunc_normal_(self.weight, std=0.02, a=-3 * 0.02, b=3 * 0.02)
[docs] def extra_repr(self): return f"in_features={self.d_model}, num_experts={self.num_experts}"
[docs] def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: return F.linear( x.float(), get_local_tensor(self.weight).view(self.num_experts, self.d_model).float() )
def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): super().apply_tp(tp_mesh, float8_enabled=float8_enabled) self.register_parameter( "weight", nn.Parameter(distribute_tensor(self.weight, tp_mesh, [Replicate()])) )