distributed.checkpoint#
A low-overhead, fast, distributed checkpointing module with a unified API for saving and
loading both local and remote checkpoints. Built on top of safetensors
and inspired by torch.distributed.checkpoint, but better suited for handling distributed models and
optimizer state without unnecessary distributed communication and GPU allocations.
Features#
Sharded distributed models, such OLMo-core’s
FSDPor PyTorch’sFullyShardedDataParallel(withuse_orig_params=True) are supported out-of-the-box.Utilizes safetensors under the hood for fast, efficient, and safe serialization/deserialization.
Save with one distributed topology, seamlessly load with a different one. For example, with FSDP you can save/load checkpoints with different world sizes or wrapping strategies.
Save/load directly to/from a remote object store like S3 or GCS. When loading from a remote object store each rank only downloads the fraction of the data it needs for its local (potentially sharded) tensors.
Checkpoints are always loaded in-place and one tensor at a time to avoid unnecessary allocations. This results in virtually no additional memory overhead.
Overview#
Use save_model_and_optim_state() to write a checkpoint with your model and optimizer’s state, then
use load_model_and_optim_state() to load the checkpoint in-place. You can also generate unsharded, full
state dictionaries from a checkpoint with unshard_model_state() and unshard_optim_state().
API Reference#
- olmo_core.distributed.checkpoint.save_model_and_optim_state(dir, model, optim, save_overwrite=False)[source]#
Save model and optimizer state dictionaries. The model state can be a sharded model, in which case this method will correctly handle the optimizer state to ensure it can be loaded again with a different distributed topology through
load_model_and_optim_state().Returns all of the files created by the current rank.
Tip
With
FullyShardedDataParallelmodels it’s not necessary to set the state dict type before calling this (orload_model_and_optim_state()) viastate_dict_type()or other methods. In fact those settings will always be ignored.Attention
At the moment
FullyShardedDataParallelmodels must haveuse_orig_params=True.
- olmo_core.distributed.checkpoint.load_model_and_optim_state(dir, model, optim=None, validate=True)[source]#
Load model and optimizer state in-place from a checkpoint saved via
save_model_and_optim_state(). This method is agnostic to the distributed topology in that it can load checkpoints saved with a different distributed topology (e.g. FSDP vs DDP, or FSDP with a different world size).Tip
Internally this function handles calling
torch.nn.Module.load_state_dict()andtorch.optim.Optimizer.load_state_dict()for you, hence the return type isNone.- Parameters:
dir (
Union[Path,PathLike,str]) – Path/URL to the checkpoint saved viasave_model_and_optim_state().model (
Module) – The model to load the state into.optim (
Optional[Optimizer], default:None) – The optimizer to load the state into.validate (
bool, default:True) – Validate that all tensors have been loaded completely from the checkpoint by pre-filling each tensor with NaNs prior to loading in-place, then checking afterwards that there are no NaNs remaining.
- olmo_core.distributed.checkpoint.unshard_model_state(dir, device=None, rank0_only=False, no_dist=False)[source]#
Unshard model state saved via
save_model_and_optim_state().See also
- Parameters:
dir (
Union[Path,PathLike,str]) – Local or remote checkpoint directory.device (
Optional[device], default:None) – Device to load the checkpoint onto. Defaults to CPU.rank0_only (
bool, default:False) – Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary.no_dist (
bool, default:False) – Set to true to avoid any distributed communication whatsoever.
- Return type:
- olmo_core.distributed.checkpoint.unshard_optim_state(dir, device=None, rank0_only=False, no_dist=False)[source]#
Unshard optimizer state saved via
save_model_and_optim_state().See also
- Parameters:
dir (
Union[Path,PathLike,str]) – Local or remote checkpoint directory.device (
Optional[device], default:None) – Device to load the checkpoint onto. Defaults to CPU.rank0_only (
bool, default:False) – Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary.no_dist (
bool, default:False) – Set to true to avoid any distributed communication whatsoever.
- Return type:
OptimStateDict
- class olmo_core.distributed.checkpoint.Checkpointer[source]#
Bases:
objectA distributed checkpointer for saving and loading non-nested state dictionaries, i.e. where keys are strings and values are either regular
torch.Tensorinstances,torch.nn.Parameterinstances,DTensorinstances, or any sharded tensors from this library.For saving and loading model and optimizer states together, use
save_model_and_optim_state()andload_model_and_optim_state()instead.- METADATA_FILENAME = 'metadata.json'#
- save(dir, state_dict, save_overwrite=False)[source]#
Save a state dict. The state dict can contain regular Tensors, Parameters, or any sharded tensors from this library.
When calling this from a distributed context, all ranks must call this at the same time and the state dict must have the same keys and tensor types across each rank.
Returns the storage metadata and a list of files created by the local rank.
- Parameters:
- Return type:
- load(dir, state_dict, no_dist=False, metadata=None, _safetensors_mfl=None, _check_for_nans=False)[source]#
Load a state dict in-place.
- Parameters:
dir (
Union[Path,PathLike,str]) – The path or URL to the checkpoint saved viasave().state_dict (
Dict[str,Tensor]) – The state dictionary to load into. This should contain all of the tensors you want to load.no_dist (
bool, default:False) – Disable distributed communication even if within a distributed context.
- unshard(dir, device=None, rank0_only=False, no_dist=False, num_threads=None)[source]#
Unshard a checkpoint, returning the full state dict. This can be used in both distributed and non-distributed contexts. If you only want to load a single copy to rank 0 in a distributed context, set
rank0_only=True, in which case other ranks will receive an empty state dict.Alternatively, setting
no_dist=Truewill return a full state dict from whatever process calls this.- Parameters:
dir (
Union[Path,PathLike,str]) – Local or remote checkpoint directory.device (
Optional[device], default:None) – Device to load the checkpoint onto. Defaults to CPU.rank0_only (
bool, default:False) – Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary.no_dist (
bool, default:False) – Set to true to avoid any distributed communication whatsoever.num_threads (
Optional[int], default:None) – The maximum number of threads to use to unshard the checkpoint. Increasingnum_threadscan lead to a substantial speed up, especially when loading from a remote checkpoint. Set to0to disable threading.
- Return type:
- class olmo_core.distributed.checkpoint.TensorShardSpec(**data)[source]#
Bases:
BaseModel- flattened_offsets: Optional[Tuple[Tuple[int, int], ...]]#
Offsets within the full flattened tensor that the given shard corresponds to.
- local_shape: Optional[Tuple[int, ...]]#
The (unflattened) shape of the local shard.
- global_offset: Optional[Tuple[int, ...]]#
The starting offset for each dimension in the global unsharded (unflattened) tensor that the local shard corresponds to.
- get_flattened_offsets(full_shape)[source]#
Get flattened offsets into the full flattened tensor that the given shard corresponds to. If
self.flattened_offsetsis set, this just returns a generator over those, otherwise it computes them fromself.local_shapeandself.global_offset.
- class olmo_core.distributed.checkpoint.TensorStorageMetadata(**data)[source]#
Bases:
BaseModel- shape: Tuple[int, ...]#
The shape of the full (unflattened) tensor.
- is_sharded: bool#
Whether the original tensor (when saved) was sharded.
- dtype: str#
The data type of the tensor.
- shard_spec_per_file: Dict[str, TensorShardSpec]#
Maps each filename to the sharding spec of the local shard within that file.