flambe.learn.distillation

Module Contents

class flambe.learn.distillation.DistillationTrainer(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, teacher_model: Module, student_model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, iter_scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[List[Metric]] = None, teacher_columns: Optional[Tuple[int, ...]] = None, student_columns: Optional[Tuple[int, ...]] = None, alpha_kl: float = 0.5, temperature: int = 1, unlabel_dataset: Optional[Dataset] = None, unlabel_sampler: Optional[Sampler] = None)[source]

Bases: flambe.learn.Trainer

Implement a Distillation Trainer.

Perform knowledge distillation between a teacher and a student model. Note that the model outputs are expected to be raw logits. Make sure that you are not applying a softmax after the decoder. You can replace the traditional Decoder with a MLPEncoder.

_compute_loss(self, batch: Tuple[torch.Tensor, ...])[source]

Compute the loss for a single batch

Important: the student and teacher output predictions must be the raw logits, so ensure that your decoder object is step with take_log=False.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on
Returns:The computed loss
Return type:torch.Tensor
_aggregate_preds(self, data_iterator)[source]

Aggregate the predicitons and targets for the dataset.

Parameters:data_iterator (Iterator) – Batches of data
Returns:The predictions, and targets
Return type:Tuple[torch.tensor, torch.tensor]