Source code for flambe.field.label

from typing import Optional
import torch

from flambe.field.text import TextField
from flambe.tokenizer import LabelTokenizer


[docs]class LabelField(TextField): """Featurizes input labels. The class also handles multilabel inputs and one hot encoding. """ def __init__(self, one_hot: bool = False, multilabel_sep: Optional[str] = None) -> None: """Initializes the LabelFetaurizer. Parameters ---------- one_hot : bool, optional Set for one-hot encoded outputs, defaults to False multilabel_sep : str, optional If given, splits the input label into multiple labels using the given separator, defaults to None. """ self.one_hot = one_hot self.multilabel_sep = multilabel_sep tokenizer = LabelTokenizer(self.multilabel_sep) super().__init__(tokenizer=tokenizer, unk_token=None, pad_token=None)
[docs] def process(self, example): """Featurize a single example. Parameters ---------- example: str The input label Returns ------- torch.Tensor A list of integer tokens """ # First process normally n = super().process(example) if self.one_hot: n = [int(i in n) for i in range(len(self.vocab))] n = torch.tensor(n) # Back to Tensor return n