nn.moe¶
MoE layers.
- class olmo_core.nn.moe.MoEBase(*, d_model, num_experts, hidden_size, router, shared_mlp=None, init_device='cpu', lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, n_layers=1, scale_loss_by_num_layers=True, dtype=torch.float32, cache=None, **kwargs)[source]¶
Bases:
ModuleBase class for MoE implementations.
- post_batch(dry_run=False)[source]¶
Should be called right after the final backward of a complete batch but before the optimizer step.
- class olmo_core.nn.moe.DroplessMoE(*, d_model, num_experts, hidden_size, router, shared_mlp=None, init_device='cpu', lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, n_layers=1, scale_loss_by_num_layers=True, dtype=torch.float32, cache=None, **kwargs)[source]¶
Bases:
MoEBaseA dropless MoE implementation.
- class olmo_core.nn.moe.MoEConfig(name='default', num_experts=1, hidden_size=256, capacity_factor=None, router=<factory>, shared_mlp=None, lb_loss_weight=0.01, lb_loss_granularity='local_batch', z_loss_weight=None, scale_loss_by_num_layers=True, dtype='float32')[source]¶
Bases:
ModuleConfig
- class olmo_core.nn.moe.MoEType(value)[source]¶
Bases:
StrEnumAn enumeration of the different MoE implementations.
- default = 'default'¶
➡️
MoE
- dropless = 'dropless'¶
➡️
DroplessMoE
- class olmo_core.nn.moe.MoEMLP(*, d_model, hidden_size, num_experts, dtype=torch.float32, init_device='cpu')[source]¶
Bases:
MoEMLPBaseA basic expert MLP module with SwiGLU activation.
- class olmo_core.nn.moe.DroplessMoEMLP(*, d_model, hidden_size, num_experts, dtype=torch.float32, init_device='cpu')[source]¶
Bases:
MoEMLPBaseA dropless expert MLP module with SwiGLU activation.
- class olmo_core.nn.moe.MoERouter(*, d_model, num_experts, top_k=1, jitter_eps=None, normalize_expert_weights=None, uniform_expert_assignment=False, bias_gamma=None, gating_function='softmax', lb_loss_weight=None, lb_loss_granularity='local_batch', z_loss_weight=None, init_device='cpu')[source]¶
Bases:
ModuleA base class for MoE router modules.
- Parameters:
d_model (
int) – The model dimensionality (hidden size).num_experts (
int) – The total number of experts.top_k (
int, default:1) – The number of experts to assign to each item/token.jitter_eps (
Optional[float], default:None) – Controls the amount of noise added to the input during training.normalize_expert_weights (
Optional[float], default:None) – The type of norm (e.g.2.0for L2 norm) to use to normalize the expert weights.uniform_expert_assignment (
bool, default:False) – Force uniform assignment. Useful for benchmarking.bias_gamma (
Optional[float], default:None) – If set to a positive float, experts scores for top-k routing will be adjusted by a bias following the “auxiliary-loss-free load balancing” strategy from DeepSeek-v3. A reasonable value is on the order of 0.0001.
- abstract get_expert_logits(x)[source]¶
Given the input
xof shape(*, d_model), compute the un-normalized expert scores.- Return type:
- Returns:
The expert logits, shape
(*, num_experts).
- class olmo_core.nn.moe.MoELinearRouter(*, dtype=torch.float32, init_device='cpu', **kwargs)[source]¶
Bases:
MoERouterA simple, learned, linear router.
- class olmo_core.nn.moe.MoERouterConfig(name='default', top_k=1, jitter_eps=None, normalize_expert_weights=None, uniform_expert_assignment=False, bias_gamma=None, gating_function='softmax', dtype=None)[source]¶
Bases:
ModuleConfigA configuration class for easily building any of the different MoE router modules.
-
name:
MoERouterType= 'default'¶ The name of the implementation.
- num_params(d_model, num_experts)[source]¶
The number of params that the module will have once built.
-
name:
- class olmo_core.nn.moe.MoERouterType(value)[source]¶
Bases:
StrEnumAn enumeration of the different MoE router implementations.
- default = 'default'¶
- class olmo_core.nn.moe.MoELoadBalancingLossGranularity(value)[source]¶
Bases:
StrEnumDefines the granularity for the router’s load balancing loss.
- local_batch = 'local_batch'¶
The loss is always computed over the rank-local shard of the batch, ignoring any parallelism strategies used. This is ideal for minimizing the number of dropped tokens for any parallel strategy.
- instance = 'instance'¶
The loss is computed over each instance, taking into account any parallelism strategies used.