float8

Utilities for training in Float8 via torchao.

class olmo_core.float8.Float8Config(ao=None, ao_recipe=None, ao_mx=None, modules_to_ignore=None, enabled=True)[source]

Bases: Config

A configuration class for specifying Float8 options.

Parameters:
modules_to_ignore: Optional[List[str]] = None

A set of fully-qualified module names to ignore for Float8 conversion.

validate()[source]

Validate fields in self. This may modify self in-place.

apply_float8_linear(model, *, modules_to_ignore=None)[source]

This method converts the linear layers of model to Float8Linear or MXLinear.

Warning

This will mutate the model in place.

Warning

This should be called before compiling the model, applying activation checkpointing, or wrapping it with FSDP(2) or any other parallel wrapper.

class olmo_core.float8.AOFloat8LinearConfig(cast_config_input=None, cast_config_input_for_grad_weight=None, cast_config_weight=None, cast_config_weight_for_grad_input=None, cast_config_grad_output=None, cast_config_grad_output_for_grad_weight=None, gemm_config_output=None, gemm_config_grad_input=None, gemm_config_grad_weight=None, enable_fsdp_float8_all_gather=None, pad_inner_dim=None, emulate=None, force_recompute_fp8_weight_in_bwd=None, round_scales_to_power_of_2=None)[source]

Bases: Config, _AOTypePlaceholder[Float8LinearConfig]

This matches the config from torchao.

class olmo_core.float8.AOFloat8LinearRecipe(value)[source]

Bases: _AOTypePlaceholder[Float8LinearRecipeName], StrEnum

An enumeration.

class olmo_core.float8.AOMXLinearConfig(block_size=None, elem_dtype=None, elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, kernel_preference=None, mxfp8_cast_kernel_choice=None, scale_calculation_mode=None)[source]

Bases: Config, _AOTypePlaceholder[MXLinearConfig]

This matches the config from torchao. Applies to MXFP8 and MXFP4 formats. https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/config.py#L106

Useful reference for MXFP8 training: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html

block_size: Optional[int] = None

block size, defaults to 32 if not specified

elem_dtype: Optional[DType] = None

element dtype, used for activations, weights and gradients, defaults to e4m3fn if not specified

elem_dtype_weight_override: Optional[DType] = None

optional element dtype override for weights

elem_dtype_grad_output_override: Optional[DType] = None

optional element dtype override for gradients. note that e4m3 is thought to be fine here because of the block-wise nature of MXFP8.

kernel_preference: Optional[AOKernelPreference] = None

if the preferred kernel is not supported on the given hardware an exception will be thrown

mxfp8_cast_kernel_choice: Optional[AOMXFP8Dim1CastKernelChoice] = None

which kernel to use for the mx fp8 cast along dim1 (dim0 is always torch). torch is slow. cuda is fastest. triton only supports “floor” scale calculation mode.

scale_calculation_mode: Optional[AOScaleCalculationMode] = None

how to calculate the mx block scaling factors. * floor [default]: strightforward method but most prone to overflow / bad for gradient calculation (dont use) * rceil (ratio ceil): computes the tightest valid ceiling. has good support from nvidia. * ceil: similar to floor but avoids overflow; prone to underflow / precision loss / quant to zero. * even: best choice from a mathematical standpoint. unbiased error distribution. but does not yet work with torch.compile.

classmethod mxfp8_cublas_rceil(**kwargs)[source]

standard mxfp8 recipe predefined in torchao

Return type:

AOMXLinearConfig