Source code for flambe.nn.pooling

from typing import Optional

import torch

from flambe.nn import Module


[docs]class FirstPooling(Module): """Get the last hidden state of a sequence."""
[docs] def forward(self, data: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs a forward pass. Parameters ---------- data : torch.Tensor The input data, as a tensor of shape [B x S x H] padding_mask: torch.Tensor The input mask, as a tensor of shape [B X S] Returns ---------- torch.Tensor The output data, as a tensor of shape [B x H] """ return data[:, 0, :]
[docs]class LastPooling(Module): """Get the last hidden state of a sequence."""
[docs] def forward(self, data: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs a forward pass. Parameters ---------- data : torch.Tensor The input data, as a tensor of shape [B x S x H] padding_mask: torch.Tensor The input mask, as a tensor of shape [B X S] Returns ---------- torch.Tensor The output data, as a tensor of shape [B x H] """ # Compute lengths if padding_mask is None: lengths = torch.tensor([data.size(1)] * data.size(0)).long() else: lengths = padding_mask.long().sum(dim=1) return data[torch.arange(data.size(0)).long(), lengths - 1, :]
[docs]class SumPooling(Module): """Get the sum of the hidden state of a sequence."""
[docs] def forward(self, data: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs a forward pass. Parameters ---------- data : torch.Tensor The input data, as a tensor of shape [B x S x H] padding_mask: torch.Tensor The input mask, as a tensor of shape [B X S] Returns ---------- torch.Tensor The output data, as a tensor of shape [B x H] """ # Apply pooling if padding_mask is None: padding_mask = torch.ones((data.size(0), data.size(1))).to(data) return (data * padding_mask.unsqueeze(2)).sum(dim=1)
[docs]class AvgPooling(Module): """Get the average of the hidden state of a sequence."""
[docs] def forward(self, data: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs a forward pass. Parameters ---------- data : torch.Tensor The input data, as a tensor of shape [B x S x H] padding_mask: torch.Tensor The input mask, as a tensor of shape [B X S] Returns ---------- torch.Tensor The output data, as a tensor of shape [B x H] """ # Apply pooling if padding_mask is None: padding_mask = torch.ones((data.size(0), data.size(1))).to(data) data = (data * padding_mask.unsqueeze(2)).sum(dim=1) return data / padding_mask.sum(dim=1)