Source code for olmo_core.nn.moe.loss

from typing import Optional, Union

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Replicate, Shard

from olmo_core.config import StrEnum
from olmo_core.distributed.utils import get_local_tensor


[docs] class MoELoadBalancingLossGranularity(StrEnum): """ Defines the granularity for the router's load balancing loss. """ local_batch = "local_batch" """ The loss is always computed over the rank-local shard of the batch, ignoring any parallelism strategies used. This is ideal for minimizing the number of dropped tokens for any parallel strategy. """ instance = "instance" """ The loss is computed over each instance, taking into account any parallelism strategies used. """
def load_balancing_loss( *, num_experts: int, top_k: int, expert_scores: torch.Tensor, batch_size_per_expert: torch.Tensor, batched_batch_size_per_expert: torch.Tensor, granularity: MoELoadBalancingLossGranularity, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, tp_mesh: Optional[dist.DeviceMesh] = None, cp_mesh: Optional[dist.DeviceMesh] = None, ) -> torch.Tensor: expert_scores, batch_size_per_expert, batched_batch_size_per_expert = ( get_local_tensor(expert_scores), get_local_tensor(batch_size_per_expert), get_local_tensor(batched_batch_size_per_expert), ) B, S, _ = expert_scores.shape loss: torch.Tensor if granularity == MoELoadBalancingLossGranularity.instance: # shape: (B, num_experts) batched_batch_size_per_expert = batched_batch_size_per_expert.type_as(expert_scores) # NOTE: for CP it suffices to reduce the 'batched_batch_size_per_expert' across the CP group # and do the rest of the computation locally. if cp_mesh is not None: dist.all_reduce(batched_batch_size_per_expert, group=cp_mesh.get_group()) # NOTE: for TP, the end result needs to be a DTensor over the TP mesh, so we handle this case # a little differently. if tp_mesh is not None: # NOTE: assumes sharded on sequence dimension and equal splits across TP group. dist.all_reduce(batched_batch_size_per_expert, group=tp_mesh.get_group()) batched_batch_size_per_expert = DTensor.from_local( batched_batch_size_per_expert, tp_mesh, (Replicate(),) ) # shape: (B * S, num_experts) -> (B, S, num_experts,) -> (B, 1, num_experts) expert_scores = expert_scores.view(B, -1, num_experts).mean(dim=1, keepdim=True) # shape: (B, 1, num_experts) -> (B, num_experts) expert_scores = DTensor.from_local(expert_scores, tp_mesh, (Shard(1),)).mean(dim=1) else: # shape: (B * S, num_experts) -> (B, S, num_experts,) -> (B, num_experts) expert_scores = expert_scores.view(B, -1, num_experts).mean(dim=1) # We compute this across the TP and CP groups, so the 'loss_div_factor' should represent # the total number of tokens across the TP and CP groups. if loss_div_factor is None: # this gives us total number of tokens across TP + CP groups. loss_div_factor = batched_batch_size_per_expert.sum() / top_k # shape: scalar loss = (expert_scores * batched_batch_size_per_expert).sum() / loss_div_factor elif granularity == MoELoadBalancingLossGranularity.local_batch: # NOTE: We essentially ignore CP for this granularity, and for TP we still compute the loss # locally, but wrap as a DTensor and reduce it at the end because the end result has to be # a DTensor over the TP mesh. # Due to that DTensor reduction, with TP the 'loss_div_factor' should be the total number # of tokens across the TP group, but not the CP group. if loss_div_factor is None: loss_div_factor = B * S if tp_mesh is not None: loss_div_factor = loss_div_factor * tp_mesh.size() elif cp_mesh is not None: loss_div_factor = loss_div_factor / cp_mesh.size() # shape: (num_experts,) batch_size_per_expert = batch_size_per_expert.type_as(expert_scores) # shape: (B, S, num_experts) -> (B * S, num_experts) expert_scores = expert_scores.view(-1, num_experts) # shape: (B * S, num_experts) -> (num_experts,) expert_scores = expert_scores.mean(dim=0) # shape: scalar loss = torch.dot(batch_size_per_expert, expert_scores) / loss_div_factor if tp_mesh is not None: loss = DTensor.from_local(loss.unsqueeze(0), tp_mesh, (Shard(0),)).sum() else: raise NotImplementedError(granularity) scale = num_experts / top_k return scale * loss def router_z_loss( *, expert_logits: torch.Tensor, loss_div_factor: Optional[Union[torch.Tensor, float]] = None, tp_mesh: Optional[dist.DeviceMesh] = None, cp_mesh: Optional[dist.DeviceMesh] = None, ) -> torch.Tensor: expert_logits = get_local_tensor(expert_logits) B, S, _ = expert_logits.shape # NOTE: with TP, end result has to be a DTensor over the TP mesh, so we wrap as a DTensor # and reduce it. Due to this reduction, the 'loss_div_factor' should represent the total # number of tokens across the TP group (but not the CP group). if loss_div_factor is None: loss_div_factor = B * S if tp_mesh is not None: loss_div_factor = loss_div_factor * tp_mesh.size() elif cp_mesh is not None: loss_div_factor = loss_div_factor / cp_mesh.size() loss = torch.logsumexp(expert_logits, dim=-1).square().sum() / loss_div_factor if tp_mesh is not None: loss = DTensor.from_local(loss.unsqueeze(0), tp_mesh, (Shard(0),)).sum() return loss