flambe.metric.dev.auc

Module Contents

flambe.metric.dev.auc.one_hot(indices: torch.Tensor, width: int) → torch.Tensor[source]

Converts a list of ints into 1-hot format.

Parameters:
  • indices (torch.Tensor) – the indices to be converted
  • width (int) – the width of the 1-hot encoding (= the maximal index value)
Returns:

A one-hot representation of the input indices.

Return type:

torch.Tensor

class flambe.metric.dev.auc.AUC(max_fpr=1.0)[source]

Bases: flambe.metric.metric.Metric

__str__(self)[source]

Return the name of the Metric (for use in logging).

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Compute AUC at the given max false positive rate.

Parameters:
  • pred (torch.Tensor) – The model predictions of shape numsamples
  • target (torch.Tensor) – The binary targets of shape numsamples
Returns:

The computed AUC

Return type:

torch.Tensor

class flambe.metric.dev.auc.MultiClassAUC[source]

Bases: flambe.metric.dev.auc.AUC

N-Ary (Multiclass) AUC for k-way classification

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Compute multiclass AUC at the given max false positive rate.

Parameters:
  • pred (torch.Tensor) – The model predictions of shape numsamples x numclasses
  • target (torch.Tensor) –
    The binary targets of shape:
    • numsamples. In this case the elements index into the different classes
    • numsamples x numclasses. This implementation only considers the indices of the max values as positive labels
Returns:

The computed AUC

Return type:

torch.Tensor