flambe.metric.dev.auc

Module Contents

class flambe.metric.dev.auc.AUC(max_fpr=1.0)[source]

Bases: flambe.metric.metric.Metric

__str__(self)[source]

Return the name of the Metric (for use in logging).

static aggregate(state: dict, *args, **kwargs)[source]

Aggregates by simply storing preds and targets

Parameters:
  • state (dict) – the metric state
  • args (the pred, target tuple) –
Returns:

the state dict

Return type:

dict

finalize(self, state: Dict)[source]

Finalizes the metric computation

Parameters:state (dict) – the metric state
Returns:The final score.
Return type:float
compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Compute AUC at the given max false positive rate.

Parameters:
  • pred (torch.Tensor) – The model predictions
  • target (torch.Tensor) – The binary targets
Returns:

The computed AUC

Return type:

torch.Tensor