distributed.parallel

olmo_core.distributed.parallel.build_world_mesh(*, dp=None, tp=None, cp=None, pp=None, ep=None, device_type=None)[source]

Build a DeviceMesh suitable for the given parallel strategies.

See also

Pass the mesh created by this function to any of the get_*_mesh() functions in this module to get the right sub-mesh for a any given parallel strategy.

  • get_dp_model_mesh() gives you the 1 or 2D sub-mesh suitable for data parallel model wrappers like FSDP(2) or DDP.

  • get_dp_mesh() gives you the 1D sub-mesh suitable for configuring data loaders.

  • get_tp_mesh() gives you the 1D sub-mesh for tensor parallelism.

  • get_cp_mesh() gives you the 1D sub-mesh for context parallelism.

  • get_pp_mesh() gives you the 1D sub-mesh for pipeline parallelism.

  • get_ep_mesh() gives you the 1D sub-mesh for expert parallelism.

Important

A data parallel config is required if any other parallel config is set.

Important

Not all parallel strategies are compatible with each other.

Parameters:
Return type:

DeviceMesh

Returns:

The world mesh with a shape compatible with the given parallel configs.

olmo_core.distributed.parallel.get_world_mesh()[source]

Get the global world mesh built with build_world_mesh().

Return type:

Optional[DeviceMesh]

class olmo_core.distributed.parallel.MeshDimName(value)[source]

Bases: StrEnum

DeviceMesh dimensions names for different forms of parallelism. This are the dimension names that you will find in the mesh created by build_world_mesh().

dp = 'dp'

Data parallel (DP).

dp_replicate = 'dp_replicate'

The DP dimension over which the model is replicated.

dp_shard = 'dp_shard'

The DP dimension over which the model is sharded.

tp = 'tp'

Tensor parallel (TP).

cp = 'cp'

Context parallel (CP).

pp = 'pp'

Pipeline parallel (PP).

ep = 'ep'

Expert parallel (EP).

olmo_core.distributed.parallel.get_dp_model_mesh(device_mesh)[source]

Get the right sub-mesh for a data parallel model wrapper like FSDP or DDP from a DeviceMesh created by build_world_mesh().

Important

You should use get_dp_mesh() instead for getting the sub-mesh to assign ranks to data loading workers. In many cases these two functions will return the same result, but there are cases where they could be different.

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

DeviceMesh

olmo_core.distributed.parallel.get_dp_mesh(device_mesh)[source]

Get the data parallel sub-mesh associated from a DeviceMesh created by build_world_mesh().

Important

This is the mesh that should be used to assign ranks to data loading workers, however you should use get_dp_model_mesh() to get the mesh for DDP/FSDP.

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

DeviceMesh

olmo_core.distributed.parallel.get_tp_mesh(device_mesh)[source]

Get the tensor parallel sub-mesh associated with a DeviceMesh created from build_world_mesh().

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

DeviceMesh

olmo_core.distributed.parallel.get_cp_mesh(device_mesh)[source]

Get the context parallel sub-mesh associated with a DeviceMesh created from build_world_mesh().

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

DeviceMesh

olmo_core.distributed.parallel.get_pp_mesh(device_mesh)[source]

Get the pipeline parallel sub-mesh associated with a DeviceMesh created from build_world_mesh().

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

DeviceMesh

olmo_core.distributed.parallel.get_pp_stage_mesh(device_mesh, pp_mesh=None)[source]

Get the sub-mesh for a single pipeline stage.

Parameters:
Return type:

DeviceMesh

olmo_core.distributed.parallel.get_ep_mesh(device_mesh)[source]

Get the expert parallel sub-mesh associated with a DeviceMesh that was potentially created from build_world_mesh().

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

DeviceMesh

olmo_core.distributed.parallel.get_dp_process_group(device_mesh)[source]

Get the data parallel process group associated with a DeviceMesh created from build_world_mesh().

Like get_dp_mesh(), this should be used for data loading, but not necessarily for data parallel model wrappers.

Parameters:

device_mesh (DeviceMesh) – The world mesh created by build_world_mesh().

Return type:

ProcessGroup

olmo_core.distributed.parallel.get_device_mesh_info(device_mesh)[source]

Get a human-readable string representation of a DeviceMesh.

Parameters:

device_mesh (DeviceMesh) – The device mesh to get info for.

Return type:

str

olmo_core.distributed.parallel.flatten_mesh(device_mesh, name=None)[source]

Flatten a multi-dimensional DeviceMesh into a 1D DeviceMesh.

Parameters:
  • device_mesh (DeviceMesh) – The multi-dimensional DeviceMesh to flatten.

  • name (Optional[str], default: None) – Optional name for the flattened dimension.

Return type:

DeviceMesh

Important

The device_mesh is modified in-place.

class olmo_core.distributed.parallel.DataParallelType(value)[source]

Bases: StrEnum

An enumeration.

class olmo_core.distributed.parallel.DataParallelConfig(name, param_dtype=None, reduce_dtype='float32', num_replicas=None, shard_degree=None)[source]

Bases: Config

get_replicate_and_shard_degree(dp_world_size)[source]

Defaults to one replica per node, with the shard degree set to the number of gpus per node.

Parameters:

dp_world_size (int) – The data parallel world size.

Return type:

Tuple[int, int]

Returns:

A tuple of (num_replicas, shard_degree)

class olmo_core.distributed.parallel.DPMeshDimName(value)[source]

Bases: StrEnum

DeviceMesh dimension names for data parallelism.

replicate = 'dp_replicate'

The device mesh dimension over which the model is replicated.

shard = 'dp_shard'

The device mesh dimension over which the model is sharded.

class olmo_core.distributed.parallel.TensorParallelConfig(degree, enable_async=False)[source]

Bases: Config

Configuration class for tensor parallelism (TP).

degree: int

The TP degree.

enable_async: bool = False

Enable experimental async tensor parallelism.

class olmo_core.distributed.parallel.ExpertParallelConfig(degree)[source]

Bases: Config

Configuration class for expert parallelism (EP).

degree: int

The EP degree.

class olmo_core.distributed.parallel.PipelineParallelConfig(degree, schedule='Interleaved1F1B', style=None)[source]

Bases: Config

Configuration class for pipeline parallelism (PP).

degree: int

The PP degree.

schedule: PipelineScheduleType = 'Interleaved1F1B'

The name of the schedule.

style: Optional[PipelineSplitStyle] = None

The split style.

rank_completion_order()[source]

The order that ranks within the PP group will complete a batch.

Return type:

Iterable[int]

stage_ids_this_rank(pp_rank, num_stages)[source]

Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule.

Return type:

Tuple[int, ...]

class olmo_core.distributed.parallel.PipelineScheduleType(value)[source]

Bases: StrEnum

An enumeration of the different pipeline schedules available.

Warning

The zero-bubble variants have several issues at the moment including not being compatible with torch.compile.

class olmo_core.distributed.parallel.PipelineSplitStyle(value)[source]

Bases: StrEnum

An enumeration.

class olmo_core.distributed.parallel.PipelineSchedule(*, model_parts, stages, pp_mesh, schedule_name, loss_fn=None, num_microbatches=None)[source]

Bases: object

A thin wrapper around PyTorch pipeline schedule classes.

Parameters:

n_microbatches – How many microbatches to split the global training batch into. If global training batch size must be evenly divisible by this. If not specified, the default will be the number of pipeline stages.

step(*args, target=None, **kwargs)[source]
Parameters:
  • args – Only passed to first stage.

  • kwargs – Passed to all stages.

Return type:

Tuple[Any, Optional[Tensor]]