flambe.metric.loss.cross_entropy

Module Contents

class flambe.metric.loss.cross_entropy.MultiLabelCrossEntropy(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]

Bases: flambe.metric.metric.Metric

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Computes the multilabel cross entropy loss.

Parameters:
  • pred (torch.Tensor) – input logits of shape (B x N)
  • target (torch.LontTensor) – target tensor of shape (B x N)
Returns:

loss – Multi label cross-entropy loss, of shape (B)

Return type:

torch.Tensor