Source code for flambe.nn.mos

# type: ignore[override]

import torch
import torch.nn as nn
from torch import Tensor

from flambe.nn.mlp import MLPEncoder
from flambe.nn.module import Module


[docs]class MixtureOfSoftmax(Module): """Implement the MixtureOfSoftmax output layer. Attributes ---------- pi: FullyConnected softmax layer over the different softmax layers: [FullyConnected] list of the k softmax layers """ def __init__(self, input_size: int, output_size: int, k: int = 1, take_log: bool = True) -> None: """Initialize the MOS layer. Parameters ---------- input_size: int input dimension output_size: int output dimension k: int (Default: 1) number of softmax in the mixture """ super().__init__() self.pi_w = MLPEncoder(input_size, k) self.softmax = nn.Softmax() self.layers = [MLPEncoder(input_size, output_size) for _ in range(k)] self.tanh = nn.Tanh() self.activation = nn.LogSoftmax() if take_log else nn.Softmax()
[docs] def forward(self, data: Tensor) -> Tensor: """Implement mixture of softmax for language modeling. Parameters ---------- data: torch.Tensor seq_len x batch_size x hidden_size Return ------- out: Variable output matrix of shape seq_len x batch_size x out_size """ w = self.softmax(self.pi_w(data)) # Compute k softmax, and combine using above weights out = [w[:, :, i] * self.tanh(W(data)) for i, W in enumerate(self.layers)] out = torch.cat(out, dim=0).sum(dim=0) return self.activation(out)