from abc import abstractmethod
from typing import Dict
import numpy as np
import torch
from flambe.metric.metric import Metric
[docs]class BinaryMetric(Metric):
def __init__(self, threshold: float = 0.5) -> None:
"""Initialize the Binary metric.
Parameters
---------
threshold: float
Given a probability p of belonging to Positive class,
p < threshold will be considered tagged as Negative by
the classifier when computing the metric.
Defaults to 0.5
"""
self.threshold = threshold
[docs] def __str__(self) -> str:
"""Return the name of the Metric (for use in logging)."""
return f'{self.__class__.__name__}@{self.threshold}'
@staticmethod
[docs] def aggregate(state: dict, *args, **kwargs) -> Dict:
"""Aggregates by simply storing preds and targets
Parameters
----------
state: dict
the metric state
args: the pred, target tuple
Returns
-------
dict
the state dict
"""
pred, target = args
if not state:
state['pred'] = []
state['target'] = []
state['pred'].append(pred.cpu().detach())
state['target'].append(target.cpu().detach())
return state
[docs] def finalize(self, state: Dict) -> float:
"""Finalizes the metric computation
Parameters
----------
state: dict
the metric state
Returns
-------
float
The final score.
"""
if not state:
# call on empty state
return np.NaN
pred = torch.cat(state['pred'], dim=0)
target = torch.cat(state['target'], dim=0)
state['accumulated_score'] = self.compute(pred, target).item()
return state['accumulated_score']
[docs] def compute(self, pred: torch.Tensor, target: torch.Tensor) \
-> torch.Tensor:
"""Compute the metric given predictions and targets
Parameters
----------
pred : Tensor
The model predictions
target : Tensor
The binary targets
Returns
-------
float
The computed binary metric
"""
pred = pred.squeeze()
target = target.squeeze().bool()
pred = (pred > self.threshold)
return self.compute_binary(pred, target)
@abstractmethod
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute a binary-input metric.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
pass
[docs]class BinaryAccuracy(BinaryMetric):
"""Compute binary accuracy.
```
|True positives + True negatives| / N
```
"""
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute binary accuracy.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
acc = pred == target
N = target.size()[0] if target.dim() > 0 else 1
if N == 0:
return torch.tensor(0)
return acc.sum().float() / N
[docs]class BinaryPrecision(BinaryMetric):
"""Compute Binary Precision.
An example is considered negative when its score is below the
specified threshold. Binary precition is computed as follows:
```
|True positives| / |True Positives| + |False Positives|
```
"""
def __init__(self, threshold: float = 0.5, positive_label: int = 1) -> None:
"""Initialize the Binary metric.
Parameters
---------
threshold: float
Given a probability p of belonging to Positive class,
p < threshold will be considered tagged as Negative by
the classifier when computing the metric.
Defaults to 0.5
positive_label: int
Specify if the positive class should be 1 or 0.
Defaults to 1.
"""
if positive_label not in [0, 1]:
raise ValueError("positive_label should be either 0 or 1")
super().__init__(threshold)
self.positive_label = positive_label
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute binary precision.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
if self.positive_label == 0:
pred = ~pred
target = ~target
acc = pred == target
true_p = acc & target
if pred.sum() == 0:
metric = torch.tensor(0)
else:
# Again, weird typing from pytorch
# check periodically for a fix
metric = (true_p.sum().float() / pred.sum().float())
return metric
[docs] def __str__(self) -> str:
"""Return the name of the Metric (for use in logging)."""
invert_label = "Negative" if self.positive_label == 0 else "Positive"
return f"{invert_label}{self.__class__.__name__}"
[docs]class BinaryRecall(BinaryMetric):
"""Compute binary recall.
An example is considered negative when its score is below the
specified threshold. Binary precition is computed as follows:
```
|True positives| / |True Positives| + |False Negatives|
```
"""
def __init__(self, threshold: float = 0.5, positive_label: int = 1) -> None:
"""Initialize the Binary metric.
Parameters
---------
threshold: float
Given a probability p of belonging to Positive class,
p < threshold will be considered tagged as Negative by
the classifier when computing the metric.
Defaults to 0.5
positive_label: int
Specify if the positive class should be 1 or 0.
Defaults to 1.
"""
if positive_label not in [0, 1]:
raise ValueError("positive_label should be either 0 or 1")
super().__init__(threshold)
self.positive_label = positive_label
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute binary recall.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
if self.positive_label == 0:
pred = ~pred
target = ~target
acc = pred == target
true_p = acc & target
if target.sum() == 0:
metric = torch.tensor(0)
else:
metric = true_p.sum().float() / target.sum().float()
return metric
[docs] def __str__(self) -> str:
"""Return the name of the Metric (for use in logging)."""
invert_label = "Negative" if self.positive_label == 0 else "Positive"
return f"{invert_label}{self.__class__.__name__}"
[docs]class F1(BinaryMetric):
def __init__(self,
threshold: float = 0.5,
positive_label: int = 1,
eps: float = 1e-8) -> None:
"""
Parameters
---------
threshold: float
Given a probability p of belonging to Positive class,
p < threshold will be considered tagged as Negative by
the classifier when computing the metric.
Defaults to 0.5
positive_label: int
Specify if the positive class should be 1 or 0.
Defaults to 1.
eps: float
Float to sum to the denominator, so that we avoid division
by zero. Defaults to 1e-8.
"""
super().__init__(threshold)
self.recall = BinaryRecall(threshold, positive_label)
self.precision = BinaryPrecision(threshold, positive_label)
self.eps = eps
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute F1. Score, the harmonic mean between precision and
recall.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
recall = self.recall.compute_binary(pred, target)
precision = self.precision.compute_binary(pred, target)
return 2 * precision * recall / (precision + recall + self.eps)