"""
This module implements a highly efficient, yet flexible, language model trainer.
Features
--------
- Async checkpointing (optional) with local or remote checkpoint directories.
- Supports any type of parallel strategy.
- Async metric logging, with support for custom metrics, even those that need to be reduced across
ranks.
- Flexible :class:`~olmo_core.train.callbacks.Callback` system for extending/modifying the training
loop behavior.
- A powerful set of built-in callbacks (:mod:`olmo_core.train.callbacks`).
Overview
--------
Call :func:`prepare_training_environment()` at the top of your training script,
then construct your trainer using a :class:`TrainerConfig`. Finally, call :meth:`Trainer.fit()`
and cleanup at the end of your script by calling :func:`teardown_training_environment()`.
For example::
if __name__ == "__main__":
prepare_training_environment()
# Build train module and data loader...
# Build trainer.
trainer = trainer_config.build(train_module, data_loader)
# Run the trainer.
trainer.fit()
# Clean up.
teardown_training_environment()
See the `train a language model <examples/train.html>`_ example for a complete, run-able demonstration.
API Reference
-------------
"""
import logging
from datetime import timedelta
from typing import Optional
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from ..distributed.utils import init_distributed, is_distributed
from ..io import add_cached_path_clients
from ..utils import LogFilterType, get_default_device, prepare_cli_environment, seed_all
from .checkpoint import Checkpointer, CheckpointerConfig
from .common import (
Duration,
DurationUnit,
LoadStrategy,
MetricMergeStrategy,
ReduceType,
StepSkipRange,
)
from .config import TrainerConfig
from .trainer import Trainer
__all__ = [
"prepare_training_environment",
"teardown_training_environment",
"TrainerConfig",
"Trainer",
"CheckpointerConfig",
"Checkpointer",
"LoadStrategy",
"Duration",
"DurationUnit",
"ReduceType",
"MetricMergeStrategy",
"StepSkipRange",
]
log = logging.getLogger(__name__)
[docs]
def prepare_training_environment(
*,
seed: Optional[int] = None,
backend: Optional[str] = "cpu:gloo,cuda:nccl",
timeout: timedelta = timedelta(minutes=15),
log_filter_type: Optional[LogFilterType] = None,
shared_filesystem: Optional[bool] = None,
):
"""
Prepare the environment for training, including setting up the distributed process group
for distributed training.
.. tip::
Internally this calls:
- :func:`~olmo_core.distributed.utils.init_distributed()`, which also calls :func:`torch.cuda.set_device()`
for backends that support CUDA, otherwise :func:`torch.set_default_device()`.
- :func:`~olmo_core.utils.prepare_cli_environment()`
So there's no need to call those separately.
.. important::
This should be invoked at the very start of your training script, such as at the beginning
of the ``if __name__ == "__main__": ...`` block.
:param seed: The seed to initialize RNG states with.
:param backend: The distributed backend to use, if any. Set to ``None`` for non-distributed training.
When using NCCL, ideally you should also include a CPU-only backend (the default) like GLOO,
which allows the trainer to run async checkpointing and bookkeeping collectives on the CPU
backend without blocking training operations.
:param timeout: The timeout for initializing the distributed process group.
:param log_filter_type: Determines which ranks are allowed to emit log messages below the
``WARNING`` level. You can also configure this through the env var ``LOG_FILTER_TYPE``.
If neither are set, this defaults to "rank0_only".
.. note::
All ranks will always emit messages at the ``WARNING`` level or higher.
:param shared_filesystem: Should be set to ``True`` if the checkpoint and working directories
are in a local filesystem shared by all ranks, e.g. on an NFS drive.
"""
# Setting the mp start method to "spawn" avoids some data loader segfaults on LUMI.
try:
mp.set_start_method("spawn", force=True)
except RuntimeError as e:
print(f"failed to set multiprocessing start method: {e}")
# Initialize process group.
if backend is not None:
init_distributed(backend=backend, timeout=timeout, shared_filesytem=shared_filesystem)
else:
torch.set_default_device(get_default_device())
# Configure logging, warning filters, exception hooks, and other CLI settings.
prepare_cli_environment(log_filter_type=log_filter_type)
# Add custom cached-path clients.
add_cached_path_clients()
# Init RNG states.
if seed is not None:
seed_all(seed)
if is_distributed():
log.info(f"Using distributed backend {dist.get_backend()}")
[docs]
def teardown_training_environment():
"""
To be run at the end of training. Tears down all distributed process groups.
"""
if is_distributed():
dist.destroy_process_group()