nn.moe

MoE layers.

class olmo_core.nn.moe.MoEBase(*, d_model, num_experts, hidden_size, router, shared_mlp=None, init_device='cpu', lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, n_layers=1, scale_loss_by_num_layers=True, dtype=torch.float32, cache=None, **kwargs)[source]

Bases: Module

Base class for MoE implementations.

post_batch(dry_run=False)[source]

Should be called right after the final backward of a complete batch but before the optimizer step.

forward(x, *, loss_div_factor=None)[source]

Run the MoE on the input x of shape (*, d_model).

Parameters:

x (Tensor) – The input of shape (*, d_model).

Return type:

Tensor

Returns:

The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss.

apply_ep(ep_mesh, **kwargs)[source]

Apply expert parallelism.

prepare_experts_for_fsdp(**kwargs)[source]

Should be called before wrapping this module with FSDP2.

prepare_experts_for_ddp(**kwargs)[source]

Should be called before wrapping this module with DDP2.

class olmo_core.nn.moe.DroplessMoE(*, d_model, num_experts, hidden_size, router, shared_mlp=None, init_device='cpu', lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, n_layers=1, scale_loss_by_num_layers=True, dtype=torch.float32, cache=None, **kwargs)[source]

Bases: MoEBase

A dropless MoE implementation.

class olmo_core.nn.moe.MoEConfig(name='default', num_experts=1, hidden_size=256, capacity_factor=None, router=<factory>, shared_mlp=None, lb_loss_weight=0.01, lb_loss_granularity='local_batch', z_loss_weight=None, scale_loss_by_num_layers=True, dtype='float32')[source]

Bases: ModuleConfig

name: MoEType = 'default'

The name of the implementation.

build(d_model, *, n_layers=1, init_device='cpu', cache=None)[source]

Build the corresponding module.

Return type:

MoEBase

class olmo_core.nn.moe.MoEType(value)[source]

Bases: StrEnum

An enumeration of the different MoE implementations.

default = 'default'

➡️ MoE

dropless = 'dropless'

➡️ DroplessMoE

class olmo_core.nn.moe.MoEMLP(*, d_model, hidden_size, num_experts, dtype=torch.float32, init_device='cpu')[source]

Bases: MoEMLPBase

A basic expert MLP module with SwiGLU activation.

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x)[source]

Compute the expert outputs.

Parameters:

x (Tensor) – The input of shape (num_local_experts, N, d_model).

Return type:

Tensor

class olmo_core.nn.moe.DroplessMoEMLP(*, d_model, hidden_size, num_experts, dtype=torch.float32, init_device='cpu')[source]

Bases: MoEMLPBase

A dropless expert MLP module with SwiGLU activation.

forward(x, batch_size_per_expert)[source]

Compute the expert outputs.

Parameters:
  • x (Tensor) – The input of shape (*, d_model).

  • batch_size_per_expert (Tensor) – Specifies how many items/tokens go to each expert. Should be a 1-D LongTensor.

Return type:

Tensor

class olmo_core.nn.moe.MoERouter(*, d_model, num_experts, top_k=1, jitter_eps=None, normalize_expert_weights=None, uniform_expert_assignment=False, bias_gamma=None, gating_function='softmax', lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, init_device='cpu')[source]

Bases: Module

A base class for MoE router modules.

Parameters:
  • d_model (int) – The model dimensionality (hidden size).

  • num_experts (int) – The total number of experts.

  • top_k (int, default: 1) – The number of experts to assign to each item/token.

  • jitter_eps (Optional[float], default: None) – Controls the amount of noise added to the input during training.

  • normalize_expert_weights (Optional[float], default: None) – The type of norm (e.g. 2.0 for L2 norm) to use to normalize the expert weights.

  • uniform_expert_assignment (bool, default: False) – Force uniform assignment. Useful for benchmarking.

  • bias_gamma (Optional[float], default: None) – If set to a positive float, experts scores for top-k routing will be adjusted by a bias following the “auxiliary-loss-free load balancing” strategy from DeepSeek-v3. A reasonable value is on the order of 0.0001.

abstract get_expert_logits(x)[source]

Given the input x of shape (*, d_model), compute the un-normalized expert scores.

Return type:

Tensor

Returns:

The expert logits, shape (*, num_experts).

forward(x, *, loss_div_factor=None)[source]

Given the input x of shape (B, S, d_model), compute the experts assignment.

Return type:

Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]

Returns:

The expert weights of shape (B, S, top_k), the expert indices of shape (B, S, top_k), the total number of items routed to each expert, with shape (num_experts,), and optionally the auxiliary losses.

class olmo_core.nn.moe.MoELinearRouter(*, dtype=torch.float32, init_device='cpu', **kwargs)[source]

Bases: MoERouter

A simple, learned, linear router.

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

get_expert_logits(x)[source]

Given the input x of shape (*, d_model), compute the un-normalized expert scores.

Return type:

Tensor

Returns:

The expert logits, shape (*, num_experts).

class olmo_core.nn.moe.MoERouterConfig(name='default', top_k=1, jitter_eps=None, normalize_expert_weights=None, uniform_expert_assignment=False, bias_gamma=None, gating_function='softmax', dtype=None)[source]

Bases: ModuleConfig

A configuration class for easily building any of the different MoE router modules.

name: MoERouterType = 'default'

The name of the implementation.

num_params(d_model, num_experts)[source]

The number of params that the module will have once built.

Parameters:

d_model (int) – The model dimensionality.

Return type:

int

build(d_model, num_experts, *, lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, dtype=None, init_device='cpu')[source]

Build the corresponding MoE router module.

Parameters:
  • d_model (int) – The model dimensionality.

  • num_experts – The number of experts.

  • init_device (str, default: 'cpu') – The device initialize the parameters on, e.g. “cpu”, “meta”.

Return type:

MoERouter

class olmo_core.nn.moe.MoERouterType(value)[source]

Bases: StrEnum

An enumeration of the different MoE router implementations.

default = 'default'

➡️ MoELinearRouter

class olmo_core.nn.moe.MoERouterGatingFunction(value)[source]

Bases: StrEnum

An enumeration.

class olmo_core.nn.moe.MoELoadBalancingLossGranularity(value)[source]

Bases: StrEnum

Defines the granularity for the router’s load balancing loss.

local_batch = 'local_batch'

The loss is always computed over the rank-local shard of the batch, ignoring any parallelism strategies used. This is ideal for minimizing the number of dropped tokens for any parallel strategy.

instance = 'instance'

The loss is computed over each instance, taking into account any parallelism strategies used.