from typing import Dict
from abc import abstractmethod
import numpy as np
import torch
from flambe.compile import Component
[docs]class Metric(Component):
"""Base Metric interface.
Objects implementing this interface should take in a sequence of
examples and provide as output a processd list of the same size.
"""
@abstractmethod
[docs] def compute(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Computes the metric over the given prediction and target.
Parameters
----------
pred: torch.Tensor
The model predictions
target: torch.Tensor
The ground truth targets
Returns
-------
torch.Tensor
The computed metric
"""
pass
[docs] def aggregate(self, state: dict, *args, **kwargs) -> Dict:
"""Aggregates by simply storing preds and targets
Parameters
----------
state: dict
the metric state
args: the pred, target tuple
Returns
-------
dict
the state dict
"""
pred, target = args
if not state:
state['pred'] = []
state['target'] = []
state['pred'].append(pred.cpu().detach())
state['target'].append(target.cpu().detach())
return state
[docs] def finalize(self, state: Dict) -> float:
"""Finalizes the metric computation
Parameters
----------
state: dict
the metric state
Returns
-------
float
The final score.
"""
if not state:
# call on empty state
return np.NaN
pred = torch.cat(state['pred'], dim=0)
target = torch.cat(state['target'], dim=0)
state['accumulated_score'] = self.compute(pred, target).item()
return state['accumulated_score']
[docs] def __call__(self, *args, **kwargs):
"""Makes Featurizer a callable."""
return self.compute(*args, **kwargs)
[docs] def __str__(self) -> str:
"""Return the name of the Metric (for use in logging)."""
return self.__class__.__name__
[docs]class AverageableMetric(Metric):
"""Metric interface for averageable metrics
Some metrics, such as accuracy, are averaged as a final step.
This allows for a more efficient metrics computation.
These metrics should inherit from this class.
"""
[docs] def aggregate(self, state: dict, *args, **kwargs) -> Dict:
"""
Parameters
----------
state: dict
the state dictionary
args:
normally pred, target
kwargs
Returns
-------
dict
The updated state (even though the update happens in-place)
"""
score = self.compute(*args, **kwargs)
score_np = score.cpu().detach().numpy() \
if isinstance(score, torch.Tensor) \
else score
try:
num_samples = args[0].size(0)
except (ValueError, AttributeError):
raise ValueError(f'Cannot get size from {type(args[0])}')
if not state:
state['accumulated_score'] = 0.
state['sample_count'] = 0
state['accumulated_score'] = \
(state['sample_count'] * state['accumulated_score'] +
num_samples * score_np.item()) / \
(state['sample_count'] + num_samples)
state['sample_count'] = state['sample_count'] + num_samples
return state
[docs] def finalize(self, state) -> float:
"""
FInalizes the metric computation
Parameters
----------
state: dict
the metric state
Returns
-------
Any
The final score. Can be anything, depending on metric.
"""
return state.get('accumulated_score')