train.train_module

class olmo_core.train.train_module.TrainModuleConfig[source]

Bases: Config

class olmo_core.train.train_module.TrainModule[source]

Bases: Stateful

A TrainModule is an abstraction around a Module and Optimizer to provide a unified API for the Trainer that’s flexible enough to handle a variety of training paradigms.

Note

TrainModule implementations are responsible for recording all necessary metrics like the training loss, which can be done by calling record_metric().

Note

See BasicTrainModule for a simple example implementation.

property trainer: Trainer

The Trainer being used.

Warning

This property can only be accessed after the trainer has been attached.

property dp_process_group: ProcessGroup | None

Should return the data parallel process group if it’s anything other than the default process group.

abstract property eval_batch_spec: EvalBatchSpec

Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular TrainModule.

on_attach()[source]

Runs as soon as the Trainer has been attached.

pre_train()[source]

Runs before the training loop starts and right after pre_train() has been called on all callbacks.

abstract state_dict(*, optim=None)[source]

Get the state dict to save or load.

Parameters:

optim (Optional[bool], default: None) – If set to False, optimizer state is not returned in the state dict.

Return type:

Dict[str, Any]

state_dict_to_save(*, optim=None)[source]

Can be overridden if the state dict to save should be different from the state dict to load. By default just returns state_dict().

Parameters:

optim (Optional[bool], default: None) – If set to False, optimizer state is not returned in the state dict.

Return type:

Dict[str, Any]

state_dict_to_load(metadata, *, optim=None)[source]

Can be overridden if the state dict to load should be different from the state dict to save. By default just returns state_dict().

Parameters:

optim (Optional[bool], default: None) – If set to False, optimizer state is not returned in the state dict.

Return type:

Dict[str, Any]

abstract load_state_dict(state_dict)[source]

Load a state dict.

Return type:

None

abstract train_batch(batch, dry_run=False)[source]

Run a forward and backward pass on a training batch.

abstract eval_batch(batch, labels=None)[source]

Run a forward pass on a eval batch.

Return type:

Any

abstract optim_step()[source]

Run an optimizer step.

abstract zero_grads()[source]

Zero-out gradients.

abstract num_flops_per_token(seq_len)[source]

Returns the number of flops per token for the given sequence length, or None if flops estimation is not supported.

Return type:

Optional[int]

abstract global_num_flops_in_batch(batch)[source]

Return the total (global) number of flops in the batch, or None if flops estimation is not supported.

Return type:

Optional[int]

record_metric(name, value, reduce_type=None, namespace=None, merge_strategy='warn')[source]

Record a metric. This is simply a convenience method that calls out to olmo_core.train.Trainer.record_metric().

See also

Use record_ce_loss() to record the cross-entropy loss, specifically.

record_ce_loss(value, reduce_type=None)[source]

Record the cross-entropy loss metric specifically.

class olmo_core.train.train_module.EvalBatchSpec(rank_batch_size, batch_size_unit='tokens', max_sequence_length=None, fixed_sequence_length=False)[source]

Bases: object

Defines how eval batches should be sized.

rank_batch_size: int

The size of eval batches per rank.

batch_size_unit: EvalBatchSizeUnit = 'tokens'

The unit for the rank_batch_size.

max_sequence_length: Optional[int] = None

The maximum allowed sequence length.

fixed_sequence_length: bool = False

If all batches should have a fixed sequence length at max_sequence_length tokens. If this is True then max_sequence_length must be specified.

class olmo_core.train.train_module.EvalBatchSizeUnit(value)[source]

Bases: StrEnum

The different units for defining the size for eval batches.

tokens = 'tokens'

Specify in tokens.

instances = 'instances'

Specify in instances.

class olmo_core.train.train_module.BasicTrainModule(model, optim, rank_microbatch_size, max_grad_norm=None, label_ignore_index=-100)[source]

Bases: TrainModule

A basic TrainModule implementation, mainly used for as an example and for testing. For a more practical implementation, see TransformerTrainModule.

Parameters:
  • model (Module) – The model to train.

  • optim (Optimizer) – The corresponding optimizer.

  • rank_microbatch_size (int) –

    The microbatch size in tokens per rank, i.e. the number of tokens to process at a time from each rank.

    Note

    This must evenly divide into the global batch size by a factor of the data parallel world size. If this is less than the global batch divided by the data parallel world size then gradient accumulation is used.

  • max_grad_norm (Optional[float], default: None) – Clip gradient norms to this value.

property eval_batch_spec: EvalBatchSpec

Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular TrainModule.

on_attach()[source]

Runs as soon as the Trainer has been attached.

state_dict(*, optim=None)[source]

Get the state dict to save or load.

Parameters:

optim (Optional[bool], default: None) – If set to False, optimizer state is not returned in the state dict.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Load a state dict.

Return type:

None

train_batch(batch, dry_run=False)[source]

Run a forward and backward pass on a training batch.

eval_batch(batch, labels=None)[source]

Run a forward pass on a eval batch.

Return type:

Any

optim_step()[source]

Run an optimizer step.

zero_grads()[source]

Zero-out gradients.

