float8¶
Utilities for training in Float8 via torchao.
- class olmo_core.float8.Float8Config(ao=None, ao_recipe=None, ao_mx=None, modules_to_ignore=None, enabled=True)[source]¶
Bases:
ConfigA configuration class for specifying Float8 options.
- Parameters:
ao (
Optional[AOFloat8LinearConfig], default:None) – A torchaoFloat8Linearlinear configuration.ao_recipe (
Optional[AOFloat8LinearRecipe], default:None) – Alternatively you can specify a recipe name from torchao.ao_mx (
Optional[AOMXLinearConfig], default:None) – A torchaoMXLinearConfigconfiguration for MX formats (MXFP8/MXFP4).enabled (
bool, default:True) – IfFalsethis will be a no-op.
-
modules_to_ignore:
Optional[List[str]] = None¶ A set of fully-qualified module names to ignore for Float8 conversion.
- apply_float8_linear(model, *, modules_to_ignore=None)[source]¶
This method converts the linear layers of
modeltoFloat8LinearorMXLinear.Warning
This will mutate the model in place.
Warning
This should be called before compiling the model, applying activation checkpointing, or wrapping it with FSDP(2) or any other parallel wrapper.
- class olmo_core.float8.AOFloat8LinearConfig(cast_config_input=None, cast_config_input_for_grad_weight=None, cast_config_weight=None, cast_config_weight_for_grad_input=None, cast_config_grad_output=None, cast_config_grad_output_for_grad_weight=None, gemm_config_output=None, gemm_config_grad_input=None, gemm_config_grad_weight=None, enable_fsdp_float8_all_gather=None, pad_inner_dim=None, emulate=None, force_recompute_fp8_weight_in_bwd=None, round_scales_to_power_of_2=None)[source]¶
Bases:
Config,_AOTypePlaceholder[Float8LinearConfig]This matches the config from torchao.
- class olmo_core.float8.AOFloat8LinearRecipe(value)[source]¶
Bases:
_AOTypePlaceholder[Float8LinearRecipeName],StrEnumAn enumeration.
- class olmo_core.float8.AOMXLinearConfig(block_size=None, elem_dtype=None, elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, kernel_preference=None, mxfp8_cast_kernel_choice=None, scale_calculation_mode=None)[source]¶
Bases:
Config,_AOTypePlaceholder[MXLinearConfig]This matches the config from torchao. Applies to MXFP8 and MXFP4 formats. https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/config.py#L106
Useful reference for MXFP8 training: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html
-
elem_dtype:
Optional[DType] = None¶ element dtype, used for activations, weights and gradients, defaults to e4m3fn if not specified
-
elem_dtype_grad_output_override:
Optional[DType] = None¶ optional element dtype override for gradients. note that e4m3 is thought to be fine here because of the block-wise nature of MXFP8.
-
kernel_preference:
Optional[AOKernelPreference] = None¶ if the preferred kernel is not supported on the given hardware an exception will be thrown
-
mxfp8_cast_kernel_choice:
Optional[AOMXFP8Dim1CastKernelChoice] = None¶ which kernel to use for the mx fp8 cast along dim1 (dim0 is always torch). torch is slow. cuda is fastest. triton only supports “floor” scale calculation mode.
-
scale_calculation_mode:
Optional[AOScaleCalculationMode] = None¶ how to calculate the mx block scaling factors. * floor [default]: strightforward method but most prone to overflow / bad for gradient calculation (dont use) * rceil (ratio ceil): computes the tightest valid ceiling. has good support from nvidia. * ceil: similar to floor but avoids overflow; prone to underflow / precision loss / quant to zero. * even: best choice from a mathematical standpoint. unbiased error distribution. but does not yet work with torch.compile.
-
elem_dtype: