distributed.checkpoint#

A low-overhead, fast, distributed checkpointing module with a unified API for saving and loading both local and remote checkpoints. Built on top of safetensors and inspired by torch.distributed.checkpoint, but better suited for handling distributed models and optimizer state without unnecessary distributed communication and GPU allocations.

Features#

  • Sharded distributed models, such OLMo-core’s FSDP or PyTorch’s FullyShardedDataParallel (with use_orig_params=True) are supported out-of-the-box.

  • Utilizes safetensors under the hood for fast, efficient, and safe serialization/deserialization.

  • Save with one distributed topology, seamlessly load with a different one. For example, with FSDP you can save/load checkpoints with different world sizes or wrapping strategies.

  • Save/load directly to/from a remote object store like S3 or GCS. When loading from a remote object store each rank only downloads the fraction of the data it needs for its local (potentially sharded) tensors.

  • Checkpoints are always loaded in-place and one tensor at a time to avoid unnecessary allocations. This results in virtually no additional memory overhead.

Overview#

Use save_model_and_optim_state() to write a checkpoint with your model and optimizer’s state, then use load_model_and_optim_state() to load the checkpoint in-place. You can also generate unsharded, full state dictionaries from a checkpoint with unshard_model_state() and unshard_optim_state().

API Reference#

olmo_core.distributed.checkpoint.save_model_and_optim_state(dir, model, optim, save_overwrite=False)[source]#

Save model and optimizer state dictionaries. The model state can be a sharded model, in which case this method will correctly handle the optimizer state to ensure it can be loaded again with a different distributed topology through load_model_and_optim_state().

Returns all of the files created by the current rank.

Tip

With FullyShardedDataParallel models it’s not necessary to set the state dict type before calling this (or load_model_and_optim_state()) via state_dict_type() or other methods. In fact those settings will always be ignored.

Attention

At the moment FullyShardedDataParallel models must have use_orig_params=True.

Parameters:
  • dir (Union[Path, PathLike, str]) – Path/URL to save to.

  • model (Module) – The model to save state from.

  • optim (Optimizer) – The optimizer to save state from.

  • save_overwrite (bool, default: False) – Overwrite existing files.

Return type:

List[Union[Path, PathLike, str]]

olmo_core.distributed.checkpoint.load_model_and_optim_state(dir, model, optim=None, validate=True)[source]#

Load model and optimizer state in-place from a checkpoint saved via save_model_and_optim_state(). This method is agnostic to the distributed topology in that it can load checkpoints saved with a different distributed topology (e.g. FSDP vs DDP, or FSDP with a different world size).

Tip

Internally this function handles calling torch.nn.Module.load_state_dict() and torch.optim.Optimizer.load_state_dict() for you, hence the return type is None.

Parameters:
  • dir (Union[Path, PathLike, str]) – Path/URL to the checkpoint saved via save_model_and_optim_state().

  • model (Module) – The model to load the state into.

  • optim (Optional[Optimizer], default: None) – The optimizer to load the state into.

  • validate (bool, default: True) – Validate that all tensors have been loaded completely from the checkpoint by pre-filling each tensor with NaNs prior to loading in-place, then checking afterwards that there are no NaNs remaining.

olmo_core.distributed.checkpoint.unshard_model_state(dir, device=None, rank0_only=False, no_dist=False)[source]#

Unshard model state saved via save_model_and_optim_state().

Parameters:
  • dir (Union[Path, PathLike, str]) – Local or remote checkpoint directory.

  • device (Optional[device], default: None) – Device to load the checkpoint onto. Defaults to CPU.

  • rank0_only (bool, default: False) – Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary.

  • no_dist (bool, default: False) – Set to true to avoid any distributed communication whatsoever.

Return type:

Dict[str, Tensor]

olmo_core.distributed.checkpoint.unshard_optim_state(dir, device=None, rank0_only=False, no_dist=False)[source]#

Unshard optimizer state saved via save_model_and_optim_state().

Parameters:
  • dir (Union[Path, PathLike, str]) – Local or remote checkpoint directory.

  • device (Optional[device], default: None) – Device to load the checkpoint onto. Defaults to CPU.

  • rank0_only (bool, default: False) – Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary.

  • no_dist (bool, default: False) – Set to true to avoid any distributed communication whatsoever.

Return type:

OptimStateDict

class olmo_core.distributed.checkpoint.Checkpointer[source]#

Bases: object

A distributed checkpointer for saving and loading non-nested state dictionaries, i.e. where keys are strings and values are either regular torch.Tensor instances, torch.nn.Parameter instances, DTensor instances, or any sharded tensors from this library.

