Source code for olmo_core.train.callbacks.comet

import logging
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

from olmo_core.config import StrEnum
from olmo_core.distributed.utils import get_rank
from olmo_core.exceptions import OLMoEnvironmentError
from olmo_core.utils import set_env_var

from .callback import Callback

if TYPE_CHECKING:
    from comet_ml import Experiment

log = logging.getLogger(__name__)

COMET_API_KEY_ENV_VAR = "COMET_API_KEY"


[docs] class CometNotificationSetting(StrEnum): """ Defines the notifications settings for the Comet.ml callback. """ all = "all" """ Send all types notifications. """ end_only = "end_only" """ Only send a notification when the experiment ends (successfully or with a failure). """ failure_only = "failure_only" """ Only send a notification when the experiment fails. """ none = "none" """ Don't send any notifcations. """
[docs] @dataclass class CometCallback(Callback): """ Logs metrics to Comet.ml from rank 0. .. important:: Requires the ``comet_ml`` package and the environment variable ``COMET_API_KEY``. .. note:: This callback logs metrics from every single step to Comet.ml, regardless of the value of :data:`Trainer.metrics_collect_interval <olmo_core.train.Trainer.metrics_collect_interval>`. """ enabled: bool = True """ Set to false to disable this callback. """ name: Optional[str] = None """ The name to give the Comet.ml experiment. """ project: Optional[str] = None """ The Comet.ml project to use. """ workspace: Optional[str] = None """ The name of the Comet.ml workspace to use. """ tags: Optional[List[str]] = None """ Tags to assign the experiment. """ config: Optional[Dict[str, Any]] = None """ The config to save to Comet.ml. """ cancel_tags: Optional[List[str]] = field( default_factory=lambda: ["cancel", "canceled", "cancelled"] ) """ If you add any of these tags to an experiment on Comet.ml, the run will cancel itself. Defaults to ``["cancel", "canceled", "cancelled"]``. """ cancel_check_interval: Optional[int] = None """ Check for cancel tags every this many steps. Defaults to :data:`olmo_core.train.Trainer.cancel_check_interval`. """ notifications: CometNotificationSetting = CometNotificationSetting.none """ The notification settings. """ failure_tag: str = "failed" """ The tag to assign to failed experiments. """ auto_resume: bool = False """ If ``True``, an existing experiment will be resumed from a checkpoint if the experiment name matches. """ _exp = None _exp_key: Optional[str] = None _finalized: bool = False @property def exp(self) -> "Experiment": return self._exp # type: ignore @exp.setter def exp(self, exp: "Experiment"): self._exp = exp @property def finalized(self) -> bool: return self._finalized def state_dict(self) -> Dict[str, Any]: return {"experiment_key": self._exp_key, "name": self.name} def load_state_dict(self, state_dict: Dict[str, Any]): if self.auto_resume and self.name == state_dict.get("name"): self._exp_key = state_dict.get("experiment_key") def finalize(self): if not self.finalized: self.exp.end() self._finalized = True def pre_train(self): if self.enabled and get_rank() == 0: set_env_var("COMET_DISABLE_AUTO_LOGGING", "1") import comet_ml as comet if COMET_API_KEY_ENV_VAR not in os.environ: raise OLMoEnvironmentError(f"missing env var '{COMET_API_KEY_ENV_VAR}'") if self.auto_resume and self._exp_key is not None: log.info(f"Resuming Comet logging from existing experiment '{self._exp_key}'") self.exp = cast( "Experiment", comet.start( api_key=os.environ[COMET_API_KEY_ENV_VAR], mode="get", experiment_key=self._exp_key, experiment_config=comet.ExperimentConfig( auto_output_logging="simple", display_summary_level=0, ), ), ) else: self.exp = comet.Experiment( api_key=os.environ[COMET_API_KEY_ENV_VAR], project_name=self.project, workspace=self.workspace, auto_output_logging="simple", display_summary_level=0, ) self._exp_key = self.exp.get_key() if self.name is not None: self.exp.set_name(self.name) if self.tags: self.exp.add_tags(self.tags) if self.config is not None: self.exp.log_parameters(self.config) if self.notifications == CometNotificationSetting.all: self.exp.send_notification( f"Experiment {self.exp.get_name()} ({self.exp.get_key()})", status="started", ) def log_metrics(self, step: int, metrics: Dict[str, float]): if self.enabled and get_rank() == 0: self.exp.log_metrics(metrics, step=step) def post_step(self): cancel_check_interval = self.cancel_check_interval or self.trainer.cancel_check_interval if self.enabled and get_rank() == 0 and self.step % cancel_check_interval == 0: self.trainer.run_bookkeeping_op( self.check_if_canceled, allow_multiple=False, distributed=False, ) def post_train(self): if self.enabled and get_rank() == 0: log.info("Finalizing successful Comet.ml experiment...") if self.notifications in ( CometNotificationSetting.all, CometNotificationSetting.end_only, ): self.exp.send_notification( f"Experiment {self.exp.get_name()} ({self.exp.get_key()})", status="completed successfully", ) def on_error(self, exc: BaseException): del exc if self.enabled and get_rank() == 0: log.warning("Finalizing failed Comet.ml experiment...") self.exp.add_tag(self.failure_tag) if self.notifications in ( CometNotificationSetting.all, CometNotificationSetting.end_only, CometNotificationSetting.failure_only, ): self.exp.send_notification( f"Experiment {self.exp.get_name()} ({self.exp.get_key()})", status="failed", ) def close(self): if self.enabled and get_rank() == 0: self.finalize() def check_if_canceled(self): if self.enabled and not self.finalized and self.cancel_tags: from comet_ml.api import API try: api = API(api_key=os.environ[COMET_API_KEY_ENV_VAR]) exp = api.get_experiment_by_key(self.exp.get_key()) assert exp is not None tags = exp.get_tags() except Exception as exc: log.exception(exc) return for tag in tags or []: if tag.lower() in self.cancel_tags: self.trainer.cancel_run("canceled from Comet.ml tag") return