import torch
from torch import Tensor
from typing import Optional, Tuple, Union
from flambe.nn import Module # type: ignore[attr-defined]
[docs]class ImageClassifier(Module):
"""Implements a simple image classifier.
This classifier consists of an encocder module, followed by
a fully connected output layer that outputs a probability
distribution.
Attributes
----------
encoder: Moodule
The encoder layer
output_layer: Module
The output layer, yields a probability distribution over targets
"""
def __init__(self,
encoder: Module,
output_layer: Module) -> None:
super().__init__()
self.encoder = encoder
self.output_layer = output_layer
[docs] def forward(self,
data: Tensor,
target: Optional[Tensor] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Run a forward pass through the network.
Parameters
----------
data: Tensor
The input data
target: Tensor, optional
The input targets, optional
Returns
-------
Union[Tensor, Tuple[Tensor, Tensor]
The output predictions, and optionally the targets
"""
encoded = self.encoder(data)
pred = self.output_layer(torch.flatten(encoded, 1))
return (pred, target) if target is not None else pred