Source code for flambe.field.label

from typing import Optional, Union, List, Iterable
from collections import OrderedDict as odict

import torch
import numpy as np

from flambe.field.field import Field
from flambe.tokenizer import LabelTokenizer


[docs]class LabelField(Field): """Featurizes input labels. The class also handles multilabel inputs and one hot encoding. """ def __init__(self, one_hot: bool = False, multilabel_sep: Optional[str] = None, labels: Optional[Union[Iterable[str], str]] = None) -> None: """Initializes the LabelFetaurizer. Parameters ---------- one_hot : bool, optional Set for one-hot encoded outputs, defaults to False multilabel_sep : str, optional If given, splits the input label into multiple labels using the given separator, defaults to None. labels: Union[Iterable[str], str], optional If given, sets the labels and the ordering is used to map the labels to indices. That means the first item in this list will have label id 0, the next one id 1, etc.. When not provided, indices are assigned as labels are encountered during preprocessing. The list can also be provided as a file with a label on each line. """ self.one_hot = one_hot self.multilabel_sep = multilabel_sep self.tokenizer = LabelTokenizer(multilabel_sep=self.multilabel_sep) if labels is not None: if isinstance(labels, str): # Labels if a file with open(labels, 'r') as f: label_list: List[str] = f.read().splitlines() elif isinstance(labels, Iterable): label_list = list(labels) self.label_given = True self.vocab = odict((label, i) for i, label in enumerate(label_list)) self.label_count_dict = {label: 0 for label in self.vocab} else: self.label_given = False self.vocab = odict() self.label_count_dict = dict() self.register_attrs('vocab') self.register_attrs('label_count_dict')
[docs] def setup(self, *data: np.ndarray) -> None: """Build the vocabulary. Parameters ---------- data : Iterable[str] List of input strings. """ # Iterate over all examples examples = (e for dataset in data for e in dataset if dataset is not None) for example in examples: # Tokenize and add to vocabulary for token in self.tokenizer(example): if self.label_given: if token not in self.vocab: raise ValueError(f"Found label {token} not provided in label list.") else: self.label_count_dict[token] += 1 else: if token not in self.vocab: self.vocab[token] = len(self.vocab) self.label_count_dict[token] = 1 else: self.label_count_dict[token] += 1
[docs] def process(self, example): """Featurize a single example. Parameters ---------- example: str The input label Returns ------- torch.Tensor A list of integer tokens """ tokens = self.tokenizer(example) # Numericalize numericals = [] for token in tokens: if token not in self.vocab: raise ValueError("Encounterd out-of-vocabulary label {token}") numerical = self.vocab[token] # type: ignore numericals.append(numerical) out = torch.tensor(numericals).long() if self.one_hot: out = [int(i in out) for i in range(len(self.vocab))] out = torch.tensor(out).long() # Back to Tensor return out.squeeze()
@property
[docs] def vocab_list(self) -> List[str]: """Get the list of tokens in the vocabulary. Returns ------- List[str] The list of tokens in the vocabulary, ordered. """ return list(self.vocab.keys())
@property
[docs] def vocab_size(self) -> int: """Get the vocabulary length. Returns ------- int The length of the vocabulary """ return len(self.vocab)
@property
[docs] def label_count(self) -> torch.Tensor: """Get the label count. Returns ------- torch.Tensor Tensor containing the count for each label, indexed by the id of the label in the vocabulary. """ counts = [self.label_count_dict[label] for label in self.vocab] return torch.tensor(counts).float()
@property
[docs] def label_freq(self) -> torch.Tensor: """Get the frequency of each label. Returns ------- torch.Tensor Tensor containing the frequency of each label, indexed by the id of the label in the vocabulary. """ counts = [self.label_count_dict[label] for label in self.vocab] return torch.tensor(counts).float() / sum(counts)
@property
[docs] def label_inv_freq(self) -> torch.Tensor: """Get the inverse frequency for each label. Returns ------- torch.Tensor Tensor containing the inverse frequency of each label, indexed by the id of the label in the vocabulary. """ return 1. / self.label_freq # type: ignore