Source code for olmo_core.float8.ao

from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, TypeVar

from ..config import Config, DType, StrEnum

if TYPE_CHECKING:
    from torchao.float8.config import (
        CastConfig,
        Float8GemmConfig,
        Float8LinearConfig,
        Float8LinearRecipeName,
        ScalingGranularity,
        ScalingType,
    )
    from torchao.prototype.mx_formats.config import (
        MXFP8Dim1CastKernelChoice,
        MXLinearConfig,
        ScaleCalculationMode,
    )
    from torchao.quantization.quantize_.common.kernel_preference import KernelPreference


T = TypeVar("T")


class _AOTypePlaceholder(Generic[T]):
    @property
    @abstractmethod
    def ao_type(self) -> Type[T]:
        raise NotImplementedError

    def to_ao_type(self) -> T:
        if isinstance(self, Config):
            kwargs: Dict[str, Any] = {}
            for k, v in self.as_dict(exclude_none=True, recurse=False).items():
                if isinstance(v, _AOTypePlaceholder):
                    v = v.to_ao_type()
                elif isinstance(v, DType):
                    v = v.as_pt()
                kwargs[k] = v

            return self.ao_type(**kwargs)
        elif isinstance(self, StrEnum):
            for option in self.ao_type:  # type: ignore
                if option.value == self:
                    return option
            else:
                raise ValueError(self)
        else:
            raise NotImplementedError


class AOScalingType(_AOTypePlaceholder["ScalingType"], StrEnum):
    dynamic = "dynamic"
    disabled = "disabled"

    @property
    def ao_type(self) -> Type["ScalingType"]:
        from torchao.float8.config import ScalingType

        return ScalingType


class AOScalingGranularity(_AOTypePlaceholder["ScalingGranularity"], StrEnum):
    tensorwise = "tensorwise"
    axiswise = "axiswise"

    @property
    def ao_type(self) -> Type["ScalingGranularity"]:
        from torchao.float8.config import ScalingGranularity

        return ScalingGranularity


@dataclass
class AOCastConfig(Config, _AOTypePlaceholder["CastConfig"]):
    scaling_type: Optional[AOScalingType] = None
    scaling_granularity: Optional[AOScalingGranularity] = None
    target_dtype: Optional[DType] = None

    @property
    def ao_type(self) -> Type["CastConfig"]:
        from torchao.float8.config import CastConfig

        return CastConfig


@dataclass
class AOFloat8GemmConfig(Config, _AOTypePlaceholder["Float8GemmConfig"]):
    use_fast_accum: Optional[bool] = False

    @property
    def ao_type(self) -> Type["Float8GemmConfig"]:
        from torchao.float8.config import Float8GemmConfig

        return Float8GemmConfig


