Source code for flambe.nlp.language_modeling.sampler

from typing import Optional, Sequence, Tuple, Iterator
import math

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from flambe.sampler.sampler import Sampler


[docs]class CorpusSampler(Sampler): """Implement a CorpusSampler object. This object is useful for iteration over a large corpus of text in an ordered way. It takes as input a dataset with a single example containing the sequence of tokens. """ def __init__(self, batch_size: int = 128, unroll_size: int = 128, n_workers: int = 0, pin_memory: bool = False, downsample: Optional[float] = None, drop_last: bool = True) -> None: """Initialize the CorpusSampler object. Parameters ---------- batch_size : int, optional The batch size to use. Default ``128``. unroll_size: int, optional Make every sequence this length. Default ``128``. n_workers : int, optional Number of workers to pass to the DataLoader (the default is 0, which means the main process) pin_memory : bool, optional Pin the memory when using cuda (the default is False) downsample: float, optional Percentage of the data to downsample to drop_last: bool, optional Set to True to drop the last incomplete batch if the dataset size is not divisible by the batch size. (the default is False) """ self.unroll_size = unroll_size self.batch_size = batch_size self.drop_last = drop_last self.n_workers = n_workers self.pin_memory = pin_memory self.downsample = downsample @staticmethod
[docs] def collate_fn(data: Sequence[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]: """Create a batch from data. Parameters ---------- data : Sequence[Tuple[Tensor, Tensor]] List of (source, target) tuples. Returns ------- Tuple[Tensor, Tensor] Source and target Tensors. """ x, y = zip(*data) return torch.stack(x).t(), torch.stack(y).t()
[docs] def sample(self, data: Sequence[Sequence[Tensor]], n_epochs: int = 1) -> Iterator[Tuple[Tensor, ...]]: """Sample from the list of features and yields batches. Parameters ---------- data: Sequence[Sequence[Tensor, ...]] The input data to sample from n_epochs: int, optional The number of epochs to run in the output iterator. Use -1 to run infinitely. Yields ------ Iterator[Tuple[Tensor]] A batch of data, as a tuple of Tensors """ if len(data) == 0: raise ValueError("No examples provided") elif len(data) > 1: raise ValueError("Expected a single input example") tensor = data[0][0] # First example, first column if self.downsample: if not (0 < self.downsample <= 1): raise ValueError("Downsample value should be in the range (0, 1]") tensor = tensor[:int(self.downsample * tensor.size(0))] # Organize batch-wise final_length = (tensor.size(0) - 1) // self.batch_size * self.batch_size x = torch.reshape(tensor[:final_length], (self.batch_size, -1)).t() y = torch.reshape(tensor[1:final_length + 1], (self.batch_size, -1)).t() loader = DataLoader(dataset=torch.utils.data.TensorDataset(x, y), collate_fn=self.collate_fn, shuffle=False, batch_size=self.unroll_size, num_workers=self.n_workers, pin_memory=self.pin_memory, drop_last=self.drop_last) if n_epochs == -1: while True: yield from loader else: for _ in range(n_epochs): yield from loader
[docs] def length(self, data: Sequence[Sequence[torch.Tensor]]) -> int: """Return the number of batches in the sampler. Parameters ---------- data: Sequence[Sequence[torch.Tensor, ...]] The input data to sample from Returns ------- int The number of batches that would be created per epoch """ tensor = data[0][0] if self.drop_last: return ((tensor.size(0) - 1) // self.batch_size) // self.unroll_size else: return math.ceil(((tensor.size(0) - 1) // self.batch_size) / self.unroll_size)