distributed.checkpoint¶
A high-level distributed checkpointing module with a unified API for saving and loading both local and remote checkpoints.
Features¶
Save with one distributed topology, seamlessly load with a different one. For example, with FSDP/FSDP2 you can save/load checkpoints with different world sizes or sharding strategies.
Save/load directly to/from a remote object store like S3 or GCS. When loading from a remote object store each rank only downloads the fraction of the data it needs for its local (potentially sharded) tensors.
Overview¶
Use save_model_and_optim_state() to write a checkpoint with your model and optimizer’s state,
then use load_model_and_optim_state() to load the checkpoint in-place.
You can unshard a checkpoint saved this way with unshard_checkpoint().
API Reference¶
- olmo_core.distributed.checkpoint.save_state_dict(dir, state_dict, *, process_group=None, save_overwrite=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False, _skip_prepare=False)[source]¶
Save an arbitrary state dictionary to a distributed format that can loaded again with a different distributed topology.
Important
Please use
save_model_and_optim_state()to save model/optimizer state dicts instead unless you know what you’re doing.- Parameters:
process_group (
Optional[ProcessGroup], default:None) – The process group to use for distributed collectives.save_overwrite (
bool, default:False) – Overwrite existing files.thread_count (
Optional[int], default:None) – Set this to override the number of threads used while writing data.process_count (
Optional[int], default:None) – Set this to use a process pool instead of a thread pool when possible (currently not compatible withthrottle_uploads).throttle_uploads (
bool, default:False) – If this is set toTrueanddiris a URL then only one rank from each node will upload data at a time.
- olmo_core.distributed.checkpoint.async_save_state_dict(dir, state_dict, *, process_group=None, save_overwrite=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False, _skip_prepare=False)[source]¶
An async version of
save_state_dict().This code first de-stages the state dict on the CPU, then writes it in a separate thread.
- Return type:
Future[None]
- olmo_core.distributed.checkpoint.load_state_dict(dir, state_dict, *, process_group=None, pre_download=False, work_dir=None, thread_count=None)[source]¶
Load an arbitrary state dict in-place from a checkpoint saved with
save_state_dict().- Parameters:
dir (
Union[Path,PathLike,str]) – Path/URL to the checkpoint saved viasave_state_dict().state_dict (
Dict[str,Any]) – The state dict to load the state into.process_group (
Optional[ProcessGroup], default:None) – The process group to use for distributed collectives.thread_count (
Optional[int], default:None) – Set the number of threads used for certain operations.
- olmo_core.distributed.checkpoint.save_model_and_optim_state(dir, model, optim=None, *, process_group=None, save_overwrite=False, flatten_optimizer_state=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False)[source]¶
Save model and optimizer state dictionaries. The model state can be a sharded model, in which case this method will correctly handle the optimizer state to ensure it can be loaded again with a different distributed topology through
load_model_and_optim_state().Tip
With
FullyShardedDataParallelmodels it’s not necessary to set the state dict type before calling this (orload_model_and_optim_state()) viastate_dict_type()or other methods. This function handles that internally.- Parameters:
model (
Module) – The model to save state from.optim (
Optional[Optimizer], default:None) – The optimizer to save state from.process_group (
Optional[ProcessGroup], default:None) – The process group to use for distributed collectives.save_overwrite (
bool, default:False) – Overwrite existing files.flatten_optimizer_state (
bool, default:False) – Flatten the optimizer state before saving. This should match the setting used when loading the state dict and is needed in a distributed setting when the params in some param groups may differ between ranks, such as with pipeline parallelism.thread_count (
Optional[int], default:None) – Set this to override the number of threads used while writing data.process_count (
Optional[int], default:None) – Set this to use a process pool instead of a thread pool when possible (currently not compatible withthrottle_uploads).throttle_uploads (
bool, default:False) – If this is set toTrueanddiris a URL then only one rank from each node will upload data at a time.
- Raises:
FileExistsError – If the checkpoint dir exists and is non-empty unless
save_overwrite=True.- Return type:
- olmo_core.distributed.checkpoint.async_save_model_and_optim_state(dir, model, optim=None, *, process_group=None, save_overwrite=False, flatten_optimizer_state=False, thread_count=None, process_count=None, throttle_uploads=False, enable_plan_caching=False)[source]¶
An async version of
save_model_and_optim_state().This code first de-stages the state dict on the CPU, then writes it in a separate thread.
- Return type:
Future[None]
- olmo_core.distributed.checkpoint.load_model_and_optim_state(dir, model, optim=None, *, process_group=None, key_mapping=None, pre_download=False, work_dir=None, strict=True, flatten_optimizer_state=False, thread_count=None)[source]¶
Load model and optimizer state in-place from a checkpoint saved via
save_model_and_optim_state(). This method is agnostic to the distributed topology in that it can load checkpoints saved with a different distributed topology (e.g. FSDP/FSDP2, DDP).Tip
With
FullyShardedDataParallelmodels it’s not necessary to set the state dict type before calling this (orsave_model_and_optim_state()) viastate_dict_type()or other methods. This function handles that internally.Warning
Due to the way
torch.distributed.checkpointworks, if you have keys in the checkpoint dict that are not present in the current state of the model or optimizer, those keys won’t be loaded.For example, if you added a custom field to one of your optimizer’s param groups before saving the checkpoint, but don’t have that field in the param group of the optimizer you’re loading into, it won’t be added.
This can cause unexpected behavior if you’re not careful. In this case the best thing to do is to ensure all keys are in present param groups when you initialize the optimizer, before saving or loading a checkpoint.
- Parameters:
dir (
Union[Path,PathLike,str]) – Path/URL to the checkpoint saved viasave_model_and_optim_state().model (
Module) – The model to load the state into.optim (
Optional[Optimizer], default:None) – The optimizer to load the state into.process_group (
Optional[ProcessGroup], default:None) – The process group to use for distributed collectives.key_mapping (
Optional[Dict[str,str]], default:None) – Can be used to load a checkpoint where certain parameter have different names. This dictionary should map current keys to keys in the checkpoint to be loaded.pre_download (
bool, default:False) – Download and cache relevant remote checkpoint files before trying to read from them.work_dir (
Union[Path,PathLike,str,None], default:None) – A working directory for caching files/directories.strict (
bool, default:True) – Load keys strictly.flatten_optimizer_state (
bool, default:False) – Flatten the optimizer state when loading. This should match the setting used when saving the state dict and is needed in a distributed setting when the params in some param groups may differ between ranks, such as with pipeline parallelism.thread_count (
Optional[int], default:None) – Set the number of threads used for certain operations.
- olmo_core.distributed.checkpoint.unshard_checkpoint(dir, target_dir, *, optim=None, save_overwrite=False, use_safetensors=False, unshard_strategy=None, pre_download=False, work_dir=None, quiet=False)[source]¶
Convert a checkpoint saved via
save_model_and_optim_state()into unsharded model and optimizer checkpoint files that can be loaded directly withtorch.load()or safetensors ifuse_safetensors=True.Warning
The safetensors format cannot be used to save optimizer state, since optimizer state can contain arbitrary Python objects that need to be pickled. Therefore
optim=Trueanduse_safetensors=Trueis incompatible.Warning
This should only be called in a non-distributed context. Otherwise a
RuntimeErroris raised.See also
load_keys()if you only need to load and unshard certain keys in the checkpoint.- Parameters:
dir (
Union[Path,PathLike,str]) – The path/URL to the original checkpoint created viasave_model_and_optim_state().target_dir (
Union[Path,PathLike,str]) – The directory to save the unsharded model/optimizer checkpoint files to. This must be a local directory. URLs are not supported.optim (
Optional[bool], default:None) – Whether to unshard the optimizer state. This defaults toTrueas long asuse_safetensors=False.save_overwrite (
bool, default:False) – Overwrite any existing files intarget_dir.use_safetensors (
bool, default:False) – Save the unsharded files withsafetensors.torch.save_file()instead oftorch.save().unshard_strategy (
Optional[UnshardStrategy], default:None) – The strategy to use. Defaults toUnshardStrategy.one_file().pre_download (
bool, default:False) – Download and cache relevant remote checkpoint files before trying to read from them.work_dir (
Union[Path,PathLike,str,None], default:None) – A working directory for caching files/directories.quiet (
bool, default:False) – Do not show progress messages.
- Return type:
- Returns:
The path to the unsharded model checkpoint and the path to the unsharded optimizer checkpoint if
optim=True. These paths may represent files or directories depending on theunshard_strategy.- Raises:
FileExistsError – If the
target_diris non-empty andsave_overwrite=False.
- olmo_core.distributed.checkpoint.load_keys(dir, keys, *, pre_download=False, work_dir=None)[source]¶
Load specific keys from a checkpoint.
Warning
This should only be called in a non-distributed context. Otherwise a
RuntimeErroris raised.- Parameters:
dir (
Union[Path,PathLike,str]) – The path/URL to the original checkpoint created viasave_model_and_optim_state(),save_state_dict(), or one of the other functions in this module.pre_download (
bool, default:False) – Download and cache relevant remote checkpoint files before trying to read from them.work_dir (
Union[Path,PathLike,str,None], default:None) – A working directory for caching files/directories.
- Return type:
- Returns:
The (unsharded) objects from the checkpoint corresponding to the given keys, in the same order as the keys.
- olmo_core.distributed.checkpoint.get_checkpoint_metadata(dir)[source]¶
Load the metadata from a checkpoint.
- class olmo_core.distributed.checkpoint.UnshardStrategy(name='one_file', chunk_size_bytes=None)[source]¶
Bases:
objectUnsharding strategy config for
unshard_checkpoint().-
name:
UnshardStrategyType= 'one_file'¶ The strategy type.
-
chunk_size_bytes:
Optional[int] = None¶ The approximate max chunk size (per file size), in bytes, for the
UnshardStrategyType.chunksstrategy.
- classmethod one_file()[source]¶
Use the
UnshardStrategy.one_filestrategy.- Return type:
- classmethod one_file_per_tensor()[source]¶
Use the
UnshardStrategy.one_file_per_tensorstrategy.- Return type:
- classmethod chunks(chunk_size_in_bytes)[source]¶
Use the
UnshardStrategy.chunksstrategy.- Return type:
-
name:
- class olmo_core.distributed.checkpoint.UnshardStrategyType(value)[source]¶
Bases:
StrEnumAn enumeration of the unsharding strategies that can be used with
unshard_checkpoint().- one_file = 'one_file'¶
Save the unsharded model state into a one file, and optionally the optimizer state into another file. The bigger the model, the more memory this requires. For very big models,
one_file_per_tensorwill scale better.
- one_file_per_tensor = 'one_file_per_tensor'¶
Save each unsharded tensor to its own file. Currently this is not compatible with optimizer state.
- chunks = 'chunks'¶
Like
one_file_per_tensorbut multiple tensors and objects may be grouped into the same file up to the limit defined byUnshardStrategy.chunk_size_bytes.