train

This module implements a highly efficient, yet flexible, language model trainer.

Features

  • Async checkpointing (optional) with local or remote checkpoint directories.

  • Supports any type of parallel strategy.

  • Async metric logging, with support for custom metrics, even those that need to be reduced across ranks.

  • Flexible Callback system for extending/modifying the training loop behavior.

  • A powerful set of built-in callbacks (olmo_core.train.callbacks).

Overview

Call prepare_training_environment() at the top of your training script, then construct your trainer using a TrainerConfig. Finally, call Trainer.fit() and cleanup at the end of your script by calling teardown_training_environment(). For example:

if __name__ == "__main__":
    prepare_training_environment()

    # Build train module and data loader...

    # Build trainer.
    trainer = trainer_config.build(train_module, data_loader)

    # Run the trainer.
    trainer.fit()

    # Clean up.
    teardown_training_environment()

See the train a language model example for a complete, run-able demonstration.

API Reference

olmo_core.train.prepare_training_environment(*, seed=None, backend='cpu:gloo,cuda:nccl', timeout=datetime.timedelta(seconds=900), log_filter_type=None, shared_filesystem=None)[source]

Prepare the environment for training, including setting up the distributed process group for distributed training.

Tip

Internally this calls:

So there’s no need to call those separately.

Important

This should be invoked at the very start of your training script, such as at the beginning of the if __name__ == "__main__": ... block.

Parameters:
  • seed (Optional[int], default: None) – The seed to initialize RNG states with.

  • backend (Optional[str], default: 'cpu:gloo,cuda:nccl') – The distributed backend to use, if any. Set to None for non-distributed training. When using NCCL, ideally you should also include a CPU-only backend (the default) like GLOO, which allows the trainer to run async checkpointing and bookkeeping collectives on the CPU backend without blocking training operations.

  • timeout (timedelta, default: datetime.timedelta(seconds=900)) – The timeout for initializing the distributed process group.

  • log_filter_type (Optional[LogFilterType], default: None) –

    Determines which ranks are allowed to emit log messages below the WARNING level. You can also configure this through the env var LOG_FILTER_TYPE. If neither are set, this defaults to “rank0_only”.

    Note

    All ranks will always emit messages at the WARNING level or higher.

  • shared_filesystem (Optional[bool], default: None) – Should be set to True if the checkpoint and working directories are in a local filesystem shared by all ranks, e.g. on an NFS drive.

olmo_core.train.teardown_training_environment()[source]

To be run at the end of training. Tears down all distributed process groups.

class olmo_core.train.TrainerConfig(save_folder, work_dir=None, load_path=None, load_strategy='if_available', load_optim_state=None, load_trainer_state=None, checkpointer=<factory>, device=None, save_overwrite=False, max_duration=<factory>, cancel_check_interval=25, hard_stop=None, metrics_collect_interval=5, callbacks=<factory>, async_bookkeeping=None, bookkeeping_soft_timeout=30, no_checkpoints=False, no_evals=False, steps_to_skip=None)[source]

Bases: Config

A configuration class for easily building Trainer instances.

See also

See the Trainer documentation for a description of the fields.

add_callback(name, callback)[source]

Add another callback.

add_callbacks(callbacks)[source]

Add a set of callbacks.

with_callback(name, callback)[source]

Return a new trainer config with an additional callback.

Parameters:
  • name (str) – A name to assign the callback. Must be unique.

  • callback (Callback) – The callback to add.

Return type:

TrainerConfig

with_callbacks(callbacks)[source]

Return a new trainer config with additional callbacks.

Parameters:

callbacks (Dict[str, Callback]) – A dictionary of callbacks to add. Keys must be unique.

Return type:

TrainerConfig

Return a new trainer config with added callbacks for downstream evaluation and validation sets.

Return type:

TrainerConfig

build(train_module, data_loader, *, dp_process_group=None, checkpointer_pg=None)[source]

Build the corresponding trainer.

Parameters:
Return type:

Trainer

