# type: ignore[override]
from typing import Optional, Tuple, List, Union
from torch import nn
from torch import Tensor
from flambe.nn.module import Module
[docs]def conv_block(conv_mod: nn.Module,
activation: nn.Module,
pooling: nn.Module,
dropout: float,
batch_norm: Optional[nn.Module] = None) -> nn.Module:
"""Return a convolutional block.
"""
mods = [conv_mod]
if pooling:
mods.append(pooling)
if batch_norm is None:
mods.append(batch_norm)
mods.append(activation)
mods.append(nn.Dropout(dropout))
return nn.Sequential(*mods)
[docs]class CNNEncoder(Module):
"""Implements a multi-layer n-dimensional CNN.
This module can be used to create multi-layer CNN models.
Attributes
----------
cnn: nn.Module
The cnn submodule
"""
def __init__(self,
input_channels: int,
channels: List[int],
conv_dim: int = 2, # Support only for 1, 2 or 3
kernel_size: Union[int, List[Union[Tuple[int, ...], int]]] = 3,
activation: nn.Module = None,
pooling: nn.Module = None,
dropout: float = 0,
batch_norm: bool = True,
stride: int = 1,
padding: int = 0) -> None:
"""Initializes the CNNEncoder object.
Parameters
----------
input_channels: int
The input's channels. For example, 3 for RGB images.
channels: List[int]
A list to specify the channels of the convolutional layers.
The length of this list will be the amount of convolutions
in the encoder.
conv_dim: int, optional
The dimension of the convolutions. Can be 1, 2 or 3.
Defaults to 2.
kernel_size: Union[int, List[Union[Tuple[int], int]]], optional
The kernel size for the convolutions. This could be an int
(the same kernel size for all convolutions and dimensions),
or a List where for each convolution you can specify an int
or a tuple (for different sizes per dimension, in which case
the length of the tuple must match the dimension of the
convolution). Defaults to 3.
activation: nn.Module, optional
The activation function to use in all layers.
Defaults to nn.ReLU
pooling: nn.Module, optional
The pooling function to use after all layers.
Defaults to None
dropout: float, optional
Amount of dropout to use between CNN layers, defaults to 0
batch_norm: bool, optional
Wether to user Batch Normalization or not. Defaults to True
stride: int, optional
The stride to use when doing convolutions. Defaults to 1
padding: int, optional
The padding to use when doing convolutions. Defaults to 0
Raises
------
ValueError
The conv_dim should be 1, 2, 3.
"""
super().__init__()
dim2mod = {
1: (nn.Conv1d, nn.BatchNorm1d, nn.MaxPool1d),
2: (nn.Conv2d, nn.BatchNorm2d, nn.MaxPool2d),
3: (nn.Conv3d, nn.BatchNorm3d, nn.MaxPool3d),
}
if conv_dim not in dim2mod:
raise ValueError(f"Invalid conv_dim value {conv_dim}. Values 1, 2, 3 supported")
if isinstance(kernel_size, List) and len(kernel_size) != len(channels):
raise ValueError("Kernel size list should have same length as channels list")
conv, bn, pool = dim2mod[conv_dim]
activation = activation or nn.ReLU()
layers = []
prev_c = input_channels
for i, c in enumerate(channels):
k: Union[int, Tuple]
if isinstance(kernel_size, int):
k = kernel_size
else:
k = kernel_size[i]
if not isinstance(k, int) and len(k) != conv_dim:
raise ValueError("Kernel size tuple should have same length as conv_dim")
layer = conv_block(
conv(prev_c, c, k, stride, padding),
activation,
pooling,
dropout,
bn(c)
)
layers.append(layer)
prev_c = c
self.cnn = nn.Sequential(*layers)
[docs] def forward(self, data: Tensor) -> Union[Tensor, Tuple[Tensor, ...]]:
"""Performs a forward pass through the network.
Parameters
----------
data : torch.Tensor
The input data, as a float tensor
Returns
-------
Union[Tensor, Tuple[Tensor, ...]]
The encoded output, as a float tensor
"""
return self.cnn(data)