flambe.metric.dev.recall

Module Contents

class flambe.metric.dev.recall.Recall(top_k: int = 1)[source]

Bases: flambe.metric.metric.AverageableMetric

__str__(self)[source]

Return the name of the Metric (for use in logging).

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Computes the recall @ k.

Parameters:
  • pred (Tensor) – input logits of shape (B x N)
  • target (LongTensor) – target tensor of shape (B) or (B x N)
Returns:

recall – single label recall, of shape (B)

Return type:

torch.Tensor