[docs]classDPMeshDimName(StrEnum):""" ``DeviceMesh`` dimension names for data parallelism. """replicate="dp_replicate"""" The device mesh dimension over which the model is replicated. """shard="dp_shard"""" The device mesh dimension over which the model is sharded. """
[docs]defget_replicate_and_shard_degree(self,dp_world_size:int)->Tuple[int,int]:""" Defaults to one replica per node, with the shard degree set to the number of gpus per node. :param dp_world_size: The data parallel world size. :return: A tuple of (num_replicas, shard_degree) """ifself.num_replicasisNoneandself.shard_degreeisNone:returnget_num_nodes(),dp_world_size//get_num_nodes()elifself.num_replicasisnotNoneandself.shard_degreeisnotNone:return_check_num_replicas(self.num_replicas,dp_world_size),_check_shard_degree(self.shard_degree,dp_world_size)elifself.num_replicasisnotNone:return(_check_num_replicas(self.num_replicas,dp_world_size),dp_world_size//self.num_replicas,)else:assertself.shard_degreeisnotNonereturndp_world_size//self.shard_degree,_check_shard_degree(self.shard_degree,dp_world_size)
def_check_num_replicas(num_replicas:int,dp_world_size:int)->int:ifdp_world_size%num_replicas!=0:raiseOLMoConfigurationError(f"data parallel world size ({dp_world_size}) must be "f"divisible by 'num_replicas' ({num_replicas})")returnnum_replicasdef_check_shard_degree(shard_degree:int,dp_world_size:int)->int:ifdp_world_size%shard_degree!=0:raiseOLMoConfigurationError(f"data parallel world size ({dp_world_size}) must be "f"divisible by 'shard_degree' ({shard_degree})")returnshard_degree