distributed.checkpoint

A high-level distributed checkpointing module with a unified API for saving and loading both local and remote checkpoints.

Features

  • Save with one distributed topology, seamlessly load with a different one. For example, with FSDP/FSDP2 you can save/load checkpoints with different world sizes or sharding strategies.

  • Save/load directly to/from a remote object store like S3 or GCS. When loading from a remote object store each rank only downloads the fraction of the data it needs for its local (potentially sharded) tensors.

Overview

Use save_model_and_optim_state() to write a checkpoint with your model and optimizer’s state, then use load_model_and_optim_state() to load the checkpoint in-place.

You can unshard a checkpoint saved this way with unshard_checkpoint().

API Reference

olmo_core.distributed.checkpoint.save_state_dict(dir, state_dict, *, process_group=None, save_overwrite=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False, _skip_prepare=False)[source]

Save an arbitrary state dictionary to a distributed format that can loaded again with a different distributed topology.

Important

Please use save_model_and_optim_state() to save model/optimizer state dicts instead unless you know what you’re doing.

Parameters:
  • dir (Union[Path, PathLike, str]) – Path/URL to save to.

  • state_dict (Dict[str, Any]) – The state dict to save.

  • process_group (Optional[ProcessGroup], default: None) – The process group to use for distributed collectives.

  • save_overwrite (bool, default: False) – Overwrite existing files.

  • thread_count (Optional[int], default: None) – Set this to override the number of threads used while writing data.

  • process_count (Optional[int], default: None) – Set this to use a process pool instead of a thread pool when possible (currently not compatible with throttle_uploads).

  • throttle_uploads (bool, default: False) – If this is set to True and dir is a URL then only one rank from each node will upload data at a time.

olmo_core.distributed.checkpoint.async_save_state_dict(dir, state_dict, *, process_group=None, save_overwrite=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False, _skip_prepare=False)[source]

An async version of save_state_dict().

This code first de-stages the state dict on the CPU, then writes it in a separate thread.

Return type:

Future[None]

olmo_core.distributed.checkpoint.load_state_dict(dir, state_dict, *, process_group=None, pre_download=False, work_dir=None, thread_count=None)[source]

Load an arbitrary state dict in-place from a checkpoint saved with save_state_dict().

Parameters:
  • dir (Union[Path, PathLike, str]) – Path/URL to the checkpoint saved via save_state_dict().

  • state_dict (Dict[str, Any]) – The state dict to load the state into.

  • process_group (Optional[ProcessGroup], default: None) – The process group to use for distributed collectives.

  • thread_count (Optional[int], default: None) – Set the number of threads used for certain operations.

olmo_core.distributed.checkpoint.save_model_and_optim_state(dir, model, optim=None, *, process_group=None, save_overwrite=False, flatten_optimizer_state=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False)[source]

Save model and optimizer state dictionaries. The model state can be a sharded model, in which case this method will correctly handle the optimizer state to ensure it can be loaded again with a different distributed topology through load_model_and_optim_state().

Tip

With FullyShardedDataParallel models it’s not necessary to set the state dict type before calling this (or load_model_and_optim_state()) via state_dict_type() or other methods. This function handles that internally.

Parameters:
  • dir (Union[Path, PathLike, str]) – Path/URL to save to.

  • model (Module) – The model to save state from.

  • optim (Optional[Optimizer], default: None) – The optimizer to save state from.

  • process_group (Optional[ProcessGroup], default: None) – The process group to use for distributed collectives.

  • save_overwrite (bool, default: False) – Overwrite existing files.

  • flatten_optimizer_state (bool, default: False) – Flatten the optimizer state before saving. This should match the setting used when loading the state dict and is needed in a distributed setting when the params in some param groups may differ between ranks, such as with pipeline parallelism.

  • thread_count (Optional[int], default: None) – Set this to override the number of threads used while writing data.

  • process_count (Optional[int], default: None) – Set this to use a process pool instead of a thread pool when possible (currently not compatible with throttle_uploads).

  • throttle_uploads (bool, default: False) – If this is set to True and dir is a URL then only one rank from each node will upload data at a time.

Raises:

FileExistsError – If the checkpoint dir exists and is non-empty unless save_overwrite=True.

Return type:

None

olmo_core.distributed.checkpoint.async_save_model_and_optim_state(dir, model, optim=None, *, process_group=None, save_overwrite=False, flatten_optimizer_state=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False)[source]

An async version of save_model_and_optim_state().

This code first de-stages the state dict on the CPU, then writes it in a separate thread.

Return type:

Future[None]

olmo_core.distributed.checkpoint.load_model_and_optim_state(dir, model, optim=None, *, process_group=None, key_mapping=None, pre_download=False, work_dir=None, strict=True, flatten_optimizer_state=False, thread_count=None)[source]

Load model and optimizer state in-place from a checkpoint saved via save_model_and_optim_state(). This method is agnostic to the distributed topology in that it can load checkpoints saved with a different distributed topology (e.g. FSDP/FSDP2, DDP).

Tip

With FullyShardedDataParallel models it’s not necessary to set the state dict type before calling this (or save_model_and_optim_state()) via state_dict_type() or other methods. This function handles that internally.

Warning

Due to the way torch.distributed.checkpoint works, if you have keys in the checkpoint dict that are not present in the current state of the model or optimizer, those keys won’t be loaded.

For example, if you added a custom field to one of your optimizer’s param groups before saving the checkpoint, but don’t have that field in the param group of the optimizer you’re loading into, it won’t be added.

This can cause unexpected behavior if you’re not careful. In this case the best thing to do is to ensure all keys are in present param groups when you initialize the optimizer, before saving or loading a checkpoint.

