[docs]classMetric(metaclass=ABCMeta):""" Base class for evaluation metrics. """def__init__(self,device:Optional[torch.device]=None,process_group:Optional[dist.ProcessGroup]=None,):self.device=deviceifdeviceisnotNoneelseget_default_device()self.process_group=process_group
[docs]@abstractmethoddefupdate(self,*args,**kwargs)->None:""" Update the metric. """raiseNotImplementedError
[docs]@abstractmethoddefcompute(self)->torch.Tensor:""" Compute the metric. """raiseNotImplementedError
[docs]@abstractmethoddefreset(self)->None:""" Reset the metric. """raiseNotImplementedError
[docs]classMeanMetric(Metric):""" Computes the mean over a stream of values. """def__init__(self,device:Optional[torch.device]=None,process_group:Optional[dist.ProcessGroup]=None,):super().__init__(device=device,process_group=process_group)self.weighted_sum=torch.tensor(0.0,device=self.device)self.weight=torch.tensor(0.0,device=self.device)
[docs]defupdate(self,value:Union[float,torch.Tensor],weight:Union[float,torch.Tensor]=1.0)->None:""" :param value: The latest value to update the metric with. Could be a tensor of values. :param weight: The corresponding weight(s) for the value. Should be the same shape as ``value``. """value=self.as_tensor(value)weight=torch.broadcast_to(self.as_tensor(weight),value.shape)ifvalue.numel()==0:returnself.weighted_sum+=(value*weight).sum()self.weight+=weight.sum()
[docs]defcompute(self)->torch.Tensor:""" Computes the mean over the values and weights given. """weighted_sum=all_reduce_value(self.weighted_sum,device=self.device,group=self.process_group)weight=all_reduce_value(self.weight,device=self.device,group=self.process_group)returnweighted_sum/weight