num_flops_per_token(seq_len)[source]

Returns the number of flops per token for the given sequence length, or None if flops estimation is not supported.

Return type:

Optional[int]

global_num_flops_in_batch(batch)[source]

Return the total (global) number of flops in the batch, or None if flops estimation is not supported.

Return type:

Optional[int]

class olmo_core.train.train_module.TransformerTrainModule(model, optim, rank_microbatch_size, max_sequence_length, compile_model=False, float8_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, autocast_precision=None, max_grad_norm=None, scheduler=None, device=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, label_ignore_index=-100)[source]

Bases: TrainModule

A TrainModule for any Transformer model implementation provided by this library.

Tip

Use the TransformerTrainModuleConfig to easily configure and build TransformerTrainModule instances.

Parameters:
  • model (Transformer) – The Transformer model to train.

  • optim (OptimConfig) – The corresponding optimizer config.

  • rank_microbatch_size (int) –

    The microbatch size in tokens per rank, i.e. the number of tokens to process at a time from each rank.

    Note

    This must evenly divide into the global batch size by a factor of the data parallel world size. If this is less than the global batch divided by the data parallel world size then gradient accumulation is used.

  • max_sequence_length (int) – The maximum expected sequence length during training and evaluation.

  • compile_model (bool, default: False) – Whether to compile to the model.

  • float8_config (Optional[Float8Config], default: None) – Float8 configuration for the model.

  • dp_config (Optional[TransformerDataParallelConfig], default: None) – Data parallel configuration for the model.

  • tp_config (Optional[TransformerTensorParallelConfig], default: None) – Tensor parallel configuration for the model.

  • cp_config (Optional[TransformerContextParallelConfig], default: None) – Context parallel configuration for the model.

  • ac_config (Optional[TransformerActivationCheckpointingConfig], default: None) – Activation checkpointing configuration for the model.

  • z_loss_multiplier (Optional[float], default: None) – Use Z-loss with this multiplier.

  • autocast_precision (Optional[dtype], default: None) – Enable AMP with this data type.

  • max_grad_norm (Optional[float], default: None) – Clip gradient norms to this value.

  • scheduler (Optional[Scheduler], default: None) – Optional learning rate scheduler for the optimizer.

  • device (Optional[device], default: None) – The device to train on.

  • state_dict_save_opts (Optional[StateDictOptions], default: None) – Can be used to override the state dict options used when saving a checkpoint.

  • state_dict_load_opts (Optional[StateDictOptions], default: None) – Can be used to override the state dict options used when loading a checkpoint.

  • load_key_mapping (Optional[Dict[str, str]], default: None) – Can be used to load a checkpoint where certain parameter have different names. This dictionary should map current keys to keys in the checkpoint to be loaded.

property dp_process_group: ProcessGroup | None

Should return the data parallel process group if it’s anything other than the default process group.

property eval_batch_spec: EvalBatchSpec

Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular TrainModule.

num_flops_per_token(seq_len)[source]

Returns the number of flops per token for the given sequence length, or None if flops estimation is not supported.

Return type:

Optional[int]

class olmo_core.train.train_module.TransformerTrainModuleConfig(rank_microbatch_size, max_sequence_length, optim, max_grad_norm=None, scheduler=None, compile_model=False, float8_config=None, pp_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, autocast_precision=None, label_ignore_index=-100)[source]

Bases: TrainModuleConfig

A configuration class for building TransformerTrainModule or TransformerPipelineTrainModule instances.

See also

See the TransformerTrainModule and TransformerPipelineTrainModule documentation for a description of the fields.

class olmo_core.train.train_module.TransformerPipelineTrainModule(model, optim, rank_microbatch_size, max_sequence_length, pp_config, compile_model=False, float8_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, autocast_precision=None, max_grad_norm=None, scheduler=None, device=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, label_ignore_index=-100)[source]

Bases: TrainModule

A pipeline-parallel TrainModule for most Transformer model implementation provided by this library.

Tip

Use the TransformerPipelineTrainModuleConfig to easily configure and build TransformerPipelineTrainModule instances.

Parameters:
  • model (Transformer) – The Transformer model to train.

  • optim (OptimConfig) – The corresponding optimizer config.

  • rank_microbatch_size (int) –

    The microbatch size in tokens per rank, i.e. the number of tokens to process at a time from each rank.

    Note

    This must evenly divide into the global batch size by a factor of the data parallel world size. If this is less than the global batch divided by the data parallel world size then gradient accumulation is used.

  • max_sequence_length (int) – The maximum expected sequence length during training and evaluation.

  • compile_model (bool, default: False) – Whether to compile to the model.

  • float8_config (Optional[Float8Config], default: None) – Float8 configuration for the model.

  • dp_config (Optional[TransformerDataParallelConfig], default: None) – Data parallel configuration for the model.

  • tp_config (Optional[TransformerTensorParallelConfig], default: None) – Tensor parallel configuration for the model.

  • cp_config (Optional[TransformerContextParallelConfig], default: None) – Context parallel configuration for the model.

  • pp_config (TransformerPipelineParallelConfig) – Pipeline parallel configuration for the model.

  • ac_config (Optional[TransformerActivationCheckpointingConfig], default: None) – Activation checkpointing configuration for the model.

  • z_loss_multiplier (Optional[float], default: None) – Use Z-loss with this multiplier.

  • autocast_precision (Optional[dtype], default: None) – Enable AMP with this data type.

  • max_grad_norm (Optional[float], default: None) – Clip gradient norms to this value.

  • scheduler (Optional[Scheduler], default: None) – Optional learning rate scheduler for the optimizer.

  • device (Optional[device], default: None) – The device to train on.

  • state_dict_save_opts (Optional[StateDictOptions], default: None) – Can be used to override the state dict options used when saving a checkpoint.

  • state_dict_load_opts (Optional[StateDictOptions], default: None) – Can be used to override the state dict options used when loading a checkpoint.

  • load_key_mapping (Optional[Dict[str, str]], default: None) – Can be used to load a checkpoint where certain parameter have different names. This dictionary should map current keys to keys in the checkpoint to be loaded.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

