import logging
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Placement, Replicate, Shard
from torch.distributed.tensor.parallel import (
PrepareModuleInput,
PrepareModuleOutput,
parallelize_module,
)
from olmo_core.config import DType, StrEnum
from olmo_core.distributed.parallel import (
flatten_mesh,
get_pp_stage_mesh,
get_world_mesh,
)
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.ops import attach_auxiliary_loss
from ..buffer_cache import BufferCache
from ..config import ModuleConfig
from ..feed_forward import FeedForwardConfig
from .loss import MoELoadBalancingLossGranularity
from .mlp import DroplessMoEMLP, MoEMLP
from .parallel_mlp import ParallelDroplessMLP, ParallelMLP, ParallelMLPBase
from .router import MoERouterConfig
if TYPE_CHECKING:
from olmo_core.train.common import ReduceType
__all__ = ["MoEBase", "MoE", "DroplessMoE", "MoEConfig", "MoEType"]
log = logging.getLogger(__name__)
[docs]
class MoEType(StrEnum):
"""
An enumeration of the different MoE implementations.
"""
default = "default"
"""
➡️ :class:`MoE`
"""
dropless = "dropless"
"""
➡️ :class:`DroplessMoE`
"""
[docs]
@dataclass
class MoEConfig(ModuleConfig):
name: MoEType = MoEType.default
"""
The name of the implementation.
"""
num_experts: int = 1
hidden_size: int = 256
capacity_factor: Optional[float] = None
router: MoERouterConfig = field(default_factory=MoERouterConfig)
shared_mlp: Optional[FeedForwardConfig] = None
lb_loss_weight: Optional[float] = 0.01
lb_loss_granularity: MoELoadBalancingLossGranularity = (
MoELoadBalancingLossGranularity.local_batch
)
z_loss_weight: Optional[float] = None
scale_loss_by_num_layers: bool = True
dtype: DType = DType.float32
def num_params(self, d_model: int) -> int:
num_params = 0
num_params += self.router.num_params(d_model, self.num_experts)
num_params += 3 * d_model * self.hidden_size * self.num_experts
if self.shared_mlp is not None:
num_params += self.shared_mlp.num_params(d_model)
return num_params
def num_active_params(self, d_model: int) -> int:
return (
self.num_params(d_model)
- (3 * d_model * self.hidden_size * self.num_experts)
+ (3 * d_model * self.hidden_size * self.router.top_k)
)
[docs]
def build(
self,
d_model: int,
*,
n_layers: int = 1,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
) -> "MoEBase":
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs.update(
d_model=d_model,
n_layers=n_layers,
init_device=init_device,
dtype=kwargs.pop("dtype").as_pt(),
cache=cache,
)
try:
if self.name == MoEType.default:
return MoE(**kwargs)
elif self.name == MoEType.dropless:
return DroplessMoE(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e
[docs]
class MoEBase(nn.Module):
"""
Base class for MoE implementations.
"""
def __init__(
self,
*,
d_model: int,
num_experts: int,
hidden_size: int,
router: MoERouterConfig,
shared_mlp: Optional[FeedForwardConfig] = None,
init_device: str = "cpu",
lb_loss_weight: Optional[float] = None,
lb_loss_granularity: MoELoadBalancingLossGranularity = MoELoadBalancingLossGranularity.local_batch,
z_loss_weight: Optional[float] = None,
n_layers: int = 1,
scale_loss_by_num_layers: bool = True,
dtype: torch.dtype = torch.float32,
cache: Optional[BufferCache] = None,
**kwargs,
):
super().__init__()
if scale_loss_by_num_layers:
if lb_loss_weight is not None:
lb_loss_weight = lb_loss_weight / n_layers
if z_loss_weight is not None:
z_loss_weight = z_loss_weight / n_layers
self.router = router.build(
d_model,
num_experts,
lb_loss_weight=lb_loss_weight,
lb_loss_granularity=lb_loss_granularity,
z_loss_weight=z_loss_weight,
dtype=dtype,
init_device=init_device,
)
self.experts = self._init_parallel_mlp(
d_model=d_model,
num_experts=num_experts,
hidden_size=hidden_size,
dtype=dtype,
init_device=init_device,
cache=cache,
**kwargs,
)
self.shared_mlp = (
None
if shared_mlp is None
else shared_mlp.build(d_model, dtype=dtype, init_device=init_device)
)
self._ep_enabled = False
@property
def num_experts(self) -> int:
return self.router.num_experts
@property
def top_k(self) -> int:
return self.router.top_k
@property
def ep_enabled(self) -> bool:
return self._ep_enabled
def warmup_cache(self, max_local_microbatch_size: int):
self.experts.warmup_cache(max_local_microbatch_size)
def compute_metrics(
self, reset: bool = True
) -> Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]]:
return self.router.compute_metrics(reset=reset)
def reset_metrics(self):
self.router.reset_metrics()
[docs]
def post_batch(self, dry_run: bool = False):
"""
Should be called right after the final backward of a complete batch but before the optimizer step.
"""
self.router.post_batch(dry_run=dry_run)
@abstractmethod
def _init_parallel_mlp(
self,
*,
d_model: int,
num_experts: int,
hidden_size: int,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
**kwargs,
) -> ParallelMLPBase:
raise NotImplementedError
[docs]
def forward(
self,
x: torch.Tensor,
*,
loss_div_factor: Optional[Union[torch.Tensor, float]] = None,
) -> torch.Tensor:
"""
Run the MoE on the input ``x`` of shape ``(*, d_model)``.
:param x: The input of shape ``(*, d_model)``.
:returns: The output of the MoE layer, the optional load-balancing loss, and the optional
router Z-loss.
"""
expert_weights, expert_indices, batch_size_per_expert, router_aux_loss = self.router(
x, loss_div_factor=loss_div_factor
)
if router_aux_loss is not None:
x = attach_auxiliary_loss(x, router_aux_loss)
shared_out: Optional[torch.Tensor] = None
if self.shared_mlp is not None:
shared_out = self.shared_mlp(x)
out = self.experts(x, expert_weights, expert_indices, batch_size_per_expert)
if shared_out is not None:
shared_out = shared_out / (self.top_k + 1)
out = shared_out.add(out, alpha=self.top_k / (self.top_k + 1))
return out
def apply_pp(self, pp_mesh: DeviceMesh):
world_mesh = get_world_mesh()
assert world_mesh is not None
stage_mesh = get_pp_stage_mesh(world_mesh, pp_mesh)
group = flatten_mesh(stage_mesh).get_group()
self.router.group = group
[docs]
def apply_ep(self, ep_mesh: DeviceMesh, **kwargs):
"""
Apply expert parallelism.
"""
self.experts.apply_ep(ep_mesh, **kwargs)
self._ep_enabled = True
[docs]
def prepare_experts_for_fsdp(self, **kwargs):
"""
Should be called before wrapping this module with FSDP2.
"""
self.experts.prepare_experts_for_fsdp(**kwargs)
[docs]
def prepare_experts_for_ddp(self, **kwargs):
"""
Should be called before wrapping this module with DDP2.
"""
self.experts.prepare_experts_for_ddp(**kwargs)
def apply_cp(self, cp_mesh: DeviceMesh):
self.router.apply_cp(cp_mesh)
def apply_tp(
self,
tp_mesh: DeviceMesh,
input_layout: Optional[Placement] = None,
output_layout: Optional[Placement] = None,
use_local_output: bool = True,
float8_enabled: bool = False,
):
# Sequence parallel for the most part.
parallelize_module(
self,
device_mesh=tp_mesh,
parallelize_plan=PrepareModuleInput(
input_layouts=None if input_layout is None else (input_layout,),
desired_input_layouts=(Shard(1),),
use_local_output=False,
),
)
# Sequence parallel.
self.router.apply_tp(tp_mesh, float8_enabled=float8_enabled)
# Expert parallel.
self.experts.apply_tp(tp_mesh, float8_enabled=float8_enabled)
# Model parallel.
if self.shared_mlp is not None:
self.shared_mlp.apply_tp(
tp_mesh,
input_layout=Shard(1),
output_layout=Shard(1),
use_local_output=True,
float8_enabled=float8_enabled,
)
parallelize_module(
self,
device_mesh=tp_mesh,
parallelize_plan=PrepareModuleOutput(
output_layouts=(Shard(1),),
desired_output_layouts=(output_layout or Replicate(),),
use_local_output=use_local_output,
),
)
def num_flops_per_token(self, seq_len: int) -> int:
router_flops = 6 * sum(p.numel() for p in self.router.parameters())
shared_mlp_flops = (
self.shared_mlp.num_flops_per_token(seq_len) if self.shared_mlp is not None else 0
)
expert_flops = self.experts.num_flops_per_token(seq_len)
return router_flops + shared_mlp_flops + expert_flops
class MoE(MoEBase):
"""
A basic MoE implementation.
"""
def __init__(
self,
*,
d_model: int,
num_experts: int,
hidden_size: int,
router: MoERouterConfig,
shared_mlp: Optional[FeedForwardConfig] = None,
capacity_factor: float = 1.2,
init_device: str = "cpu",
lb_loss_weight: Optional[float] = None,
lb_loss_granularity: MoELoadBalancingLossGranularity = MoELoadBalancingLossGranularity.local_batch,
z_loss_weight: Optional[float] = None,
scale_loss_by_num_layers: bool = True,
n_layers: int = 1,
dtype: torch.dtype = torch.float32,
cache: Optional[BufferCache] = None,
):
super().__init__(
d_model=d_model,
num_experts=num_experts,
hidden_size=hidden_size,
router=router,
shared_mlp=shared_mlp,
init_device=init_device,
lb_loss_weight=lb_loss_weight,
lb_loss_granularity=lb_loss_granularity,
z_loss_weight=z_loss_weight,
scale_loss_by_num_layers=scale_loss_by_num_layers,
n_layers=n_layers,
dtype=dtype,
capacity_factor=capacity_factor,
cache=cache,
)
def _init_parallel_mlp( # type: ignore[override]
self,
*,
d_model: int,
num_experts: int,
hidden_size: int,
capacity_factor: float,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
) -> ParallelMLP:
return ParallelMLP(
mlp=MoEMLP(
d_model=d_model,
hidden_size=hidden_size,
num_experts=num_experts,
dtype=dtype,
init_device=init_device,
),
top_k=self.router.top_k,
capacity_factor=capacity_factor,
cache=cache,
)
[docs]
class DroplessMoE(MoEBase):
"""
A dropless MoE implementation.
"""
def _init_parallel_mlp( # type: ignore[override]
self,
*,
d_model: int,
num_experts: int,
hidden_size: int,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
) -> ParallelDroplessMLP:
return ParallelDroplessMLP(
mlp=DroplessMoEMLP(
d_model=d_model,
num_experts=num_experts,
hidden_size=hidden_size,
dtype=dtype,
init_device=init_device,
),
top_k=self.router.top_k,
cache=cache,
)