flambe.vision.classification

Package Contents

class flambe.vision.classification.MNISTDataset(train_images: np.ndarray = None, train_labels: np.ndarray = None, test_images: np.ndarray = None, test_labels: np.ndarray = None, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None)[source]

Bases: flambe.dataset.Dataset

The official MNIST dataset.

data_type
URL = http://yann.lecun.com/exdb/mnist/
train :List[Tuple[torch.Tensor, torch.Tensor]]

Returns the training data

val :List[Tuple[torch.Tensor, torch.Tensor]]

Returns the validation data

test :List[Tuple[torch.Tensor, torch.Tensor]]

Returns the test data

classmethod from_path(cls, train_images_path: str, train_labels_path: str, test_images_path: str, test_labels_path: str, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None)

Initialize the MNISTDataset from local files.

Parameters:
  • train_images_path (str) – path to the train images file in the idx format
  • train_labels_path (str) – path to the train labels file in the idx format
  • test_images_path (str) – path to the test images file in the idx format
  • test_labels_path (str) – path to the test labels file in the idx format
  • val_ratio (Optional[float]) – validation set ratio. Default 0.2
  • seed (Optional[int]) – random seed for the validation set split
classmethod _parse_local_gzipped_idx(cls, path: str)

Parse a local gzipped idx file

classmethod _parse_downloaded_idx(cls, url: str)

Parse a downloaded idx file

classmethod _parse_idx(cls, data: bytes)

Parse an idx filie

class flambe.vision.classification.ImageClassifier(encoder: Module, output_layer: Module)[source]

Bases: flambe.nn.Module

Implements a simple image classifier.

This classifier consists of an encocder module, followed by a fully connected output layer that outputs a probability distribution.

encoder

The encoder layer

Type:Moodule
output_layer

The output layer, yields a probability distribution over targets

Type:Module
forward(self, data: Tensor, target: Optional[Tensor] = None)

Run a forward pass through the network.

Parameters:
  • data (Tensor) – The input data
  • target (Tensor, optional) – The input targets, optional
Returns:

The output predictions, and optionally the targets

Return type:

Union[Tensor, Tuple[Tensor, Tensor]