Source code for olmo_core.distributed.parallel.tensor_parallel

import logging
from dataclasses import dataclass
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Placement, Shard, distribute_module
from torch.distributed.tensor.parallel import SequenceParallel as _SequenceParallel

from olmo_core.config import Config

log = logging.getLogger(__name__)


[docs] @dataclass class TensorParallelConfig(Config): """ Configuration class for tensor parallelism (TP). """ degree: int """ The TP degree. """ enable_async: bool = False """ Enable experimental async tensor parallelism. """ def maybe_enable_async_tp(self, tp_mesh: DeviceMesh): if self.enable_async: log.info("Enabling async tensor parallel") from torch.distributed._symmetric_memory import enable_symm_mem_for_group torch._inductor.config._micro_pipeline_tp = True # type: ignore enable_symm_mem_for_group(tp_mesh.get_group().group_name)
class SequenceParallel(_SequenceParallel): def __init__( self, *, sequence_dim: int = 1, use_local_output: bool = False, output_layouts: Optional[Placement] = None, ): super().__init__(sequence_dim=sequence_dim, use_local_output=use_local_output) self.output_layouts = (output_layouts or Shard(sequence_dim),) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): del mod, device_mesh if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) return outputs.to_local() if use_local_output else outputs def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, self._replicate_module_fn, partial(self._prepare_input_fn, self.sequence_sharding), # type: ignore partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), # type: ignore )