distributed.tensors#

Distributed tensor and parameter classes.

class olmo_core.distributed.tensors.ShardedFlatTensor(data: Tensor, requires_grad: bool = False)[source]#

Bases: Tensor

ShardedFlatTensor represents a sharded tensor with the assumption that every shard is a contiguous slice into the flattened unsharded tensor.

classmethod shard(tensor, sharding_spec=None, process_group=None, synchronize=True, device=None, requires_grad=None)[source]#

Shard a tensor across a process group.

Return type:

TypeVar(T, bound= ShardedFlatTensor)

gather(dtype=None, rank0_only=False)[source]#

Gather the sharded flat parameter across a process group into a full unsharded parameter.

Return type:

Tensor

unshard_(unsharded_data=None, dtype=None, rank0_only=False)[source]#

Unshard this parameter’s data in-place. You should generally call reshard_() afterwards.

If rank0_only=True, non rank 0 processes will have an empty tensor in their data.

reshard_(writeback=False)[source]#

Reshard this parameter’s data in-place. Should only be called after unshard_(). This does not do anything with the parameter’s gradient, if it has one. That should be handled separately by the calling code.

mark_as_sharded(sharding_spec, process_group=None)[source]#
wrap(tensor, requires_grad=None)[source]#

Wrap another tensor and mark as sharded with the same sharding spec. tensor should have the same shape.

Return type:

ShardedFlatTensor

chunk_unsharded(tensor, pad=False)[source]#

Chunk an unsharded tensor with the same shape as self.unsharded_shape and split it into flat chunks where each chunk has the shape of sharded data corresponding to that rank.

Parameters:

pad (bool, default: False) – Whether or not to add right padding to the chunks to ensure they’re all the same size.

Return type:

List[Tensor]

sharded_chunk(tensor)[source]#

Get this rank’s sharded chunk of an unsharded tensor with the same shape as self.unsharded_shape.

Return type:

Tensor

property metadata_set: bool#
property is_sharded: bool#
property sharding_spec: ShardingSpec#
property process_group: ProcessGroup | None#
property unsharded_flattened_offsets: Tuple[Tuple[int, int], ...]#
property unsharded_numel: int#
property unsharded_shape: Tuple[int, ...]#
property sharded_numel: int#
property sharded_shape: Tuple[int, ...]#
property sharded_data: Tensor#
property unsharded_data: Tensor | None#
class olmo_core.distributed.tensors.ShardedFlatParameter(data: Tensor | None = None, requires_grad: bool = True)[source]#

Bases: ShardedFlatTensor, Parameter

A Parameter version of ShardedFlatTensor.

class olmo_core.distributed.tensors.ShardingSpec(unsharded_shape, unsharded_flattened_offsets)[source]#

Bases: object

unsharded_shape: Tuple[int, ...]#

The shape of the full unsharded (unflattened) parameter.

unsharded_flattened_offsets: Tuple[Tuple[Tuple[int, int], ...], ...]#

The offsets ((start_idx, end_idx)) within the full unsharded flattened parameter that each local shard within the process group corresponds to.

This tuple is indexed by rank within the process group. For example, the offsets within the full unsharded flattened parameter for the local shard of the current rank is given by unsharded_flattened_offsets[dist.get_rank(process_group)].

property unsharded_numel: int#

The number of elements in the full unsharded tensor.

property sharded_numels: Tuple[int, ...]#

The number of elements in each shard.

property unsharded_flattened_shape: Tuple[int, ...]#

The shape of the unsharded flattened tensor.

Helper functions for dealing with PyTorch’s DTensor.

olmo_core.distributed.tensors.dtensor_utils.get_local_shape_and_global_offset(dtensor, rank=None)[source]#

Like compute_local_shape_and_global_offset(), but acts directly on a DTensor instance.

Parameters:
  • dtensor (DTensor) – A DTensor instance.

  • rank (Optional[int], default: None) – The global rank to compute the local shape and global offsets for. If None, defaults to the current rank.

Return type:

Tuple[Tuple[int, ...], Tuple[int, ...]]

Returns:

The local shape and global offset.

olmo_core.distributed.tensors.dtensor_utils.compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=None)[source]#

Compute the local tensor shape and the global offsets into the original tensor of a DTensor on its current global rank. This is useful for checkpointing purpose.

Parameters:
  • global_shape (Union[Size, List[int], Tuple[int, ...]]) – The shape of the global unsharded tensor.

  • mesh (DeviceMesh) – The device mesh.

  • placements (Sequence[Placement]) – The placements of the DTensor.

  • rank (Optional[int], default: None) – The global rank to compute the local shape and global offsets for. If None, defaults to the current rank.

Return type:

Tuple[Tuple[int, ...], Tuple[int, ...]]

Returns:

The local shape and global offset.

Example (2 host with 4GPUs each):

# Below is a DeviceMesh with mesh_shape of (2, 4)
mesh = DeviceMesh(device_type="cuda", mesh=[
    [0, 1, 2, 3],
    [4, 5, 6, 7]
])

Let’s say we distribute a global_tensor of shape (8,4) over the above DeviceMesh with a placements of [Shard(0), Shard(0)].

The local shape and global offset will be as follows:

  • rank0 -- local_shape:[1, 4], global_offset:[0, 0]

  • rank1 -- local_shape:[1, 4], global_offset:[1, 0]

  • rank2 -- local_shape:[1, 4], global_offset:[2, 0]

  • rank5 -- local_shape:[1, 4], global_offset:[5, 0]

  • rank3 -- local_shape:[1, 4], global_offset:[3, 0]

  • rank4 -- local_shape:[1, 4], global_offset:[4, 0]

  • rank6 -- local_shape:[1, 4], global_offset:[6, 0]

  • rank7 -- local_shape:[1, 4], global_offset:[7, 0]

Let’s say we distribute a global_tensor of shape (2,) over the above DeviceMesh with a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks.

The local shape and global offset will be as follows:

  • rank0 -- local_shape:[1,], global_offset:[0,]

  • rank1 -- local_shape:[1,], global_offset:[1,]

  • rank2 -- local_shape:[0,], global_offset:[2,]

  • rank5 -- local_shape:[0,], global_offset:[2,]

  • rank3 -- local_shape:[0,], global_offset:[2,]

  • rank4 -- local_shape:[0,], global_offset:[2,]

  • rank6 -- local_shape:[0,], global_offset:[2,]

  • rank7 -- local_shape:[0,], global_offset:[2,]