flambe.learn.train

Module Contents

class flambe.learn.train.Trainer(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, extra_validation_metrics: Optional[List[Metric]] = None)[source]

Bases: flambe.compile.Component

Implement a Trainer block.

A Trainer takes as input data, model and optimizer, and executes training incrementally in run.

Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.

_batch_to_device(self, batch: Tuple[torch.Tensor, ...])[source]

Move the current batch on the correct device.

Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_compute_loss(self, batch: Tuple[torch.Tensor, ...])[source]

Compute the loss given a single batch

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_train_step(self)[source]

Run a training step over the training data.

_aggregate_preds(self, data_iterator: Iterator)[source]

Aggregate the predicitons and targets for the dataset.

Parameters:data_iterator (Iterator) – Batches of data.
Returns:The predictions and targets.
Return type:Tuple[torch.tensor, torch.tensor]
_eval_step(self)[source]

Run an evaluation step over the validation data.

run(self)[source]

Train until the next checkpoint, and evaluate.

Returns:Whether the computable is not yet complete.
Return type:bool
metric(self)[source]

Override this method to enable scheduling.

Returns:The metric to compare computable variants.
Return type:float
_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any])[source]
_load_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any], strict: bool, missing_keys: List[Any], unexpected_keys: List[Any], error_msgs: List[Any])[source]
classmethod precompile(cls, **kwargs)[source]

Override initialization.

Ensure that the model is compiled and pushed to the right device before its parameters are passed to the optimizer.