distributed.parallel¶
- olmo_core.distributed.parallel.build_world_mesh(*, dp=None, tp=None, cp=None, pp=None, ep=None, device_type=None)[source]¶
Build a
DeviceMeshsuitable for the given parallel strategies.See also
Pass the mesh created by this function to any of the
get_*_mesh()functions in this module to get the right sub-mesh for a any given parallel strategy.get_dp_model_mesh()gives you the 1 or 2D sub-mesh suitable for data parallel model wrappers like FSDP(2) or DDP.get_dp_mesh()gives you the 1D sub-mesh suitable for configuring data loaders.get_tp_mesh()gives you the 1D sub-mesh for tensor parallelism.get_cp_mesh()gives you the 1D sub-mesh for context parallelism.get_pp_mesh()gives you the 1D sub-mesh for pipeline parallelism.get_ep_mesh()gives you the 1D sub-mesh for expert parallelism.
Important
A data parallel config is required if any other parallel config is set.
Important
Not all parallel strategies are compatible with each other.
- Parameters:
dp (
Optional[DataParallelConfig], default:None) – Data parallel config.tp (
Optional[TensorParallelConfig], default:None) – Tensor parallel config.cp (
Optional[ContextParallelConfig], default:None) – Context parallel config.pp (
Optional[PipelineParallelConfig], default:None) – Pipeline parallel config.ep (
Optional[ExpertParallelConfig], default:None) – Expert parallel config.device_type (
Optional[str], default:None) – The device type.
- Return type:
- Returns:
The world mesh with a shape compatible with the given parallel configs.
- olmo_core.distributed.parallel.get_world_mesh()[source]¶
Get the global world mesh built with
build_world_mesh().- Return type:
- class olmo_core.distributed.parallel.MeshDimName(value)[source]¶
Bases:
StrEnumDeviceMeshdimensions names for different forms of parallelism. This are the dimension names that you will find in the mesh created bybuild_world_mesh().- dp = 'dp'¶
Data parallel (DP).
- dp_replicate = 'dp_replicate'¶
The DP dimension over which the model is replicated.
- dp_shard = 'dp_shard'¶
The DP dimension over which the model is sharded.
- tp = 'tp'¶
Tensor parallel (TP).
- cp = 'cp'¶
Context parallel (CP).
- pp = 'pp'¶
Pipeline parallel (PP).
- ep = 'ep'¶
Expert parallel (EP).
- olmo_core.distributed.parallel.get_dp_model_mesh(device_mesh)[source]¶
Get the right sub-mesh for a data parallel model wrapper like FSDP or DDP from a
DeviceMeshcreated bybuild_world_mesh().Important
You should use
get_dp_mesh()instead for getting the sub-mesh to assign ranks to data loading workers. In many cases these two functions will return the same result, but there are cases where they could be different.- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_dp_mesh(device_mesh)[source]¶
Get the data parallel sub-mesh associated from a
DeviceMeshcreated bybuild_world_mesh().Important
This is the mesh that should be used to assign ranks to data loading workers, however you should use
get_dp_model_mesh()to get the mesh for DDP/FSDP.- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_tp_mesh(device_mesh)[source]¶
Get the tensor parallel sub-mesh associated with a
DeviceMeshcreated frombuild_world_mesh().- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_cp_mesh(device_mesh)[source]¶
Get the context parallel sub-mesh associated with a
DeviceMeshcreated frombuild_world_mesh().- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_pp_mesh(device_mesh)[source]¶
Get the pipeline parallel sub-mesh associated with a
DeviceMeshcreated frombuild_world_mesh().- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_pp_stage_mesh(device_mesh, pp_mesh=None)[source]¶
Get the sub-mesh for a single pipeline stage.
- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().pp_mesh (
Optional[DeviceMesh], default:None) – Optional pipeline parallel mesh. If not provided, it will be extracted from the device_mesh usingget_pp_mesh().
- Return type:
- olmo_core.distributed.parallel.get_ep_mesh(device_mesh)[source]¶
Get the expert parallel sub-mesh associated with a
DeviceMeshthat was potentially created frombuild_world_mesh().- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_dp_process_group(device_mesh)[source]¶
Get the data parallel process group associated with a
DeviceMeshcreated frombuild_world_mesh().Like
get_dp_mesh(), this should be used for data loading, but not necessarily for data parallel model wrappers.- Parameters:
device_mesh (
DeviceMesh) – The world mesh created bybuild_world_mesh().- Return type:
- olmo_core.distributed.parallel.get_device_mesh_info(device_mesh)[source]¶
Get a human-readable string representation of a
DeviceMesh.- Parameters:
device_mesh (
DeviceMesh) – The device mesh to get info for.- Return type:
- olmo_core.distributed.parallel.flatten_mesh(device_mesh, name=None)[source]¶
Flatten a multi-dimensional
DeviceMeshinto a 1DDeviceMesh.- Parameters:
device_mesh (
DeviceMesh) – The multi-dimensionalDeviceMeshto flatten.name (
Optional[str], default:None) – Optional name for the flattened dimension.
- Return type:
Important
The
device_meshis modified in-place.
- class olmo_core.distributed.parallel.DataParallelType(value)[source]¶
Bases:
StrEnumAn enumeration.
- class olmo_core.distributed.parallel.DataParallelConfig(name, param_dtype=None, reduce_dtype='float32', num_replicas=None, shard_degree=None)[source]¶
Bases:
Config
- class olmo_core.distributed.parallel.DPMeshDimName(value)[source]¶
Bases:
StrEnumDeviceMeshdimension names for data parallelism.- replicate = 'dp_replicate'¶
The device mesh dimension over which the model is replicated.
- shard = 'dp_shard'¶
The device mesh dimension over which the model is sharded.
- class olmo_core.distributed.parallel.TensorParallelConfig(degree, enable_async=False)[source]¶
Bases:
ConfigConfiguration class for tensor parallelism (TP).
- class olmo_core.distributed.parallel.ExpertParallelConfig(degree)[source]¶
Bases:
ConfigConfiguration class for expert parallelism (EP).
- class olmo_core.distributed.parallel.PipelineParallelConfig(degree, schedule='Interleaved1F1B', style=None)[source]¶
Bases:
ConfigConfiguration class for pipeline parallelism (PP).
-
schedule:
PipelineScheduleType= 'Interleaved1F1B'¶ The name of the schedule.
-
style:
Optional[PipelineSplitStyle] = None¶ The split style.
-
schedule:
- class olmo_core.distributed.parallel.PipelineScheduleType(value)[source]¶
Bases:
StrEnumAn enumeration of the different pipeline schedules available.
Warning
The zero-bubble variants have several issues at the moment including not being compatible with
torch.compile.
- class olmo_core.distributed.parallel.PipelineSplitStyle(value)[source]¶
Bases:
StrEnumAn enumeration.
- class olmo_core.distributed.parallel.PipelineSchedule(*, model_parts, stages, pp_mesh, schedule_name, loss_fn=None, num_microbatches=None)[source]¶
Bases:
objectA thin wrapper around PyTorch pipeline schedule classes.
- Parameters:
n_microbatches – How many microbatches to split the global training batch into. If global training batch size must be evenly divisible by this. If not specified, the default will be the number of pipeline stages.