flambe.nlp.transformers.model

Module Contents

class flambe.nlp.transformers.model.PretrainedTransformerEmbedder(alias: str, cache_dir: Optional[str] = None, padding_idx: Optional[int] = None, pool: bool = False, **kwargs)[source]

Bases: flambe.nn.Module

Embedder intergation of the transformers library.

Instantiate this object using any alias available in the transformers library. More information can be found here:

https://huggingface.co/transformers/

forward(self, data: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None)[source]

Perform a forward pass through the network.

If pool was provided, will only return the pooled output of shape [B x H]. Otherwise, returns the full sequence encoding of shape [S x B x H].

Parameters:
  • data (torch.Tensor) – The input data of shape [B x S]
  • token_type_ids (Optional[torch.Tensor], optional) – Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]: 0 corresponds to a sentence A token, 1 corresponds to a sentence B token. Has shape [B x S]
  • attention_mask (Optional[torch.Tensor], optional) – FloatTensor of shape [B x S]. Masked values should be 0 for padding tokens, 1 otherwise.
  • position_ids (Optional[torch.Tensor], optional) – Indices of positions of each input sequence tokens in the position embedding. Defaults to the order given in the input. Has shape [B x S].
  • head_mask (Optional[torch.Tensor], optional) – Mask to nullify selected heads of the self-attention modules. Should be 0 for heads to mask, 1 otherwise. Has shape [num_layers x num_heads]
Returns:

If pool is True, returns a tneosr of shape [B x H], else returns an encoding for each token in the sequence of shape [B x S x H].

Return type:

torch.Tensor

__getattr__(self, name: str)[source]

Override getattr to inspect config.

Parameters:name (str) – The attribute to fetch
Returns:The attribute
Return type:Any