flambe.metric.loss.nll_loss

Module Contents

class flambe.metric.loss.nll_loss.MultiLabelNLLLoss(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]

Bases: flambe.metric.metric.Metric

__str__(self)[source]

Return the name of the Metric (for use in logging).

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Computes the Negative log likelihood loss for multilabel.

Parameters:
  • pred (torch.Tensor) – input logits of shape (B x N)
  • target (torch.LontTensor) – target tensor of shape (B x N)
Returns:

loss – Multi label negative log likelihood loss, of shape (B)

Return type:

torch.float