[docs]@dataclassclassWandBCallback(Callback):""" Logs metrics to Weights & Biases from rank 0. .. important:: Requires the ``wandb`` package and the environment variable ``WANDB_API_KEY``. .. note:: This callback logs metrics from every single step to W&B, regardless of the value of :data:`Trainer.metrics_collect_interval <olmo_core.train.Trainer.metrics_collect_interval>`. """enabled:bool=True""" Set to false to disable this callback. """name:Optional[str]=None""" The name to give the W&B run. """project:Optional[str]=None""" The W&B project to use. """entity:Optional[str]=None""" The W&B entity to use. """group:Optional[str]=None""" The W&B group to use. """tags:Optional[List[str]]=None""" Tags to assign the run. """notes:Optional[str]=None""" A note/description of the run. """config:Optional[Dict[str,Any]]=None""" The config to load to W&B. """cancel_tags:Optional[List[str]]=field(default_factory=lambda:["cancel","canceled","cancelled"])""" If you add any of these tags to a run on W&B, the run will cancel itself. Defaults to ``["cancel", "canceled", "cancelled"]``. """cancel_check_interval:Optional[int]=None""" Check for cancel tags every this many steps. Defaults to :data:`olmo_core.train.Trainer.cancel_check_interval`. """_wandb=None_run_path=None_finalized:bool=False@propertydefwandb(self):ifself._wandbisNone:importwandb# type: ignoreself._wandb=wandbreturnself._wandb@propertydefrun(self)->"Run":returnself.wandb.run@propertydefrun_path(self):returnself._run_path@propertydeffinalized(self)->bool:returnself._finalizeddeffinalize(self,exit_code:int=0):ifnotself.finalized:ifexit_code>0:log.warning("Finalizing failed W&B run...")else:log.info("Finalizing successful W&B run...")self.wandb.finish(exit_code=exit_code,quiet=True)self._finalized=Truedefpre_train(self):ifself.enabledandget_rank()==0:ifWANDB_API_KEY_ENV_VARnotinos.environ:raiseOLMoEnvironmentError(f"missing env var '{WANDB_API_KEY_ENV_VAR}'")self.wandbwandb_dir=self.trainer.work_dir/"wandb"wandb_dir.mkdir(parents=True,exist_ok=True)self.wandb.init(dir=wandb_dir,project=self.project,entity=self.entity,group=self.group,name=self.name,tags=self.tags,notes=self.notes,config=self.config,)self._run_path=self.run.path# type: ignoredeflog_metrics(self,step:int,metrics:Dict[str,float]):ifself.enabledandget_rank()==0:self.wandb.log(metrics,step=step)defpost_step(self):cancel_check_interval=self.cancel_check_intervalorself.trainer.cancel_check_intervalifself.enabledandget_rank()==0andself.step%cancel_check_interval==0:self.trainer.run_bookkeeping_op(self.check_if_canceled,allow_multiple=False,distributed=False,)defon_error(self,exc:BaseException):delexcifself.enabledandget_rank()==0andself.runisnotNone:self.finalize(exit_code=1)defclose(self):ifself.enabledandget_rank()==0andself.runisnotNone:self.finalize()defcheck_if_canceled(self):ifself.enabledandself.cancel_tags:fromrequests.exceptionsimportRequestExceptionfromwandb.errorsimportCommError# type: ignoretry:# NOTE: need to re-initialize the API client every time, otherwise# I guess it return cached run data.api=self.wandb.Api(api_key=os.environ[WANDB_API_KEY_ENV_VAR],timeout=5)run=api.run(self.run_path)# type: ignorefortaginrun.tagsor[]:iftag.lower()inself.cancel_tags:self.trainer.cancel_run("canceled from W&B tag")returnexcept(RequestException,CommError,TimeoutError):log.warning("Failed to communicate with W&B API")