Source code for flambe.metric.dev.bpc

from typing import Dict

import numpy as np
import torch

from flambe.metric.dev.perplexity import Perplexity


[docs]class BPC(Perplexity): """Bits per character. Computed as log_2(perplexity) Inherits from Perplexity to share aggregate functionality. """
[docs] def compute(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute the bits per character given the input and target. Parameters ---------- pred: torch.Tensor input logits of shape (B x N) target: torch.LontTensor target tensor of shape (B) Returns ------- torch.float Output perplexity """ entropy = self.entropy(pred, target).mean() return torch.log2(torch.exp(entropy))
[docs] def finalize(self, state: Dict) -> float: """Finalizes the metric computation Parameters ---------- state: dict the metric state Returns ------- float The final score. """ if not state or state['sample_count'] == 0: # call on empty state return np.NaN return torch.log2(torch.exp(state['accumulated_score'] / state['sample_count'])).item()