flambe.nlp.fewshot

Package Contents

class flambe.nlp.fewshot.PrototypicalTextClassifier(embedder: Embedder, distance: str = 'euclidean', detach_mean: bool = False)[source]

Bases: flambe.nn.Module

Implements a standard classifier.

The classifier is composed of an encoder module, followed by a fully connected output layer, with a dropout layer in between.

encoder

the encoder object

Type:Module
decoder

the decoder layer

Type:Decoder
drop

the dropout layer

Type:nn.Dropout
loss

the loss function to optimize the model with

Type:Metric
metric

the dev metric to evaluate the model on

Type:Metric
compute_prototypes(self, support: Tensor, label: Tensor)

Set the current prototypes used for classification.

Parameters:
  • data (torch.Tensor) – Input encodings
  • label (torch.Tensor) – Corresponding labels
forward(self, query: Tensor, query_label: Optional[Tensor] = None, support: Optional[Tensor] = None, support_label: Optional[Tensor] = None, prototypes: Optional[Tensor] = None)

Run a forward pass through the network.

Parameters:data (Tensor) – The input data
Returns:The output predictions
Return type:Union[Tensor, Tuple[Tensor, Tensor]]