train.train_module¶
- class olmo_core.train.train_module.TrainModule[source]¶
Bases:
StatefulA
TrainModuleis an abstraction around aModuleandOptimizerto provide a unified API for theTrainerthat’s flexible enough to handle a variety of training paradigms.Note
TrainModuleimplementations are responsible for recording all necessary metrics like the training loss, which can be done by callingrecord_metric().Note
See
BasicTrainModulefor a simple example implementation.- property trainer: Trainer¶
The
Trainerbeing used.Warning
This property can only be accessed after the trainer has been attached.
- property dp_process_group: ProcessGroup | None¶
Should return the data parallel process group if it’s anything other than the default process group.
- abstract property eval_batch_spec: EvalBatchSpec¶
Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular
TrainModule.
- pre_train()[source]¶
Runs before the training loop starts and right after
pre_train()has been called on all callbacks.
- state_dict_to_save(*, optim=None)[source]¶
Can be overridden if the state dict to save should be different from the state dict to load. By default just returns
state_dict().
- state_dict_to_load(metadata, *, optim=None)[source]¶
Can be overridden if the state dict to load should be different from the state dict to save. By default just returns
state_dict().
- abstract train_batch(batch, dry_run=False)[source]¶
Run a forward and backward pass on a training batch.
- abstract num_flops_per_token(seq_len)[source]¶
Returns the number of flops per token for the given sequence length, or
Noneif flops estimation is not supported.
- abstract global_num_flops_in_batch(batch)[source]¶
Return the total (global) number of flops in the batch, or
Noneif flops estimation is not supported.
- record_metric(name, value, reduce_type=None, namespace=None, merge_strategy='warn')[source]¶
Record a metric. This is simply a convenience method that calls out to
olmo_core.train.Trainer.record_metric().See also
Use
record_ce_loss()to record the cross-entropy loss, specifically.
- class olmo_core.train.train_module.EvalBatchSpec(rank_batch_size, batch_size_unit='tokens', max_sequence_length=None, fixed_sequence_length=False)[source]¶
Bases:
objectDefines how eval batches should be sized.
-
batch_size_unit:
EvalBatchSizeUnit= 'tokens'¶ The unit for the
rank_batch_size.
-
fixed_sequence_length:
bool= False¶ If all batches should have a fixed sequence length at
max_sequence_lengthtokens. If this isTruethenmax_sequence_lengthmust be specified.
-
batch_size_unit:
- class olmo_core.train.train_module.EvalBatchSizeUnit(value)[source]¶
Bases:
StrEnumThe different units for defining the size for eval batches.
- tokens = 'tokens'¶
Specify in tokens.
- instances = 'instances'¶
Specify in instances.
- class olmo_core.train.train_module.BasicTrainModule(model, optim, rank_microbatch_size, max_grad_norm=None, label_ignore_index=-100)[source]¶
Bases:
TrainModuleA basic
TrainModuleimplementation, mainly used for as an example and for testing. For a more practical implementation, seeTransformerTrainModule.- Parameters:
model (
Module) – The model to train.optim (
Optimizer) – The corresponding optimizer.rank_microbatch_size (
int) –The microbatch size in tokens per rank, i.e. the number of tokens to process at a time from each rank.
Note
This must evenly divide into the global batch size by a factor of the data parallel world size. If this is less than the global batch divided by the data parallel world size then gradient accumulation is used.
max_grad_norm (
Optional[float], default:None) – Clip gradient norms to this value.
- property eval_batch_spec: EvalBatchSpec¶
Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular
TrainModule.
- class olmo_core.train.train_module.TransformerTrainModule(model, optim, rank_microbatch_size, max_sequence_length, compile_model=False, float8_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, autocast_precision=None, max_grad_norm=None, scheduler=None, device=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, label_ignore_index=-100)[source]¶
Bases:
TrainModuleA
TrainModulefor anyTransformermodel implementation provided by this library.Tip
Use the
TransformerTrainModuleConfigto easily configure and buildTransformerTrainModuleinstances.- Parameters:
model (
Transformer) – TheTransformermodel to train.optim (
OptimConfig) – The corresponding optimizer config.rank_microbatch_size (
int) –The microbatch size in tokens per rank, i.e. the number of tokens to process at a time from each rank.
Note
This must evenly divide into the global batch size by a factor of the data parallel world size. If this is less than the global batch divided by the data parallel world size then gradient accumulation is used.
max_sequence_length (
int) – The maximum expected sequence length during training and evaluation.compile_model (
bool, default:False) – Whether to compile to the model.float8_config (
Optional[Float8Config], default:None) – Float8 configuration for the model.dp_config (
Optional[TransformerDataParallelConfig], default:None) – Data parallel configuration for the model.tp_config (
Optional[TransformerTensorParallelConfig], default:None) – Tensor parallel configuration for the model.cp_config (
Optional[TransformerContextParallelConfig], default:None) – Context parallel configuration for the model.ac_config (
Optional[TransformerActivationCheckpointingConfig], default:None) – Activation checkpointing configuration for the model.z_loss_multiplier (
Optional[float], default:None) – Use Z-loss with this multiplier.autocast_precision (
Optional[dtype], default:None) – Enable AMP with this data type.max_grad_norm (
Optional[float], default:None) – Clip gradient norms to this value.scheduler (
Optional[Scheduler], default:None) – Optional learning rate scheduler for the optimizer.device (
Optional[device], default:None) – The device to train on.state_dict_save_opts (
Optional[StateDictOptions], default:None) – Can be used to override the state dict options used when saving a checkpoint.state_dict_load_opts (
Optional[StateDictOptions], default:None) – Can be used to override the state dict options used when loading a checkpoint.load_key_mapping (
Optional[Dict[str,str]], default:None) – Can be used to load a checkpoint where certain parameter have different names. This dictionary should map current keys to keys in the checkpoint to be loaded.
- property dp_process_group: ProcessGroup | None¶
Should return the data parallel process group if it’s anything other than the default process group.
- property eval_batch_spec: EvalBatchSpec¶
Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular
TrainModule.
- class olmo_core.train.train_module.TransformerTrainModuleConfig(rank_microbatch_size, max_sequence_length, optim, max_grad_norm=None, scheduler=None, compile_model=False, float8_config=None, pp_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, autocast_precision=None, label_ignore_index=-100)[source]¶
Bases:
TrainModuleConfigA configuration class for building
TransformerTrainModuleorTransformerPipelineTrainModuleinstances.See also
See the
TransformerTrainModuleandTransformerPipelineTrainModuledocumentation for a description of the fields.
- class olmo_core.train.train_module.TransformerPipelineTrainModule(model, optim, rank_microbatch_size, max_sequence_length, pp_config, compile_model=False, float8_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, autocast_precision=None, max_grad_norm=None, scheduler=None, device=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, label_ignore_index=-100)[source]¶
Bases:
TrainModuleA pipeline-parallel
TrainModulefor mostTransformermodel implementation provided by this library.Tip
Use the
TransformerPipelineTrainModuleConfigto easily configure and buildTransformerPipelineTrainModuleinstances.- Parameters:
model (
Transformer) – TheTransformermodel to train.optim (
OptimConfig) – The corresponding optimizer config.rank_microbatch_size (
int) –The microbatch size in tokens per rank, i.e. the number of tokens to process at a time from each rank.
Note
This must evenly divide into the global batch size by a factor of the data parallel world size. If this is less than the global batch divided by the data parallel world size then gradient accumulation is used.
max_sequence_length (
int) – The maximum expected sequence length during training and evaluation.compile_model (
bool, default:False) – Whether to compile to the model.float8_config (
Optional[Float8Config], default:None) – Float8 configuration for the model.dp_config (
Optional[TransformerDataParallelConfig], default:None) – Data parallel configuration for the model.tp_config (
Optional[TransformerTensorParallelConfig], default:None) – Tensor parallel configuration for the model.cp_config (
Optional[TransformerContextParallelConfig], default:None) – Context parallel configuration for the model.pp_config (
TransformerPipelineParallelConfig) – Pipeline parallel configuration for the model.ac_config (
Optional[TransformerActivationCheckpointingConfig], default:None) – Activation checkpointing configuration for the model.z_loss_multiplier (
Optional[float], default:None) – Use Z-loss with this multiplier.autocast_precision (
Optional[dtype], default:None) – Enable AMP with this data type.max_grad_norm (
Optional[float], default:None) – Clip gradient norms to this value.scheduler (
Optional[Scheduler], default:None) – Optional learning rate scheduler for the optimizer.device (
Optional[device], default:None) – The device to train on.state_dict_save_opts (
Optional[StateDictOptions], default:None) – Can be used to override the state dict options used when saving a checkpoint.state_dict_load_opts (
Optional[StateDictOptions], default:None) – Can be used to override the state dict options used when loading a checkpoint.load_key_mapping (
Optional[Dict[str,str]], default:None) – Can be used to load a checkpoint where certain parameter have different names. This dictionary should map current keys to keys in the checkpoint to be loaded.
Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
- property dp_process_group: ProcessGroup | None¶
Should return the data parallel process group if it’s anything other than the default process group.
- property eval_batch_spec: EvalBatchSpec¶
Should return the desired specification for evaluation batches. This is used for in-loop evaluation, for example, to determine how to build eval batches in a way that will work for the particular
TrainModule.
- class olmo_core.train.train_module.TransformerPipelineTrainModuleConfig(rank_microbatch_size, max_sequence_length, optim, max_grad_norm=None, scheduler=None, compile_model=False, float8_config=None, pp_config=None, dp_config=None, tp_config=None, cp_config=None, ep_config=None, ac_config=None, z_loss_multiplier=None, state_dict_save_opts=None, state_dict_load_opts=None, load_key_mapping=None, autocast_precision=None, label_ignore_index=-100)[source]¶
Bases:
TransformerTrainModuleConfigKept for backwards compatibility, but please use
TransformerTrainModuleConfiginstead.Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
- class olmo_core.train.train_module.TransformerActivationCheckpointingConfig(mode='full', block_interval=None, modules=None, activation_memory_budget=None)[source]¶
Bases:
ConfigDefines the activation checkpointing strategy for a transformer model.
Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
-
mode:
TransformerActivationCheckpointingMode= 'full'¶ The activation checkpointing mode.
-
block_interval:
Optional[int] = None¶ Required when
modeis “selected_blocks”. Determines which blocks are wrapped.
-
modules:
Optional[List[str]] = None¶ Required when
modeis “selected_modules”. A list of modules names to wrap for activation checkpointing. Globs are supported.
-
activation_memory_budget:
Optional[float] = None¶ Required when
modeis “budget”. Memory budget for activation checkpointing in range [0, 1]. 0 = recompute all activations, 1 = recompute none (default). Requires compilation to be enabled.See https://pytorch.org/blog/activation-checkpointing-techniques/ for more details.
-
mode:
- class olmo_core.train.train_module.TransformerActivationCheckpointingMode(value)[source]¶
Bases:
StrEnumAn enumeration of the different activation checkpointing modes.
Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
- full = 'full'¶
Checkpoint every block.
- selected_blocks = 'selected_blocks'¶
Checkpoint only selected blocks.
- selected_modules = 'selected_modules'¶
Checkpoint only selected modules.
- selected_ops = 'selected_ops'¶
Checkpoint only a specific set of operations.
- budget = 'budget'¶
Checkpoint based on a budget.
- class olmo_core.train.train_module.TransformerDataParallelConfig(name, param_dtype=None, reduce_dtype='float32', num_replicas=None, shard_degree=None, wrapping_strategy='full', prefetch_factor=0)[source]¶
Bases:
DataParallelConfigTransformer-specific data parallel config.
-
wrapping_strategy:
TransformerDataParallelWrappingStrategy= 'full'¶ The wrapping strategy.
-
wrapping_strategy:
- class olmo_core.train.train_module.TransformerDataParallelWrappingStrategy(value)[source]¶
Bases:
StrEnumAn enumeration of the different wrapping strategy for the data parallel implementations.
- full = 'full'¶
Wrap each block and the LM head (only applies to FSDP).
- blocks = 'blocks'¶
Like full but the LM head is not wrapped separately (only applies to FSDP).
- fine_grained = 'fine_grained'¶
Wrap certain modules within each block in addition to wrapping each block (only applies to FSDP).
- class olmo_core.train.train_module.TransformerExpertParallelConfig(degree)[source]¶
Bases:
ExpertParallelConfigTransformer-specific expert parallel config.
- class olmo_core.train.train_module.TransformerTensorParallelConfig(degree, enable_async=False)[source]¶
Bases:
TensorParallelConfigTransformer-specific tensor parallel config.
- class olmo_core.train.train_module.TransformerContextParallelConfig(degree, ring=None, uly=None)[source]¶
Bases:
ContextParallelConfigTransformer-specific context parallel config.
- class olmo_core.train.train_module.TransformerPipelineParallelConfig(degree, schedule='Interleaved1F1B', style=None, split_points=None)[source]¶
Bases:
PipelineParallelConfigTransformer-specific pipeline parallel config.
Warning
This is a beta feature! The API is subject to change even with minor and patch releases. If you choose to use this feature please read the CHANGELOG before upgrading your version of this library.
-
split_points:
Optional[List[int]] = None¶ A list of unique, increasing block indices that define how to split the model into stages.
For example,
split_points = [0, 2]with a 4-layer model means the model will be split into 3 stages, with the first containing just the embedding, the second containing blocks 0 and 1, and the third containing blocks 2 and 3 and the language modeling head.If not specified the split points are determined automatically based on the schedule type.
-
split_points: