Source code for flambe.nn.embedding

import math
from typing import Tuple, Union, Optional

import torch
from torch import nn
from torch import Tensor

from flambe.compile import registrable_factory
from flambe.nn.module import Module

[docs]class Embeddings(Module): """Implement an Embeddings module. This object replicates the usage of nn.Embedding but registers the from_pretrained classmethod to be used inside a Flambé configuration, as this does not happen automatically during the registration of PyTorch objects. The module also adds optional positional encoding, which can either be sinusoidal or learned during training. For the non-learned positional embeddings, we use sine and cosine functions of different frequencies. .. math:: \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) \text{where pos is the word position and i is the embed idx) """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = 0, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, sparse: bool = False, positional_encoding: bool = False, positional_learned: bool = False, positonal_max_length: int = 5000) -> None: """Initialize an Embeddings module. Parameters ---------- num_embeddings : int Size of the dictionary of embeddings. embedding_dim : int The size of each embedding vector. padding_idx : int, optional Pads the output with the embedding vector at :attr:`padding_idx` (initialized to zeros) whenever it encounters the index, by default 0 max_norm : Optional[float], optional If given, each embedding vector with norm larger than :attr:`max_norm` is normalized to have norm :attr:`max_norm` norm_type : float, optional The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. scale_grad_by_freq : bool, optional If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. sparse : bool, optional If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details. positional_encoding : bool, optional If True, adds positonal encoding to the token embeddings. By default, the embeddings are frozen sinusodial embeddings. To learn these during training, set positional_learned. Default ``False``. positional_learned : bool, optional Learns the positional embeddings during training instead of using frozen sinusodial ones. Default ``False``. positonal_max_length : int, optional The maximum length of a sequence used for the positonal embedding matrix. Default ``5000``. """ super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.num_positions = positonal_max_length self.token_embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) self.pos_embedding = None if positional_learned and not positional_encoding: raise ValueError("postional_encoding is False, but positonal_learned is True") elif positional_encoding and positional_learned: self.pos_embedding = nn.Embedding(positonal_max_length, embedding_dim) elif positional_encoding and not positional_learned: # Use sinusodial encoding position = torch.arange(0, positonal_max_length, dtype=torch.float).unsqueeze(1) div_term = torch.arange(0, embedding_dim, 2).float() div_term = torch.exp(div_term * (-math.log(10000.0) / embedding_dim)) pos_embedding = torch.zeros(positonal_max_length, embedding_dim) pos_embedding[:, 0::2] = torch.sin(position * div_term) pos_embedding[:, 1::2] = torch.cos(position * div_term) self.pos_embedding = nn.Embedding.from_pretrained(pos_embedding, freeze=True) @registrable_factory @classmethod
[docs] def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, padding_idx: int = 0, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, positional_encoding: bool = False, positional_learned: bool = False, positonal_max_length: int = 5000, positonal_embeddings: Optional[Tensor] = None, positonal_freeze: bool = True): """Create an Embeddings instance from pretrained embeddings. Parameters ---------- embeddings: torch.Tensor FloatTensor containing weights for the Embedding. First dimension is being passed to Embedding as num_embeddings, second as embedding_dim. freeze: bool If True, the tensor does not get updated in the learning process. Default: True padding_idx : int, optional Pads the output with the embedding vector at :attr:`padding_idx` (initialized to zeros) whenever it encounters the index, by default 0 max_norm : Optional[float], optional If given, each embedding vector with norm larger than :attr:`max_norm` is normalized to have norm :attr:`max_norm` norm_type : float, optional The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. scale_grad_by_freq : bool, optional If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. sparse : bool, optional If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details. positional_encoding : bool, optional If True, adds positonal encoding to the token embeddings. By default, the embeddings are frozen sinusodial embeddings. To learn these during training, set positional_learned. Default ``False``. positional_learned : bool, optional Learns the positional embeddings during training instead of using frozen sinusodial ones. Default ``False``. positonal_embeddings: torch.Tensor, optional If given, also replaces the positonal embeddings with this matrix. The max length will be ignored and replaced by the dimension of this matrix. positonal_freeze: bool, optional Whether the positonal embeddings should be frozen """ if embeddings.dim() != 2: raise ValueError('Embeddings parameter is expected to be 2-dimensional') if positonal_embeddings is not None: if positonal_embeddings.dim() != 2: raise ValueError('Positonal embeddings parameter is expected to be 2-dimensional') if positonal_embeddings.size() != embeddings.size(): raise ValueError('Both pretrained matrices must have the same dimensions') rows, cols = embeddings.shape positional_encoding = positional_encoding or (positonal_embeddings is not None) embedding = cls(num_embeddings=rows, embedding_dim=cols, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, positional_encoding=positional_encoding, positional_learned=positional_learned, positonal_max_length=positonal_max_length) = embeddings embedding.token_embedding.weight.requires_grad = not freeze if positonal_embeddings is not None: = positonal_embeddings # type: ignore embedding.pos_embedding.weight.requires_grad = not positonal_freeze # type: ignore return embedding
[docs] def forward(self, data: Tensor) -> Tensor: """Perform a forward pass. Parameters ---------- data : Tensor The input tensor of shape [S x B] Returns ------- Tensor The output tensor of shape [S x B x E] """ out = self.token_embedding(data) if self.pos_embedding is not None: column = torch.arange(data.size(0)).unsqueeze(1) positions = column.repeat(1, data.size(1)).to(data) out = out + self.pos_embedding(positions) return out
[docs]class Embedder(Module): """Implements an Embedder module. An Embedder takes as input a sequence of index tokens, and computes the corresponding embedded representations, and padding mask. The encoder may be initialized using a pretrained embedding matrix. Attributes ---------- embeddings: Module The embedding module encoder: Module The sub-encoder that this object is wrapping pooling: Module An optional pooling module drop: nn.Dropout The dropout layer """ def __init__(self, embedding: Module, encoder: Module, pooling: Optional[Module] = None, embedding_dropout: float = 0, padding_idx: Optional[int] = 0) -> None: """Initializes the TextEncoder module. Extra arguments are passed to the nn.Embedding module. Parameters ---------- embedding: nn.Embedding The embedding layer encoder: Module The encoder pooling: Module, optional An optioonal pooling module, takes a sequence of Tensor and reduces them to a single Tensor. embedding_dropout: float, optional Amount of dropout between the embeddings and the encoder padding_idx: int, optional Passed the nn.Embedding object. See pytorch documentation. """ super().__init__() self.embedding = embedding self.dropout = nn.Dropout(embedding_dropout) self.encoder = encoder self.pooling = pooling self.padding_idx = padding_idx
[docs] def forward(self, data: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Performs a forward pass through the network. Parameters ---------- data : torch.Tensor The input data, as a float tensor of shape [S x B] Returns ------- Union[Tensor, Tuple[Tensor, Tensor]] The encoded output, as a float tensor. May return a state if the encoder is an RNN and no pooling is provided. """ embedded = self.embedding(data) embedded = self.dropout(embedded) if self.padding_idx is not None: padding_mask = (data != self.padding_idx).byte() encoding = self.encoder(embedded, padding_mask=padding_mask) else: encoding = self.encoder(embedded) if self.pooling is not None: # Ignore states from encoders such as RNN or TransformerSRU encoding = encoding[0] if isinstance(encoding, tuple) else encoding encoding = self.pooling(encoding) return encoding