flambe.metric.dev.perplexity

Module Contents

class flambe.metric.dev.perplexity.Perplexity[source]

Bases: flambe.metric.Metric

Token level perplexity, computed a exp(cross_entropy).

compute(self, pred: torch.Tensor, target: torch.Tensor)[source]

Compute the preplexity given the input and target.

Parameters:
  • pred (torch.Tensor) – input logits of shape (B x N)
  • target (torch.LontTensor) – target tensor of shape (B)
Returns:

Output perplexity

Return type:

torch.float

aggregate(self, state: dict, *args, **kwargs)[source]

Aggregates by only storing entropy per sample

Parameters:
  • state (dict) – the metric state
  • args (the pred, target tuple) –
Returns:

the state dict

Return type:

dict

finalize(self, state: Dict)[source]

Finalizes the metric computation

Parameters:state (dict) – the metric state
Returns:The final score.
Return type:float