import math
from typing import Dict, List, Optional, Any, Tuple, Iterator
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.nn.utils.clip_grad import clip_grad_norm_, clip_grad_value_
from flambe.dataset import Dataset
from flambe.compile import Schema, State, Component, Link
from flambe.learn.utils import select_device
from flambe.nn import Module
from flambe.sampler import Sampler
from flambe.metric import Metric
from flambe.logging import log
[docs]class Trainer(Component):
"""Implement a Trainer block.
A `Trainer` takes as input data, model and optimizer,
and executes training incrementally in `run`.
Note that it is important that a trainer run be long enough
to not increase overhead, so at least a few seconds, and ideally
multiple minutes.
"""
def __init__(self,
dataset: Dataset,
train_sampler: Sampler,
val_sampler: Sampler,
model: Module,
loss_fn: Metric,
metric_fn: Metric,
optimizer: Optimizer,
scheduler: Optional[_LRScheduler] = None,
iter_scheduler: Optional[_LRScheduler] = None,
device: Optional[str] = None,
max_steps: int = 10,
epoch_per_step: float = 1.0,
iter_per_step: Optional[int] = None,
batches_per_iter: int = 1,
lower_is_better: bool = False,
max_grad_norm: Optional[float] = None,
max_grad_abs_val: Optional[float] = None,
extra_validation_metrics: Optional[Dict[str, Metric]] = None) -> None:
"""Initialize an instance of Trainer
Parameters
----------
dataset : Dataset
The dataset to use in training the model
train_sampler : Sampler
The sampler to use over training examples during training
val_sampler : Sampler
The sampler to use over validation examples
model : Module
The model to train
loss_fn: Metric
The loss function to use in training the model
metric_fn: Metric
The metric function to use in evaluation
optimizer : torch.optim.Optimizer
The optimizer to use
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
An optional learning rate scheduler to run after each step
iter_scheduler : torch.optim.lr_scheduler._LRScheduler, optional
An optional learning rate scheduler to run after each batch
(i.e iteration)
device: str, optional
The device to use in the computation.
max_steps : int, optional
The maximum number of training steps to run
epoch_per_step : float, optional
Fraction of an epoch to perform in a single training step
(i.e before a checkpoint.) Defaults to 1.
Overridden by `iter_per_step`, if given.
iter_per_step : int, optional
Number of iterations to perform in a single training step.
Overrides `epoch_per_step` if given.
batches_per_iter : int, optional
Number of batches to pass through the model before
calling optimizer.step. Requires the sampler to have
drop_last set to True. (default set to 1 so optimizer.step
is called after every batch)
lower_is_better : bool, optional
If true, the lowest val metric is considered best,
otherwise the highest. Defaults to False.
max_grad_norm : float, optional
Maximum Euclidean norm of gradient after clipping.
max_grad_abs_val: float, optional
Maximum absolute value of all gradient vector components
after clipping.
extra_validation_metrics: Optional[Dict[str, Metric]]
A dict with extra metrics to show in each step
but which don't guide the training procedures
(i.e model selection through early stopping)
The key of the metric will be used for displaying
the values in tensorboard.
"""
self.dataset = dataset
self.train_sampler = train_sampler
self.val_sampler = val_sampler
self.model = model
self.loss_fn = loss_fn
self.metric_fn = metric_fn
self.optimizer = optimizer
self.scheduler = scheduler
self.iter_scheduler = iter_scheduler
self.lower_is_better = lower_is_better
self.max_grad_norm = max_grad_norm
self.max_grad_abs_val = max_grad_abs_val
self.extra_validation_metrics = extra_validation_metrics or {}
# By default, no prefix applied to tb logs
self.tb_log_prefix = None
# Select right device
self.device = select_device(device)
if (not getattr(self.train_sampler, 'drop_last', False) and batches_per_iter != 1):
raise ValueError(f'batches_per_iter cannot be set to {batches_per_iter} '
'if the sampler does not have `drop_last` set to True')
self.batches_per_iter = batches_per_iter
n_batches = self.train_sampler.length(dataset.train)
if iter_per_step is None:
# Compute epoch per step
if self.batches_per_iter > n_batches:
raise Exception(f'Please set batches_per_iter ({self.batches_per_iter}) '
f'to be ≤ the length of your train_sampler '
f'({n_batches})')
iter_per_epoch = n_batches // self.batches_per_iter
iter_per_step = math.ceil(epoch_per_step * iter_per_epoch)
else:
# Iter per step takes precedent over epoch_per_step
epoch_per_step = iter_per_step / n_batches
self.iter_per_step = iter_per_step
self.max_steps = max_steps
self._step = 0
self._best_metric = None
self._best_model: Dict[str, torch.Tensor] = dict()
self.register_attrs('_step', '_best_metric', '_best_model')
self.n_epochs = math.ceil(epoch_per_step * max_steps)
self._create_train_iterator()
[docs] def _create_train_iterator(self):
self._train_iterator = self.train_sampler.sample(self.dataset.train, self.n_epochs)
[docs] def _batch_to_device(self, batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Move the current batch on the correct device.
Can be overriden if a batch doesn't follow the expected
structure. For example if the batch is a dictionary.
Parameters
----------
batch: Tuple[torch.Tensor, ...]
The batch to train on.
"""
batch = tuple(t.to(self.device) for t in batch)
return batch
[docs] def _compute_loss(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Compute the loss given a single batch
Parameters
----------
batch: Tuple[torch.Tensor, ...]
The batch to train on.
"""
pred, target = self.model(*batch)
loss = self.loss_fn(pred, target)
return loss
[docs] def _train_step(self) -> None:
"""Run a training step over the training data."""
self.model.train()
tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
with torch.enable_grad():
for i in range(self.iter_per_step):
# Zero the gradients and clear the accumulated loss
self.optimizer.zero_grad()
accumulated_loss = 0.0
for _ in range(self.batches_per_iter):
# Get next batch
try:
batch = next(self._train_iterator)
except StopIteration:
self._create_train_iterator()
batch = next(self._train_iterator)
batch = self._batch_to_device(batch)
# Compute loss
loss = self._compute_loss(batch) / self.batches_per_iter
accumulated_loss += loss.item()
loss.backward()
# Log loss
global_step = (self.iter_per_step * self._step) + i
# Clip gradients if necessary
if self.max_grad_norm:
clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
if self.max_grad_abs_val:
clip_grad_value_(self.model.parameters(), self.max_grad_abs_val)
log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm,
global_step)
log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm,
global_step)
# Optimize
self.optimizer.step()
# Update iter scheduler
if self.iter_scheduler is not None:
learning_rate = self.iter_scheduler.get_lr()[0] # type: ignore
log(f'{tb_prefix}Training/LR', learning_rate, global_step)
self.iter_scheduler.step() # type: ignore
# Zero the gradients when exiting a train step
self.optimizer.zero_grad()
[docs] def _aggregate_preds(self, data_iterator: Iterator) -> Tuple[torch.Tensor, torch.Tensor]:
"""Aggregate the predicitons and targets for the dataset.
Parameters
----------
data_iterator: Iterator
Batches of data.
Returns
-------
Tuple[torch.tensor, torch.tensor]
The predictions and targets.
"""
preds, targets = [], []
for batch in data_iterator:
batch = self._batch_to_device(batch)
pred, target = self.model(*batch)
preds.append(pred.cpu())
targets.append(target.cpu())
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
return preds, targets
[docs] def _eval_step(self) -> None:
"""Run an evaluation step over the validation data."""
self.model.eval()
# Initialize a 1-epoch iteration through the validation set
val_iterator = self.val_sampler.sample(self.dataset.val)
with torch.no_grad():
preds, targets = self._aggregate_preds(val_iterator)
val_loss = self.loss_fn(preds, targets).item()
val_metric = self.metric_fn(preds, targets).item()
# Update best model
sign = (-1)**(self.lower_is_better)
if self._best_metric is None or (sign * val_metric > sign * self._best_metric):
self._best_metric = val_metric
best_model_state = self.model.state_dict()
for k, t in best_model_state.items():
best_model_state[k] = t.cpu().detach()
self._best_model = best_model_state
# Update scheduler
if self.scheduler is not None:
if isinstance(self.scheduler, ReduceLROnPlateau):
self.scheduler.step(val_loss)
else:
# torch's _LRScheduler.step DOES have a default value
# so passing in no args is fine; it will automatically
# compute the current epoch
self.scheduler.step() # type: ignore
tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
# Log metrics
log(f'{tb_prefix}Validation/Loss', val_loss, self._step)
log(f'{tb_prefix}Validation/{self.metric_fn}', val_metric, self._step)
log(f'{tb_prefix}Best/{self.metric_fn}', self._best_metric, self._step) # type: ignore
for metric_name, metric in self.extra_validation_metrics.items():
log(f'{tb_prefix}Validation/{metric_name}',
metric(preds, targets).item(), self._step) # type: ignore
[docs] def run(self) -> bool:
"""Evaluate and then train until the next checkpoint
Returns
------
bool
Whether the component should continue running.
"""
self._eval_step()
if self._step < self.max_steps:
self._train_step()
# Simple stopping rule, if we exceed the max number of steps
self._step += 1
continue_ = self._step < self.max_steps
if not continue_:
self._eval_step()
self.model.cpu()
self.model.load_state_dict(self._best_model, strict=False)
return continue_
[docs] def metric(self) -> Optional[float]:
"""Override this method to enable scheduling.
Returns
-------
float
The metric to compare computable variants.
"""
return self._best_metric
[docs] def _state(self,
state_dict: State,
prefix: str,
local_metadata: Dict[str, Any]) -> State:
state_dict[prefix + 'optimizer'] = self.optimizer.state_dict()
if self.scheduler is not None:
state_dict[prefix + 'scheduler'] = self.scheduler.state_dict()
return state_dict
[docs] def _load_state(self,
state_dict: State,
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[Any],
unexpected_keys: List[Any],
error_msgs: List[Any]) -> None:
self.optimizer.load_state_dict(state_dict[prefix + 'optimizer'])
if self.scheduler is not None:
self.scheduler.load_state_dict(state_dict[prefix + 'scheduler'])
# Useful when loading the model after training
done = self._step >= self.max_steps
if done:
self.model.load_state_dict(self._best_model, strict=False)
@classmethod
[docs] def precompile(cls, **kwargs):
"""Override initialization.
Ensure that the model is compiled and pushed to the right
device before its parameters are passed to the optimizer.
"""
# Select right device
device = kwargs.get('device', None)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
def move_to_device(obj: Any):
if isinstance(obj, torch.nn.Module):
obj.to(device)
# Compile all objects and push Modules to the device
for k, obj in kwargs.items():
if isinstance(obj, (Schema, Link)):
obj.post_init_hooks.append(move_to_device)