flambe.learn

Package Contents

class flambe.learn.Trainer(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[List[Metric]] = None)[source]

Bases: flambe.compile.Component

Implement a Trainer block.

A Trainer takes as input data, model and optimizer, and executes training incrementally in run.

Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.

_batch_to_device(self, batch: Tuple[torch.Tensor, ...])

Move the current batch on the correct device.

Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_compute_loss(self, batch: Tuple[torch.Tensor, ...])

Compute the loss given a single batch

Parameters:batch (Tuple[torch.Tensor, ..]) – The batch to train on.
_train_step(self)

Run a training step over the training data.

_aggregate_preds(self, data_iterator: Iterator)

Aggregate the predicitons and targets for the dataset.

Parameters:data_iterator (Iterator) – Batches of data.
Returns:The predictions and targets.
Return type:Tuple[torch.tensor, torch.tensor]
_eval_step(self)

Run an evaluation step over the validation data.

run(self)

Train until the next checkpoint, and evaluate.

Returns:Whether the computable is not yet complete.
Return type:bool
metric(self)

Override this method to enable scheduling.

Returns:The metric to compare computable variants.
Return type:float
_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any])
_load_state(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any], strict: bool, missing_keys: List[Any], unexpected_keys: List[Any], error_msgs: List[Any])
classmethod precompile(cls, **kwargs)

Override initialization.

Ensure that the model is compiled and pushed to the right device before its parameters are passed to the optimizer.

class flambe.learn.Evaluator(dataset: Dataset, eval_sampler: Sampler, model: Module, metric_fn: Metric, eval_data: str = 'test', device: Optional[str] = None)[source]

Bases: flambe.compile.Component

Implement an Evaluator block.

An Evaluator takes as input data, and a model and executes the evaluation. This is a single step Component object.

Parameters:
  • dataset (Dataset) – The dataset to run evaluation on
  • eval_sampler (Sampler) – The sampler to use over validation examples
  • model (Module) – The model to train
  • metric_fn (Metric) – The metric to use for evaluation
  • eval_data (str) – The data split to evaluate on: one of train, val or test
  • device (str, optional) – The device to use in the computation.
run(self, block_name: str = None)

Run the evaluation.

Returns:Whether the computable has completed.
Return type:bool
metric(self)

Override this method to enable scheduling.

Returns:The metric to compare computable varients
Return type:float
class flambe.learn.Script(script: str, args: Dict[str, Any])[source]

Bases: flambe.compile.Component

Implement a Script computable.

The obejct can be used to turn any script into a Flambé computable. This is useful when you want to rapidly integrate code. Note however that this computable does not enable checkpointing or linking to internal components as it does not have any attributes.

To use this object, your script needs to be in a pip installable, containing all dependencies. The script is run with the following command:

python -m script.py --arg1 value1 --arg2 value2
Parameters:
  • path (str) – The script module
  • args (Dict[str, Any]) – Argument dictionary
run(self)

Run the evaluation.

Returns:Report dictionary to use for logging
Return type:Dict[str, float]