class olmo_core.train.Trainer(work_dir, train_module, data_loader, device, save_folder, checkpointer, callbacks, max_duration, save_overwrite=False, load_path=None, load_strategy='if_available', load_trainer_state=None, load_optim_state=None, metrics_collect_interval=5, dp_process_group=None, global_step=0, global_train_tokens_seen=0, global_train_petaflops=0.0, epoch=1, cancel_check_interval=25, hard_stop=None, async_bookkeeping=None, bookkeeping_soft_timeout=30, no_checkpoints=False, no_evals=False, steps_to_skip=None, _metrics=<factory>, _metrics_reduce_type=<factory>, _canceled=False, _cancel_reason=None, _canceling_rank=None, _error=None, _rank_batch_size=None, _multi_thread_pool=None, _single_thread_pool=None, _bookkeeping_queue=<factory>, _bookkeeping_pg=None, _blocking_ephemeral_checkpoints=<factory>, _checkpoint_loaded=False, _metrics_consistent=None)[source]

Bases: object

Language model trainer.

Tip

Use TrainerConfig instead of constructing this class directly.

work_dir: Path

A local working directory to use for temporary files needed during training. Files added to this working directory can be persisted to the save_folder via persist_working_file().

Note

When constructing your trainer through a TrainerConfig this will default to the save_folder if it’s a local directory.

train_module: TrainModule

The train module to fit.

data_loader: DataLoaderBase

The train data loader.

device: device

The default device to use. Should match the device the model is on and be appropriate for the main distributed backend.

save_folder: str

The folder to save all checkpoints to. Could be a local directory (if using a shared filesytem) or a URL.

Warning

If you try to use a local directory without a globally shared filesystem across all ranks you will get an error.

checkpointer: Checkpointer

The checkpointer. This is a wrapper around the functionality in olmo_core.distributed.checkpoint, which means you can use unshard_checkpoint() to unshard the model and optimizer state from a train checkpoint after the fact.

callbacks: Dict[str, Callback]

Trainer callbacks.

max_duration: Duration

The duration to train for.

Important

The total number of training steps must be known ahead of time for various reasons such as setting a learning rate schedule. Therefore if your data loader’s number of batches (total_batches) is unknown ahead of time, you must set the max_duration in terms of tokens or steps, but not epochs.

save_overwrite: bool = False

Whether to overwrite existing files/checkpoints in the save_folder.

load_path: Union[Path, PathLike, str, None] = None

An alternative location to load a checkpoint from if no checkpoint is found in the current save_folder.

This can be set to a checkpoint path or the path to a folder of checkpoints such as the save_folder from a different run.

load_strategy: LoadStrategy = 'if_available'

The strategy for loading a checkpoint prior to training.

load_trainer_state: Optional[bool] = None

Whether to load the trainer state (including dataloader state). If None, this will attempt to load the trainer state if it exists in the checkpoint, but will will not error if it doesn’t.

load_optim_state: Optional[bool] = None

Whether to load the optimizer state. If None, this will attempt to load the optimizer state if it exists in the checkpoint, but will not error if it doesn’t.

metrics_collect_interval: int = 5

How often (in steps) to collect, reduce, and pass on metrics to the Callback.log_metrics method on callbacks.

Note

Regardless of what this is set to, the Callback.log_metrics methods are still called with the metrics for every single step, but will be delayed according to this value.

For example, if this is set to 5, then every 5 steps the metrics from the past 5 steps are collected and reduced together, then passed on to Callback.log_metrics altogether.

Tip

Increasing this can improve throughput since logging metrics always requires a host-device sync.

dp_process_group: Optional[ProcessGroup] = None

The distributed process group for all data parallel ranks.

global_step: int = 0

The current step/batch. 1-based, though it’s initialized to 0 before the first step. This does not reset after an epoch.

global_train_tokens_seen: int = 0

The total number of training tokens seen.

global_train_petaflops: float = 0.0

The total number of training petaflops computed.

epoch: int = 1

The current epoch (1-based).

cancel_check_interval: int = 25

The interval (in steps) to check if the run is canceled. Checking requires distributed comms, but if you’ve configured a separate CPU-only backend (like “gloo”) then this shouldn’t impact training throughput.

hard_stop: Optional[Duration] = None

Set a hard stopping point for the trainer. This is useful for ablations when you you don’t want to do a complete training run, but you don’t want to change max_duration as to not affect the learning rate schedule.

async_bookkeeping: Optional[bool] = None

Do collective bookkeeping operations like reducing metrics asynchronously. This requires a separate CPU-only backend, and will default to True if one is available.

bookkeeping_soft_timeout: int = 30

A soft timeout (in seconds) for bookkeeping operations. If a bookkeeping operation takes longer than this then a warning is emitted.

