Source code for flambe.learn.eval

from typing import Optional

import torch

from flambe.compile import Component
from flambe.dataset import Dataset
from flambe.nn import Module
from flambe.metric import Metric
from flambe.sampler import Sampler
from flambe.logging import log


[docs]class Evaluator(Component): """Implement an Evaluator block. An `Evaluator` takes as input data, and a model and executes the evaluation. This is a single step `Component` object. Parameters ---------- dataset : Dataset The dataset to run evaluation on eval_sampler : Sampler The sampler to use over validation examples model : Module The model to train metric_fn: Metric The metric to use for evaluation eval_data: str The data split to evaluate on: one of train, val or test device: str, optional The device to use in the computation. """ def __init__(self, dataset: Dataset, eval_sampler: Sampler, model: Module, metric_fn: Metric, eval_data: str = 'test', device: Optional[str] = None) -> None: self.eval_sampler = eval_sampler self.model = model self.metric_fn = metric_fn self.eval_metric = None self.dataset = dataset # Select right device if device is not None: self.device = device else: self.device = "cuda" if torch.cuda.is_available() else "cpu" data = getattr(dataset, eval_data) self._eval_iterator = self.eval_sampler.sample(data) # By default, no prefix applied to tb logs self.tb_log_prefix = None
[docs] def run(self, block_name: str = None) -> bool: """Run the evaluation. Returns ------ bool Whether the computable has completed. """ self.model.to(self.device) self.model.eval() with torch.no_grad(): preds, targets = [], [] for batch in self._eval_iterator: pred, target = self.model(*[t.to(self.device) for t in batch]) preds.append(pred.cpu()) targets.append(target.cpu()) preds = torch.cat(preds, dim=0) # type: ignore targets = torch.cat(targets, dim=0) # type: ignore self.eval_metric = self.metric_fn(preds, targets).item() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log(f'{tb_prefix}Eval {self.metric_fn}', # type: ignore self.eval_metric, global_step=0) continue_ = False # Single step so don't continue return continue_
[docs] def metric(self) -> Optional[float]: """Override this method to enable scheduling. Returns ------- float The metric to compare computable varients """ return self.eval_metric