flambe.metric.dev.auc
¶
Module Contents¶
-
flambe.metric.dev.auc.
one_hot
(indices: torch.Tensor, width: int) → torch.Tensor[source]¶ Converts a list of ints into 1-hot format.
Parameters: - indices (torch.Tensor) – the indices to be converted
- width (int) – the width of the 1-hot encoding (= the maximal index value)
Returns: A one-hot representation of the input indices.
Return type: torch.Tensor
-
class
flambe.metric.dev.auc.
AUC
(max_fpr=1.0)[source]¶ Bases:
flambe.metric.metric.Metric
-
static
aggregate
(state: dict, *args, **kwargs)[source]¶ Aggregates by simply storing preds and targets
Parameters: - state (dict) – the metric state
- args (the pred, target tuple) –
Returns: the state dict
Return type: dict
-
finalize
(self, state: Dict)[source]¶ Finalizes the metric computation
Parameters: state (dict) – the metric state Returns: The final score. Return type: float
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)[source]¶ Compute AUC at the given max false positive rate.
Parameters: - pred (torch.Tensor) – The model predictions of shape numsamples
- target (torch.Tensor) – The binary targets of shape numsamples
Returns: The computed AUC
Return type: torch.Tensor
-
static
-
class
flambe.metric.dev.auc.
MultiClassAUC
[source]¶ Bases:
flambe.metric.dev.auc.AUC
N-Ary (Multiclass) AUC for k-way classification
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)[source]¶ Compute multiclass AUC at the given max false positive rate.
Parameters: - pred (torch.Tensor) – The model predictions of shape numsamples x numclasses
- target (torch.Tensor) –
- The binary targets of shape:
- numsamples. In this case the elements index into the different classes
- numsamples x numclasses. This implementation only considers the indices of the max values as positive labels
Returns: The computed AUC
Return type: torch.Tensor
-