from typing import Optional
import torch
from flambe.compile import Component
from flambe.dataset import Dataset
from flambe.nn import Module
from flambe.metric import Metric
from flambe.sampler import Sampler, BaseSampler
from flambe.logging import log
[docs]class Evaluator(Component):
"""Implement an Evaluator block.
An `Evaluator` takes as input data, and a model and executes
the evaluation. This is a single step `Component` object.
"""
def __init__(self,
dataset: Dataset,
model: Module,
metric_fn: Metric,
eval_sampler: Optional[Sampler] = None,
eval_data: str = 'test',
device: Optional[str] = None) -> None:
"""Initialize the evaluator.
Parameters
----------
dataset : Dataset
The dataset to run evaluation on
model : Module
The model to train
metric_fn: Metric
The metric to use for evaluation
eval_sampler : Optional[Sampler]
The sampler to use over validation examples. By default
it will use `BaseSampler` with batch size 16 and without
shuffling.
eval_data: str
The data split to evaluate on: one of train, val or test
device: str, optional
The device to use in the computation.
"""
self.eval_sampler = eval_sampler or BaseSampler(batch_size=16, shuffle=False)
self.model = model
self.metric_fn = metric_fn
self.eval_metric = None
self.dataset = dataset
# Select right device
if device is not None:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
data = getattr(dataset, eval_data)
self._eval_iterator = self.eval_sampler.sample(data)
# By default, no prefix applied to tb logs
self.tb_log_prefix = None
[docs] def run(self, block_name: str = None) -> bool:
"""Run the evaluation.
Returns
------
bool
Whether the component should continue running.
"""
self.model.to(self.device)
self.model.eval()
with torch.no_grad():
preds, targets = [], []
for batch in self._eval_iterator:
pred, target = self.model(*[t.to(self.device) for t in batch])
preds.append(pred.cpu())
targets.append(target.cpu())
preds = torch.cat(preds, dim=0) # type: ignore
targets = torch.cat(targets, dim=0) # type: ignore
self.eval_metric = self.metric_fn(preds, targets).item()
tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
log(f'{tb_prefix}Eval {self.metric_fn}', # type: ignore
self.eval_metric, global_step=0)
continue_ = False # Single step so don't continue
return continue_
[docs] def metric(self) -> Optional[float]:
"""Override this method to enable scheduling.
Returns
-------
float
The metric to compare computable varients
"""
return self.eval_metric