train¶
This module implements a highly efficient, yet flexible, language model trainer.
Features¶
Async checkpointing (optional) with local or remote checkpoint directories.
Supports any type of parallel strategy.
Async metric logging, with support for custom metrics, even those that need to be reduced across ranks.
Flexible
Callbacksystem for extending/modifying the training loop behavior.A powerful set of built-in callbacks (
olmo_core.train.callbacks).
Overview¶
Call prepare_training_environment() at the top of your training script,
then construct your trainer using a TrainerConfig. Finally, call Trainer.fit()
and cleanup at the end of your script by calling teardown_training_environment().
For example:
if __name__ == "__main__":
prepare_training_environment()
# Build train module and data loader...
# Build trainer.
trainer = trainer_config.build(train_module, data_loader)
# Run the trainer.
trainer.fit()
# Clean up.
teardown_training_environment()
See the train a language model example for a complete, run-able demonstration.
API Reference¶
- olmo_core.train.prepare_training_environment(*, seed=None, backend='cpu:gloo,cuda:nccl', timeout=datetime.timedelta(seconds=900), log_filter_type=None, shared_filesystem=None)[source]¶
Prepare the environment for training, including setting up the distributed process group for distributed training.
Tip
Internally this calls:
init_distributed(), which also callstorch.cuda.set_device()for backends that support CUDA, otherwisetorch.set_default_device().
So there’s no need to call those separately.
Important
This should be invoked at the very start of your training script, such as at the beginning of the
if __name__ == "__main__": ...block.- Parameters:
seed (
Optional[int], default:None) – The seed to initialize RNG states with.backend (
Optional[str], default:'cpu:gloo,cuda:nccl') – The distributed backend to use, if any. Set toNonefor non-distributed training. When using NCCL, ideally you should also include a CPU-only backend (the default) like GLOO, which allows the trainer to run async checkpointing and bookkeeping collectives on the CPU backend without blocking training operations.timeout (
timedelta, default:datetime.timedelta(seconds=900)) – The timeout for initializing the distributed process group.log_filter_type (
Optional[LogFilterType], default:None) –Determines which ranks are allowed to emit log messages below the
WARNINGlevel. You can also configure this through the env varLOG_FILTER_TYPE. If neither are set, this defaults to “rank0_only”.Note
All ranks will always emit messages at the
WARNINGlevel or higher.shared_filesystem (
Optional[bool], default:None) – Should be set toTrueif the checkpoint and working directories are in a local filesystem shared by all ranks, e.g. on an NFS drive.
- olmo_core.train.teardown_training_environment()[source]¶
To be run at the end of training. Tears down all distributed process groups.
- class olmo_core.train.TrainerConfig(save_folder, work_dir=None, load_path=None, load_strategy='if_available', load_optim_state=None, load_trainer_state=None, checkpointer=<factory>, device=None, save_overwrite=False, max_duration=<factory>, cancel_check_interval=25, hard_stop=None, metrics_collect_interval=5, callbacks=<factory>, async_bookkeeping=None, bookkeeping_soft_timeout=30, no_checkpoints=False, no_evals=False, steps_to_skip=None)[source]¶
Bases:
ConfigA configuration class for easily building
Trainerinstances.See also
See the
Trainerdocumentation for a description of the fields.- with_callback(name, callback)[source]¶
Return a new trainer config with an additional callback.
- Parameters:
- Return type:
- with_callbacks(callbacks)[source]¶
Return a new trainer config with additional callbacks.
- with_recommended_evals(tokenizer, sequence_length, cluster, task_set='full', eval_interval=10000, lazy_load=False)[source]¶
Return a new trainer config with added callbacks for downstream evaluation and validation sets.
- Return type:
- build(train_module, data_loader, *, dp_process_group=None, checkpointer_pg=None)[source]¶
Build the corresponding trainer.
- Parameters:
train_module (
TrainModule) – The train module to fit.data_loader (
DataLoaderBase) – The data loader to train on.dp_process_group (
Optional[ProcessGroup], default:None) – The data parallel process group. Defaults toolmo_core.train.train_module.TrainModule.dp_process_group.
- Return type:
- class olmo_core.train.Trainer(work_dir, train_module, data_loader, device, save_folder, checkpointer, callbacks, max_duration, save_overwrite=False, load_path=None, load_strategy='if_available', load_trainer_state=None, load_optim_state=None, metrics_collect_interval=5, dp_process_group=None, global_step=0, global_train_tokens_seen=0, global_train_petaflops=0.0, epoch=1, cancel_check_interval=25, hard_stop=None, async_bookkeeping=None, bookkeeping_soft_timeout=30, no_checkpoints=False, no_evals=False, steps_to_skip=None, _metrics=<factory>, _metrics_reduce_type=<factory>, _canceled=False, _cancel_reason=None, _canceling_rank=None, _error=None, _rank_batch_size=None, _multi_thread_pool=None, _single_thread_pool=None, _bookkeeping_queue=<factory>, _bookkeeping_pg=None, _blocking_ephemeral_checkpoints=<factory>, _checkpoint_loaded=False, _metrics_consistent=None)[source]¶
Bases:
objectLanguage model trainer.
Tip
Use
TrainerConfiginstead of constructing this class directly.-
work_dir:
Path¶ A local working directory to use for temporary files needed during training. Files added to this working directory can be persisted to the
save_folderviapersist_working_file().Note
When constructing your trainer through a
TrainerConfigthis will default to thesave_folderif it’s a local directory.
-
train_module:
TrainModule¶ The train module to fit.
-
data_loader:
DataLoaderBase¶ The train data loader.
-
device:
device¶ The default device to use. Should match the device the model is on and be appropriate for the main distributed backend.
-
save_folder:
str¶ The folder to save all checkpoints to. Could be a local directory (if using a shared filesytem) or a URL.
Warning
If you try to use a local directory without a globally shared filesystem across all ranks you will get an error.
-
checkpointer:
Checkpointer¶ The checkpointer. This is a wrapper around the functionality in
olmo_core.distributed.checkpoint, which means you can useunshard_checkpoint()to unshard the model and optimizer state from a train checkpoint after the fact.
-
max_duration:
Duration¶ The duration to train for.
Important
The total number of training steps must be known ahead of time for various reasons such as setting a learning rate schedule. Therefore if your data loader’s number of batches (
total_batches) is unknown ahead of time, you must set themax_durationin terms oftokensorsteps, but not epochs.
-
save_overwrite:
bool= False¶ Whether to overwrite existing files/checkpoints in the
save_folder.
-
load_path:
Union[Path,PathLike,str,None] = None¶ An alternative location to load a checkpoint from if no checkpoint is found in the current
save_folder.This can be set to a checkpoint path or the path to a folder of checkpoints such as the
save_folderfrom a different run.
-
load_strategy:
LoadStrategy= 'if_available'¶ The strategy for loading a checkpoint prior to training.
-
load_trainer_state:
Optional[bool] = None¶ Whether to load the trainer state (including dataloader state). If
None, this will attempt to load the trainer state if it exists in the checkpoint, but will will not error if it doesn’t.
-
load_optim_state:
Optional[bool] = None¶ Whether to load the optimizer state. If
None, this will attempt to load the optimizer state if it exists in the checkpoint, but will not error if it doesn’t.
-
metrics_collect_interval:
int= 5¶ How often (in steps) to collect, reduce, and pass on metrics to the
Callback.log_metricsmethod on callbacks.Note
Regardless of what this is set to, the
Callback.log_metricsmethods are still called with the metrics for every single step, but will be delayed according to this value.For example, if this is set to 5, then every 5 steps the metrics from the past 5 steps are collected and reduced together, then passed on to
Callback.log_metricsaltogether.Tip
Increasing this can improve throughput since logging metrics always requires a host-device sync.
-
dp_process_group:
Optional[ProcessGroup] = None¶ The distributed process group for all data parallel ranks.
-
global_step:
int= 0¶ The current step/batch. 1-based, though it’s initialized to 0 before the first step. This does not reset after an epoch.
-
cancel_check_interval:
int= 25¶ The interval (in steps) to check if the run is canceled. Checking requires distributed comms, but if you’ve configured a separate CPU-only backend (like “gloo”) then this shouldn’t impact training throughput.
-
hard_stop:
Optional[Duration] = None¶ Set a hard stopping point for the trainer. This is useful for ablations when you you don’t want to do a complete training run, but you don’t want to change
max_durationas to not affect the learning rate schedule.
-
async_bookkeeping:
Optional[bool] = None¶ Do collective bookkeeping operations like reducing metrics asynchronously. This requires a separate CPU-only backend, and will default to
Trueif one is available.
-
bookkeeping_soft_timeout:
int= 30¶ A soft timeout (in seconds) for bookkeeping operations. If a bookkeeping operation takes longer than this then a warning is emitted.
-
no_checkpoints:
bool= False¶ Set this to
Trueto disable automatic saving/loading of checkpoints altogether. This is useful for benchmarking.
-
no_evals:
bool= False¶ Set this to
Trueto disable evaluator callbacks. This is useful for benchmarking.
-
steps_to_skip:
Optional[List[StepSkipRange]] = None¶ Ranges of steps to completely skip training on.
- property tokens_per_epoch: int | None¶
The total number of tokens in the training dataset, minus left-overs.
- property max_steps: int | None¶
The maximum number of steps to train for, as determined by
max_duration.
- property max_tokens: int | None¶
The maximum number of tokens to train for, as determined by
max_duration.
- property bookkeeping_device: device¶
The device used for collective bookkeeping (non-training) operations that can potentially. use a different backend.
- property bookkeeping_pg: ProcessGroup | None¶
The process group used for bookkeeping collectives.
Since bookkeeping collectives might be done in a separate thread, we need a separate process group to avoid potential race conditions.
- property multi_thread_pool: ThreadPoolExecutor¶
A multi-threaded executor for bookkeeping tasks that don’t involve distributed communication.
- property single_thread_pool: ThreadPoolExecutor¶
A single-threaded executor for bookkeeping tasks that involve distributed communication.
- check_if_canceled()[source]¶
Asynchronously check if the run is canceled. Use
is_canceledto see the result. This needs to be called by all ranks at the same point in the training loop.
- fit()[source]¶
Fit the model, potentially loading a checkpoint first depending on the
load_strategy.
- load_checkpoint(dir, *, load_trainer_state=None, load_optim_state=None)[source]¶
Load a checkpoint.
Note
fit()may call this method automatically depending on theload_strategy.- Parameters:
dir (
Union[Path,PathLike,str]) – The path/URL to a checkpoint or a folder of checkpoints.load_trainer_state (
Optional[bool], default:None) – Load trainer state (data loader state, RNG states, and other bookkeeping).load_optim_state (
Optional[bool], default:None) – Load optimizer state in the train module.
- maybe_load_checkpoint(dir=None, *, load_trainer_state=None, load_optim_state=None)[source]¶
Like
load_checkpoint()but is a no-op if there is no checkpoint in thedirprovided.Note
fit()may call this method automatically depending on theload_strategy.- Return type:
- Returns:
If a checkpoint was loaded.
- save_checkpoint(ephemeral=False)[source]¶
Save a checkpoint for the current step to the
save_folder.
- save_checkpoint_async(ephemeral=False)[source]¶
Save a checkpoint for the current step to the
save_folderasynchronously.- Parameters:
ephemeral (
bool, default:False) – Whether to mark the checkpoint as ephemeral in its metadata. Note that the trainer itself won’t remove ephemeral checkpoints. That’s up to theCheckpointerCallback.- Return type:
- Returns:
The path/URL to the checkpoint and a future which will complete when the checkpoint is successfully saved.
- record_metric(name, value, reduce_type=None, namespace=None, merge_strategy='warn')[source]¶
Record a new metric for the current step.
See also
Use
record_ce_loss()to record the cross-entropy loss, specifically.- Parameters:
name (
str) – The name of the metric.reduce_type (
Optional[ReduceType], default:None) – Specifies how to reduce the metric across the distributed process group.Nonemeans no reduction.namespace (
Optional[str], default:None) – A namespace to record the metric under, i.g. “train” or “optim”.merge_strategy (
MetricMergeStrategy, default:'warn') – How to merge metrics when duplicates are logged.
- get_metric(name, namespace=None)[source]¶
Get the value of a metric recorded during the current step.
Warning
Metrics will only be available from the time they’re recorded until the end of the current step.
Warning
Accessing a metric can inadvertently trigger a host-device sync, which slows down training.
- write_file(name, contents, dir=None)[source]¶
Write a file to the
save_folderordir, if provided.- Parameters:
fname – The name of the file to write, relative to the
save_folderordir.contents (
Union[str,bytes]) – The contents of the file to write.dir (
Union[Path,PathLike,str,None], default:None) – The path/URL to a directory to write the file to. Defaults tosave_folder.
- Return type:
- Returns:
The path/URL of the file.
- persist_working_file(name)[source]¶
Persist a file in the
work_dirby saving/uploading it to thesave_folder.- Parameters:
name (
Union[Path,PathLike,str]) – The name/path of the file relative to thework_dir.- Return type:
- Returns:
The full path/URL to the saved file.
- Raises:
FileNotFoundError – If the file can’t be found.
FileExistsError – If the file already exists in the save folder and
save_overwriteisFalse.
- persist_working_subdir(name)[source]¶
Persist a directory in the
work_dirby saving/uploading it to thesave_folder.- Parameters:
name (
Union[Path,PathLike,str]) – The name/path of the directory relative to thework_dir.- Return type:
- Returns:
The full path/URL to the saved directory.
- Raises:
FileNotFoundError – If the directory can’t be found.
FileExistsError – If any files in the directory already exist in the save folder and
save_overwriteisFalse.
- has_callback(cb_class)[source]¶
Check if the trainer already has a registered instance of the given callback class.
- Return type:
- run_bookkeeping_op(op, *args, cb=None, op_name=None, cancel_in_progress=None, allow_multiple=True, soft_timeout=None, distributed=True, **kwargs)[source]¶
Run a bookkeeping operation, potentially in a background thread.
- Parameters:
args – Positional arguments to pass to the operation.
kwargs – Keyword arguments to pass to the operation.
cb (
Optional[Callable[[TypeVar(T)],None]], default:None) – A callback to call with the result of the operation when it finishes.op_name (
Optional[str], default:None) – A name for the operation, used for logging, debugging, and potentially canceling old invocations of the same operation whenallow_multipleisFalse.allow_multiple (
bool, default:True) – IfFalse, only one bookkeeping operation with the given name is allowed to run, so if there are other ops with the same name that are queued, those will be canceled, and if there’s another one that’s already running, the current invocation will be ignored.soft_timeout (
Optional[int], default:None) – A soft timeout, in seconds, to wait for the operation to finish. If the op takes longer than this a warning will be issued.distributed (
bool, default:True) – This should only be set toFalseif the op doesn’t use distributed communication, in which case it will be allowed to run concurrently with other ops.
-
work_dir:
- class olmo_core.train.CheckpointerConfig(work_dir=None, save_overwrite=None, pre_download=False, save_thread_count=None, load_thread_count=None, throttle_uploads=False)[source]¶
Bases:
ConfigA configuration class for building
Checkpointerinstances.
- class olmo_core.train.Checkpointer(work_dir, save_overwrite=False, pre_download=False, process_group=None, save_thread_count=None, load_thread_count=None, throttle_uploads=False)[source]¶
Bases:
objectTrainer checkpointer.
- save(dir, train_module, train_state, ephemeral=False)[source]¶
Save model, optim, and other training state to a local or remote directory.
- save_async(dir, train_module, train_state, ephemeral=False)[source]¶
An async version of
save().- Return type:
Future[None]
- load(dir, train_module, *, load_trainer_state=None, load_optim_state=None)[source]¶
Load model, optim, and other training state from a local or remote checkpoint directory created via
save()orsave_async().
- write_file(dir, fname, contents)[source]¶
Write something to a file in a local or remote directory.
- Parameters:
- Return type:
- Returns:
The path/URL of the file.
- classmethod dir_is_checkpoint(dir)[source]¶
Check if a directory is a checkpoint directory.
- Return type:
- classmethod contains_checkpoint(dir)[source]¶
Check if a directory is a checkpoint directory or contains a child checkpoint directory.
- Return type:
- classmethod latest_checkpoint(dir)[source]¶
Find the latest checkpoint in a directory of checkpoints.
- Raises:
FileNotFoundError – If no checkpoints are found.
- Return type:
- class olmo_core.train.LoadStrategy(value)[source]¶
Bases:
StrEnumDetermines the strategy for loading checkpoints prior to training.
- if_available = 'if_available'¶
The trainer will attempt to load a checkpoint from the save folder or load path (in that order) but will train from scratch if no checkoint is found.
- always = 'always'¶
The trainer will attempt to load a checkpoint from the save folder or load path (in that order) and raise an error if no checkpoint is found.
- never = 'never'¶
The trainer will never load a checkpoint even if one exists in the save folder or load path.
- class olmo_core.train.Duration(value, unit)[source]¶
Bases:
object-
unit:
DurationUnit¶ The unit associated with the
value.
- classmethod chinchilla_tokens(multiple, *, model_params, _tok_per_param=20)[source]¶
Define a duration based on a multiple of the Chinchilla-optimal number of tokens.
The rule of thumb for Chinchilla compute optimality is 20 tokens-per-parameter for decoder-only natural language models trained with AdamW on dataset mixtures similar to the Pile.
Chinchilla optimality refers to training-time compute only, and does not account for inference-time compute. In practice, models are often trained with more tokens than the Chinchilla optimal value (“overtrained”) to improve inference-time performance.
Chinchilla: https://arxiv.org/abs/2203.15556 Chinchilla replication: https://arxiv.org/abs/2404.10102
- Parameters:
- Return type:
-
unit:
- class olmo_core.train.DurationUnit(value)[source]¶
Bases:
StrEnumUnits that can be used to define a
Duration.- steps = 'steps'¶
Steps (batches).
- epochs = 'epochs'¶
Epochs.
- tokens = 'tokens'¶
Tokens.
- class olmo_core.train.ReduceType(value)[source]¶
Bases:
StrEnumAn enumeration of the allowed ways to reduce a metric across ranks.
- mean = 'mean'¶
Average across the process group.
- sum = 'sum'¶
Add across the process group.
- max = 'max'¶
Take the max across the process group.
- l2_norm = 'l2_norm'¶
For metrics that are computed as L2 norms on each rank, this will correctly reduce the norm across the process group to produce the global L2 norm.
- class olmo_core.train.MetricMergeStrategy(value)[source]¶
Bases:
StrEnumDetermines how duplicate metrics are merged.
- warn = 'warn'¶
Warn when a duplicate is logged, keeping the current value.
- latest = 'latest'¶
The latest is used.
- oldest = 'oldest'¶
The oldest (first logged) is used.
- mean = 'mean'¶
When a duplicate is logged we take the average with the last value.
- sum = 'sum'¶
The sum of the duplicates is used.
- max = 'max'¶
Take the maximum value of the duplicates.
- min = 'min'¶
Take the minimum value of the duplicates.
- class olmo_core.train.StepSkipRange(start, stop)[source]¶
Bases:
objectDefines a range of steps to skip during training.
Submodules
train.callbacksCallbackCallbackConfigCheckpointerCallbackCheckpointRemovalStrategyCometCallbackCometNotificationSettingConfigSaverCallbackConsoleLoggerCallbackEvaluatorCallbackLMEvaluatorCallbackConfigDownstreamEvaluatorCallbackConfigGAPMonitorCallbackGarbageCollectorCallbackGPUMemoryMonitorCallbackHFConverterCallbackProfilerCallbackSlackNotifierCallbackSlackNotificationSettingSequenceLengthSchedulerCallbackSpeedMonitorCallbackStabilityMonitorCallbackWandBCallbackBeakerCallbackBatchSizeSchedulerCallbackMonkeyPatcherCallbackMetricSaverCallbackModelMergeCallbackListCheckpointerCallback
train.train_moduleTrainModuleConfigTrainModuleEvalBatchSpecEvalBatchSizeUnitBasicTrainModuleTransformerTrainModuleTransformerTrainModuleConfigTransformerPipelineTrainModuleTransformerPipelineTrainModuleConfigTransformerActivationCheckpointingConfigTransformerActivationCheckpointingModeTransformerDataParallelConfigTransformerDataParallelWrappingStrategyTransformerExpertParallelConfigTransformerTensorParallelConfigTransformerContextParallelConfigTransformerPipelineParallelConfig