Source code for olmo_core.train.callbacks.hf_converter
"""Callback for converting the final checkpoint to HuggingFace format at the end of training."""importloggingfromdataclassesimportdataclassfromtypingimportAny,ClassVar,Dict,Optionalimporttorchimporttorch.distributed.checkpoint.state_dictasdist_cp_sdfromolmo_core.configimportDTypefromolmo_core.distributed.utilsimportbarrier,get_rankfrom.callbackimportCallbackfrom.checkpointerimportCheckpointerCallbacklog=logging.getLogger(__name__)
[docs]@dataclassclassHFConverterCallback(Callback):""" Converts the final saved checkpoint to HuggingFace format at the end of a training job. This callback runs after training completes and uses :func:`olmo_core.nn.hf.convert_checkpoint_to_hf` to convert the final OLMo Core checkpoint to a HuggingFace-compatible format. .. note:: This callback requires the ``transformers`` library to be installed. .. warning:: In distributed training, ALL ranks must participate in this callback because gathering the full model state dict from FSDP requires collective operations. Only rank 0 performs the actual HF conversion and saving. """priority:ClassVar[int]=-1# Run after checkpointer callback.enabled:bool=True""" Whether this callback is enabled. Set to ``False`` to disable HF conversion. """output_folder:Optional[str]=None""" The folder to save the HuggingFace checkpoint to. If not specified, defaults to ``{checkpoint_path}-hf`` where ``checkpoint_path`` is the final checkpoint path. """dtype:Optional[DType]=DType.bfloat16""" The dtype to save the HuggingFace model weights as. Defaults to bfloat16. """validate:bool=False""" Whether to validate the converted model against the original model. Validation loads both models and compares their outputs. """debug:bool=False""" Whether to output debug information during validation. Only has an effect if ``validate`` is ``True``. """tokenizer_id:Optional[str]=None""" The HuggingFace tokenizer identifier to save with the model. If not specified, uses the tokenizer from the experiment config. """max_sequence_length:Optional[int]=None""" The maximum sequence length for the model. If not specified, uses the tokenizer's default max length. """device:Optional[str]=None""" The device to use for conversion. Defaults to CPU. """moe_capacity_factor:Optional[float]=None""" The MoE capacity factor. Higher values can decrease validation false negatives but may cause OOM errors. Only relevant for MoE models. """def_get_checkpointer_callback(self)->Optional[CheckpointerCallback]:forcallbackinself.trainer.callbacks.values():ifisinstance(callback,CheckpointerCallback):returncallbackreturnNonedef_get_latest_checkpoint_path(self)->Optional[str]:checkpointer=self._get_checkpointer_callback()ifcheckpointerisNone:log.warning("CheckpointerCallback not found, cannot determine latest checkpoint path")returnNoneifcheckpointer._latest_checkpoint_path:returncheckpointer._latest_checkpoint_pathifcheckpointer._checkpoints:returncheckpointer._checkpoints[-1]returnNonedef_get_full_model_state_dict(self)->Dict[str,Any]:""" Get the full model state dict from the trainer's model. This is a collective operation - ALL ranks must call this method. The full state dict is gathered to rank 0. """model=self.trainer.train_module.model# full_state_dict=True gathers the complete model state to rank 0.# cpu_offload=True avoids GPU OOM for large models.sd_options=dist_cp_sd.StateDictOptions(full_state_dict=True,cpu_offload=True)returndist_cp_sd.get_model_state_dict(model,options=sd_options)defpost_train(self):# NOTE: In distributed training with FSDP, getting the full model state dict requires# ALL ranks to participate in the collective operation. Only rank 0 performs the actual# HF conversion; all ranks synchronize at a barrier before returning.ifnotself.enabled:log.info("HFConverterCallback is disabled, skipping conversion")barrier()returncheckpoint_path=self._get_latest_checkpoint_path()ifcheckpoint_pathisNone:log.warning("No checkpoint found, skipping HF conversion")barrier()returntry:fromolmo_core.nn.hfimportconvert_checkpoint_to_hf,load_configexceptImportError:log.error("Failed to import HF conversion utilities. ""Make sure the 'transformers' library is installed.")barrier()returnexperiment_config:Optional[dict]=Noneifget_rank()==0:try:experiment_config=load_config(checkpoint_path)exceptExceptionase:log.error(f"Failed to load config from checkpoint: {e}")# ALL ranks must participate in gathering the full state dict (FSDP collective).log.info("Gathering full model state dict (collective operation)...")try:model_state_dict=self._get_full_model_state_dict()exceptExceptionase:log.error(f"Failed to get model state dict: {e}")barrier()raiseifget_rank()==0:log.info(f"Converting checkpoint at '{checkpoint_path}' to HuggingFace format")ifexperiment_configisNone:log.error("Experiment config not found in checkpoint, cannot convert to HF format")barrier()returntransformer_config_dict=experiment_config.get("model")tokenizer_config_dict=experiment_config.get("dataset",{}).get("tokenizer")iftransformer_config_dictisNone:log.error("Model config not found in experiment config, cannot convert to HF format")barrier()returniftokenizer_config_dictisNone:log.warning("Tokenizer config not found in experiment config, ""conversion will proceed without tokenizer")tokenizer_config_dict={}ifself.output_folderisnotNone:output_path=self.output_folderelse:output_path=checkpoint_path+"-hf"device=torch.device(self.device)ifself.deviceelseNonetry:convert_checkpoint_to_hf(original_checkpoint_path=checkpoint_path,output_path=output_path,transformer_config_dict=transformer_config_dict,tokenizer_config_dict=tokenizer_config_dict,model_state_dict=model_state_dict,dtype=self.dtype,tokenizer_id=self.tokenizer_id,max_sequence_length=self.max_sequence_length,validate=self.validate,debug=self.debug,device=device,moe_capacity_factor=self.moe_capacity_factor,)log.info(f"Successfully converted checkpoint to HuggingFace format at '{output_path}'")exceptExceptionase:log.error(f"Failed to convert checkpoint to HuggingFace format: {e}")barrier()raisebarrier()