Source code for olmo_core.nn.transformer.model
import logging
from collections import defaultdict
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
Union,
cast,
)
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from olmo_core.data.utils import get_cumulative_document_lengths
from olmo_core.distributed.parallel import get_pp_mesh
from olmo_core.distributed.utils import hide_from_torch, unhide_from_torch
from olmo_core.doc_utils import beta_feature
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.float8 import Float8Config
from olmo_core.nn.attention.ring import (
RingContextParallelStyle,
UlyssesContextParallelStyle,
)
from olmo_core.utils import get_default_device, mark_dynamic, move_to_device
from ..attention import (
Attention,
FusedAttention,
RingAttentionLoadBalancer,
SequenceMixer,
)
from ..buffer_cache import BufferCache
from ..functional import l2_normalize
from ..layer_norm import LayerNormConfig
from ..lm_head import LMHeadConfig, LMOutputWithLoss
from ..moe import MoEBase
from ..rope import RoPEBuffers, RotaryEmbeddingBase
from ..utils import selective_checkpointing_context_fn
from .block import (
MoETransformerBlock,
NormalizedTransformerBlock,
TransformerBlock,
TransformerBlockBase,
)
from .config import (
TransformerActivationCheckpointingMode,
TransformerBlockConfig,
TransformerDataParallelWrappingStrategy,
resolve_block_configs,
)
from .init import InitMethod
if TYPE_CHECKING:
from olmo_core.train.common import ReduceType
__all__ = [
"Transformer",
"NormalizedTransformer",
"MoETransformer",
"TransformerDataParallelWrappingStrategy",
"TransformerActivationCheckpointingMode",
]
log = logging.getLogger(__name__)
[docs]
class Transformer(nn.Module):
"""
A typical "Llama-style" transformer implementation.
:param d_model: The model dimensionality.
:param vocab_size: The vocab size.
:param n_layers: The number of transformer layers/blocks.
:param block: The block configuration. Can be a single block config or a dict of named blocks.
:param layer_norm: The layer norm config for the final layer norm.
:param bias: Whether to use a bias in the final linear layer.
:param dtype: The datatype to use for the linear output layer.
:param init_device: The device used when initializing parameters.
:param init_seed: The seed used when initializing parameters.
:param init_std: The standard deviation used when initializing parameters.
:param embedding_init_std: The standard deviation used when initializing the embeddings.
:param block_overrides: Overrides for specific blocks. Not supported if `block` is a dict of named blocks.
:param block_pattern: The pattern of blocks to use. Required if `block` is a dict of named blocks.
:param embed_scale: The scale factor for the embeddings.
"""
def __init__(
self,
*,
d_model: int,
vocab_size: int,
n_layers: int,
block: TransformerBlockConfig | dict[str, TransformerBlockConfig],
lm_head: LMHeadConfig,
embedding_norm: Optional[LayerNormConfig] = None,
dtype: torch.dtype = torch.float32,
init_method: InitMethod = InitMethod.normal,
init_device: str = "cpu",
init_seed: int = 0,
init_std: float = 0.02,
embedding_init_std: Optional[float] = None,
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None,
block_pattern: Optional[List[str]] = None,
embed_scale: Optional[float] = None,
):
super().__init__()
cache = BufferCache()
self.d_model = d_model
self.vocab_size = vocab_size
self.n_layers = n_layers
self.dtype = dtype
self.embed_scale = embed_scale
self.embeddings = nn.Embedding(vocab_size, d_model, dtype=dtype, device=init_device)
self.embedding_norm = (
None
if embedding_norm is None
else embedding_norm.build(
d_model,
init_device=init_device,
)
)
block_configs: List[TransformerBlockConfig] = resolve_block_configs(
n_layers=n_layers,
block=block,
block_pattern=block_pattern,
block_overrides=block_overrides,
)
self.blocks = nn.ModuleDict()
for block_idx in range(n_layers):
self.blocks[str(block_idx)] = self._validate_block(
block_configs[block_idx].build(
d_model=d_model,
block_idx=block_idx,
n_layers=n_layers,
init_device=init_device,
cache=cache,
)
)
self.lm_head = lm_head.build(
d_model=d_model, vocab_size=vocab_size, init_device=init_device
)
self.init_device = init_device
self.init_method = InitMethod(init_method)
self.init_seed = init_seed
self.init_std = init_std
self.embedding_init_std = embedding_init_std
self._cache = cache
self._pp_enabled = False
self._pp_group_size = 1
self._fp8_enabled = False
self._precompute_float8_dynamic_scale_for_fsdp = False
self._compile_enabled = False
self._device: Optional[torch.device] = None
self._cp_load_balancer: Optional[RingAttentionLoadBalancer] = None
self._tp_enabled = False
self._tp_mesh: Optional[DeviceMesh] = None
self._fsdp_enabled = False
# Cache the value of these properties up-front in case the parameters are removed
# later, like for pipeline parallelism.
self.num_params
self.num_non_embedding_params
def _validate_block(self, block: TransformerBlockBase) -> TransformerBlockBase:
return block
def compute_auxiliary_metrics(
self, reset: bool = True
) -> Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]]:
del reset
return {}
def reset_auxiliary_metrics(self):
pass
@property
def pp_enabled(self) -> bool:
return self._pp_enabled
@property
def fp8_enabled(self) -> bool:
return self._fp8_enabled
@property
def tp_enabled(self) -> bool:
return self._tp_enabled
@property
def fsdp_enabled(self) -> bool:
return self._fsdp_enabled
@property
def is_moe(self) -> bool:
return False
@property
def device(self) -> torch.device:
if self._device is None:
for p in self.parameters():
if p.numel() > 0:
self._device = p.device
break
else:
self._device = get_default_device()
return self._device
@property
def compile_enabled(self) -> bool:
return self._compile_enabled
[docs]
def get_rope_buffers(
self, seq_len: int, device: Optional[torch.device] = None
) -> Dict[int, Optional[RoPEBuffers]]:
"""
Get the RoPE buffers to pass to each layer.
"""
if device is None:
device = self.device
rope_buffers = {}
for key, block in self.blocks.items():
if isinstance(block.attention, (Attention, FusedAttention)):
rope = cast(Optional[RotaryEmbeddingBase], block.attention.rope)
rope_buffers[int(key)] = None if rope is None else rope.get_buffers(seq_len, device)
else:
rope_buffers[int(key)] = None
return rope_buffers
[docs]
@torch.no_grad()
def init_weights(
self,
*,
max_seq_len: Optional[int] = None,
max_local_microbatch_size: Optional[int] = None,
device: Optional[torch.device] = None,
world_mesh: Optional[DeviceMesh] = None,
) -> torch.Generator:
"""
Initialize the model weights.
:param max_seq_len: The maximum sequence length expected. This is used
to warm up the RoPE cache.
:param max_local_microbatch_size: The maximum local (rank) micro-batch size (in tokens)
expected. This is used to warm-up some MoE cache.
:param device: The device the local copy of the model will be trained on.
"""
device = device or self.device
self.to_empty(device=device)
for module in self.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters() # type: ignore
seed = self.init_seed
if world_mesh is not None and self.pp_enabled:
seed += get_pp_mesh(world_mesh).get_local_rank()
generator = torch.Generator(device).manual_seed(seed)
if self.embeddings is not None:
self.init_method.init_embeddings(
self.embeddings,
d_model=self.d_model,
embed_scale=self.embed_scale,
std=self.embedding_init_std
if self.embedding_init_std is not None
else self.init_std,
generator=generator,
)
for block in self.blocks.values():
# This might fail if it's wrapped.
# assert isinstance(block, TransformerBlock)
block = cast(TransformerBlock, block)
att = cast(SequenceMixer, block.attention)
# Attention weights.
self.init_method.init_attention(
att,
d_model=self.d_model,
block_idx=block.block_idx,
num_blocks=self.n_layers,
std=self.init_std,
generator=generator,
)
# Feed-forward weights.
if hasattr(block, "feed_forward"):
self.init_method.init_feed_forward(
block.feed_forward,
d_model=self.d_model,
block_idx=block.block_idx,
num_blocks=self.n_layers,
std=self.init_std,
generator=generator,
)
# MoE weights.
if hasattr(block, "feed_forward_moe"):
block = cast(MoETransformerBlock, block)
if max_local_microbatch_size is not None:
block.feed_forward_moe.warmup_cache(max_local_microbatch_size)
self.init_method.init_feed_forward_moe(
block.feed_forward_moe,
d_model=self.d_model,
block_idx=block.block_idx,
num_blocks=self.n_layers,
std=self.init_std,
generator=generator,
)
if isinstance(att, (Attention, FusedAttention)):
# Warm up attention backend cache.
if max_seq_len is not None and att.backend is not None:
att.backend.warmup_cache(max_seq_len, device)
# Warm up RoPE cache.
if max_seq_len is not None and att.rope is not None:
att.rope.warmup_cache(max_seq_len, device)
if self.lm_head is not None:
self.init_method.init_final_w_out(
self.lm_head.w_out,
d_model=self.d_model,
std=self.init_std,
generator=generator,
)
return generator
def _prepare_inputs(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
*,
ignore_index: int = -100,
loss_reduction: Literal["mean", "sum", "none"] = "mean",
z_loss_multiplier: Optional[float] = None,
loss_div_factor: Optional[Union[torch.Tensor, float]] = None,
return_logits: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Dict[str, Any],
Dict[int, Dict[str, Any]],
Dict[str, Any],
]:
# NOTE: with pipeline parallelism input_ids might actually be an intermediate output,
# so we have to be careful here.
B, S = input_ids.shape[:2]
all_block_kwargs: Dict[str, Any] = {}
per_block_kwargs: Dict[int, Dict[str, Any]] = defaultdict(dict)
lm_head_kwargs: Dict[str, Any] = dict(
ignore_index=ignore_index,
loss_reduction=loss_reduction,
z_loss_multiplier=z_loss_multiplier,
return_logits=return_logits,
logits_to_keep=logits_to_keep,
)
if loss_div_factor is not None:
loss_div_factor = move_to_device(loss_div_factor, self.device)
lm_head_kwargs["loss_div_factor"] = loss_div_factor
all_block_kwargs["loss_div_factor"] = loss_div_factor
# Prepare document length inputs.
max_doc_len: Optional[int] = None
cu_doc_lens: Optional[torch.Tensor] = None
doc_lens: Optional[torch.Tensor] = None
cache_leftpad: Optional[torch.Tensor] = kwargs.pop("cache_leftpad", None)
if (doc_lens := kwargs.pop("doc_lens", None)) is not None and (
max_doc_lens := kwargs.pop("max_doc_lens", None)
) is not None:
max_doc_len = max(max_doc_lens)
cu_doc_lens = get_cumulative_document_lengths(doc_lens)
# Shard inputs and RoPE buffers on sequence dimension if using context parallelism.
if (cp_load_balancer := self._cp_load_balancer) is not None:
inputs = [input_ids]
seq_dims = [1]
pad_values: List[Union[int, float]] = [0]
keys = ["input_ids"]
# NOTE: initialize buffer(s) on CPU to avoid possible host-device sync when sharding.
for block_idx, rope_buffers in self.get_rope_buffers(S, torch.device("cpu")).items():
if rope_buffers is not None:
# Also shard RoPE buffers based on the context parallelism load balancer.
if rope_buffers.pos_sin is not None:
inputs.append(rope_buffers.pos_sin)
seq_dims.append(0)
pad_values.append(0.0)
keys.append(f"block_{block_idx}.pos_sin")
if rope_buffers.pos_cos is not None:
inputs.append(rope_buffers.pos_cos)
seq_dims.append(0)
pad_values.append(0.0)
keys.append(f"block_{block_idx}.pos_cos")
if rope_buffers.freqs_cis is not None:
inputs.append(rope_buffers.freqs_cis)
seq_dims.append(0)
pad_values.append(0.0)
keys.append(f"block_{block_idx}.freqs_cis")
if labels is not None:
inputs.append(labels)
seq_dims.append(1)
pad_values.append(ignore_index)
keys.append("labels")
if cache_leftpad is not None:
raise NotImplementedError("cache_leftpad is not supported with context parallelism")
if cu_doc_lens is not None:
# NOTE: Can only shard properly here if 'input_ids' is flat, i.e. a single instance.
# TODO: (epwalsh) We could just flatten all of the inputs here, but then we risk going
# beyond the model's maximum sequence length, which might be okay at least
# with relative positional encodings, but then again if you're resorting to context
# parallelism you can probably only fit a single instance at a time anyway.
if B != 1:
raise RuntimeError(
f"Rank micro-batches must consist of a single instance when using "
f"context parallelism with intra-document masking (got {B} instances)"
)
inputs, additional_inputs = cp_load_balancer.batch_shard_by_document(
inputs=inputs,
seq_dims=seq_dims,
cu_doc_lens=cu_doc_lens,
pad_values=pad_values,
length_multiple=16,
)
for key, value in additional_inputs.items():
all_block_kwargs[key] = move_to_device(value, self.device)
else:
inputs = cp_load_balancer.batch_shard(
inputs=inputs,
seq_dims=seq_dims,
pad_values=pad_values,
)
for key, value in zip(keys, inputs):
if key.startswith("block_"):
block_key, key = key.split(".", 1)
block_idx = int(block_key.replace("block_", ""))
per_block_kwargs[block_idx][key] = move_to_device(value, self.device)
else:
all_block_kwargs[key] = move_to_device(value, self.device)
input_ids = all_block_kwargs.pop("input_ids")
labels = all_block_kwargs.pop("labels", None)
else:
input_ids = move_to_device(input_ids, self.device)
labels = move_to_device(labels, self.device)
if (max_doc_len is not None or cu_doc_lens is not None) and cache_leftpad is not None:
raise ValueError("max_doc_len/cu_doc_lens and cache_leftpad are mutually exclusive")
if max_doc_len is not None or cu_doc_lens is not None:
all_block_kwargs["max_doc_len"] = max_doc_len
all_block_kwargs["cu_doc_lens"] = move_to_device(cu_doc_lens, self.device)
if cache_leftpad is not None:
all_block_kwargs["cache_leftpad"] = move_to_device(cache_leftpad, self.device)
if "cu_doc_lens" in all_block_kwargs:
mark_dynamic(all_block_kwargs["cu_doc_lens"], 0, strict=False) # type: ignore[arg-type]
return (
input_ids,
labels,
all_block_kwargs,
per_block_kwargs,
lm_head_kwargs,
)
[docs]
def forward(
self,
input_ids: torch.Tensor,
*,
labels: Optional[torch.Tensor] = None,
ignore_index: int = -100,
loss_reduction: Literal["mean", "sum", "none"] = "mean",
z_loss_multiplier: Optional[float] = None,
loss_div_factor: Optional[Union[torch.Tensor, float]] = None,
return_logits: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[torch.Tensor, LMOutputWithLoss]:
"""
Run the transformer on the token input IDs.
:param input_ids: The token input IDs, shape ``(batch_size, seq_len)``.
:param labels: The token labels, shape ``(batch_size, seq_len)``.
:param ignore_index: The index to ignore in the loss computation. Default is -100.
:param loss_reduction: The reduction method for the loss. Can be "mean", "sum", or "none".
:param z_loss_multiplier: Optional multiplier for the z-loss regularization term.
:param loss_div_factor: Optional divisor for the loss, can be a scalar or tensor.
:param return_logits: Whether to return logits along with the loss when labels are provided.
:param logits_to_keep: Number of positions to keep from the end of the sequence (if int),
or tensor specifying which positions to keep. Default is 0 (keep all).
:returns: The logits if ``labels`` is ``None`` or the losses if ``labels`` is not ``None``.
"""
(
input_ids,
labels,
all_block_kwargs,
per_block_kwargs,
lm_head_kwargs,
) = self._prepare_inputs(
input_ids,
labels,
ignore_index=ignore_index,
loss_reduction=loss_reduction,
z_loss_multiplier=z_loss_multiplier,
loss_div_factor=loss_div_factor,
return_logits=return_logits,
logits_to_keep=logits_to_keep,
**kwargs,
)
# Get embeddings but pass-through for non-existent layers to allow easy
# pipeline parallel configuration.
h = self.embeddings(input_ids) if self.embeddings is not None else input_ids
if self.embeddings is not None and self.embed_scale is not None:
h = h * self.embed_scale
if self.embedding_norm is not None:
h = self.embedding_norm(h)
# Run each block.
for block_key, block in self.blocks.items():
block_idx = int(block_key)
block_kwargs = per_block_kwargs.get(block_idx, {})
# Mark sizes as dynamic for torch.compile().
if self.compile_enabled:
mark_dynamic(h, (0, 1), strict=False)
h = block(h, **all_block_kwargs, **block_kwargs)
# Get final logits but again pass-through in case of pipeline parallelism.
if self.lm_head is not None:
if self.compile_enabled:
mark_dynamic(h, (0, 1), strict=False)
if labels is not None:
mark_dynamic(labels, (0, 1), strict=False)
# NOTE: When TP is active we can't pass 'labels=None' or the hook from 'PrepareModuleInput'
# will throw an exception.
if labels is not None:
lm_head_kwargs["labels"] = labels
return self.lm_head(h, **lm_head_kwargs)
else:
return h
[docs]
def apply_fp8(self, float8_config: Float8Config):
"""
Use an FP8 recipe on most linear layers.
"""
if not float8_config.enabled:
return
modules_to_ignore = set()
if self.lm_head is not None:
modules_to_ignore.add("lm_head.w_out")
if float8_config.modules_to_ignore is not None:
modules_to_ignore.update(float8_config.modules_to_ignore)
float8_config.apply_float8_linear(self, modules_to_ignore=modules_to_ignore)
self._fp8_enabled = True
self._precompute_float8_dynamic_scale_for_fsdp = (
float8_config.should_precompute_float8_dynamic_scale_for_fsdp
)
[docs]
def apply_pp(self, pp_mesh: DeviceMesh):
"""
Prepare the model for pipeline parallelism after it's been split into stages.
"""
for block in self.blocks.values():
block = cast(TransformerBlockBase, block)
block.apply_pp(pp_mesh)
self._pp_enabled = True
self._pp_group_size = pp_mesh.size()
[docs]
def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
"""
Apply tensor parallelism to the model.
:param loss_parallel: Set to ``True`` if parallelizing the loss function as well.
:param float8_enabled: Set this to ``True`` if training with float8 linear layers.
"""
if float8_enabled is None:
float8_enabled = self.fp8_enabled
elif not float8_enabled and self.fp8_enabled:
raise OLMoConfigurationError(
"Got 'float8_enabled=False', but FP8 has already been enabled"
)
if self.embeddings is not None:
parallelize_module(
self.embeddings,
device_mesh=tp_mesh,
parallelize_plan=RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
use_local_output=False,
),
)
if self.embedding_norm is not None:
parallelize_module(
self.embedding_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel()
)
# Apply tensor/sequence parallelism to every transformer block.
for block in self.blocks.values():
block = cast(TransformerBlockBase, block)
block.apply_tp(tp_mesh, input_layout=Shard(1), float8_enabled=float8_enabled)
if self.lm_head is not None:
self.lm_head.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate()))
self._tp_enabled = True
self._tp_mesh = tp_mesh
[docs]
def apply_cp(
self,
cp_mesh: DeviceMesh,
ring: RingContextParallelStyle | None = None,
uly: UlyssesContextParallelStyle | None = None,
):
"""
Prepare the model for context-parallelism (CP).
:param cp_mesh: The CP device mesh.
:param ring: The ring context parallel style.
:param uly: The ulysses context parallel style.
"""
if ring is not None:
self._cp_load_balancer = ring.load_balancer.build(cp_mesh)
elif uly is not None:
self._cp_load_balancer = uly.load_balancer.build(cp_mesh)
for block in self.blocks.values():
cast(TransformerBlockBase, block).apply_cp(cp_mesh, ring=ring, uly=uly)
if self.lm_head is not None:
self.lm_head.apply_cp(cp_mesh)
[docs]
def apply_activation_checkpointing(
self,
mode: TransformerActivationCheckpointingMode,
block_interval: Optional[int] = None,
modules: Optional[List[str]] = None,
activation_memory_budget: Optional[float] = None,
):
"""
Apply activation checkpointing to the model.
:param mode: Determines how to apply activation checkpointing.
:param block_interval: Required when :data:`mode` is "selected_blocks". Determines
which blocks are wrapped.
:param modules: Required when :data:`mode` is "selected_modules". A list of modules names
to wrap for activation checkpointing. Globs are supported.
:param activation_memory_budget: The memory budget for activation checkpointing in the range
[0, 1]. 0 corresponds to the memory usage when recomputing all activations, and 1
corresponds to the memory usage when recomputing no activations (which is the default).
Requires compilation to be enabled.
"""
if mode == TransformerActivationCheckpointingMode.budget:
if activation_memory_budget is None:
raise ValueError("'activation_memory_budget' is required for 'budget' mode")
if activation_memory_budget < 0 or activation_memory_budget > 1:
raise ValueError("'activation_memory_budget' must be in the range [0, 1]")
torch._functorch.config.activation_memory_budget = activation_memory_budget
return
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
if (
mode == TransformerActivationCheckpointingMode.selected_blocks
and block_interval is None
):
raise ValueError("'block_interval' is required for 'selected_blocks' mode")
if mode == TransformerActivationCheckpointingMode.selected_modules and modules is None:
raise ValueError("'modules' is required for 'selected_modules' mode")
# TODO: only preserve RNG state if dropout is active
preserve_rng_state = False
if mode == TransformerActivationCheckpointingMode.selected_modules:
from fnmatch import fnmatch
assert modules is not None
wrapped_modules: Set[str] = set()
for name, module in self.named_modules():
for pattern in modules:
if fnmatch(name, pattern):
break
else:
continue
if isinstance(module, MoEBase):
raise OLMoConfigurationError(
"Wrapping an entire MoE module for activation checkpointing is not supported. "
"Please try a finer-grained wrapping strategy."
)
# NOTE: have to be careful not to try to wrap submodules of modules that have been wrapped.
parent_name = ".".join(name.split(".")[:-1])
if parent_name in wrapped_modules:
continue
parent = self if not parent_name else self.get_submodule(parent_name)
module = ptd_checkpoint_wrapper(module, preserve_rng_state=preserve_rng_state)
parent.register_module(name.split(".")[-1], module)
log.info(f"Wrapped '{name}' for activation checkpointing")
wrapped_modules.add(name)
else:
for block_idx, block in enumerate(self.blocks.values()):
if mode == TransformerActivationCheckpointingMode.selected_blocks:
assert block_interval is not None
if block_idx % block_interval == 0:
if isinstance(block, MoETransformerBlock):
raise OLMoConfigurationError(
"Wrapping MoE blocks for activation checkpointing is not supported."
)
block = ptd_checkpoint_wrapper(block, preserve_rng_state=preserve_rng_state)
elif mode == TransformerActivationCheckpointingMode.full:
if isinstance(block, MoETransformerBlock):
raise OLMoConfigurationError(
"Wrapping MoE blocks for activation checkpointing is not supported."
)
block = ptd_checkpoint_wrapper(block, preserve_rng_state=preserve_rng_state)
elif mode == TransformerActivationCheckpointingMode.selected_ops:
block = ptd_checkpoint_wrapper(
block,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=preserve_rng_state,
)
self.blocks.register_module(str(block_idx), block)
[docs]
def apply_compile(self):
"""
Apply ``torch.compile()`` to each transformer block, which makes compilation efficient
due to repeated structure.
.. warning::
This must be called after :meth:`apply_activation_checkpointing()` but
before :meth:`apply_fsdp()` or :meth:`apply_ddp()`.
"""
for block in self.blocks.values():
block = cast(TransformerBlockBase, block)
block.apply_compile()
if self.lm_head is not None:
self.lm_head.compile(fullgraph=False)
torch.compiler.config.dynamic_sources += "L['kwargs']['max_doc_len'],"
self._compile_enabled = True
[docs]
def apply_fsdp(
self,
dp_mesh: Optional[DeviceMesh] = None,
param_dtype: Optional[torch.dtype] = None,
reduce_dtype: torch.dtype = torch.float32,
pp_enabled: bool = False,
prefetch_factor: int = 0,
wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full,
):
"""
Apply FSDP(2) to the model.
.. warning::
This should generally be called last if using any other parallelism strategies or optimizations
like :meth:`apply_compile()`.
:param dp_mesh: The model data parallel device mesh.
:param param_dtype: The data type to materialize params in. Defaults to the current param dtype.
:param reduce_dtype: The data type for gradient reduction.
:pp_enabled: If pipeline parallelism is also enabled.
:prefetch_factor: For tuning the prefetch settings. 0 is the default, and higher values result
in more aggressive prefetching.
:wrapping_strategy: The wrapping strategy.
"""
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype or self.dtype, reduce_dtype=reduce_dtype
)
fsdp_config = dict(mesh=dp_mesh, mp_policy=mp_policy)
# For PP, do not reshard after forward to avoid per-microbatch all-gathers,
# which can be expensive and non-overlapped
reshard_after_forward = False if pp_enabled else True
for block in self.blocks.values():
block = cast(TransformerBlockBase, block)
block.apply_fsdp(
dp_mesh=dp_mesh,
prefetch_factor=prefetch_factor,
wrapping_strategy=wrapping_strategy,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
)
if self.embeddings is not None:
fully_shard(
self.embeddings,
reshard_after_forward=reshard_after_forward,
**fsdp_config,
)
# Embedding params are not needed for backwards computation.
cast(FSDPModule, self.embeddings).set_unshard_in_backward(False)
if wrapping_strategy != TransformerDataParallelWrappingStrategy.blocks:
if self.embedding_norm is not None:
fully_shard(self.embedding_norm, **fsdp_config)
if self.lm_head is not None:
fully_shard(self.lm_head, reshard_after_forward=False, **fsdp_config)
fully_shard(self, reshard_after_forward=reshard_after_forward, **fsdp_config)
# Some inputs need to be on CPU initially, but FSDP will move everything to model's
# device if we don't hide it.
self.register_forward_pre_hook(_hide_cpu_inputs_from_torch, prepend=True, with_kwargs=True)
self.register_forward_pre_hook(
_unhide_cpu_inputs_from_torch, prepend=False, with_kwargs=True
)
if prefetch_factor > 0:
blocks = cast(List[FSDPModule], list(self.blocks.values()))
for i in range(len(blocks)):
block = blocks[i]
if i + 1 < len(blocks):
block.set_modules_to_forward_prefetch(blocks[i + 1 : i + 1 + prefetch_factor])
elif isinstance(self.lm_head, FSDPModule):
block.set_modules_to_forward_prefetch([self.lm_head])
self._fsdp_enabled = True
[docs]
def apply_ddp(
self,
dp_mesh: Optional[DeviceMesh] = None,
param_dtype: Optional[torch.dtype] = None,
compile_enabled: bool = False,
autograd_compile_enabled: bool = False,
):
"""
Apply DDP to the model.
"""
from torch.distributed._composable.replicate import replicate
# Cast model explicitly to the specified dtype before applying DDP
target_dtype = param_dtype or self.dtype
if target_dtype != self.dtype:
self.to(dtype=target_dtype)
# Adapted from
# https://github.com/pytorch/torchtitan/blob/90c889e972b56b9faadebbb78fc985dedc537ed9/torchtitan/parallelisms/parallelize_llama.py#L328
if compile_enabled:
if autograd_compile_enabled:
torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore
replicate(self, device_mesh=dp_mesh, bucket_cap_mb=100)
# Some inputs need to be on CPU initially, but DDP will move everything to model's
# device if we don't hide it.
self.register_forward_pre_hook(_hide_cpu_inputs_from_torch, prepend=True, with_kwargs=True)
self.register_forward_pre_hook(
_unhide_cpu_inputs_from_torch, prepend=False, with_kwargs=True
)
@cached_property
def num_params(self) -> int:
return sum(p.numel() for p in self.parameters())
@property
def num_trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@cached_property
def num_non_embedding_params(self) -> int:
return self.num_params - self.embeddings.weight.numel()
[docs]
def num_flops_per_token(self, seq_len: int) -> int:
"""
Returns the idealized number of flops per token for the given sequence length. Purposefully
does not account for wasted flops due to padding, recomputation, etc.
"""
flops_per_token = 0
blocks = cast(List[TransformerBlockBase], list(self.blocks.values()))
for block in blocks:
flops_per_token += block.num_flops_per_token(seq_len)
if self.lm_head is not None:
flops_per_token += self.lm_head.num_flops_per_token(seq_len)
return flops_per_token
[docs]
def post_batch(self, dry_run: bool = False):
"""
Should be called right after the final backward of a complete batch but before the optimizer step.
"""
del dry_run
[docs]
def post_optim_step(self):
"""
Should be called right after an optimizer step.
"""
if self.fp8_enabled and self._precompute_float8_dynamic_scale_for_fsdp:
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
precompute_float8_dynamic_scale_for_fsdp(self)
[docs]
@beta_feature
class NormalizedTransformer(Transformer):
"""
A nGPT transformer implementation, to be used with the :class:`NormalizedTransformerBlock` block
type.
"""
def __init__(
self,
*,
d_model: int,
vocab_size: int,
n_layers: int,
block: TransformerBlockConfig | dict[str, TransformerBlockConfig],
lm_head: LMHeadConfig,
dtype: torch.dtype = torch.float32,
init_method: InitMethod = InitMethod.normalized,
init_device: str = "cpu",
init_seed: int = 0,
init_std: float = 0.02,
embedding_init_std: Optional[float] = None,
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None,
block_pattern: Optional[List[str]] = None,
):
super().__init__(
d_model=d_model,
vocab_size=vocab_size,
n_layers=n_layers,
block=block,
lm_head=lm_head,
dtype=dtype,
init_method=init_method,
init_device=init_device,
init_seed=init_seed,
init_std=init_std,
embedding_init_std=embedding_init_std,
block_overrides=block_overrides,
block_pattern=block_pattern,
)
def _validate_block(self, block: TransformerBlockBase) -> TransformerBlockBase:
if not isinstance(block, NormalizedTransformerBlock):
raise OLMoConfigurationError(
f"'{self.__class__.__name__}' requires a '{NormalizedTransformerBlock.__name__}' block"
)
return block
[docs]
@torch.no_grad()
def init_weights(self, *args, **kwargs) -> torch.Generator:
generator = super().init_weights(*args, **kwargs)
self.normalize_matrices()
return generator
[docs]
@torch.no_grad()
def normalize_matrices(self):
"""
Normalize the weights in all matrices. This should be called after each optimizer step, which
the :class:`~olmo_core.train.train_module.TransformerTrainModule` will handle for you.
"""
if self.embeddings is not None:
self._normalize_matrix(self.embeddings.weight)
for block in self.blocks.values():
if hasattr(block, "normalize_matrices"):
block.normalize_matrices() # type: ignore
if self.lm_head is not None:
self.lm_head.normalize_matrices() # type: ignore
def _normalize_matrix(self, w: torch.Tensor, dim: int = -1):
w.copy_(l2_normalize(w, dim=dim))
[docs]
def apply_tp(
self,
tp_mesh: DeviceMesh,
float8_enabled: Optional[bool] = None,
):
del tp_mesh, float8_enabled
raise NotImplementedError(
"TP is not implemented yet for the normalized transformer variant"
)
[docs]
def apply_compile(self):
super().apply_compile()
self.normalize_matrices = torch.compile(self.normalize_matrices)
[docs]
@beta_feature
class MoETransformer(Transformer):
"""
An MoE transformer implementation, to be used with one of the
:class:`MoETransformerBlock` block types.
"""
@property
def is_moe(self) -> bool:
return True
def compute_auxiliary_metrics(
self, reset: bool = True
) -> Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]]:
from olmo_core.train.common import ReduceType
mean_offset = 1.0
if self.pp_enabled:
# Change the divisor to 'world_size // pp_group_size'
mean_offset = self._pp_group_size
out: Dict[str, Tuple[torch.Tensor, Optional["ReduceType"]]] = {}
for block_idx, block in self.blocks.items():
if not block.is_moe:
continue
block = cast(MoETransformerBlock, block)
block_metrics = block.compute_metrics(reset=reset)
for metric_name, (metric_val, reduce_type) in block_metrics.items():
out[f"block {int(block_idx):02d}/{metric_name}"] = (
metric_val,
reduce_type,
)
if self.pp_enabled and reduce_type == ReduceType.mean:
metric_val = metric_val.float() * mean_offset
if metric_name not in out:
out[metric_name] = (metric_val, reduce_type)
elif reduce_type in (ReduceType.mean, ReduceType.sum):
out[metric_name] = (
out[metric_name][0] + metric_val,
reduce_type,
)
elif reduce_type == ReduceType.max:
out[metric_name] = (
torch.max(out[metric_name][0], metric_val),
reduce_type,
)
else:
raise NotImplementedError(reduce_type)
return out
def reset_auxiliary_metrics(self):
for block in self.blocks.values():
if not block.is_moe:
continue
cast(MoETransformerBlock, block).reset_metrics()
def apply_ep(self, ep_mesh: DeviceMesh, **kwargs):
for block in self.blocks.values():
if not block.is_moe:
continue
block = cast(MoETransformerBlock, block)
block.apply_ep(ep_mesh, **kwargs)
def prepare_experts_for_fsdp(
self,
world_mesh: DeviceMesh,
param_dtype: Optional[torch.dtype] = None,
reduce_dtype: torch.dtype = torch.float32,
pp_enabled: bool = False,
):
for block in self.blocks.values():
if not block.is_moe:
continue
block = cast(MoETransformerBlock, block)
reshard_after_forward = True
if pp_enabled or block.ep_enabled or block.tp_enabled:
reshard_after_forward = False
block.feed_forward_moe.prepare_experts_for_fsdp(
world_mesh=world_mesh,
mp_policy=MixedPrecisionPolicy(
param_dtype=param_dtype or self.dtype, reduce_dtype=reduce_dtype
),
reshard_after_forward=reshard_after_forward,
)
def prepare_experts_for_ddp(self, world_mesh: DeviceMesh):
for block in self.blocks.values():
if not block.is_moe:
continue
cast(MoETransformerBlock, block).feed_forward_moe.prepare_experts_for_ddp(
world_mesh=world_mesh,
)
[docs]
def post_batch(self, dry_run: bool = False):
for block in self.blocks.values():
if not block.is_moe:
continue
block = cast(MoETransformerBlock, block)
block.feed_forward_moe.post_batch(dry_run=dry_run)
def _hide_cpu_inputs_from_torch(m, args, kwargs) -> Optional[Tuple[Any, Dict[str, Any]]]:
del m
if (doc_lens := kwargs.get("doc_lens")) is not None:
kwargs["doc_lens"] = hide_from_torch(doc_lens)
return (args, kwargs)
def _unhide_cpu_inputs_from_torch(m, args, kwargs) -> Optional[Tuple[Any, Dict[str, Any]]]:
del m
if (doc_lens := kwargs.get("doc_lens")) is not None:
kwargs["doc_lens"] = unhide_from_torch(doc_lens)
return (args, kwargs)