distributed.tensors#
Distributed tensor and parameter classes.
- class olmo_core.distributed.tensors.ShardedFlatTensor(data: Tensor, requires_grad: bool = False)[source]#
Bases:
TensorShardedFlatTensorrepresents a sharded tensor with the assumption that every shard is a contiguous slice into the flattened unsharded tensor.- classmethod shard(tensor, sharding_spec=None, process_group=None, synchronize=True, device=None, requires_grad=None)[source]#
Shard a tensor across a process group.
- Return type:
TypeVar(T, bound= ShardedFlatTensor)
- gather(dtype=None, rank0_only=False)[source]#
Gather the sharded flat parameter across a process group into a full unsharded parameter.
- Return type:
- unshard_(unsharded_data=None, dtype=None, rank0_only=False)[source]#
Unshard this parameter’s data in-place. You should generally call
reshard_()afterwards.If
rank0_only=True, non rank 0 processes will have an empty tensor in their data.
- reshard_(writeback=False)[source]#
Reshard this parameter’s data in-place. Should only be called after
unshard_(). This does not do anything with the parameter’s gradient, if it has one. That should be handled separately by the calling code.
- wrap(tensor, requires_grad=None)[source]#
Wrap another tensor and mark as sharded with the same sharding spec.
tensorshould have the same shape.- Return type:
- chunk_unsharded(tensor, pad=False)[source]#
Chunk an unsharded tensor with the same shape as
self.unsharded_shapeand split it into flat chunks where each chunk has the shape of sharded data corresponding to that rank.
- sharded_chunk(tensor)[source]#
Get this rank’s sharded chunk of an unsharded tensor with the same shape as
self.unsharded_shape.- Return type:
- property sharding_spec: ShardingSpec#
- class olmo_core.distributed.tensors.ShardedFlatParameter(data: Tensor | None = None, requires_grad: bool = True)[source]#
Bases:
ShardedFlatTensor,ParameterA
Parameterversion ofShardedFlatTensor.
- class olmo_core.distributed.tensors.ShardingSpec(unsharded_shape, unsharded_flattened_offsets)[source]#
Bases:
object-
unsharded_flattened_offsets:
Tuple[Tuple[Tuple[int,int],...],...]# The offsets (
(start_idx, end_idx)) within the full unsharded flattened parameter that each local shard within the process group corresponds to.This tuple is indexed by rank within the process group. For example, the offsets within the full unsharded flattened parameter for the local shard of the current rank is given by
unsharded_flattened_offsets[dist.get_rank(process_group)].
-
unsharded_flattened_offsets:
Helper functions for dealing with PyTorch’s DTensor.
- olmo_core.distributed.tensors.dtensor_utils.get_local_shape_and_global_offset(dtensor, rank=None)[source]#
Like
compute_local_shape_and_global_offset(), but acts directly on aDTensorinstance.
- olmo_core.distributed.tensors.dtensor_utils.compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=None)[source]#
Compute the local tensor shape and the global offsets into the original tensor of a DTensor on its current global rank. This is useful for checkpointing purpose.
- Parameters:
global_shape (
Union[Size,List[int],Tuple[int,...]]) – The shape of the global unsharded tensor.mesh (
DeviceMesh) – The device mesh.placements (
Sequence[Placement]) – The placements of theDTensor.rank (
Optional[int], default:None) – The global rank to compute the local shape and global offsets for. IfNone, defaults to the current rank.
- Return type:
- Returns:
The local shape and global offset.
Example (2 host with 4GPUs each):
# Below is a DeviceMesh with mesh_shape of (2, 4) mesh = DeviceMesh(device_type="cuda", mesh=[ [0, 1, 2, 3], [4, 5, 6, 7] ])
Let’s say we distribute a global_tensor of shape
(8,4)over the above DeviceMesh with a placements of[Shard(0), Shard(0)].The local shape and global offset will be as follows:
rank0 -- local_shape:[1, 4], global_offset:[0, 0]rank1 -- local_shape:[1, 4], global_offset:[1, 0]rank2 -- local_shape:[1, 4], global_offset:[2, 0]rank5 -- local_shape:[1, 4], global_offset:[5, 0]rank3 -- local_shape:[1, 4], global_offset:[3, 0]rank4 -- local_shape:[1, 4], global_offset:[4, 0]rank6 -- local_shape:[1, 4], global_offset:[6, 0]rank7 -- local_shape:[1, 4], global_offset:[7, 0]
Let’s say we distribute a global_tensor of shape
(2,)over the above DeviceMesh with a placements of[Shard(0)]. We will not have non-empty local tensor for all the ranks.The local shape and global offset will be as follows:
rank0 -- local_shape:[1,], global_offset:[0,]rank1 -- local_shape:[1,], global_offset:[1,]rank2 -- local_shape:[0,], global_offset:[2,]rank5 -- local_shape:[0,], global_offset:[2,]rank3 -- local_shape:[0,], global_offset:[2,]rank4 -- local_shape:[0,], global_offset:[2,]rank6 -- local_shape:[0,], global_offset:[2,]rank7 -- local_shape:[0,], global_offset:[2,]