For saving and loading model and optimizer states together, use save_model_and_optim_state() and load_model_and_optim_state() instead.

METADATA_FILENAME = 'metadata.json'#
save(dir, state_dict, save_overwrite=False)[source]#

Save a state dict. The state dict can contain regular Tensors, Parameters, or any sharded tensors from this library.

When calling this from a distributed context, all ranks must call this at the same time and the state dict must have the same keys and tensor types across each rank.

Returns the storage metadata and a list of files created by the local rank.

Parameters:
  • dir (Union[Path, PathLike, str]) – The location to save the checkpoint to. Could be a path to a local directory or a URL to a “folder” in an S3 or GCS bucket.

  • state_dict (Dict[str, Tensor]) – The state dictionary to save.

  • save_overwrite (bool, default: False) – Overwrite existing data.

Return type:

Tuple[StorageMetadata, List[Union[Path, PathLike, str]]]

load(dir, state_dict, no_dist=False, metadata=None, _safetensors_mfl=None, _check_for_nans=False)[source]#

Load a state dict in-place.

Parameters:
  • dir (Union[Path, PathLike, str]) – The path or URL to the checkpoint saved via save().

  • state_dict (Dict[str, Tensor]) – The state dictionary to load into. This should contain all of the tensors you want to load.

  • no_dist (bool, default: False) – Disable distributed communication even if within a distributed context.

unshard(dir, device=None, rank0_only=False, no_dist=False, num_threads=None)[source]#

Unshard a checkpoint, returning the full state dict. This can be used in both distributed and non-distributed contexts. If you only want to load a single copy to rank 0 in a distributed context, set rank0_only=True, in which case other ranks will receive an empty state dict.

Alternatively, setting no_dist=True will return a full state dict from whatever process calls this.

Parameters:
  • dir (Union[Path, PathLike, str]) – Local or remote checkpoint directory.

  • device (Optional[device], default: None) – Device to load the checkpoint onto. Defaults to CPU.

  • rank0_only (bool, default: False) – Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary.

  • no_dist (bool, default: False) – Set to true to avoid any distributed communication whatsoever.

  • num_threads (Optional[int], default: None) – The maximum number of threads to use to unshard the checkpoint. Increasing num_threads can lead to a substantial speed up, especially when loading from a remote checkpoint. Set to 0 to disable threading.

Return type:

Dict[str, Tensor]

get_metadata(dir, no_dist=False)[source]#

Get the storage metadata from a checkpoint directory.

Return type:

StorageMetadata

class olmo_core.distributed.checkpoint.TensorShardSpec(**data)[source]#

Bases: BaseModel

flattened_offsets: Optional[Tuple[Tuple[int, int], ...]]#

Offsets within the full flattened tensor that the given shard corresponds to.

local_shape: Optional[Tuple[int, ...]]#

The (unflattened) shape of the local shard.

global_offset: Optional[Tuple[int, ...]]#

The starting offset for each dimension in the global unsharded (unflattened) tensor that the local shard corresponds to.

property local_numel: int#
get_flattened_offsets(full_shape)[source]#

Get flattened offsets into the full flattened tensor that the given shard corresponds to. If self.flattened_offsets is set, this just returns a generator over those, otherwise it computes them from self.local_shape and self.global_offset.

Return type:

Generator[Tuple[int, int], None, None]

get_merged_flattened_offsets(full_shape)[source]#

Like get_flattened_offset() but it merges consecutive offsets that are contiguous.

Return type:

Generator[Tuple[int, int], None, None]

compute_overlap_with(other, full_shape)[source]#
Return type:

Optional[OverlapType]

class olmo_core.distributed.checkpoint.TensorStorageMetadata(**data)[source]#

Bases: BaseModel

shape: Tuple[int, ...]#

The shape of the full (unflattened) tensor.

is_sharded: bool#

Whether the original tensor (when saved) was sharded.

dtype: str#

The data type of the tensor.

shard_spec_per_file: Dict[str, TensorShardSpec]#

Maps each filename to the sharding spec of the local shard within that file.

property torch_dtype: dtype#
materialize_empty(*, device=None, shape=None)[source]#
Return type:

Tensor

get_flattened_offsets_in_file(filename)[source]#
Return type:

Generator[Tuple[int, int], None, None]

get_numel_in_file(filename)[source]#
Return type:

int

class olmo_core.distributed.checkpoint.StorageMetadata(**data)[source]#

Bases: BaseModel

tensors: Dict[str, TensorStorageMetadata]#