no_checkpoints: bool = False

Set this to True to disable automatic saving/loading of checkpoints altogether. This is useful for benchmarking.

no_evals: bool = False

Set this to True to disable evaluator callbacks. This is useful for benchmarking.

steps_to_skip: Optional[List[StepSkipRange]] = None

Ranges of steps to completely skip training on.

property global_batch_size: int

Global training batch size in tokens.

property rank_batch_size: int

The number of tokens in each training batch per rank.

property tokens_per_batch: int

The number of tokens in each training batch.

property steps_per_epoch: int | None

The total number of training steps in an epoch, if known.

property tokens_per_epoch: int | None

The total number of tokens in the training dataset, minus left-overs.

property max_steps: int | None

The maximum number of steps to train for, as determined by max_duration.

property max_tokens: int | None

The maximum number of tokens to train for, as determined by max_duration.

convert_duration_to_steps(duration)[source]

Convert a duration to steps.

Return type:

int

property bookkeeping_device: device

The device used for collective bookkeeping (non-training) operations that can potentially. use a different backend.

property bookkeeping_pg: ProcessGroup | None

The process group used for bookkeeping collectives.

Since bookkeeping collectives might be done in a separate thread, we need a separate process group to avoid potential race conditions.

property multi_thread_pool: ThreadPoolExecutor

A multi-threaded executor for bookkeeping tasks that don’t involve distributed communication.

property single_thread_pool: ThreadPoolExecutor

A single-threaded executor for bookkeeping tasks that involve distributed communication.

property checkpoint_loaded: bool

If a checkpoint has been loaded.

cancel_run(reason, no_sync=False)[source]

Mark the run canceled.

Parameters:
  • reason (str) – The reason for canceling.

  • no_sync (bool, default: False) – Set this to True only if you’re calling this from all ranks at the same time, otherwise you’ll get a distributed deadlock.

check_if_canceled()[source]

Asynchronously check if the run is canceled. Use is_canceled to see the result. This needs to be called by all ranks at the same point in the training loop.

fit()[source]

Fit the model, potentially loading a checkpoint first depending on the load_strategy.

state_dict()[source]

Get the trainer state to save.

Return type:

TrainerStateDict

load_state_dict(state_dict)[source]

Load trainer state (not model or optimizer state).

load_checkpoint(dir, *, load_trainer_state=None, load_optim_state=None)[source]

Load a checkpoint.

Note

fit() may call this method automatically depending on the load_strategy.

Parameters:
  • dir (Union[Path, PathLike, str]) – The path/URL to a checkpoint or a folder of checkpoints.

  • load_trainer_state (Optional[bool], default: None) – Load trainer state (data loader state, RNG states, and other bookkeeping).

  • load_optim_state (Optional[bool], default: None) – Load optimizer state in the train module.

maybe_load_checkpoint(dir=None, *, load_trainer_state=None, load_optim_state=None)[source]

Like load_checkpoint() but is a no-op if there is no checkpoint in the dir provided.

Note

fit() may call this method automatically depending on the load_strategy.

Return type:

bool

Returns:

If a checkpoint was loaded.

save_checkpoint(ephemeral=False)[source]

Save a checkpoint for the current step to the save_folder.

Parameters:

ephemeral (bool, default: False) – Whether to mark the checkpoint as ephemeral in its metadata. Note that the trainer itself won’t remove ephemeral checkpoints. That’s up to the CheckpointerCallback.

Return type:

Union[Path, PathLike, str]

Returns:

The path/URL to the checkpoint.

save_checkpoint_async(ephemeral=False)[source]

Save a checkpoint for the current step to the save_folder asynchronously.

Parameters:

ephemeral (bool, default: False) – Whether to mark the checkpoint as ephemeral in its metadata. Note that the trainer itself won’t remove ephemeral checkpoints. That’s up to the CheckpointerCallback.

Return type:

Tuple[Union[Path, PathLike, str], Future]

Returns:

The path/URL to the checkpoint and a future which will complete when the checkpoint is successfully saved.

record_metric(name, value, reduce_type=None, namespace=None, merge_strategy='warn')[source]

Record a new metric for the current step.

See also

Use record_ce_loss() to record the cross-entropy loss, specifically.

