from typing import TYPE_CHECKING, Optional, Union, cast
import torch
import torch.nn as nn
from torch.distributed.tensor import DTensor
from olmo_core.config import StrEnum
from olmo_core.distributed.utils import distribute_like, get_local_tensor
if TYPE_CHECKING:
from ..attention import SequenceMixer
from ..feed_forward import FeedForward
from ..moe import MoEBase
def _apply_init(init_fun, x: torch.Tensor, *args, **kwargs):
if not isinstance(x, DTensor):
init_fun(x, *args, **kwargs)
return
# Initialize full version of x locally, then apply init to that.
full_x = torch.zeros(x.shape, dtype=x.dtype, device=x.device)
init_fun(full_x, *args, **kwargs)
full_x = distribute_like(x, full_x)
# Now copy over the corresponding shard of `full_x` into `x`.
get_local_tensor(x).copy_(get_local_tensor(full_x))
def init_linear(
m: nn.Linear | nn.Conv1d, *, std: float = 0.02, generator: Optional[torch.Generator] = None
):
_apply_init(
nn.init.trunc_normal_,
m.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
generator=generator,
)
if m.bias is not None:
nn.init.zeros_(m.bias)
[docs]
class InitMethod(StrEnum):
normal = "normal"
"""
Every linear and embedding layer and initialized from a truncated normal distributed
with standard deviation 0.02.
"""
normalized = "normalized"
"""
Follow the nGPT initialization scheme.
"""
llama = "llama"
"""
Like :data:`normal`, but "output" layers are initialized with a standard deviation that's
dependent on either ``d_model`` or the number of layers.
"""
llama_depth = "llama_depth"
"""
Like :data:`normal`, but "output" layers are initialized with a standard deviation that's
dependent on either ``d_model`` or the layer index.
"""
fan_in = "fan_in"
"""
Per-layer fan-in initialization where each weight matrix is initialized with
``std = 1/√d_in`` where ``d_in`` is the fan-in (number of input features) of that
specific layer. Embeddings use ``std = 1.0`` with normal distribution.
This provides forward-pass variance-preserving initialization adapted to each layer's
specific dimensions, with no depth scaling.
"""
def init_embeddings(
self,
m: nn.Embedding,
*,
d_model: int,
embed_scale: Optional[float] = None,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
):
if self in (InitMethod.llama, InitMethod.llama_depth):
_apply_init(nn.init.normal_, m.weight, generator=generator)
elif self == InitMethod.normalized:
_apply_init(nn.init.normal_, m.weight, generator=generator, std=d_model**-0.5)
elif self == InitMethod.fan_in:
# Fan-in init uses std = 1.0 for embeddings, scaled down by embed_scale if set
emb_std = 1.0 / embed_scale if embed_scale is not None else 1.0
_apply_init(nn.init.normal_, m.weight, generator=generator, std=emb_std)
else:
_apply_init(
nn.init.trunc_normal_,
m.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
generator=generator,
)
def init_final_w_out(
self,
m: nn.Linear,
*,
d_model: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
):
if self in (
InitMethod.llama,
InitMethod.llama_depth,
InitMethod.normalized,
InitMethod.fan_in,
):
std = d_model**-0.5
init_linear(m, std=std, generator=generator)
def init_attention(
self,
m: "SequenceMixer",
*,
d_model: int,
block_idx: int,
num_blocks: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
):
m.init_weights(
init_method=self,
d_model=d_model,
block_idx=block_idx,
num_blocks=num_blocks,
std=std,
generator=generator,
)
def init_feed_forward(
self,
m: "FeedForward",
*,
d_model: int,
block_idx: int,
num_blocks: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
):
# Compute std for w1 initialization
if self == InitMethod.fan_in:
# For fan_in, w1 uses 1/√d_in where d_in = d_model (ignores base std parameter)
std = m.w1.in_features**-0.5
elif self == InitMethod.normalized:
std = d_model**-0.5
init_linear(m.w1, std=std, generator=generator)
# Compute std for w3 initialization
if self == InitMethod.fan_in:
# For fan_in, w3 uses 1/√d_in where d_in = d_model
std = m.w3.in_features**-0.5
elif self == InitMethod.llama:
std = std / (2 * num_blocks) ** 0.5
elif self == InitMethod.llama_depth:
std = std / (2 * (block_idx + 1)) ** 0.5
init_linear(m.w3, std=std, generator=generator)
# Compute std for w2 initialization
if self == InitMethod.fan_in:
# For fan_in, w2 uses 1/√d_in where d_in = hidden_size
std = m.w2.in_features**-0.5
elif self == InitMethod.normalized:
std = std / (2 * num_blocks) ** 0.5
init_linear(m.w2, std=std, generator=generator)
def init_feed_forward_moe(
self,
m: "MoEBase",
*,
d_model: int,
block_idx: int,
num_blocks: int,
std: float = 0.02,
generator: Optional[torch.Generator] = None,
):
from ..moe import DroplessMoEMLP, MoELinearRouter, MoEMLP
if self == InitMethod.llama:
std = std / (2 * num_blocks) ** 0.5
elif self == InitMethod.llama_depth:
std = std / (2 * (block_idx + 1)) ** 0.5
elif self == InitMethod.fan_in:
# For fan_in, router weight uses 1/√d_model
std = d_model**-0.5
_apply_init(
nn.init.trunc_normal_,
cast(MoELinearRouter, m.router).weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
generator=generator,
)
mlp = cast(Union[MoEMLP, DroplessMoEMLP], m.experts.mlp)
# Initialize w1 (maps d_model -> hidden_size, fan-in = d_model)
if self == InitMethod.fan_in:
std = mlp.d_model**-0.5
_apply_init(
nn.init.trunc_normal_,
mlp.w1,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
generator=generator,
)
# Initialize w2 (maps hidden_size -> d_model, fan-in = hidden_size)
if self == InitMethod.fan_in:
std = mlp.hidden_size**-0.5
_apply_init(
nn.init.trunc_normal_,
mlp.w2,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
generator=generator,
)
# Initialize w3 (maps d_model -> hidden_size, fan-in = d_model)
if self == InitMethod.fan_in:
std = mlp.d_model**-0.5
_apply_init(
nn.init.trunc_normal_,
mlp.w3,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
generator=generator,
)