from typing import Any, Dict, Iterable, Optional, Sequence, Set
import torch
import torch.distributed as dist
from ..data import DataCollator, NumpyFSLDataLoader, NumpyPaddedFSLDataset
from ..distributed.utils import get_fs_local_rank, get_rank, get_world_size
from ..exceptions import OLMoConfigurationError
from ..utils import get_default_device
from .evaluator import Evaluator
from .metrics import MeanMetric
[docs]
class LMEvaluator(Evaluator):
"""
Language modeling evaluator that computes cross entropy loss and perplexity over one or more
datasets.
.. important::
The :data:`batches` generated from these evaluators must contain a "metadata" field which
should be a list of dictionaries, and each dictionary item in the list should contain
a string field called "label" which indicates which dataset the data file is associated
with, and should be included in the ``labels`` argument to this class.
:param labels: All of the task labels.
"""
def __init__(
self,
*,
name: str,
batches: Iterable[Dict[str, Any]],
labels: Sequence[str],
device: Optional[torch.device] = None,
):
super().__init__(name=name, batches=batches, device=device)
self.metrics = {label: MeanMetric(device=device) for label in labels}
[docs]
@classmethod
def from_numpy_dataset(
cls,
dataset: NumpyPaddedFSLDataset,
*,
name: str,
global_batch_size: int,
collator: DataCollator,
device: Optional[torch.device] = None,
dp_process_group: Optional[dist.ProcessGroup] = None,
seed: int = 0,
num_threads: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: Optional[int] = None,
) -> "LMEvaluator":
"""
Initialize an :class:`LMEvaluator` from a :class:`~olmo_core.data.numpy_dataset.NumpyPaddedFSLDataset`.
"""
labels: Set[str] = set()
for path, metadata in zip(dataset.paths, dataset.metadata):
if "label" not in metadata:
raise OLMoConfigurationError(
f"Missing dataset 'label' in metadata for '{path}' dataset"
)
labels.add(metadata["label"])
dataset.prepare()
device = device or get_default_device()
data_loader = NumpyFSLDataLoader(
dataset,
global_batch_size=global_batch_size,
collator=collator,
work_dir=dataset.work_dir,
seed=seed,
dp_world_size=get_world_size(dp_process_group),
dp_rank=get_rank(dp_process_group),
fs_local_rank=get_fs_local_rank(),
target_device_type=device.type,
num_threads=num_threads,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
)
return cls(
name=name,
batches=data_loader,
labels=list(labels),
device=device,
)
[docs]
def update_metrics(
self, batch: Dict[str, Any], ce_loss: Optional[torch.Tensor], logits: Optional[torch.Tensor]
) -> None:
# ``logits`` may be ``None`` when context parallelism (CP) or tensor parallelism (TP) is
# enabled, since gathering the full logits across ranks is unnecessary for perplexity.
# Only ``ce_loss`` (already local per-token values) is needed here.
if ce_loss is None:
return
for idx, (metadata, tokens_loss) in enumerate(zip(batch["metadata"], ce_loss)):
metric = self.metrics[metadata["label"]]
if "label_mask" in batch:
tokens_loss = tokens_loss.masked_select(batch["label_mask"][idx])
metric.update(tokens_loss)
[docs]
def compute_metrics(self) -> Dict[str, torch.Tensor]:
out: Dict[str, torch.Tensor] = {}
for label in sorted(self.metrics.keys()):
metric = self.metrics[label]
# In this case we probably haven't called '.update()' on this metric yet,
# so we do so here with dummy values. Since we pass 0.0 in for weight this won't
# affect the final value.
# This can happen when the evaluator contains multiple tasks/datasets and we didn't
# get to this one within the current evaluation loop.
metric.update(0.0, 0.0)
ce_loss = metric.compute() # could be nan but that's okay.
out[f"{label}/CE loss"] = ce_loss
out[f"{label}/PPL"] = torch.exp(ce_loss)
return out
[docs]
def reset_metrics(self) -> None:
for metric in self.metrics.values():
metric.reset()