flambe.vision.classification.datasets

Module Contents

class flambe.vision.classification.datasets.MNISTDataset(train_images: np.ndarray = None, train_labels: np.ndarray = None, test_images: np.ndarray = None, test_labels: np.ndarray = None, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None)[source]

Bases: flambe.dataset.Dataset

The official MNIST dataset.

data_type[source]
URL = http://yann.lecun.com/exdb/mnist/[source]
train :List[Tuple[torch.Tensor, torch.Tensor]][source]

Returns the training data

val :List[Tuple[torch.Tensor, torch.Tensor]][source]

Returns the validation data

test :List[Tuple[torch.Tensor, torch.Tensor]][source]

Returns the test data

classmethod from_path(cls, train_images_path: str, train_labels_path: str, test_images_path: str, test_labels_path: str, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None)[source]

Initialize the MNISTDataset from local files.

Parameters:
  • train_images_path (str) – path to the train images file in the idx format
  • train_labels_path (str) – path to the train labels file in the idx format
  • test_images_path (str) – path to the test images file in the idx format
  • test_labels_path (str) – path to the test labels file in the idx format
  • val_ratio (Optional[float]) – validation set ratio. Default 0.2
  • seed (Optional[int]) – random seed for the validation set split
classmethod _parse_local_gzipped_idx(cls, path: str)[source]

Parse a local gzipped idx file

classmethod _parse_downloaded_idx(cls, url: str)[source]

Parse a downloaded idx file

classmethod _parse_idx(cls, data: bytes)[source]

Parse an idx filie

flambe.vision.classification.datasets.get_dataset(images: np.ndarray, labels: np.ndarray) → List[Tuple[torch.Tensor, torch.Tensor]][source]