[docs] class AOFloat8LinearRecipe(_AOTypePlaceholder["Float8LinearRecipeName"], StrEnum): tensorwise = "tensorwise" rowwise = "rowwise" rowwise_with_gw_hp = "rowwise_with_gw_hp" @property def ao_type(self) -> Type["Float8LinearRecipeName"]: from torchao.float8.config import Float8LinearRecipeName return Float8LinearRecipeName
class AOKernelPreference(_AOTypePlaceholder["KernelPreference"], StrEnum): emulated = "emulated" auto = "auto" cuda = "cuda" torch = "torch" @property def ao_type(self) -> Type["KernelPreference"]: from torchao.quantization.quantize_.common.kernel_preference import ( KernelPreference, ) return KernelPreference class AOMXFP8Dim1CastKernelChoice(_AOTypePlaceholder["MXFP8Dim1CastKernelChoice"], StrEnum): torch = "torch" cuda = "cuda" triton = "triton" @property def ao_type(self) -> Type["MXFP8Dim1CastKernelChoice"]: from torchao.prototype.mx_formats.config import MXFP8Dim1CastKernelChoice return MXFP8Dim1CastKernelChoice class AOScaleCalculationMode(_AOTypePlaceholder["ScaleCalculationMode"], StrEnum): floor = "floor" rceil = "rceil" ceil = "ceil" even = "even" @property def ao_type(self) -> Type["ScaleCalculationMode"]: from torchao.prototype.mx_formats.config import ScaleCalculationMode return ScaleCalculationMode
[docs] @dataclass class AOFloat8LinearConfig(Config, _AOTypePlaceholder["Float8LinearConfig"]): """ This matches the config from torchao. """ cast_config_input: Optional[AOCastConfig] = None cast_config_input_for_grad_weight: Optional[AOCastConfig] = None cast_config_weight: Optional[AOCastConfig] = None cast_config_weight_for_grad_input: Optional[AOCastConfig] = None cast_config_grad_output: Optional[AOCastConfig] = None cast_config_grad_output_for_grad_weight: Optional[AOCastConfig] = None gemm_config_output: Optional[AOFloat8GemmConfig] = None gemm_config_grad_input: Optional[AOFloat8GemmConfig] = None gemm_config_grad_weight: Optional[AOFloat8GemmConfig] = None enable_fsdp_float8_all_gather: Optional[bool] = None pad_inner_dim: Optional[bool] = None emulate: Optional[bool] = None force_recompute_fp8_weight_in_bwd: Optional[bool] = None # deprecated, no effect round_scales_to_power_of_2: Optional[bool] = None @staticmethod def recommended(**kwargs: Any) -> "AOFloat8LinearConfig": return AOFloat8LinearConfig( enable_fsdp_float8_all_gather=True, force_recompute_fp8_weight_in_bwd=True, round_scales_to_power_of_2=True, **kwargs, ) @property def ao_type(self) -> Type["Float8LinearConfig"]: from torchao.float8.config import Float8LinearConfig return Float8LinearConfig
[docs] @dataclass class AOMXLinearConfig(Config, _AOTypePlaceholder["MXLinearConfig"]): """ This matches the config from torchao. Applies to MXFP8 and MXFP4 formats. https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/config.py#L106 Useful reference for MXFP8 training: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html """ block_size: Optional[int] = None """block size, defaults to 32 if not specified""" elem_dtype: Optional[DType] = None """element dtype, used for activations, weights and gradients, defaults to e4m3fn if not specified""" elem_dtype_weight_override: Optional[DType] = None """optional element dtype override for weights""" elem_dtype_grad_output_override: Optional[DType] = None """ optional element dtype override for gradients. note that e4m3 is thought to be fine here because of the block-wise nature of MXFP8. """ kernel_preference: Optional[AOKernelPreference] = None """if the preferred kernel is not supported on the given hardware an exception will be thrown""" mxfp8_cast_kernel_choice: Optional[AOMXFP8Dim1CastKernelChoice] = None """ which kernel to use for the mx fp8 cast along dim1 (dim0 is always torch). torch is slow. cuda is fastest. triton only supports "floor" scale calculation mode. """ scale_calculation_mode: Optional[AOScaleCalculationMode] = None """ how to calculate the mx block scaling factors. * floor [default]: strightforward method but most prone to overflow / bad for gradient calculation (dont use) * rceil (ratio ceil): computes the tightest valid ceiling. has good support from nvidia. * ceil: similar to floor but avoids overflow; prone to underflow / precision loss / quant to zero. * even: best choice from a mathematical standpoint. unbiased error distribution. but does not yet work with torch.compile. """
[docs] @classmethod def mxfp8_cublas_rceil(cls, **kwargs: Any) -> "AOMXLinearConfig": """standard mxfp8 recipe predefined in torchao""" return AOMXLinearConfig( mxfp8_cast_kernel_choice=AOMXFP8Dim1CastKernelChoice.cuda, scale_calculation_mode=AOScaleCalculationMode.rceil, **kwargs, )
@property def ao_type(self) -> Type["MXLinearConfig"]: from torchao.prototype.mx_formats import MXLinearConfig return MXLinearConfig