from typing import Tuple, Iterator, Dict, Any, Sequence
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from flambe.sampler import Sampler
[docs]class EpisodicSampler(Sampler):
"""Implement an EpisodicSample object.
Currently only supports sequence inputs.
"""
def __init__(self,
n_support: int,
n_query: int,
n_episodes: int,
n_classes: int = None,
pad_index: int = 0,
balance_query: bool = False) -> None:
"""Initialize the EpisodicSampler.
Parameters
----------
n_support : int
The number of support points per class
n_query : int
If balance_query is True, this should be the number
of query points per class, otherwise, this is the total
number of query points for the episode
n_episodes : int
Number of episodes to run in one "epoch"
n_classes : int, optional
The number of classes to sample per episode, defaults to all
pad_index : int, optional
The padding index used on sequences.
balance_query : bool, optional
If True, the same number of query points are sampled per
class, otherwise query points are sampled uniformly
from the input data.
"""
self.pad = pad_index
self.n_support = n_support
self.n_query = n_query
self.n_classes = n_classes
self.n_episodes = n_episodes
self.balance_query = balance_query
[docs] def sample(self,
data: Sequence[Sequence[torch.Tensor]],
n_epochs: int = 1) -> Iterator[Tuple[torch.Tensor, ...]]:
"""Sample from the list of features and yields batches.
Parameters
----------
data: Sequence[Sequence[torch.Tensor, torch.Tensor]]
The input data as a list of (source, target) pairs
n_epochs: int, optional
The number of epochs to run in the output iterator.
For this object, the total number of batches will be
(n_episodes * n_epochs)
Yields
------
Iterator[Tuple[Tensor, Tensor, Tensor, Tensor]]
In order: the query_source, the query_target
the support_source, and the support_target tensors.
For sequences, the batch is used as first dimension.
"""
if len(data) == 0:
raise ValueError("No examples provided")
# Split dataset by target
target_to_examples: Dict[int, Any] = dict()
for source, target in data:
target_to_examples.setdefault(int(target), []).append((source, target))
all_classes = list(target_to_examples.keys())
for epoch in range(n_epochs):
for _ in range(self.n_episodes):
# Sample n_classes to run a training episode over
classes = all_classes
if self.n_classes is not None:
classes = list(np.random.permutation(all_classes))[:self.n_classes]
# Sample n_support and n_query points per class
supports, queries = [], []
for i, target_class in enumerate(classes):
examples = target_to_examples[target_class]
indices = np.random.permutation(len(examples))
supports.extend([(examples[j][0], i) for j in indices[:self.n_support]])
if self.balance_query:
query_indices = indices[self.n_support:self.n_support + self.n_query]
queries.extend([(examples[j][0], i) for j in query_indices])
else:
queries.extend([(examples[j][0], i) for j in indices[self.n_support:]])
if not self.balance_query:
indices = np.random.permutation(len(queries))
queries = [queries[i] for i in indices[:self.n_query]]
query_source, query_target = list(zip(*queries))
support_source, support_target = list(zip(*supports))
query_source = pad_sequence(query_source,
batch_first=True,
padding_value=self.pad)
query_target = torch.tensor(query_target)
support_source = pad_sequence(support_source,
batch_first=True,
padding_value=self.pad)
support_target = torch.tensor(support_target)
if len(query_target.size()) == 2:
query_target = query_target.squeeze()
if len(support_target.size()) == 2:
support_target = support_target.squeeze()
yield (query_source.long(),
query_target.long(),
support_source.long(),
support_target.long())
[docs] def length(self, data: Sequence[Sequence[torch.Tensor]]) -> int:
"""Return the number of batches in the sampler.
Parameters
----------
data: Sequence[Sequence[torch.Tensor, ...]]
The input data to sample from
Returns
-------
int
The number of batches that would be created per epoch
"""
return self.n_episodes