flambe.learn

Package Contents

class flambe.learn.Trainer(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, iter_scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics_log_interval: Optional[int] = None)[source]

Bases: flambe.compile.Component

Implement a Trainer block.

A Trainer takes as input data, model and optimizer, and executes training incrementally in run.

Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.

validation_metrics

Adding property for backwards compatibility

_create_train_iterator(self)
_batch_to_device(self, batch: Tuple[torch.Tensor, ...])

Move the current batch on the correct device.

Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_compute_loss(self, batch: Tuple[torch.Tensor, ...])

Compute the loss given a single batch DEPRECATED, only exists for legacy compatibility with custom trainers

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_compute_batch(self, batch: Tuple[torch.Tensor, ...], metrics: List[Tuple] = [])

Computes a batch.

Does a model forward pass over a batch, and returns prediction, target and loss.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
static _log_metrics(log_prefix: str, metrics_with_states: List[Tuple], global_step: int)

Logs all provided metrics

Iterates through the provided list of metrics with states, finalizes the metric, and logs it.

Parameters:
  • log_prefix (str) – A string, such as a tensorboard prefix
  • metrics_with_states (List[Tuple[Metric, Dict]]) – a list of metric-state tuples
  • global_step (int) – the global step for loggin
_train_step(self)

Run a training step over the training data.

_aggregate_preds(self, data_iterator: Iterator)

DEPRECATED Aggregate the predicitons, targets and mean loss for the dataset.

Parameters:data_iterator (Iterator) – Batches of data.
Returns:
  • Tuple[torch.tensor, torch.tensor, float] – The predictions, targets and mean loss.
  • DEPRECATED; only existed to aggregate for the metric functions.
  • The metric functions do this in-place now.
_eval_step(self)

Run an evaluation step over the validation data.

run(self)

Evaluate and then train until the next checkpoint

Returns:Whether the component should continue running.
Return type:bool
metric(self)

Override this method to enable scheduling.

Returns:The metric to compare computable variants.
Return type:float
_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any])
_load_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any], strict: bool, missing_keys: List[Any], unexpected_keys: List[Any], error_msgs: List[Any])
classmethod precompile(cls, **kwargs)

Override initialization.

Ensure that the model is compiled and pushed to the right device before its parameters are passed to the optimizer.

class flambe.learn.Evaluator(dataset: Dataset, model: Module, metric_fn: Metric, eval_sampler: Optional[Sampler] = None, eval_data: str = 'test', device: Optional[str] = None)[source]

Bases: flambe.compile.Component

Implement an Evaluator block.

An Evaluator takes as input data, and a model and executes the evaluation. This is a single step Component object.

run(self, block_name: str = None)

Run the evaluation.

Returns:Whether the component should continue running.
Return type:bool
metric(self)

Override this method to enable scheduling.

Returns:The metric to compare computable varients
Return type:float
class flambe.learn.Script(script: str, args: List[Any], kwargs: Optional[Dict[str, Any]] = None, output_dir_arg: Optional[str] = None)[source]

Bases: flambe.compile.Component

Implement a Script computable.

The obejct can be used to turn any script into a Flambé computable. This is useful when you want to rapidly integrate code. Note however that this computable does not enable checkpointing or linking to internal components as it does not have any attributes.

To use this object, your script needs to be in a pip installable, containing all dependencies. The script is run with the following command:

python -m script.py --arg1 value1 --arg2 value2
run(self)

Run the evaluation.

Returns:Report dictionary to use for logging
Return type:Dict[str, float]
class flambe.learn.DistillationTrainer(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, teacher_model: Module, student_model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, iter_scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[List[Metric]] = None, teacher_columns: Optional[Tuple[int, ...]] = None, student_columns: Optional[Tuple[int, ...]] = None, alpha_kl: float = 0.5, temperature: int = 1, unlabel_dataset: Optional[Dataset] = None, unlabel_sampler: Optional[Sampler] = None)[source]

Bases: flambe.learn.Trainer

Implement a Distillation Trainer.

Perform knowledge distillation between a teacher and a student model. Note that the model outputs are expected to be raw logits. Make sure that you are not applying a softmax after the decoder. You can replace the traditional Decoder with a MLPEncoder.

_compute_loss(self, batch: Tuple[torch.Tensor, ...])

Compute the loss for a single batch

Important: the student and teacher output predictions must be the raw logits, so ensure that your decoder object is step with take_log=False.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on
Returns:The computed loss
Return type:torch.Tensor
_aggregate_preds(self, data_iterator)

Aggregate the predicitons and targets for the dataset.

Parameters:data_iterator (Iterator) – Batches of data
Returns:The predictions, and targets
Return type:Tuple[torch.tensor, torch.tensor]