Parameters:
  • name (str) – The name of the metric.

  • value (Union[float, Tensor]) – The value of the metric.

  • reduce_type (Optional[ReduceType], default: None) – Specifies how to reduce the metric across the distributed process group. None means no reduction.

  • namespace (Optional[str], default: None) – A namespace to record the metric under, i.g. “train” or “optim”.

  • merge_strategy (MetricMergeStrategy, default: 'warn') – How to merge metrics when duplicates are logged.

record_ce_loss(value, reduce_type=None)[source]

Record the cross-entropy loss metric specifically.

get_metric(name, namespace=None)[source]

Get the value of a metric recorded during the current step.

Warning

Metrics will only be available from the time they’re recorded until the end of the current step.

Warning

Accessing a metric can inadvertently trigger a host-device sync, which slows down training.

Parameters:

name (str) – The name of the metric.

Return type:

Optional[Tensor]

write_file(name, contents, dir=None)[source]

Write a file to the save_folder or dir, if provided.

Parameters:
  • fname – The name of the file to write, relative to the save_folder or dir.

  • contents (Union[str, bytes]) – The contents of the file to write.

  • dir (Union[Path, PathLike, str, None], default: None) – The path/URL to a directory to write the file to. Defaults to save_folder.

Return type:

Union[Path, PathLike, str]

Returns:

The path/URL of the file.

persist_working_file(name)[source]

Persist a file in the work_dir by saving/uploading it to the save_folder.

Parameters:

name (Union[Path, PathLike, str]) – The name/path of the file relative to the work_dir.

Return type:

Union[Path, PathLike, str]

Returns:

The full path/URL to the saved file.

Raises:
persist_working_subdir(name)[source]

Persist a directory in the work_dir by saving/uploading it to the save_folder.

Parameters:

name (Union[Path, PathLike, str]) – The name/path of the directory relative to the work_dir.

Return type:

Union[Path, PathLike, str]

Returns:

The full path/URL to the saved directory.

Raises:
add_callback(name, callback)[source]

Add a callback to the trainer.

has_callback(cb_class)[source]

Check if the trainer already has a registered instance of the given callback class.

Return type:

bool

run_bookkeeping_op(op, *args, cb=None, op_name=None, cancel_in_progress=None, allow_multiple=True, soft_timeout=None, distributed=True, **kwargs)[source]

Run a bookkeeping operation, potentially in a background thread.

Parameters:
  • op (Callable[..., TypeVar(T)]) – The operation to run.

  • args – Positional arguments to pass to the operation.

  • kwargs – Keyword arguments to pass to the operation.

  • cb (Optional[Callable[[TypeVar(T)], None]], default: None) – A callback to call with the result of the operation when it finishes.

  • op_name (Optional[str], default: None) – A name for the operation, used for logging, debugging, and potentially canceling old invocations of the same operation when allow_multiple is False.

  • allow_multiple (bool, default: True) – If False, only one bookkeeping operation with the given name is allowed to run, so if there are other ops with the same name that are queued, those will be canceled, and if there’s another one that’s already running, the current invocation will be ignored.

  • soft_timeout (Optional[int], default: None) – A soft timeout, in seconds, to wait for the operation to finish. If the op takes longer than this a warning will be issued.

  • distributed (bool, default: True) – This should only be set to False if the op doesn’t use distributed communication, in which case it will be allowed to run concurrently with other ops.

class olmo_core.train.CheckpointerConfig(work_dir=None, save_overwrite=None, pre_download=False, save_thread_count=None, load_thread_count=None, throttle_uploads=False)[source]

Bases: Config

A configuration class for building Checkpointer instances.

class olmo_core.train.Checkpointer(work_dir, save_overwrite=False, pre_download=False, process_group=None, save_thread_count=None, load_thread_count=None, throttle_uploads=False)[source]

Bases: object

Trainer checkpointer.

save(dir, train_module, train_state, ephemeral=False)[source]

Save model, optim, and other training state to a local or remote directory.

save_async(dir, train_module, train_state, ephemeral=False)[source]

An async version of save().

Return type:

Future[None]

load(dir, train_module, *, load_trainer_state=None, load_optim_state=None)[source]

Load model, optim, and other training state from a local or remote checkpoint directory created via save() or save_async().

Return type:

Optional[Dict[str, Any]]

write_file(dir, fname, contents)[source]

Write something to a file in a local or remote directory.

