flambe.metric.dev.accuracy

Module Contents

class flambe.metric.dev.accuracy.Accuracy[source]

Bases: flambe.metric.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Computes the loss.

Parameters:
  • pred (Tensor) – input logits of shape (B x N)
  • target (LontTensor) – target tensor of shape (B) or (B x N)
Returns:

accuracy – single label accuracy, of shape (B)

Return type:

torch.Tensor