Source code for flambe.nn.transformer_sru
# type: ignore[override]
import copy
from typing import Optional, Dict, Any, Tuple
import torch
import torch.nn as nn
from sru import SRUCell
from flambe.nn import Module
[docs]class TransformerSRU(Module):
"""A Transformer with an SRU replacing the FFN."""
def __init__(self,
input_size: int = 512,
d_model: int = 512,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
sru_dropout: Optional[float] = None,
bidrectional: bool = False,
**kwargs: Dict[str, Any]) -> None:
"""Initialize the TransformerSRU Model.
Parameters
----------
input_size : int, optional
dimension of embeddings (default=512). if different from
d_model, then a linear layer is added to project from
input_size to d_model.
d_model : int, optional
the number of expected features in the
encoder/decoder inputs (default=512).
nhead : int, optional
the number of heads in the multiheadattention
models (default=8).
num_encoder_layers : int, optional
the number of sub-encoder-layers in the encoder
(default=6).
num_decoder_layers : int, optional
the number of sub-decoder-layers in the decoder
(default=6).
dim_feedforward : int, optional
the dimension of the feedforward network model
(default=2048).
dropout : float, optional
the dropout value (default=0.1).
sru_dropout: float, optional
Dropout for the SRU cell. If not given, uses the same
dropout value as the rest of the transformer.
bidrectional: bool, optional
Whether the SRU Encoder module should be bidrectional.
Defaul ``False``.
Extra keyword arguments are passed to the SRUCell.
"""
super().__init__()
self.encoder = TransformerSRUEncoder(input_size,
d_model,
nhead,
dim_feedforward,
num_encoder_layers,
dropout,
sru_dropout,
bidrectional,
**kwargs)
self.decoder = TransformerSRUDecoder(input_size,
d_model,
nhead,
dim_feedforward,
num_encoder_layers,
dropout,
sru_dropout,
**kwargs)
[docs] def forward(self, # type: ignore
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Take in and process masked source/target sequences.
Parameters
----------
src: torch.Tensor
the sequence to the encoder (required).
shape: :math:`(N, S, E)`.
tgt: torch.Tensor
the sequence to the decoder (required).
shape: :math:`(N, T, E)`.
src_mask: torch.Tensor, optional
the additive mask for the src sequence (optional).
shape: :math:`(S, S)`.
tgt_mask: torch.Tensor, optional
the additive mask for the tgt sequence (optional).
shape: :math:`(T, T)`.
memory_mask: torch.Tensor, optional
the additive mask for the encoder output (optional).
shape: :math:`(T, S)`.
src_key_padding_mask: torch.Tensor, optional
the ByteTensor mask for src keys per batch (optional).
shape: :math:`(N, S)`.
tgt_key_padding_mask: torch.Tensor, optional
the ByteTensor mask for tgt keys per batch (optional).
shape: :math:`(N, T)`.
memory_key_padding_mask: torch.Tensor, optional
the ByteTensor mask for memory keys per batch (optional).
shape" :math:`(N, S)`.
Returns
-------
output: torch.Tensor
The output sequence, shape: :math:`(T, N, E)`.
Note: [src/tgt/memory]_mask should be filled with
float('-inf') for the masked positions and float(0.0) else.
These masks ensure that predictions for position i depend
only on the unmasked positions j and are applied identically
for each sequence in a batch.
[src/tgt/memory]_key_padding_mask should be a ByteTensor
where False values are positions that should be masked with
float('-inf') and True values will be unchanged.
This mask ensures that no information will be taken from
position i if it is masked, and has a separate mask for each
sequence in a batch.
Note: Due to the multi-head attention architecture in the
transformer model, the output sequence length of a
transformer is same as the input sequence
(i.e. target) length of the decode.
where S is the source sequence length, T is the target
sequence length, N is the batchsize, E is the feature number
"""
if src.size(1) != tgt.size(1):
raise RuntimeError("the batch number of src and tgt must be equal")
if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
memory, state = self.encoder(src,
mask=src_mask,
padding_mask=src_key_padding_mask)
output = self.decoder(tgt,
memory,
state=state,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output
[docs]class TransformerSRUEncoder(Module):
"""A TransformerSRUEncoder with an SRU replacing the FFN."""
def __init__(self,
input_size: int = 512,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
sru_dropout: Optional[float] = None,
bidirectional: bool = False,
**kwargs: Dict[str, Any]) -> None:
"""Initialize the TransformerEncoder.
Parameters
---------
input_size : int
The embedding dimension of the model. If different from
d_model, a linear projection layer is added.
d_model : int
the number of expected features in encoder/decoder inputs.
Default ``512``.
nhead : int, optional
the number of heads in the multiheadattention
Default ``8``.
num_layers : int
the number of sub-encoder-layers in the encoder (required).
Default ``6``.
dim_feedforward : int, optional
the inner feedforard dimension. Default ``2048``.
dropout : float, optional
the dropout percentage. Default ``0.1``.
sru_dropout: float, optional
Dropout for the SRU cell. If not given, uses the same
dropout value as the rest of the transformer.
bidirectional: bool
Whether the SRU module should be bidrectional.
Defaul ``False``.
Extra keyword arguments are passed to the SRUCell.
"""
super().__init__()
self.input_size = input_size
self.d_model = d_model
if input_size != d_model:
self.proj = nn.Linear(input_size, d_model)
layer = TransformerSRUEncoderLayer(d_model,
nhead,
dim_feedforward,
dropout,
sru_dropout,
bidirectional)
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
self.num_layers = num_layers
self._reset_parameters()
[docs] def forward(self, # type: ignore
src: torch.Tensor,
state: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pass the input through the endocder layers in turn.
Parameters
----------
src: torch.Tensor
The sequnce to the encoder (required).
state: Optional[torch.Tensor]
Optional state from previous sequence encoding.
Only passed to the SRU (not used to perform multihead
attention).
mask: torch.Tensor, optional
The mask for the src sequence (optional).
padding_mask: torch.Tensor, optional
The mask for the src keys per batch (optional).
Should be True for tokens to leave untouched, and False
for padding tokens.
"""
output = src.transpose(0, 1)
if self.input_size != self.d_model:
output = self.proj(output)
new_states = []
for i in range(self.num_layers):
input_state = state[i] if state is not None else None
output, new_state = self.layers[i](output,
state=input_state,
src_mask=mask,
padding_mask=padding_mask)
new_states.append(new_state)
new_states = torch.stack(new_states, dim=0)
return output.transpose(0, 1), new_states
[docs] def _reset_parameters(self):
"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
[docs]class TransformerSRUDecoder(Module):
"""A TransformerSRUDecoderwith an SRU replacing the FFN."""
def __init__(self,
input_size: int = 512,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
sru_dropout: Optional[float] = None,
**kwargs: Dict[str, Any]) -> None:
"""Initialize the TransformerEncoder.
Parameters
---------
input_size : int
The embedding dimension of the model. If different from
d_model, a linear projection layer is added.
d_model : int
the number of expected features in encoder/decoder inputs.
Default ``512``.
nhead : int, optional
the number of heads in the multiheadattention
Default ``8``.
num_layers : int
the number of sub-encoder-layers in the encoder (required).
Default ``6``.
dim_feedforward : int, optional
the inner feedforard dimension. Default ``2048``.
dropout : float, optional
the dropout percentage. Default ``0.1``.
sru_dropout: float, optional
Dropout for the SRU cell. If not given, uses the same
dropout value as the rest of the transformer.
Extra keyword arguments are passed to the SRUCell.
"""
super().__init__()
self.input_size = input_size
self.d_model = d_model
if input_size != d_model:
self.proj = nn.Linear(input_size, d_model)
layer = TransformerSRUDecoderLayer(d_model,
nhead,
dim_feedforward,
dropout,
sru_dropout)
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
self.num_layers = num_layers
self._reset_parameters()
[docs] def forward(self, # type: ignore
tgt: torch.Tensor,
memory: torch.Tensor,
state: Optional[torch.Tensor] = None,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer in turn.
Parameters
----------
tgt: torch.Tensor
The sequence to the decoder (required).
memory: torch.Tensor
The sequence from the last layer of the encoder (required).
state: Optional[torch.Tensor]
Optional state from previous sequence encoding.
Only passed to the SRU (not used to perform multihead
attention).
tgt_mask: torch.Tensor, optional
The mask for the tgt sequence (optional).
memory_mask: torch.Tensor, optional
The mask for the memory sequence (optional).
padding_mask: torch.Tensor, optional
The mask for the tgt keys per batch (optional).
Should be True for tokens to leave untouched, and False
for padding tokens.
memory_key_padding_mask: torch.Tensor, optional
The mask for the memory keys per batch (optional).
Returns
-------
torch.Tensor
"""
output = tgt.transpose(0, 1)
state = state or [None] * self.num_layers
if self.input_size != self.d_model:
output = self.proj(output)
for i in range(self.num_layers):
output = self.layers[i](output,
memory,
state=state[i],
tgt_mask=tgt_mask,
memory_mask=memory_mask,
padding_mask=padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output.transpose(0, 1)
[docs] def _reset_parameters(self):
"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
[docs]class TransformerSRUEncoderLayer(Module):
"""A TransformerSRUEncoderLayer with an SRU replacing the FFN."""
def __init__(self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
sru_dropout: Optional[float] = None,
bidirectional: bool = False,
**kwargs: Dict[str, Any]) -> None:
"""Initialize a TransformerSRUEncoderLayer.
Parameters
----------
d_model : int
The number of expected features in the input.
n_head : int
The number of heads in the multiheadattention models.
dim_feedforward : int, optional
The dimension of the feedforward network (default=2048).
dropout : float, optional
The dropout value (default=0.1).
sru_dropout: float, optional
Dropout for the SRU cell. If not given, uses the same
dropout value as the rest of the transformer.
bidirectional: bool
Whether the SRU module should be bidrectional.
Defaul ``False``.
Extra keyword arguments are passed to the SRUCell.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.sru = SRUCell(d_model,
dim_feedforward,
dropout,
sru_dropout or dropout,
bidirectional=bidirectional,
has_skip_term=False, **kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
[docs] def forward(self, # type: ignore
src: torch.Tensor,
state: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pass the input through the endocder layer.
Parameters
----------
src: torch.Tensor
The sequence to the encoder layer (required).
state: Optional[torch.Tensor]
Optional state from previous sequence encoding.
Only passed to the SRU (not used to perform multihead
attention).
src_mask: torch.Tensor, optional
The mask for the src sequence (optional).
padding_mask: torch.Tensor, optional
The mask for the src keys per batch (optional).
Should be True for tokens to leave untouched, and False
for padding tokens.
Returns
-------
torch.Tensor
Output Tensor of shape [S x B x H]
torch.Tensor
Output state of the SRU of shape [N x B x H]
"""
# Transpose and reverse
reversed_mask = None
if padding_mask is not None:
reversed_mask = ~padding_mask
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=reversed_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2, state = self.sru(src, state, mask_pad=padding_mask)
src2 = self.linear2(src2)
src = src + self.dropout2(src2)
src = self.norm2(src)
return src, state
[docs]class TransformerSRUDecoderLayer(Module):
"""A TransformerSRUDecoderLayer with an SRU replacing the FFN."""
def __init__(self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
sru_dropout: Optional[float] = None,
**kwargs: Dict[str, Any]) -> None:
"""Initialize a TransformerDecoder.
Parameters
----------
d_model : int
The number of expected features in the input.
n_head : int
The number of heads in the multiheadattention models.
dim_feedforward : int, optional
The dimension of the feedforward network (default=2048).
dropout : float, optional
The dropout value (default=0.1).
sru_dropout: float, optional
Dropout for the SRU cell. If not given, uses the same
dropout value as the rest of the transformer.
Extra keyword arguments are passed to the SRUCell.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.sru = SRUCell(d_model,
dim_feedforward,
dropout,
sru_dropout or dropout,
bidirectional=False,
has_skip_term=False, **kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
[docs] def forward(self, # type: ignore
tgt: torch.Tensor,
memory: torch.Tensor,
state: Optional[torch.Tensor] = None,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Parameters
----------
tgt: torch.Tensor
The sequence to the decoder layer (required).
memory: torch.Tensor
The sequence from the last layer of the encoder (required).
state: Optional[torch.Tensor]
Optional state from previous sequence encoding.
Only passed to the SRU (not used to perform multihead
attention).
tgt_mask: torch.Tensor, optional
The mask for the tgt sequence (optional).
memory_mask: torch.Tensor, optional
the mask for the memory sequence (optional).
padding_mask: torch.Tensor, optional
the mask for the tgt keys per batch (optional).
memory_key_padding_mask: torch.Tensor, optional
the mask for the memory keys per batch (optional).
Returns
-------
torch.Tensor
Output Tensor of shape [S x B x H]
"""
# Transpose and reverse
reversed_mask = None
if padding_mask is not None:
reversed_mask = ~padding_mask
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
key_padding_mask=reversed_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2, _ = self.sru(tgt, state, mask_pad=padding_mask)
tgt2 = self.linear2(tgt2)
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt