[docs]classEvaluator(metaclass=ABCMeta):""" Base class for in-loop evaluators. .. seealso:: This can be used with an :class:`~olmo_core.train.callbacks.EvaluatorCallback` to run an evaluator within the training loop. :param name: A name to assign to the evaluator. :param batches: Generates batches for the evaluator. These should at least include the "input_ids" field, but can contain any other arbitrary fields as well. :param batches_factory: A callable that returns an iterable over batches. This is an alternative to providing the ``batches`` argument directly. :param device: The device to compute/reduce metrics on. :param deterministic: When ``True`` and :data:`batches` is a :class:`~olmo_core.data.data_loader.DataLoaderBase`, each evaluation pass resets the data loader and reshuffles with ``epoch=1`` so repeated evals read the same batches in the same order. This is useful when eval loops are truncated via :class:`~olmo_core.train.common.Duration`. When ``False``, the data loader still resets to batch 0 before each pass, but reshuffles without pinning the epoch so the batch order may change between eval runs. This does not implement a moving window across evals; if an eval is truncated, different reshuffles may result in different instances being evaluated each time. """def__init__(self,*,name:str,batches:Optional[Iterable[Dict[str,Any]]]=None,batches_factory:Optional[Callable[[],Iterable[Dict[str,Any]]]]=None,device:Optional[torch.device]=None,deterministic:bool=True,):ifbatchesisNone:assert(batches_factoryisnotNone),"Either 'batches' or 'batches_factory' must be provided."else:assert(batches_factoryisNone),"'batches' and 'batches_factory' cannot both be provided."self.name=nameself.batches=batchesself.batches_factory=batches_factoryself.device=deviceself.deterministic=deterministicdef__iter__(self)->Iterator[Dict[str,Any]]:""" Iterator over the evaluator's batches. """ifself.batchesisNone:assertself.batches_factoryisnotNoneself.batches=self.batches_factory()ifisinstance(self.batches,DataLoaderBase):# Reset bookkeeping before reshuffling so eval_duration-limited evals always restart# from batch 0 instead of resuming from the previous partial pass.self.batches.reset()ifself.deterministic:self.batches.reshuffle(epoch=1,in_memory=True)else:self.batches.reshuffle(in_memory=True)forbatchinself.batches:yieldbatchifisinstance(self.batches,DataLoaderBase):self.batches.reset()@propertydefdisplay_name(self)->str:returnself.name@propertydeftotal_batches(self)->Optional[int]:""" Get the total number of batches in an eval loop if it's known ahead of time. """try:returnlen(self.batches)# type: ignoreexceptTypeError:returnNone
[docs]@abstractmethoddefupdate_metrics(self,batch:Dict[str,Any],ce_loss:Optional[torch.Tensor],logits:Optional[torch.Tensor])->None:""" Update metrics with from the ``batch`` just processed and the corresponding ``logits``. :param batch: A batch generated from :data:`batches`. :param ce_loss: The cross-entropy loss per token (un-reduced) of the batch. This will have shape ``(batch_size, (seq_len - 1))``. :param logits: The logits generated from the forward pass of the model. """raiseNotImplementedError
[docs]@abstractmethoddefcompute_metrics(self)->Dict[str,torch.Tensor]:""" Compute the final value of the metrics for the current evaluation loop. The metrics returned should already be reduced, if needed. """raiseNotImplementedError
[docs]@abstractmethoddefreset_metrics(self)->None:""" Reset metrics. Should be called after :meth:`compute_metrics()`. """raiseNotImplementedError