Source code for flambe.metric.dev.perplexity

import torch

from flambe.metric import Metric


[docs]class Perplexity(Metric): def __init__(self): """Initalizes the Metric.""" self.entropy = torch.nn.CrossEntropyLoss()
[docs] def compute(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute the preplexity given the input and target. Parameters ---------- pred: torch.Tensor input logits of shape (B x N) target: torch.LontTensor target tensor of shape (B x N) Returns ------- torch.float Output perplexity """ entropy = self.entropy(pred, target).mean() return torch.exp(entropy)