from typing import Optional, Tuple, List
import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from flambe.dataset import Dataset
from flambe.metric import Metric
from flambe.sampler import Sampler
from flambe.learn import Trainer
from flambe.nn import Module
[docs]class DistillationTrainer(Trainer):
"""Implement a Distillation Trainer.
Perform knowledge distillation between a teacher and a student
model. Note that the model outputs are expected to be raw logits.
Make sure that you are not applying a softmax after the decoder.
You can replace the traditional Decoder with a MLPEncoder.
"""
def __init__(self,
dataset: Dataset,
train_sampler: Sampler,
dev_sampler: Sampler,
teacher_model: Module,
student_model: Module,
loss_fn: Metric,
metric_fn: Metric,
optimizer: Optimizer,
scheduler: Optional[_LRScheduler] = None,
device: Optional[str] = None,
max_steps: int = 10,
epoch_per_step: float = 1.0,
iter_per_step: Optional[int] = None,
batches_per_iter: int = 1,
lower_is_better: bool = False,
max_grad_norm: Optional[float] = None,
max_grad_abs_val: Optional[float] = None,
extra_validation_metrics: Optional[List[Metric]] = None,
teacher_columns: Optional[Tuple[int, ...]] = None,
student_columns: Optional[Tuple[int, ...]] = None,
alpha_kl: float = 0.5,
temperature: int = 1,
unlabel_dataset: Optional[Dataset] = None,
unlabel_sampler: Optional[Sampler] = None) -> None:
"""Initialize the Trainer.
Parameters
----------
dataset: Dataset
The dataset containing the first N columns of data for the
student model, and the last N columns for the target.
train_sampler : Sampler
The sampler to use over training examples
dev_sampler : Sampler
The sampler to use over validation examples
model : Module
The model to train
optimizer : torch.optim.Optimizer
The optimizer to use
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
An optional learning rate scheduler
device: str, optional
The device to use in the computation. Only used by compile.
max_steps : int, optional
The maximum number of training steps to run
epoch_per_step : float, optional
Fraction of an epoch to perform in a single training step
(i.e before a checkpoint.) Defaults to 1.
Overriden by `iter_per_step`, if given.
iter_per_step : int, optional
Number of iterations to perform in a single training step.
Overrides `epoch_per_step` if given.
batches_per_iter : int, optional
Number of batches to pass through the model before
calling optimizer.step. Requires the sampler to have
drop_last set to True. (default set to 1 so optimizer.step
is called after every batch)
lower_is_better : bool, optional
If true, the lowest dev metric is considered best,
otherwise the highest. Defaults to False.
max_grad_norm : float, optional
Maximum Euclidean norm of gradient after clipping.
max_grad_abs_val: float, optional
Maximum absolute value of all gradient vector components
after clipping.
extra_validation_metrics: Optional[List[Metric]]
A list with extra metrics to show in each step
but which don't guide the training procedures
(i.e model selection through early stopping)
alpha_kl: float, optional
Weight applied to the distillation loss.
temperature: int, optional
The temperature applied to the logits
unlabel_dataset: Dataset, optional
Optional dataset of unlabel data
unlabel_sampler: Sampler, optional
Optional sampler over unlabel examples
"""
super().__init__(dataset,
train_sampler, # type: ignore
dev_sampler,
student_model,
loss_fn,
metric_fn,
optimizer,
scheduler,
device,
max_steps,
epoch_per_step,
iter_per_step,
batches_per_iter,
lower_is_better,
max_grad_norm,
max_grad_abs_val,
extra_validation_metrics)
self.student_model = self.model
self.teacher_model = teacher_model
self.teacher_columns = teacher_columns
self.student_columns = student_columns
self.alpha_kl = alpha_kl
self.temp = temperature
self.unlabel_dataset = None
self.unlabel_sampler = None
if unlabel_sampler is not None and unlabel_dataset is not None:
self.unlabel_sampler = unlabel_sampler
self._unlabel_iterator = unlabel_sampler.sample(unlabel_dataset.train, -1)
[docs] def _compute_loss(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Compute the loss for a single batch
Important: the student and teacher output predictions must
be the raw logits, so ensure that your decoder object is step
with `take_log=False`.
Parameters
----------
batch: Tuple[torch.Tensor, ...]
The batch to train on
Returns
-------
torch.Tensor
The computed loss
"""
student_columns = self.student_columns or range(len(batch))
teacher_columns = self.teacher_columns or range(len(batch))
student_batch = [batch[i] for i in student_columns]
teacher_batch = [batch[i].detach() for i in teacher_columns]
student_logits, student_target = self.student_model(*student_batch)
teacher_logits, _ = self.teacher_model(*teacher_batch)
# Compute losses
student_pred = F.log_softmax(student_logits, dim=-1)
loss = (1 - self.alpha_kl) * self.loss_fn(student_pred, student_target)
# Add unlabelled batch
if self.unlabel_sampler is not None:
# Get next batch
unlabelled, = next(self._unlabel_iterator)
student_unlabel_logits = self.student_model(unlabelled)
teacher_unlabel_logits = self.teacher_model(unlabelled.detach())
student_logits = torch.cat((student_logits, student_unlabel_logits))
teacher_logits = torch.cat((teacher_logits, teacher_unlabel_logits))
student_pred = F.log_softmax(student_logits / self.temp, dim=1)
teacher_pred = F.softmax(teacher_logits / self.temp, dim=1)
kl_loss = F.kl_div(student_pred, teacher_pred, size_average=False) / teacher_pred.shape[0]
loss += (self.alpha_kl * self.temp**2) * kl_loss
return loss
[docs] def _aggregate_preds(self, data_iterator) -> Tuple[torch.Tensor, torch.Tensor]:
"""Aggregate the predicitons and targets for the dataset.
Parameters
----------
data_iterator: Iterator
Batches of data
Returns
-------
Tuple[torch.tensor, torch.tensor]
The predictions, and targets
"""
preds, targets = [], []
for batch in data_iterator:
student_columns = self.student_columns or range(len(batch))
student_batch = [batch[i] for i in student_columns]
pred, target = self.model(**[t.to(self.device) for t in student_batch])
pred = F.log_softmax(pred, dim=-1)
preds.append(pred.cpu())
targets.append(target.cpu())
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
return preds, targets