Parameters:
  • dir (Union[Path, PathLike, str]) – Path/URL to the checkpoint saved via save_model_and_optim_state().

  • model (Module) – The model to load the state into.

  • optim (Optional[Optimizer], default: None) – The optimizer to load the state into.

  • process_group (Optional[ProcessGroup], default: None) – The process group to use for distributed collectives.

  • key_mapping (Optional[Dict[str, str]], default: None) – Can be used to load a checkpoint where certain parameter have different names. This dictionary should map current keys to keys in the checkpoint to be loaded.

  • pre_download (bool, default: False) – Download and cache relevant remote checkpoint files before trying to read from them.

  • work_dir (Union[Path, PathLike, str, None], default: None) – A working directory for caching files/directories.

  • strict (bool, default: True) – Load keys strictly.

  • flatten_optimizer_state (bool, default: False) – Flatten the optimizer state when loading. This should match the setting used when saving the state dict and is needed in a distributed setting when the params in some param groups may differ between ranks, such as with pipeline parallelism.

  • thread_count (Optional[int], default: None) – Set the number of threads used for certain operations.

olmo_core.distributed.checkpoint.unshard_checkpoint(dir, target_dir, *, optim=None, save_overwrite=False, use_safetensors=False, unshard_strategy=None, pre_download=False, work_dir=None, quiet=False)[source]

Convert a checkpoint saved via save_model_and_optim_state() into unsharded model and optimizer checkpoint files that can be loaded directly with torch.load() or safetensors if use_safetensors=True.

Warning

The safetensors format cannot be used to save optimizer state, since optimizer state can contain arbitrary Python objects that need to be pickled. Therefore optim=True and use_safetensors=True is incompatible.

Warning

This should only be called in a non-distributed context. Otherwise a RuntimeError is raised.

See also

load_keys() if you only need to load and unshard certain keys in the checkpoint.

Parameters:
  • dir (Union[Path, PathLike, str]) – The path/URL to the original checkpoint created via save_model_and_optim_state().

  • target_dir (Union[Path, PathLike, str]) – The directory to save the unsharded model/optimizer checkpoint files to. This must be a local directory. URLs are not supported.

  • optim (Optional[bool], default: None) – Whether to unshard the optimizer state. This defaults to True as long as use_safetensors=False.

  • save_overwrite (bool, default: False) – Overwrite any existing files in target_dir.

  • use_safetensors (bool, default: False) – Save the unsharded files with safetensors.torch.save_file() instead of torch.save().

  • unshard_strategy (Optional[UnshardStrategy], default: None) – The strategy to use. Defaults to UnshardStrategy.one_file().

  • pre_download (bool, default: False) – Download and cache relevant remote checkpoint files before trying to read from them.

  • work_dir (Union[Path, PathLike, str, None], default: None) – A working directory for caching files/directories.

  • quiet (bool, default: False) – Do not show progress messages.

Return type:

Tuple[Path, Optional[Path]]

Returns:

The path to the unsharded model checkpoint and the path to the unsharded optimizer checkpoint if optim=True. These paths may represent files or directories depending on the unshard_strategy.

Raises:

FileExistsError – If the target_dir is non-empty and save_overwrite=False.

olmo_core.distributed.checkpoint.load_keys(dir, keys, *, pre_download=False, work_dir=None)[source]

Load specific keys from a checkpoint.

Warning

This should only be called in a non-distributed context. Otherwise a RuntimeError is raised.

Parameters:
Return type:

Generator[Any, None, None]

Returns:

The (unsharded) objects from the checkpoint corresponding to the given keys, in the same order as the keys.

olmo_core.distributed.checkpoint.get_checkpoint_metadata(dir)[source]

Load the metadata from a checkpoint.

Parameters:

dir (Union[Path, PathLike, str]) – The path/URL to the checkpoint.

Return type:

Metadata

class olmo_core.distributed.checkpoint.UnshardStrategy(name='one_file', chunk_size_bytes=None)[source]

Bases: object

Unsharding strategy config for unshard_checkpoint().

name: UnshardStrategyType = 'one_file'

The strategy type.

chunk_size_bytes: Optional[int] = None

The approximate max chunk size (per file size), in bytes, for the UnshardStrategyType.chunks strategy.

classmethod one_file()[source]

Use the UnshardStrategy.one_file strategy.

Return type:

UnshardStrategy

classmethod one_file_per_tensor()[source]

Use the UnshardStrategy.one_file_per_tensor strategy.

Return type:

UnshardStrategy

classmethod chunks(chunk_size_in_bytes)[source]

Use the UnshardStrategy.chunks strategy.

Return type:

UnshardStrategy

class olmo_core.distributed.checkpoint.UnshardStrategyType(value)[source]

Bases: StrEnum

An enumeration of the unsharding strategies that can be used with unshard_checkpoint().

one_file = 'one_file'

Save the unsharded model state into a one file, and optionally the optimizer state into another file. The bigger the model, the more memory this requires. For very big models, one_file_per_tensor will scale better.

one_file_per_tensor = 'one_file_per_tensor'

Save each unsharded tensor to its own file. Currently this is not compatible with optimizer state.

chunks = 'chunks'

Like one_file_per_tensor but multiple tensors and objects may be grouped into the same file up to the limit defined by UnshardStrategy.chunk_size_bytes.

olmo_core.distributed.checkpoint.prune_state_dict(state_dict, allowed_keys)[source]

Prune a state dict by removing all keys not in allowed_keys.

Return type:

Set[str]

Returns:

The keys that were pruned.

olmo_core.distributed.checkpoint.merge_state_dicts(lhs, rhs)[source]

Merge rhs state dict into lhs.