Parameters:
  • dir (Union[Path, PathLike, str]) – The path/URL of the directory to write the file to.

  • fname (str) – The name of the file to write, relative to dir.

  • contents (Union[str, bytes]) – The contents of the file to write.

Return type:

Union[Path, PathLike, str]

Returns:

The path/URL of the file.

classmethod dir_is_checkpoint(dir)[source]

Check if a directory is a checkpoint directory.

Return type:

bool

classmethod find_checkpoints(dir, ephemeral=None)[source]

Find checkpoints within a directory.

Return type:

Generator[Tuple[int, str], None, None]

classmethod contains_checkpoint(dir)[source]

Check if a directory is a checkpoint directory or contains a child checkpoint directory.

Return type:

bool

classmethod latest_checkpoint(dir)[source]

Find the latest checkpoint in a directory of checkpoints.

Raises:

FileNotFoundError – If no checkpoints are found.

Return type:

str

class olmo_core.train.LoadStrategy(value)[source]

Bases: StrEnum

Determines the strategy for loading checkpoints prior to training.

if_available = 'if_available'

The trainer will attempt to load a checkpoint from the save folder or load path (in that order) but will train from scratch if no checkoint is found.

always = 'always'

The trainer will attempt to load a checkpoint from the save folder or load path (in that order) and raise an error if no checkpoint is found.

never = 'never'

The trainer will never load a checkpoint even if one exists in the save folder or load path.

class olmo_core.train.Duration(value, unit)[source]

Bases: object

value: int

The value of the duration.

unit: DurationUnit

The unit associated with the value.

classmethod steps(steps)[source]

Define a duration from a number of steps.

Return type:

Duration

classmethod epochs(epochs)[source]

Define a duration from a number of epochs.

Return type:

Duration

classmethod tokens(tokens)[source]

Define a duration from a number of tokens.

Return type:

Duration

classmethod chinchilla_tokens(multiple, *, model_params, _tok_per_param=20)[source]

Define a duration based on a multiple of the Chinchilla-optimal number of tokens.

The rule of thumb for Chinchilla compute optimality is 20 tokens-per-parameter for decoder-only natural language models trained with AdamW on dataset mixtures similar to the Pile.

Chinchilla optimality refers to training-time compute only, and does not account for inference-time compute. In practice, models are often trained with more tokens than the Chinchilla optimal value (“overtrained”) to improve inference-time performance.

Chinchilla: https://arxiv.org/abs/2203.15556 Chinchilla replication: https://arxiv.org/abs/2404.10102

Parameters:
  • multiple (float) – The Chinchilla multiplier. 1.0 is the Chinchilla optimal value. Values less than 1.0 will undertrain relative to Chinchilla, and values greater than 1.0 will overtrain relative to Chinchilla.

  • model_params (int) – The number of active, non-embedding parameters in the target model.

Return type:

Duration

due(*, step, tokens, epoch)[source]

Check if the duration is due.

Return type:

bool

class olmo_core.train.DurationUnit(value)[source]

Bases: StrEnum

Units that can be used to define a Duration.

steps = 'steps'

Steps (batches).

epochs = 'epochs'

Epochs.

tokens = 'tokens'

Tokens.

class olmo_core.train.ReduceType(value)[source]

Bases: StrEnum

An enumeration of the allowed ways to reduce a metric across ranks.

mean = 'mean'

Average across the process group.

sum = 'sum'

Add across the process group.

max = 'max'

Take the max across the process group.

l2_norm = 'l2_norm'

For metrics that are computed as L2 norms on each rank, this will correctly reduce the norm across the process group to produce the global L2 norm.

class olmo_core.train.MetricMergeStrategy(value)[source]

Bases: StrEnum

Determines how duplicate metrics are merged.

warn = 'warn'

Warn when a duplicate is logged, keeping the current value.

latest = 'latest'

The latest is used.

oldest = 'oldest'

The oldest (first logged) is used.

mean = 'mean'

When a duplicate is logged we take the average with the last value.

sum = 'sum'

The sum of the duplicates is used.

max = 'max'

Take the maximum value of the duplicates.

min = 'min'

Take the minimum value of the duplicates.

class olmo_core.train.StepSkipRange(start, stop)[source]

Bases: object

Defines a range of steps to skip during training.

start: int

The first step to skip (steps start at 1, not 0).

stop: int

The endpoint of the range (exclusive).