# from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence, Any, Union, Dict
import numpy as np
from ray.tune import grid_search
from flambe.compile import Registrable, alias
[docs]Number = Union[float, int]
[docs]class Options(Registrable, ABC):
@classmethod
@abstractmethod
[docs] def from_sequence(cls, options: Sequence[Any]) -> 'Options':
"""Construct an options class from a sequence of values
Parameters
----------
options : Sequence[Any]
Discrete sequence that defines what values to search over
Returns
-------
T
Returns a subclass of DiscreteOptions
"""
pass
@abstractmethod
[docs] def convert(self) -> Dict:
"""Convert the options to Ray Tune representation.
Returns
-------
Dict
The Ray Tune conversion
"""
pass
@classmethod
[docs] def to_yaml(cls, representer: Any, node: Any, tag: str) -> Any:
return representer.represent_sequence(tag, node.elements)
@classmethod
[docs] def from_yaml(cls, constructor: Any, node: Any, factory_name: str) -> 'Options':
args, = list(constructor.construct_yaml_seq(node))
if factory_name is None or factory_name == 'from_sequence':
return cls.from_sequence(args) # type: ignore
else:
factory = getattr(cls, factory_name)
return factory(args)
[docs]@alias('g')
class GridSearchOptions(Sequence[Any], Options):
"""Discrete set of values used for grid search
Defines a finite, discrete set of values to be substituted
at the location where the set currently resides in the config
"""
def __init__(self, elements: Sequence[Any]) -> None:
self.elements = elements
@classmethod
[docs] def from_sequence(cls, options: Sequence[Any]) -> 'GridSearchOptions':
return cls(options)
[docs] def convert(self) -> Dict:
return grid_search(list(self.elements))
[docs] def __getitem__(self, key: Any) -> Any:
return self.elements[key]
[docs] def __len__(self) -> int:
return len(self.elements)
[docs] def __repr__(self) -> str:
return 'gridoptions(' + repr(self.elements) + ')'