Source code for flambe.nlp.fewshot.model

# type: ignore[override]

from typing import Tuple, Dict, Any, Union, Optional

import torch
from torch import Tensor

from flambe.nn import Embedder, Module
from flambe.nn.distance import get_distance_module, get_mean_module


[docs]class PrototypicalTextClassifier(Module): """Implements a standard classifier. The classifier is composed of an encoder module, followed by a fully connected output layer, with a dropout layer in between. Attributes ---------- encoder: Module the encoder object decoder: Decoder the decoder layer drop: nn.Dropout the dropout layer loss: Metric the loss function to optimize the model with metric: Metric the dev metric to evaluate the model on """ def __init__(self, embedder: Embedder, distance: str = 'euclidean', detach_mean: bool = False) -> None: """Initialize the TextClassifier model. Parameters ---------- embedder: Embedder The embedder layer """ super().__init__() self.embedder = embedder self.distance_module = get_distance_module(distance) self.mean_module = get_mean_module(distance) self.detach_mean = detach_mean
[docs] def compute_prototypes(self, support: Tensor, label: Tensor) -> Tensor: """Set the current prototypes used for classification. Parameters ---------- data : torch.Tensor Input encodings label : torch.Tensor Corresponding labels """ means_dict: Dict[int, Any] = {} for i in range(support.size(0)): means_dict.setdefault(int(label[i]), []).append(support[i]) means = [] n_means = len(means_dict) for i in range(n_means): # Ensure that all contiguous indices are in the means dict supports = torch.stack(means_dict[i], dim=0) if supports.size(0) > 1: mean = self.mean_module(supports).squeeze(0) else: mean = supports.squeeze(0) means.append(mean) prototypes = torch.stack(means, dim=0) return prototypes
[docs] def forward(self, # type: ignore query: Tensor, query_label: Optional[Tensor] = None, support: Optional[Tensor] = None, support_label: Optional[Tensor] = None, prototypes: Optional[Tensor] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Run a forward pass through the network. Parameters ---------- data: Tensor The input data Returns ------- Union[Tensor, Tuple[Tensor, Tensor]] The output predictions """ query_encoding = self.embedder(query) if isinstance(query_encoding, tuple): # RNN query_encoding = query_encoding[0] if prototypes is not None: prototypes = prototypes elif support is not None and support_label is not None: if self.detach_mean: support = support.detach() support_label = support_label.detach() # type: ignore support_encoding = self.embedder(support) if isinstance(support_encoding, tuple): # RNN support_encoding = support_encoding[0] # Compute prototypes prototypes = self.compute_prototypes(support_encoding, support_label) else: raise ValueError("No prototypes set or provided") dist = self.distance_module(query_encoding, prototypes) if query_label is not None: return - dist, query_label else: return - dist