from abc import abstractmethod
import torch
from flambe.metric.metric import Metric
[docs]class BinaryMetric(Metric):
def __init__(self, threshold: float = 0.5) -> None:
"""Initialize the Binary metric.
Parameters
---------
threshold: float
Given a probability p of belonging to Positive class,
p < threshold will be considered tagged as Negative by
the classifier when computing the metric.
"""
self.threshold = threshold
[docs] def compute(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute the metric given predictions and targets
Parameters
----------
pred : Tensor
The model predictions
target : Tensor
The binary targets
Returns
-------
float
The computed binary metric
"""
# Cast because pytorch's byte method returns a Tensor type
pred = (pred > self.threshold).byte()
target = target.byte()
return self.compute_binary(pred, target)
@abstractmethod
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute a binary-input metric.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
pass
[docs]class BinaryPrecision(BinaryMetric):
"""Compute Binary Precision.
An example is considered negative when its score is below the
specified threshold. Binary precition is computed as follows:
```
|True positives| / |True Positives| + |False Positives|
```
"""
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute binary precision.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
acc = pred == target
true_p = acc & target
if pred.sum() == 0:
metric = torch.tensor(0)
else:
# Again, weird typing from pytorch
# check periodically for a fix
metric = (true_p.sum().float() / pred.sum().float())
return metric
[docs]class BinaryRecall(BinaryMetric):
"""Compute binary recall.
An example is considered negative when its score is below the
specified threshold. Binary precition is computed as follows:
```
|True positives| / |True Positives| + |False Negatives|
```
"""
[docs] def compute_binary(self,
pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Compute binary recall.
Parameters
---------
pred: torch.Tensor
Predictions made by the model. It should be a probability
0 <= p <= 1 for each sample, 1 being the positive class.
target: torch.Tensor
Ground truth. Each label should be either 0 or 1.
Returns
------
torch.float
The computed binary metric
"""
acc = pred == target
true_p = acc & target
if target.sum() == 0:
metric = torch.tensor(0)
else:
metric = true_p.sum().float() / target.sum().float()
return metric