Source code for flambe.nn.transformer

# type: ignore[override]

"""
Code taken from the PyTorch source code. Slightly modified to improve
the interface to the TransformerEncoder, and TransformerDecoder modules.

"""
import copy
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from flambe.nn import Module


[docs]class Transformer(Module): """A Transformer model User is able to modify the attributes as needed. The architechture is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. """ def __init__(self, input_size, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1) -> None: """Initialize the Transformer Model. Parameters ---------- input_size : int, optional dimension of embeddings. If different from d_model, then a linear layer is added to project from input_size to d_model. d_model : int, optional the number of expected features in the encoder/decoder inputs (default=512). nhead : int, optional the number of heads in the multiheadattention models (default=8). num_encoder_layers : int, optional the number of sub-encoder-layers in the encoder (default=6). num_decoder_layers : int, optional the number of sub-decoder-layers in the decoder (default=6). dim_feedforward : int, optional the dimension of the feedforward network model (default=2048). dropout : float, optional the dropout value (default=0.1). """ super().__init__() self.encoder = TransformerEncoder(input_size, d_model, nhead, dim_feedforward, num_encoder_layers, dropout) self.decoder = TransformerDecoder(input_size, d_model, nhead, dim_feedforward, num_encoder_layers, dropout)
[docs] def forward(self, # type: ignore src: torch.Tensor, tgt: torch.Tensor, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Take in and process masked source/target sequences. Parameters ---------- src: torch.Tensor the sequence to the encoder (required). shape: :math:`(N, S, E)`. tgt: torch.Tensor the sequence to the decoder (required). shape: :math:`(N, T, E)`. src_mask: torch.Tensor, optional the additive mask for the src sequence (optional). shape: :math:`(S, S)`. tgt_mask: torch.Tensor, optional the additive mask for the tgt sequence (optional). shape: :math:`(T, T)`. memory_mask: torch.Tensor, optional the additive mask for the encoder output (optional). shape: :math:`(T, S)`. src_key_padding_mask: torch.Tensor, optional the ByteTensor mask for src keys per batch (optional). shape: :math:`(N, S)` tgt_key_padding_mask: torch.Tensor, optional the ByteTensor mask for tgt keys per batch (optional). shape: :math:`(N, T)`. memory_key_padding_mask: torch.Tensor, optional the ByteTensor mask for memory keys per batch (optional). shape" :math:`(N, S)`. Returns ------- output: torch.Tensor The output sequence, shape: :math:`(N, T, E)`. Note: [src/tgt/memory]_mask should be filled with float('-inf') for the masked positions and float(0.0) else. These masks ensure that predictions for position i depend only on the unmasked positions j and are applied identically for each sequence in a batch. [src/tgt/memory]_key_padding_mask should be a ByteTensor where False values are positions that should be masked with float('-inf') and True values will be unchanged. This mask ensures that no information will be taken from position i if it is masked, and has a separate mask for each sequence in a batch. Note: Due to the multi-head attention architecture in the transformer model, the output sequence length of a transformer is same as the input sequence (i.e. target) length of the decode. where S is the source sequence length, T is the target sequence length, N is the batchsize, E is the feature number """ if src.size(1) != tgt.size(1): raise RuntimeError("the batch number of src and tgt must be equal") if src.size(2) != self.d_model or tgt.size(2) != self.d_model: raise RuntimeError("the feature number of src and tgt must be equal to d_model") memory = self.encoder(src, mask=src_mask, padding_mask=src_key_padding_mask) output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) return output
[docs]class TransformerEncoder(Module): """TransformerEncoder is a stack of N encoder layers.""" def __init__(self, input_size: int = 512, d_model: int = 512, nhead: int = 8, num_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1) -> None: """Initialize the TransformerEncoder. Parameters --------- input_size : int The embedding dimension of the model. If different from d_model, a linear projection layer is added. d_model : int the number of expected features in encoder/decoder inputs. Default ``512``. nhead : int, optional the number of heads in the multiheadattention Default ``8``. num_layers : int the number of sub-encoder-layers in the encoder (required). Default ``6``. dim_feedforward : int, optional the inner feedforard dimension. Default ``2048``. dropout : float, optional the dropout percentage. Default ``0.1``. """ super().__init__() self.input_size = input_size self.d_model = d_model if input_size != d_model: self.proj = nn.Linear(input_size, d_model) layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) self.num_layers = num_layers self._reset_parameters()
[docs] def forward(self, # type: ignore src: torch.Tensor, memory: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Pass the input through the endocder layers in turn. Parameters ---------- src: torch.Tensor The sequence to the encoder (required). memory: torch.Tensor, optional Optional memory, unused by default. mask: torch.Tensor, optional The mask for the src sequence (optional). padding_mask: torch.Tensor, optional The mask for the src keys per batch (optional). Should be True for tokens to leave untouched, and False for padding tokens. """ output = src.transpose(0, 1) if self.input_size != self.d_model: output = self.proj(output) for i in range(self.num_layers): output = self.layers[i](output, memory=memory, src_mask=mask, padding_mask=padding_mask) return output.transpose(0, 1)
[docs] def _reset_parameters(self): """Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
[docs]class TransformerDecoder(Module): """TransformerDecoder is a stack of N decoder layers""" def __init__(self, input_size: int, d_model: int, nhead: int, num_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1) -> None: """Initialize the TransformerDecoder. Parameters --------- input_size : int The embedding dimension of the model. If different from d_model, a linear projection layer is added. d_model : int The number of expected features in encoder/decoder inputs. nhead : int, optional The number of heads in the multiheadattention. num_layers : int The number of sub-encoder-layers in the encoder (required). dim_feedforward : int, optional The inner feedforard dimension, by default 2048. dropout : float, optional The dropout percentage, by default 0.1. """ super().__init__() self.input_size = input_size self.d_model = d_model if input_size != d_model: self.proj = nn.Linear(input_size, d_model) layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) self.num_layers = num_layers self._reset_parameters()
[docs] def forward(self, # type: ignore tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer in turn. Parameters ---------- tgt: torch.Tensor The sequence to the decoder (required). memory: torch.Tensor The sequence from the last layer of the encoder (required). tgt_mask: torch.Tensor, optional The mask for the tgt sequence (optional). memory_mask: torch.Tensor, optional The mask for the memory sequence (optional). padding_mask: torch.Tensor, optional The mask for the tgt keys per batch (optional). Should be True for tokens to leave untouched, and False for padding tokens. memory_key_padding_mask: torch.Tensor, optional The mask for the memory keys per batch (optional). Returns ------- torch.Tensor """ output = tgt.transpose(0, 1) if self.input_size != self.d_model: output = self.proj(output) for i in range(self.num_layers): output = self.layers[i](output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, padding_mask=padding_mask, memory_key_padding_mask=memory_key_padding_mask) return output.transpose(0, 1)
[docs] def _reset_parameters(self): """Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
[docs]class TransformerEncoderLayer(Module): """TransformerEncoderLayer is made up of self-attn and feedforward. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. """ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1) -> None: """Initialize a TransformerEncoderLayer. Parameters ---------- d_model : int The number of expected features in the input. n_head : int The number of heads in the multiheadattention models. dim_feedforward : int, optional The dimension of the feedforward network (default=2048). dropout : float, optional The dropout value (default=0.1). """ super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.dropout = nn.Dropout(dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout)
[docs] def forward(self, # type: ignore src: torch.Tensor, memory: Optional[torch.Tensor] = None, src_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Pass the input through the endocder layer. Parameters ---------- src: torch.Tensor The seqeunce to the encoder layer (required). memory: torch.Tensor, optional Optional memory from previous sequence, unused by default. src_mask: torch.Tensor, optional The mask for the src sequence (optional). padding_mask: torch.Tensor, optional The mask for the src keys per batch (optional). Should be True for tokens to leave untouched, and False for padding tokens. Returns ------- torch.Tensor Output tensor of shape [B x S x H] """ # Transpose and reverse if padding_mask is not None: padding_mask = ~padding_mask.bool() src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src
[docs]class TransformerDecoderLayer(Module): """A TransformerDecoderLayer. A TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. This standard decoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. """ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1) -> None: """Initialize a TransformerDecoder. Parameters ---------- d_model : int The number of expected features in the input. n_head : int The number of heads in the multiheadattention models. dim_feedforward : int, optional The dimension of the feedforward network (default=2048). dropout : float, optional The dropout value (default=0.1). """ super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.dropout = nn.Dropout(dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout)
[docs] def forward(self, # type: ignore tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: r"""Pass the inputs (and mask) through the decoder layer. Parameters ---------- tgt: torch.Tensor The sequence to the decoder layer (required). memory: torch.Tensor The sequnce from the last layer of the encoder (required). tgt_mask: torch.Tensor, optional The mask for the tgt sequence (optional). memory_mask: torch.Tensor, optional the mask for the memory sequence (optional). padding_mask: torch.Tensor, optional the mask for the tgt keys per batch (optional). Should be True for tokens to leave untouched, and False for padding tokens. memory_key_padding_mask: torch.Tensor, optional the mask for the memory keys per batch (optional). Returns ------- torch.Tensor Output tensor of shape [T x B x H] """ # Transpose anr reverse if padding_mask is not None: padding_mask = ~padding_mask tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt
[docs]def generate_square_subsequent_mask(self, sz): r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = (torch.triu(torch.ones(sz, sz)) == 1).t() mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask