distributed.fsdp#
This is a light-weight, experimental rewrite of PyTorch’s FullyShardedDataParallel
with a few of improvements, including:
Well-defined “hands off” handling of buffers. FSDP never shards buffers, they are left as-is.
Well-defined handling of frozen params. You can mix and match within an FSDP instance as long as you’re consistent across the process group with which parameters are frozen.
Full support for CPU-only training and inference via the GLOO backend.
Low-overhead checkpointing with
olmo_core.distributed.checkpoint.
Usage Tips#
Always initialize your optimizer after wrapping your model with FSDP.
When you use initialize model (prior to wrapping with FSDP), use
device=torch.device("meta")when initializing parameters to save memory.FSDPwill automatically materialize and move parameters to the right device when wrapping. Then you can useFSDP.apply()to initialize parameters how you want.Analogous to with PyTorch’s
FullyShardedDataParallel, you should useFSDP.clip_grad_norm_()for clipping gradient norms instead oftorch.nn.utils.clip_grad_norm_().Use activation checkpointing via
torch.utils.checkpoint.checkpoint()to save more memory during the forward and backward pass at the expense of more computation.To save and load checkpoints for your FSDP model and its optimizer, use
save_model_and_optim_state()andload_model_and_optim_state(), respectively.
Implementation Details#
When you wrap a Module with FSDP, the wrapping FSDP instance will replace
each original parameter in the module with a ShardedFlatParameter instance,
and each rank will only keep a shard of the original data. Buffers are left as-is.
Note
Further, the sharded data for all of the ShardedFlatParameter
instances will be collected into a single FlatParamHandle, and each flat parameter will
just hold a view into a slice of the data managed by the handle. This makes gathering the full
params more efficient as it only requires a single all-gather per FSDP node.
Forward Pass#
When the forward() method is called on the wrapping FSDP instance, it will gather
the full unsharded data for each parameter in the desired dtype
(as defined by the FSDPPrecision settings) while caching the sharded data behind the scenes.
Then it runs the forward method of the wrapped module, which is completely unsharded at that point.
After the forward method of the wrapped module returns, the wrapping FSDP instance will reshard the parameters and, if gradients are enabled, register backward hooks to manage the state of parameters and gradients during the backward pass.
During the first forward pass the root FSDP instance will also record the order of execution of all
FSDP children, and use that order to prefetch the full parameters for its FSDP children during
subsequent forward passes. The number of children that are prefetched at once is controlled by the
max_prefetch_count setting.
Note
When CUDA is available FSDP instances utilize multiple CUDA streams in order to overlap
communication (e.g. unsharding params or reducing gradients) with computation
(e.g. the forward pass or computing gradients during the backward pass).
Backward Pass#
At the end of the forward method, the wrapping FSDP instance registers ephemeral “pre-backward” and “post-backward” hooks to unshard the parameters and reduce-scatter the gradients, respectively, during the backward pass.
At the end of the backward pass the grad attribute of each (non-frozen) parameter will
be the shard of the full gradient corresponding to the shard of the full parameter, i.e. it will
have the same shape/size as the sharded parameter.
Just how the root FSDP instance records the execution order of its FSDP children during the first forward pass, the root will also record the order during the first backward pass and use that to prefetch the full parameters of its children during subsequent backward passes.
API Reference#
- class olmo_core.distributed.fsdp.FSDP(module, process_group=None, device_mesh=None, precision=None, sharding_strategy=FSDPShardingStrategy.FULL_SHARD, max_prefetch_count=1, free_root_after_forward=False, _debug_config=None)[source]#
-
FSDP, a.k.a. Fully Sharded Data Parallel, a ZeRO-3 model wrapper.
- Parameters:
process_group (
Optional[ProcessGroup], default:None) – The distributed process group to shard across.device_mesh (
Optional[DeviceMesh], default:None) – Mutually exclusive withprocess_group. This is required forFSDPShardingStrategy.HYBRID_SHARD, in which case the first dimension should specify the number of model replicas (hybrid groups), and the second dimension should specify the number of shards within each replica. If you’re not usingFSDPShardingStrategy.HYBRID_SHARDand you specifydevice_mesh, the process group in the first dimension will be used.precision (
Optional[FSDPPrecision], default:None) – Mixed precision settings.sharding_strategy (
FSDPShardingStrategy, default:'FULL_SHARD') – The sharding strategy to use.max_prefetch_count (
int, default:1) – The number of nested FSDP modules that can be prefetched during the forward and backward passes. This is like PyTorch’slimit_all_gathersexcept it allows more control.free_root_after_forward (
bool, default:False) – By default the root FSDP instance keeps its full params in memory after the forward pass when grads are enabled to avoid immediately regathering during the backward pass. Setting this toFalsecan save some memory at the expense of throughput.
- WRAPPED_MODULE_PREFIX = '_fsdp_wrapped_module'#
The prefix the wrapped module is stored under. In general you don’t need to know this as the wrapping FSDP instance behaves like the wrapped module itself for most APIs, and otherwise you should access the wrapped module through the
moduleproperty.
- classmethod auto_wrap(module, children_to_wrap, **fsdp_kwargs)[source]#
Wrap a module and specific children of the module specific by
children_to_wrap.- Parameters:
children_to_wrap (
Union[Sequence[Union[str,Module,Type[Module]]],Callable[[Module],bool]]) – Specify which children modules to wrap. This can be a list of children FQNs (wildcards allowed), module instances, module types, or a function that takes a module and returns a boolean that indicates whether it should be wrapped.fsdp_kwargs – Keyword args to the FSDP constructor.
- Return type:
- property module: M#
Get the wrapped module.
- forward(*args, **kwargs)[source]#
Run the forward pass on the wrapped module, gathering full parameters when necessary.
- state_dict(*args, **kwargs)[source]#
Return the state dict.
See also
For saving and loading
FSDPcheckpoints, seeolmo_core.distributed.checkpoint.Tip
The data in the state dict will be sharded flat data unless you’re within the
summon_full_params()context or have gathered the full parameters another way.Tip
The parameter names will be the original parameter names of the wrapped module, i.e. without the
WRAPPED_MODULE_PREFIX.
- load_state_dict(state_dict, *args, **kwargs)[source]#
Load a state dict. The data in the state dict should correspond to the current state of the FSDP wrapper, either sharded or unsharded.
See also
For saving and loading
FSDPcheckpoints, seeolmo_core.distributed.checkpoint.
- named_buffers(*args, **kwargs)[source]#
Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
Tip
The parameter names will be the original parameter names of the wrapped module, i.e. without the
WRAPPED_MODULE_PREFIX.
- named_parameters(*args, **kwargs)[source]#
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Tip
The parameter names will be the original parameter names of the wrapped module, i.e. without the
WRAPPED_MODULE_PREFIX.
- summon_full_params(recurse=True, writeback=True, rank0_only=False, cast=False)[source]#
Gather full unsharded params in-place with this context manager.
- Parameters:
recurse (
bool, default:True) – Gather unsharded params for all child FSDP instances as well.writeback (
bool, default:True) – Write the unsharded data back from rank 0 to all other ranks while exiting the context manager.rank0_only (
bool, default:False) – Only summon full params on rank 0.cast (
bool, default:False) – If using a mixed-precision strategy, params are cast to the same dtype as they are during the forward and backward passes. If this isTrue,writebackmust beFalse.
- apply(fn)[source]#
Apply
fnrecursively to every submodule (as returned by.children()) as well as self.Typical use includes initializing the parameters of a model.
Compared to
torch.nn.Module.apply(), this version additionally gathers the full parameters for all sharded parameters that are directly managed but the given FSDP instance before applyingfn. This should not be called from within anothersummon_full_params()context.
- class olmo_core.distributed.fsdp.FSDPPrecision(param_dtype=None, reduce_dtype=None)[source]#
Bases:
objectMixed precision settings for
FSDP.
- class olmo_core.distributed.fsdp.FSDPShardingStrategy(value)[source]#
Bases:
StrEnumDefines the sharding strategy used by
FSDP.- FULL_SHARD = 'FULL_SHARD'#
Parameters, gradients, and optimizer states are sharded. For the parameters, this strategy unshards (via all-gather) before the forward, reshards after the forward (except potentially for the root FSDP instance), unshards before the backward computation, and reshards after the backward computation. For gradients, it synchronizes and shards them (via reduce-scatter) after the backward computation. The sharded optimizer states are updated locally per rank.
- HYBRID_SHARD = 'HYBRID_SHARD'#
Apply
FULL_SHARDwithin a process group, and replicate parameters across process groups. This results in reduced communication volume as expensive all-gathers and reduce-scatters are only done within a node, which can be more performant for medium to large-sized models.
- SHARD_GRAD_OP = 'SHARD_GRAD_OP'#
Like
FULL_SHARDexcept parameters are not resharded after the forward pass when gradients are enabled, instead only after the backwards pass.