Train an LLMΒΆ
The following snippets can be found in src/examples/llm/.
The train.py script is meant to be launched via torchrun.
You can also use the python -m olmo_core.launch.beaker CLI to quickly launch this script on Beaker.
"""
Example of how to train a transformer language model.
Launch this with torchrun:
torchrun --nproc-per-node=4 src/examples/llm/train.py run_name [OVERRIDES...]
"""
import argparse
import logging
import os
import sys
from dataclasses import dataclass
from typing import List, Optional, cast
import rich
from olmo_core.config import Config, DType
from olmo_core.data import (
NumpyDataLoaderConfig,
NumpyFSLDatasetConfig,
NumpyPaddedFSLDatasetConfig,
TokenizerConfig,
)
from olmo_core.data.numpy_dataset import NumpyDatasetConfig
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import get_rank
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
TrainerConfig,
prepare_training_environment,
teardown_training_environment,
)
from olmo_core.train.callbacks import (
CheckpointerCallback,
CometCallback,
ConfigSaverCallback,
DownstreamEvaluatorCallbackConfig,
GAPMonitorCallback,
GPUMemoryMonitorCallback,
LMEvaluatorCallbackConfig,
ProfilerCallback,
WandBCallback,
)
from olmo_core.train.train_module import (
TransformerDataParallelConfig,
TransformerTrainModuleConfig,
)
from olmo_core.utils import seed_all
log = logging.getLogger(__name__)
# Check for the data on common Ai2 drives. If those don't exist we'll stream the data over the internet,
# which can be a lot slower. Alternatively you can download the files with wget, for example:
# > wget http://olmo-data.org/examples/c4-en/gpt2/c4-train.00000-00099.npy
DEFAULT_DATA_ROOT = "http://olmo-data.org/examples/c4-en/gpt2"
for dir in (
"/net/nfs/allennlp/llm-data/c4/en/",
"/weka/oe-training-default/ai2-llm/examples/c4-en/gpt2/",
):
if os.path.exists(dir):
DEFAULT_DATA_ROOT = dir
break
DATA_ROOT = os.environ.get("OLMO_DATA_ROOT", DEFAULT_DATA_ROOT).rstrip("/")
DATA_PATHS = [
f"{DATA_ROOT}/c4-train.00000-00099.npy",
# Uncomment for full dataset which might not be available on NFS or Weka.
# f"{DATA_ROOT}/c4-train.00100-00199.npy",
# f"{DATA_ROOT}/c4-train.00200-00299.npy",
# f"{DATA_ROOT}/c4-train.00300-00399.npy",
# f"{DATA_ROOT}/c4-train.00400-00499.npy",
# f"{DATA_ROOT}/c4-train.00500-00599.npy",
# f"{DATA_ROOT}/c4-train.00600-00699.npy",
# f"{DATA_ROOT}/c4-train.00700-00799.npy",
# f"{DATA_ROOT}/c4-train.00800-00899.npy",
# f"{DATA_ROOT}/c4-train.00900-00999.npy",
# f"{DATA_ROOT}/c4-train.01000-01023.npy",
]
EVAL_DATA_PATHS = [f"{DATA_ROOT}/c4-validation.00000-00008.npy"]
# docs: start-define-config
@dataclass
class ExperimentConfig(Config):
model: TransformerConfig
"""Model config."""
dataset: NumpyDatasetConfig
"""Dataset config."""
data_loader: NumpyDataLoaderConfig
"""Data loader config."""
trainer: TrainerConfig
"""Trainer config."""
train_module: TransformerTrainModuleConfig
"""Train module config. Contains settings for optimizer."""
init_seed: int = 12536
"""Random seed to initialize model weights."""
load_path: Optional[str] = None
"""Path to load checkpoint from if no checkpoint is found in the save folder.
Mainly used when you want to fine-tune from a pretrained model."""
load_trainer_state: bool = False
"""Whether to load the trainer state (including data loader state) when loading from `load_path`.
This only makes sense when trainer state is available in the checkpoint and you're resuming
on the same dataset."""
# docs: end-define-config
def train(config: ExperimentConfig):
if get_rank() == 0:
rich.print(config)
# Set RNG states on all devices.
seed_all(config.init_seed)
# docs: start-build-components
# Build components.
model = config.model.build(init_device="meta")
train_module = config.train_module.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(dataset, dp_process_group=train_module.dp_process_group)
trainer = config.trainer.build(train_module, data_loader)
# docs: end-build-components
# Save config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict
# docs: start-load-path
# If we have a load path set and there is no checkpoint in the save folder, load the
# checkpoint from the load path.
if not trainer.no_checkpoints and not trainer.maybe_load_checkpoint() and config.load_path:
log.info(
f"Loading checkpoint from {config.load_path} since no checkpoints were found in the save folder..."
)
trainer.load_checkpoint(config.load_path, load_trainer_state=config.load_trainer_state)
# docs: end-load-path
# Train.
trainer.fit()
def build_config(opts, overrides: List[str]) -> ExperimentConfig:
save_folder = opts.save_folder
if not save_folder:
save_folder = f"/tmp/{opts.run_name}"
work_dir = opts.work_dir
if not work_dir:
work_dir = "/tmp/dataset-cache"
tokenizer_config = TokenizerConfig.gpt2()
# docs: start-model-config
try:
factory = getattr(TransformerConfig, opts.model_factory)
except AttributeError:
raise ValueError(f"Unknown model factory: {opts.model_factory}")
model_config = factory(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
)
# docs: end-model-config
log.info(f"Using data root: {DATA_ROOT}")
dataset_config = NumpyFSLDatasetConfig(
paths=DATA_PATHS,
sequence_length=opts.sequence_length,
tokenizer=tokenizer_config,
work_dir=work_dir,
)
data_loader_config = NumpyDataLoaderConfig(
global_batch_size=256 * 1024, # NOTE: this is specified in tokens, not instances
seed=0,
num_workers=4,
)
train_module_config = TransformerTrainModuleConfig(
rank_microbatch_size=16 * 1024, # NOTE: this is specified in tokens, not instances
max_sequence_length=opts.sequence_length,
optim=AdamWConfig(
lr=1e-3,
group_overrides=[
OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0))
],
),
compile_model=True,
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
max_grad_norm=1.0,
scheduler=CosWithWarmup(warmup_steps=100),
)
trainer_config = (
TrainerConfig(
save_folder=save_folder,
save_overwrite=True,
metrics_collect_interval=5,
cancel_check_interval=5,
)
.with_callback("gpu_monitor", GPUMemoryMonitorCallback())
.with_callback(
"checkpointer",
CheckpointerCallback(
save_interval=1000,
ephemeral_save_interval=100,
save_async=True,
),
)
.with_callback(
"comet",
CometCallback(
name=opts.run_name,
cancel_check_interval=10,
enabled=False, # change to true to enable
),
)
.with_callback(
"wandb",
WandBCallback(
name=opts.run_name,
cancel_check_interval=10,
enabled=False, # change to true to enable
),
)
.with_callback("config_saver", ConfigSaverCallback())
.with_callback("profiler", ProfilerCallback(enabled=False))
.with_callback(
"lm_evaluator",
LMEvaluatorCallbackConfig(
eval_dataset=NumpyPaddedFSLDatasetConfig(
paths=EVAL_DATA_PATHS,
metadata=[{"label": "c4-validation"}],
sequence_length=opts.sequence_length,
tokenizer=tokenizer_config,
work_dir=work_dir,
),
eval_interval=250,
eval_duration=Duration.steps(50),
),
)
.with_callback(
"downstream_evaluator",
DownstreamEvaluatorCallbackConfig(
tasks=["hellaswag"],
tokenizer=tokenizer_config,
eval_interval=250,
),
)
.with_callback(
"gap_monitor",
GAPMonitorCallback(enabled=False, dump_gradients=False),
)
)
config = ExperimentConfig(
model=model_config,
dataset=dataset_config,
data_loader=data_loader_config,
train_module=train_module_config,
trainer=trainer_config,
)
# Apply overrides.
# docs: start-config-merge
config = config.merge(overrides)
# docs: end-config-merge
return config
def parser_args():
parser = argparse.ArgumentParser(
prog=sys.argv[0],
usage=f"python {sys.argv[0]} RUN_NAME [OPTIONS...] [CONFIG_OVERRIDES...]",
description="Train a transformer language model on c4.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("run_name", type=str, help="""The name of the run.""")
parser.add_argument(
"--model-factory",
type=str,
default="llama2_271M",
help="""The name of the model factory to use.
This can be any classmethod on the TransformerConfig class.""",
)
parser.add_argument(
"--sequence-length",
type=int,
default=2048,
help="""The sequence length to train and eval on.""",
)
parser.add_argument(
"--save-folder",
type=str,
help="""A local or remote directory to save checkpoints to.
Defaults to a temporary directory if not provided.""",
)
parser.add_argument(
"--work-dir",
type=str,
help="""A local working directory for dataset preprocessing.
Defaults to a temporary directory if not provided.""",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="""Print the config and exit.""",
)
opts, overrides = parser.parse_known_args()
return opts, overrides
def main():
opts, overrides = parser_args()
config = build_config(opts, overrides)
if opts.dry_run:
rich.print(config)
return
prepare_training_environment()
train(config)
teardown_training_environment()
if __name__ == "__main__":
main()