"""Utilities for training in Float8 via `torchao <https://github.com/pytorch/ao>`_."""importloggingfromdataclassesimportdataclassfromtypingimportList,Optional,Setimporttorch.nnasnnfromolmo_core.utilsimporthas_compute_capabilityfrom..configimportConfigfrom..exceptionsimportOLMoConfigurationErrorfrom.aoimportAOFloat8LinearConfig,AOFloat8LinearRecipe,AOMXLinearConfig__all__=["Float8Config","AOFloat8LinearConfig","AOFloat8LinearRecipe","AOMXLinearConfig"]log=logging.getLogger(__name__)
[docs]@dataclassclassFloat8Config(Config):""" A configuration class for specifying Float8 options. :param ao: A torchao ``Float8Linear`` linear configuration. :param ao_recipe: Alternatively you can specify a recipe name from torchao. :param ao_mx: A torchao ``MXLinearConfig`` configuration for MX formats (MXFP8/MXFP4). :param enabled: If ``False`` this will be a no-op. """ao:Optional[AOFloat8LinearConfig]=Noneao_recipe:Optional[AOFloat8LinearRecipe]=Noneao_mx:Optional[AOMXLinearConfig]=Nonemodules_to_ignore:Optional[List[str]]=None"""A set of fully-qualified module names to ignore for Float8 conversion."""enabled:bool=Truedef__post_init__(self):self.validate()
[docs]defvalidate(self):config_count=sum([self.aoisnotNone,self.ao_recipeisnotNone,self.ao_mxisnotNone])ifconfig_count>1:raiseOLMoConfigurationError("'ao', 'ao_recipe', and 'ao_mx' configs are mutually exclusive")
[docs]defapply_float8_linear(self,model:nn.Module,*,modules_to_ignore:Optional[Set[str]]=None):""" This method converts the linear layers of ``model`` to ``Float8Linear`` or ``MXLinear``. .. warning:: This will mutate the model in place. .. warning:: This should be called before compiling the model, applying activation checkpointing, or wrapping it with FSDP(2) or any other parallel wrapper. """ifnotself.enabled:returnfromtorchao.utilsimporttorchself.validate()ignored_modules_found=set()defmodule_filter_fn(m:nn.Module,fqn:str)->bool:nonlocalignored_modules_foundifmodules_to_ignoreisnotNoneandfqninmodules_to_ignore:ignored_modules_found.add(fqn)returnFalse# Linear layers must have all dimensions divisible by 16.ifisinstance(m,nn.Linear):fordinm.weight.shape:ifd%16!=0:returnFalsereturnTruedefquantize_filter_fn(m:nn.Module,fqn:str)->bool:nonlocalignored_modules_foundifmodules_to_ignoreisnotNoneandfqninmodules_to_ignore:ignored_modules_found.add(fqn)returnFalseifisinstance(m,torch.nn.Linear)andhasattr(m,"weight"):returnTruereturnFalse# NOTE: there's a bug with `Float8Linear.from_float()` where it will override `requires_grad=False`# when `enable_fsdp_float8_all_gather=True`. So we have to reset frozen params after the fact.# https://github.com/pytorch/ao/issues/1871frozen_params:Set[str]=set()forn,pinmodel.named_parameters():ifnotp.requires_grad:frozen_params.add(n)# Handle MX format conversionifself.ao_mxisnotNone:ifnothas_compute_capability(10,0):raiseRuntimeError("MX format training is only supported on SM100 or later")fromtorchao.quantizationimportquantize_asao_quantize_mx_linear_config=self.ao_mx.to_ao_type()ao_quantize_(model,config=mx_linear_config,filter_fn=quantize_filter_fn,# !!! Opposite semantics of the module_filter_fn below)else:fromtorchao.float8importFloat8LinearConfig,convert_to_float8_training# Mutates the model in place, replacing instances of nn.Linear with Float8Linear.float8_linear_config:Float8LinearConfigifself.ao_recipeisnotNone:float8_linear_config=Float8LinearConfig.from_recipe_name(self.ao_recipe.to_ao_type())else:float8_linear_config=(self.aoifself.aoisnotNoneelseAOFloat8LinearConfig()).to_ao_type()convert_to_float8_training(model,config=float8_linear_config,module_filter_fn=module_filter_fn,)ifmodules_to_ignoreisnotNoneandmodules_to_ignore!=ignored_modules_found:raiseOLMoConfigurationError(f"invalid module name(s) in 'modules_to_ignore': {list(modules_to_ignore-ignored_modules_found)}")ifignored_modules_found:log.info(f"Ignored modules for Float8 conversion: {sorted(ignored_modules_found)}")forninfrozen_params:p=model.get_parameter(n)p.requires_grad=False