flambe.metric

Package Contents

class flambe.metric.Metric[source]

Bases: flambe.compile.Component

Base Metric interface.

Objects implementing this interface should take in a sequence of examples and provide as output a processd list of the same size.

compute(self, pred: torch.Tensor, target: torch.Tensor)

Computes the metric over the given prediction and target.

Parameters:
  • pred (torch.Tensor) – The model predictions
  • target (torch.Tensor) – The ground truth targets
Returns:

The computed metric

Return type:

torch.Tensor

__call__(self, *args, **kwargs)

Makes Featurizer a callable.

__str__(self)

Return the name of the Metric (for use in logging).

class flambe.metric.MultiLabelCrossEntropy(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]

Bases: flambe.metric.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)

Computes the multilabel cross entropy loss.

Parameters:
  • pred (torch.Tensor) – input logits of shape (B x N)
  • target (torch.LontTensor) – target tensor of shape (B x N)
Returns:

loss – Multi label cross-entropy loss, of shape (B)

Return type:

torch.Tensor

class flambe.metric.MultiLabelNLLLoss(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]

Bases: flambe.metric.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)

Computes the Negative log likelihood loss for multilabel.

Parameters:
  • pred (torch.Tensor) – input logits of shape (B x N)
  • target (torch.LontTensor) – target tensor of shape (B x N)
Returns:

loss – Multi label negative log likelihood loss, of shape (B)

Return type:

torch.float

class flambe.metric.Accuracy[source]

Bases: flambe.metric.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)

Computes the loss.

Parameters:
  • pred (Tensor) – input logits of shape (B x N)
  • target (LontTensor) – target tensor of shape (B) or (B x N)
Returns:

accuracy – single label accuracy, of shape (B)

Return type:

torch.Tensor

class flambe.metric.Perplexity[source]

Bases: flambe.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)

Compute the preplexity given the input and target.

Parameters:
  • pred (torch.Tensor) – input logits of shape (B x N)
  • target (torch.LontTensor) – target tensor of shape (B x N)
Returns:

Output perplexity

Return type:

torch.float

class flambe.metric.AUC(max_fpr=1.0)[source]

Bases: flambe.metric.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)

Compute AUC at the given max false positive rate.

Parameters:
  • pred (torch.Tensor) – The model predictions
  • target (torch.Tensor) – The binary targets
Returns:

The computed AUC

Return type:

torch.Tensor

class flambe.metric.BinaryPrecision[source]

Bases: flambe.metric.dev.binary.BinaryMetric

Compute Binary Precision.

An example is considered negative when its score is below the specified threshold. Binary precition is computed as follows:

` |True positives| / |True Positives| + |False Positives| `

compute_binary(self, pred: torch.Tensor, target: torch.Tensor)

Compute binary precision.

Parameters:
  • pred (torch.Tensor) – Predictions made by the model. It should be a probability 0 <= p <= 1 for each sample, 1 being the positive class.
  • target (torch.Tensor) – Ground truth. Each label should be either 0 or 1.
Returns:

The computed binary metric

Return type:

torch.float

class flambe.metric.BinaryRecall[source]

Bases: flambe.metric.dev.binary.BinaryMetric

Compute binary recall.

An example is considered negative when its score is below the specified threshold. Binary precition is computed as follows:

` |True positives| / |True Positives| + |False Negatives| `

compute_binary(self, pred: torch.Tensor, target: torch.Tensor)

Compute binary recall.

Parameters:
  • pred (torch.Tensor) – Predictions made by the model. It should be a probability 0 <= p <= 1 for each sample, 1 being the positive class.
  • target (torch.Tensor) – Ground truth. Each label should be either 0 or 1.
Returns:

The computed binary metric

Return type:

torch.float