import concurrent.futures
import json
import logging
import math
import typing
from abc import ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
import rich
from cached_path import cached_path
import olmo_core.distributed.utils as dist_utils
import olmo_core.eval.task_groups as task_groups
import olmo_core.io as io
import olmo_core.train.callbacks as callbacks
from olmo_core.aliases import PathOrStr
from olmo_core.config import Config
from olmo_core.data import DataMix, NumpyPaddedFSLDatasetConfig, TokenizerConfig
from olmo_core.data.composable import (
ComposableDataLoaderConfig,
InstanceSourceConfig,
set_composable_seed,
)
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.nn.config import ModelConfig
from olmo_core.optim import OptimConfig, Scheduler
from olmo_core.train import (
Checkpointer,
Duration,
DurationUnit,
TrainerConfig,
prepare_training_environment,
teardown_training_environment,
)
from olmo_core.train.train_module import TrainModule
from olmo_core.utils import warn_once
from .utils import format_count, format_tokens
if TYPE_CHECKING:
from pandas import DataFrame
log = logging.getLogger(__name__)
[docs]
class DeviceMeshSpec(NamedTuple):
"""
Describes the relevant dimensions of a device mesh needed to train a model of a certain size.
"""
world_size: int
"""The mininum numbers of devices required."""
dp_world_size: int | None
"""
The mininum size of the data parallel group. This can be set to ``None`` if the data parallel
world size should equal the world size. This, along with the per-device micro-batch size, is
needed to determine the right global batch size.
"""
[docs]
@dataclass(frozen=True)
class RunCheckpointInfo:
"""Describes a checkpoint from a model run."""
name: str
"""A descriptive name for the checkpoint, assigned by the :class:`RunConfigurator`."""
step: int
"""The training step number of the checkpoint."""
tokens: int
"""The number of training tokens processed up to this checkpoint."""
path: PathOrStr
"""A path to the checkpoint directory."""
metrics_path: PathOrStr | None
"""A path to the metrics JSON file for this checkpoint, if it exists."""
exists: bool
"""Whether the checkpoint actually exists."""
[docs]
def display(self) -> str:
"""Get a rich-formatted string representation of the checkpoint info."""
info = f"Step {self.step:,d} ({format_tokens(self.tokens)}) [b cyan]{self.name}[/]"
if self.exists:
out = f"[b green]✔[/] {info}\n ↳ checkpoint: [u blue]{self.path}[/]"
if self.metrics_path is not None:
out += f"\n ↳ metrics: [u blue]{self.metrics_path}[/]"
return out
else:
return f"[b yellow]✘[/] {info}"
M = TypeVar("M", bound=ModelConfig)
[docs]
@dataclass(kw_only=True)
class ModelConfigurator(Config, Generic[M], metaclass=ABCMeta):
"""
Defines how to configure a model of a particular size.
"""
[docs]
@abstractmethod
def build_train_module(
self,
*,
size_spec: str,
sequence_length: int,
rank_microbatch_size: int,
model_config: M,
optim_config: OptimConfig,
scheduler: Scheduler,
device_type: str,
) -> TrainModule:
"""Build the train module for the given model and optimizer configs."""
raise NotImplementedError
[docs]
@dataclass(kw_only=True)
class RunConfigurator(Config, metaclass=ABCMeta):
"""
Defines how to configure a run for a model of a particular size.
"""
[docs]
@abstractmethod
def plot_lr_schedule(
self,
num_params: int,
batch_size: int,
*,
show: bool = True,
save_path: PathOrStr | None = None,
) -> PathOrStr | None:
"""Render a plot of the learning rate schedule."""
raise NotImplementedError
[docs]
@dataclass(kw_only=True)
class ModelLadder(Config):
"""
Represents a complete model ladder of runs.
"""
name: str
"""A name to assign to the ladder."""
dir: str
"""A unique directory where ladder run results and intermediate checkpoints should be saved."""
project: str | None = None
"""
An optional project name to associate with the ladder runs. Defaults to :data:`name`.
This is used by some logging backends (e.g. Weights & Biases).
"""
sizes: list[str]
"""A list of model size specs to run as part of the ladder."""
max_devices: int
"""The number of accelerator devices available to use for each run."""
device_type: str
"""The type of accelerator device available to use for each run (e.g. "NVIDIA H100 80GB HBM3")."""
model_configurator: ModelConfigurator
"""The model configurator to use."""
run_configurator: RunConfigurator
"""The run configurator to use."""
data_loader: ComposableDataLoaderConfig
"""The data loader configuration to use for each run."""
instance_sources: list[InstanceSourceConfig]
"""The instance sources to use for each run."""
sequence_length: int = 8192
"""The sequence length to train each run on."""
tokenizer: TokenizerConfig
"""The tokenizer to use."""
seed: int = 42
"""The initial random seed to use for all runs in the ladder."""
backend: str = "cpu:gloo,cuda:nccl"
"""The distributed backend to use for each run."""
def __post_init__(self):
if self.max_devices <= 0:
raise OLMoConfigurationError("max_devices must be a positive integer.")
for size_spec in self.sizes:
min_devices, _ = self.model_configurator.configure_minimal_device_mesh_spec(
size_spec=size_spec,
sequence_length=self.sequence_length,
device_type=self.device_type,
)
if min_devices > self.max_devices:
raise OLMoConfigurationError(
f"Model of size {size_spec} requires at least {min_devices} devices, "
f"but max_devices is set to {self.max_devices}."
)
@property
def work_dir(self) -> PathOrStr:
return "./cache" if io.is_url(self.dir) else str(io.join_path(self.dir, "cache"))
[docs]
def dry_run(self, size_spec: str, show_plot: bool = True, save_plot: PathOrStr | None = None):
"""
Do a dry-run, which prints relevant hyperparameters, the required number of devices,
and a displays a plot of the learning rate schedule.
"""
if size_spec not in self.sizes:
raise ValueError(f"Invalid size_spec '{size_spec}', must be one of {self.sizes}")
num_params = self.get_num_params(size_spec)
# Configure global batch size, make sure request number of devices matches the number
# of devices available.
target_global_batch_size = self.run_configurator.configure_target_batch_size(num_params)
(
global_batch_size,
rank_microbatch_size,
requested_devices,
dp_world_size,
) = self._configure_batch_size_and_num_devices(size_spec, num_params)
assert rank_microbatch_size % self.sequence_length == 0
assert global_batch_size % rank_microbatch_size == 0
assert global_batch_size % self.sequence_length == 0
assert rank_microbatch_size % self.sequence_length == 0
assert global_batch_size % (rank_microbatch_size * dp_world_size) == 0
num_grad_accum_steps = global_batch_size // (rank_microbatch_size * dp_world_size)
rich.get_console().print(
f"Dry run for model size {size_spec}:\n"
f" ❯ Actual number of non-embedding params is {format_count(num_params)}\n"
f" ❯ Target batch size is {target_global_batch_size:,d} tokens\n"
f" ❯ Actual batch size is {global_batch_size:,d} tokens, which is "
f"{global_batch_size // self.sequence_length:,d} instance(s)\n"
f" ❯ Micro-batch size per device size {rank_microbatch_size:,d} tokens, which is "
f"{rank_microbatch_size // self.sequence_length} instance(s)\n"
f" ❯ And the run requires {requested_devices} out of {self.max_devices} devices, "
f"with a data-parallel world size of {dp_world_size:,d}\n"
f" ❯ So there will be {num_grad_accum_steps:,d} grad accumulation step(s) per batch",
highlight=False,
)
if show_plot or save_plot is not None:
log.info("Plotting LR schedule...")
path = self.run_configurator.plot_lr_schedule(
num_params, global_batch_size, show=show_plot, save_path=save_plot
)
if path is not None:
log.info(f"Saved LR schedule plot to '{path}'")
[docs]
def run(self, size_spec: str, for_benchmarking: bool = False):
"""
Execute a particular model run of the experiment locally and store the results.
"""
if size_spec not in self.sizes:
raise ValueError(f"Invalid size_spec '{size_spec}', must be one of {self.sizes}")
prepare_training_environment(seed=self.seed, backend=self.backend)
set_composable_seed(self.seed)
# Configure model.
model_config = self.get_model_config(size_spec)
num_params = model_config.num_non_embedding_params
# Configure global batch size, make sure request number of devices matches the number
# of devices available.
(
global_batch_size,
rank_microbatch_size,
requested_devices,
_,
) = self._configure_batch_size_and_num_devices(size_spec, num_params)
if requested_devices != dist_utils.get_world_size():
raise OLMoConfigurationError(
f"Requested {requested_devices} devices for model of size '{size_spec}', "
f"but {dist_utils.get_world_size()} are available."
)
# Configure optimizer and scheduler.
optim_config = self.run_configurator.configure_optimizer(num_params, global_batch_size)
scheduler = self.run_configurator.configure_lr_scheduler(num_params, global_batch_size)
# Configure trainer.
trainer_config = self._configure_trainer(size_spec, for_benchmarking=for_benchmarking)
# Build instance sources and data loader.
instance_sources = [
source.build(work_dir=self.work_dir) for source in self.instance_sources
]
data_loader = self.data_loader.build(
*instance_sources,
work_dir=self.work_dir,
global_batch_size=global_batch_size,
tokenizer=self.tokenizer,
)
if data_loader.sequence_length != self.sequence_length:
raise OLMoConfigurationError(
f"Data loader sequence of {data_loader.sequence_length} does not match "
f"configured sequence length of {self.sequence_length}."
)
# Build train module.
train_module = self.model_configurator.build_train_module(
size_spec=size_spec,
sequence_length=self.sequence_length,
rank_microbatch_size=rank_microbatch_size,
model_config=model_config,
optim_config=optim_config,
scheduler=scheduler,
device_type=self.device_type,
)
# Build trainer.
trainer = trainer_config.build(train_module, data_loader)
# Record all configs.
config_dict = {
"seed": self.seed,
"size": str(size_spec),
"model": model_config.as_config_dict(),
"optim": optim_config.as_config_dict(),
"scheduler": scheduler.as_config_dict(),
"data_loader": self.data_loader.as_config_dict(),
"instance_sources": [s.as_config_dict() for s in self.instance_sources],
}
typing.cast(
callbacks.ConfigSaverCallback, trainer.callbacks["config_saver"]
).config = config_dict
# Train.
trainer.fit()
teardown_training_environment()
[docs]
def run_benchmark(self, size_spec: str):
"""
Do a bench-marking run for a model of the given size spec. This is just like
:meth:`run`, but with benchmarking-specific settings (no checkpoints, no evals, hard stop).
"""
self.run(size_spec, for_benchmarking=True)
[docs]
def get_model_config(self, size_spec: str) -> ModelConfig:
"""Get the model config for a model of the given size spec."""
return self.model_configurator.configure_model(
size_spec=size_spec,
sequence_length=self.sequence_length,
tokenizer=self.tokenizer,
device_type=self.device_type,
)
[docs]
def get_num_params(self, size_spec: str):
"""Get the actual number of non-embedding parameters for a model of the given size spec."""
return self.get_model_config(size_spec).num_non_embedding_params
[docs]
def get_num_devices(self, size_spec: str) -> int:
"""Get the number of devices that would be used for a run of the given size spec."""
_, _, num_devices, _ = self._configure_batch_size_and_num_devices(
size_spec, self.get_num_params(size_spec)
)
return num_devices
[docs]
def get_save_folder(self, size_spec: str) -> str:
"""Get the training save folder for a run of the given size spec."""
return str(io.join_path(self.dir, size_spec))
[docs]
def get_checkpoints(
self,
size_spec: str,
download_metrics: bool = False,
discover_all: bool = False,
alternative_dirs: list[PathOrStr] | None = None,
) -> list[RunCheckpointInfo]:
"""
Get the list of ordered checkpoints from the run for the given size spec.
:param size_spec: The size specification for the model run.
:param download_metrics: If ``True``, download metrics files to local cache.
:param discover_all: If ``True``, discover all checkpoints that exist in the save folder
rather than only checking at the intervals defined by
:meth:`RunConfigurator.configure_checkpoint_intervals()`.
:param alternative_dirs: Optional list of alternative root directories to search for checkpoints.
The size_spec is appended to each directory. For each checkpoint, the primary save
directory is checked first, then each alternative directory in order until found.
"""
save_folder = self.get_save_folder(size_spec)
folders_to_search: list[str] = [save_folder] + [
str(io.join_path(d, size_spec)) for d in (alternative_dirs or [])
]
for folder in folders_to_search:
io.init_client(folder)
num_params = self.get_num_params(size_spec)
global_batch_size, *_ = self._configure_batch_size_and_num_devices(size_spec, num_params)
def _find_checkpoint_path(step: int) -> tuple[PathOrStr, bool]:
"""Find the checkpoint path across all folders, returning (path, exists)."""
dirname = Checkpointer.checkpoint_dirname(step)
for folder in folders_to_search:
path = io.join_path(folder, dirname)
if Checkpointer.dir_is_checkpoint(path):
return path, True
# Not found in any folder, return the primary path
return io.join_path(save_folder, dirname), False
def _find_metrics_path(step: int) -> PathOrStr | None:
"""Find the metrics file across all folders."""
metrics_filename = f"metrics_step{step}.json"
for folder in folders_to_search:
metrics_path = io.join_path(folder, metrics_filename)
if io.file_exists(metrics_path):
return metrics_path
return None
def _get_checkpoint_info(
step: int, name: str, path: PathOrStr | None = None
) -> RunCheckpointInfo:
if path is None:
path, exists = _find_checkpoint_path(step)
else:
exists = Checkpointer.dir_is_checkpoint(path)
metrics_path = _find_metrics_path(step)
if metrics_path is not None and download_metrics:
metrics_path = cached_path(metrics_path, quiet=True)
return RunCheckpointInfo(
name=name,
step=step,
tokens=step * global_batch_size,
path=path,
metrics_path=metrics_path,
exists=exists,
)
checkpoints_to_check: dict[int, tuple[str, PathOrStr | None]] = {}
path: PathOrStr | None # help mypy
if discover_all:
# Discover all checkpoints that exist across all folders.
for folder in folders_to_search:
try:
for step, path in Checkpointer.find_checkpoints(folder):
if step not in checkpoints_to_check:
checkpoints_to_check[step] = (f"step-{step}", path)
except FileNotFoundError:
continue
else:
# Only check at the expected checkpoint intervals.
checkpoints_to_check[0] = ("initialization", None)
for step, (_, checkpoint_name) in zip(
self._get_checkpoint_intervals(
num_params=num_params, global_batch_size=global_batch_size
),
self.run_configurator.configure_checkpoint_intervals(num_params, global_batch_size),
):
checkpoints_to_check[step] = (checkpoint_name, None)
step_to_checkpoint_info: dict[int, RunCheckpointInfo] = {}
with ThreadPoolExecutor() as executor:
futures = []
for step, (name, path) in checkpoints_to_check.items():
futures.append(executor.submit(_get_checkpoint_info, step, name, path))
for future in concurrent.futures.as_completed(futures):
info = future.result()
step_to_checkpoint_info[info.step] = info
return [step_to_checkpoint_info[step] for step in sorted(step_to_checkpoint_info.keys())]
[docs]
def get_metrics(
self,
size_spec: str,
prefix: str | None = None,
discover_all: bool = False,
alternative_dirs: list[PathOrStr] | None = None,
) -> "DataFrame | None":
"""
Get the metrics from the run of the given size spec.
:param size_spec: The size specification for the model run.
:param prefix: If provided, only include metrics with keys starting with this prefix.
:param discover_all: If ``True``, discover all checkpoints that exist in the save folder
rather than only checking at the intervals defined by
:meth:`RunConfigurator.configure_checkpoint_intervals()`.
:param alternative_dirs: Optional list of alternative root directories to search for
checkpoints and metrics files. The size_spec is appended to each directory.
"""
import pandas as pd
checkpoints = self.get_checkpoints(
size_spec,
download_metrics=True,
discover_all=discover_all,
alternative_dirs=alternative_dirs,
)
num_params = self.get_num_params(size_spec)
all_metrics = []
for checkpoint in checkpoints:
if checkpoint.metrics_path is not None:
with open(checkpoint.metrics_path, "r") as f:
metrics = json.load(f)
if prefix is not None:
metrics = {k: v for k, v in metrics.items() if k.startswith(prefix)}
metrics["name"] = checkpoint.name
metrics["step"] = checkpoint.step
metrics["tokens"] = checkpoint.tokens
metrics["size"] = size_spec
metrics["num_params"] = num_params
all_metrics.append(metrics)
if all_metrics:
df = pd.DataFrame(all_metrics)
return df
else:
return None
def _get_checkpoint_intervals(self, *, num_params: int, global_batch_size: int) -> list[int]:
return [
self._duration_to_steps(d, global_batch_size)
for d, _ in self.run_configurator.configure_checkpoint_intervals(
num_params, global_batch_size
)
]
def _duration_to_steps(self, d: Duration, global_batch_size: int) -> int:
if d.unit == DurationUnit.steps:
return d.value
elif d.unit == DurationUnit.tokens:
steps = d.value // global_batch_size
return steps
else:
raise ValueError(f"Unsupported checkpoint interval duration unit: {d.unit}.")
def _configure_batch_size_and_num_devices(
self, size_spec: str, num_params: int
) -> tuple[int, int, int, int]:
# Configure global batch size and device micro-batch size.
target_global_batch_size = self.run_configurator.configure_target_batch_size(num_params)
rank_microbatch_size = self.model_configurator.configure_rank_microbatch_size(
size_spec=size_spec,
sequence_length=self.sequence_length,
device_type=self.device_type,
)
rank_microbatch_size = min(rank_microbatch_size, target_global_batch_size)
# Configure minimal device mesh spec, i.e. the minimum number of devices needed and the
# corresponding minimum data parallel world size.
(
min_world_size,
min_dp_world_size,
) = self.model_configurator.configure_minimal_device_mesh_spec(
size_spec=size_spec,
sequence_length=self.sequence_length,
device_type=self.device_type,
)
if min_dp_world_size is None:
min_dp_world_size = min_world_size
if min_world_size % min_dp_world_size != 0:
raise OLMoConfigurationError(
f"Invalid device mesh spec for model of size '{size_spec}': "
f"minimum world size {min_world_size} is not divisible by "
f"the minimum data parallel world size {min_dp_world_size}."
)
if self.max_devices < min_world_size:
raise OLMoConfigurationError(
f"Not enough devices ({self.max_devices}) to run model of size '{size_spec}' "
f"which requires at least {min_world_size} devices."
)
# And from that we adjust the global batch size to be a multiple of
# `rank_microbatch_size x min_dp_world_size`.
gbz_factor = rank_microbatch_size * min_dp_world_size
global_batch_size = max(1, round(target_global_batch_size / gbz_factor)) * gbz_factor
# Then we can determine the actual number of devices to allocate to the run. In particular
# we can expand `min_world_size` up to the number of devices available (`self.max_devices`)
# by a factor that's just the number of gradient accumulation steps needed with the minimum
# requested number of devices.
max_num_grad_accum_steps = global_batch_size // gbz_factor
expansion_factor = min(self.max_devices // min_world_size, max_num_grad_accum_steps)
num_devices = min_world_size * expansion_factor
dp_world_size = min_dp_world_size * expansion_factor
# Finally we ensure `global_batch_size` is divisible by the micro-batch size.
microbatch_size = rank_microbatch_size * dp_world_size
global_batch_size = max(1, round(global_batch_size / microbatch_size)) * microbatch_size
# Warn if final global batch size is more than 10% different from target.
if (
pct_diff := (
math.fabs(global_batch_size - target_global_batch_size) / target_global_batch_size
)
) > 0.1:
msg = (
f"Global batch size to use ({format_tokens(global_batch_size)}) "
f"differs from target global batch size ({format_tokens(target_global_batch_size)}) "
f"by ~{100 * pct_diff:.1f}%."
)
# In most cases the discrepancy is due to the rank micro-batch size being too large.
# So we can suggest making it smaller if the micro-batch size per rank is more than one
# instance.
if rank_microbatch_size // self.sequence_length > 1:
msg += (
f"\nConsider decreasing the configured rank micro-batch size, which is set by the "
f"method {self.model_configurator.__class__.__name__}.configure_rank_microbatch_size()."
)
warn_once(msg, UserWarning)
return global_batch_size, rank_microbatch_size, num_devices, dp_world_size
def _configure_trainer(
self,
size_spec: str,
for_benchmarking: bool = False,
) -> TrainerConfig:
run_name = f"{self.name}-{size_spec}"
save_folder = self.get_save_folder(size_spec)
num_params = self.get_num_params(size_spec)
global_batch_size, *_ = self._configure_batch_size_and_num_devices(size_spec, num_params)
duration = self.run_configurator.configure_duration(num_params, global_batch_size)
checkpoint_interval_steps = self._get_checkpoint_intervals(
num_params=num_params, global_batch_size=global_batch_size
)
return TrainerConfig(
save_folder=save_folder,
work_dir=str(self.work_dir),
metrics_collect_interval=10,
cancel_check_interval=10,
max_duration=duration,
hard_stop=Duration.steps(100) if for_benchmarking else None,
no_checkpoints=for_benchmarking,
no_evals=for_benchmarking,
save_overwrite=True,
callbacks={
"gpu_monitor": callbacks.GPUMemoryMonitorCallback(),
"config_saver": callbacks.ConfigSaverCallback(),
"garbage_collector": callbacks.GarbageCollectorCallback(),
"checkpointer": callbacks.CheckpointerCallback(
ephemeral_save_interval=1000,
ephemeral_cooldown=250,
save_interval=None,
save_async=True,
fixed_steps=checkpoint_interval_steps,
enabled=not for_benchmarking,
),
"profiler": callbacks.ProfilerCallback(enabled=for_benchmarking),
"gap_monitor": callbacks.GAPMonitorCallback(enabled=False, interval=10),
"slack_notifier": callbacks.SlackNotifierCallback(name=run_name, enabled=False),
"beaker": callbacks.BeakerCallback(),
"wandb": callbacks.WandBCallback(
name=run_name,
group=self.name,
project=self.project or self.name,
cancel_check_interval=50,
enabled=not for_benchmarking,
),
"lm_evaluator": callbacks.LMEvaluatorCallbackConfig(
eval_dataset=NumpyPaddedFSLDatasetConfig.from_data_mix(
DataMix.v3_small_ppl_validation,
mix_base_dir=self._get_mix_base_dir(),
sequence_length=self.sequence_length,
tokenizer=self.tokenizer,
work_dir=str(self.work_dir),
),
eval_interval=None,
fixed_steps=checkpoint_interval_steps,
enabled=not for_benchmarking,
),
"downstream_evaluator": callbacks.DownstreamEvaluatorCallbackConfig(
tokenizer=self.tokenizer,
tasks=self._get_in_loop_eval_tasks(),
eval_interval=None,
fixed_steps=checkpoint_interval_steps,
enabled=not for_benchmarking,
),
"metric_saver": callbacks.MetricSaverCallback(
fixed_steps=checkpoint_interval_steps,
enabled=not for_benchmarking,
),
},
)
def _get_in_loop_eval_tasks(self) -> list[str]:
return sorted(task_groups.FULL_TASKS)
def _get_mix_base_dir(self) -> str:
if self.dir.startswith("/weka/"):
return "/weka/oe-training-default/ai2-llm"
else:
return "gs://ai2-llm"