Source code for flambe.vision.classification.datasets

import gzip
import torch
import requests
from typing import List, Tuple, Optional

from sklearn.model_selection import train_test_split

import numpy as np

from flambe.dataset import Dataset
from flambe.compile import registrable_factory


[docs]class MNISTDataset(Dataset): """The official MNIST dataset."""
[docs] data_type = { 0x08: np.uint8, 0x09: np.int8, 0x0b: np.dtype('>i2'), 0x0c: np.dtype('>i4'), 0x0d: np.dtype('>f4'), 0x0e: np.dtype('>f8'),
}
[docs] URL = "http://yann.lecun.com/exdb/mnist/"
def __init__(self, train_images: np.ndarray = None, train_labels: np.ndarray = None, test_images: np.ndarray = None, test_labels: np.ndarray = None, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None) -> None: """Initialize the MNISTDataset. Parameters ---------- train_images: np.ndarray parsed train images as a numpy array train_labels: np.ndarray parsed train labels as a numpy array test_images: np.ndarray parsed test images as a numpy array test_labels: np.ndarray parsed test labels as a numpy array val_ratio: Optional[float] validation set ratio. Default 0.2 seed: Optional[int] random seed for the validation set split """ if train_images is None: train_images = self._parse_downloaded_idx(self.URL + "train-images-idx3-ubyte.gz") if train_labels is None: train_labels = self._parse_downloaded_idx(self.URL + "train-labels-idx1-ubyte.gz") if test_images is None: test_images = self._parse_downloaded_idx(self.URL + "t10k-images-idx3-ubyte.gz") if test_labels is None: test_labels = self._parse_downloaded_idx(self.URL + "t10k-labels-idx1-ubyte.gz") self.train_images, self.val_images, self.train_labels, self.val_labels = \ train_test_split(train_images, train_labels, test_size=val_ratio, random_state=seed) self.test_images = test_images self.test_labels = test_labels self._train = get_dataset(self.train_images, self.train_labels) self._val = get_dataset(self.val_images, self.val_labels) self._test = get_dataset(self.test_images, self.test_labels) @registrable_factory @classmethod
[docs] def from_path(cls, train_images_path: str, train_labels_path: str, test_images_path: str, test_labels_path: str, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None) -> 'MNISTDataset': """Initialize the MNISTDataset from local files. Parameters ---------- train_images_path: str path to the train images file in the idx format train_labels_path: str path to the train labels file in the idx format test_images_path: str path to the test images file in the idx format test_labels_path: str path to the test labels file in the idx format val_ratio: Optional[float] validation set ratio. Default 0.2 seed: Optional[int] random seed for the validation set split """ return cls( cls._parse_local_gzipped_idx(train_images_path), cls._parse_local_gzipped_idx(train_labels_path), cls._parse_local_gzipped_idx(test_images_path), cls._parse_local_gzipped_idx(test_labels_path), val_ratio=val_ratio, seed=seed,
) @property
[docs] def train(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Returns the training data""" return self._train
@property
[docs] def val(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Returns the validation data""" return self._val
@property
[docs] def test(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Returns the test data""" return self._test
@classmethod
[docs] def _parse_local_gzipped_idx(cls, path: str) -> np.ndarray: """Parse a local gzipped idx file""" with gzip.open(path) as f: return cls._parse_idx(f.read())
@classmethod
[docs] def _parse_downloaded_idx(cls, url: str) -> np.ndarray: """Parse a downloaded idx file""" r = requests.get(url) return cls._parse_idx(gzip.decompress(r.content))
@classmethod
[docs] def _parse_idx(cls, data: bytes) -> np.ndarray: """Parse an idx filie""" # parse the magic number that contains the dimension # and the data type magic = int.from_bytes(data[0:4], 'big') dim = magic % 256 data_type = magic // 256 shape = [int.from_bytes(data[4 * (i + 1): 4 * (i + 2)], 'big') for i in range(dim)] return np.frombuffer( data, dtype=cls.data_type[data_type], offset=4 * (dim + 1) ).reshape(shape)
[docs]def get_dataset(images: np.ndarray, labels: np.ndarray) -> List[Tuple[torch.Tensor, torch.Tensor]]: return [( torch.from_numpy(image).float().unsqueeze(0), torch.tensor(label).long() ) for image, label in zip(images, labels)]