property dp_process_group: ProcessGroup | None

Should return the data parallel process group if it’s anything other than the default process group.

property eval_batch_spec: EvalBatchSpec

Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular TrainModule.

num_flops_per_token(seq_len)[source]

Returns the number of flops per token for the given sequence length, or None if flops estimation is not supported.

Return type:

Optional[int]

class olmo_core.train.train_module.TransformerPipelineTrainModuleConfig(rank_microbatch_size, max_sequence_length, optim, max_grad_norm=None, scheduler=None, compile_model=False, float8_config=None, pp_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, autocast_precision=None, label_ignore_index=-100)[source]

Bases: TransformerTrainModuleConfig

Kept for backwards compatibility, but please use TransformerTrainModuleConfig instead.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

class olmo_core.train.train_module.TransformerActivationCheckpointingConfig(mode='full', block_interval=None, modules=None, activation_memory_budget=None)[source]

Bases: Config

Defines the activation checkpointing strategy for a transformer model.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

mode: TransformerActivationCheckpointingMode = 'full'

The activation checkpointing mode.

block_interval: Optional[int] = None

Required when mode is “selected_blocks”. Determines which blocks are wrapped.

modules: Optional[List[str]] = None

Required when mode is “selected_modules”. A list of modules names to wrap for activation checkpointing. Globs are supported.

activation_memory_budget: Optional[float] = None

Required when mode is “budget”. Memory budget for activation checkpointing in range [0, 1]. 0 = recompute all activations, 1 = recompute none (default). Requires compilation to be enabled.

See https://pytorch.org/blog/activation-checkpointing-techniques/ for more details.

class olmo_core.train.train_module.TransformerActivationCheckpointingMode(value)[source]

Bases: StrEnum

An enumeration of the different activation checkpointing modes.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

full = 'full'

Checkpoint every block.

selected_blocks = 'selected_blocks'

Checkpoint only selected blocks.

selected_modules = 'selected_modules'

Checkpoint only selected modules.

selected_ops = 'selected_ops'

Checkpoint only a specific set of operations.

budget = 'budget'

Checkpoint based on a budget.

class olmo_core.train.train_module.TransformerDataParallelConfig(name, param_dtype=None, reduce_dtype='float32', num_replicas=None, shard_degree=None, wrapping_strategy='full', prefetch_factor=0)[source]

Bases: DataParallelConfig

Transformer-specific data parallel config.

wrapping_strategy: TransformerDataParallelWrappingStrategy = 'full'

The wrapping strategy.

class olmo_core.train.train_module.TransformerDataParallelWrappingStrategy(value)[source]

Bases: StrEnum

An enumeration of the different wrapping strategy for the data parallel implementations.

full = 'full'

Wrap each block and the LM head (only applies to FSDP).

blocks = 'blocks'

Like full but the LM head is not wrapped separately (only applies to FSDP).

fine_grained = 'fine_grained'

Wrap certain modules within each block in addition to wrapping each block (only applies to FSDP).

class olmo_core.train.train_module.TransformerExpertParallelConfig(degree)[source]

Bases: ExpertParallelConfig

Transformer-specific expert parallel config.

class olmo_core.train.train_module.TransformerTensorParallelConfig(degree, enable_async=False)[source]

Bases: TensorParallelConfig

Transformer-specific tensor parallel config.

class olmo_core.train.train_module.TransformerContextParallelConfig(degree, ring=None, uly=None)[source]

Bases: ContextParallelConfig

Transformer-specific context parallel config.

class olmo_core.train.train_module.TransformerPipelineParallelConfig(degree, schedule='Interleaved1F1B', style=None, split_points=None)[source]

Bases: PipelineParallelConfig

Transformer-specific pipeline parallel config.

Warning

This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.

split_points: Optional[List[int]] = None

A list of unique, increasing block indices that define how to split the model into stages.

For example, split_points = [0, 2] with a 4-layer model means the model will be split into 3 stages, with the first containing just the embedding, the second containing blocks 0 and 1, and the third containing blocks 2 and 3 and the language modeling head.

If not specified the split points are determined automatically based on the schedule type.