Module Contents

class flambe.learn.train.Trainer(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, iter_scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics_log_interval: Optional[int] = None)[source]

Bases: flambe.compile.Component

Implement a Trainer block.

A Trainer takes as input data, model and optimizer, and executes training incrementally in run.

Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.


Adding property for backwards compatibility

_batch_to_device(self, batch: Tuple[torch.Tensor, ...])[source]

Move the current batch on the correct device.

Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_compute_loss(self, batch: Tuple[torch.Tensor, ...])[source]

Compute the loss given a single batch DEPRECATED, only exists for legacy compatibility with custom trainers

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_compute_batch(self, batch: Tuple[torch.Tensor, ...], metrics: List[Tuple] = [])[source]

Computes a batch.

Does a model forward pass over a batch, and returns prediction, target and loss.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
static _log_metrics(log_prefix: str, metrics_with_states: List[Tuple], global_step: int)[source]

Logs all provided metrics

Iterates through the provided list of metrics with states, finalizes the metric, and logs it.

  • log_prefix (str) – A string, such as a tensorboard prefix
  • metrics_with_states (List[Tuple[Metric, Dict]]) – a list of metric-state tuples
  • global_step (int) – the global step for loggin

Run a training step over the training data.

_aggregate_preds(self, data_iterator: Iterator)[source]

DEPRECATED Aggregate the predicitons, targets and mean loss for the dataset.

Parameters:data_iterator (Iterator) – Batches of data.
  • Tuple[torch.tensor, torch.tensor, float] – The predictions, targets and mean loss.
  • DEPRECATED; only existed to aggregate for the metric functions.
  • The metric functions do this in-place now.

Run an evaluation step over the validation data.


Evaluate and then train until the next checkpoint

Returns:Whether the component should continue running.
Return type:bool

Override this method to enable scheduling.

Returns:The metric to compare computable variants.
Return type:float
_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any])[source]
_load_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any], strict: bool, missing_keys: List[Any], unexpected_keys: List[Any], error_msgs: List[Any])[source]
classmethod precompile(cls, **kwargs)[source]

Override initialization.

Ensure that the model is compiled and pushed to the right device before its parameters are passed to the optimizer.