"""Local run script for flambe.
"""
import argparse
import logging
import torch
import ray
import os
from flambe.const import FLAMBE_GLOBAL_FOLDER
from flambe.compile import make_component
from flambe.compile.extensions import download_extensions
from flambe.runnable.utils import is_dev_mode, get_flambe_repo_location
from flambe.logging import setup_global_logging
from flambe.logging import coloredlogs as cl
from flambe.runnable.context import SafeExecutionContext
from flambe.runnable import ClusterRunnable
from flambe.cluster import Cluster
from flambe.runner.utils import check_system_reqs
from flambe.logo import ASCII_LOGO, ASCII_LOGO_DEV
import flambe
from typing import cast
[docs]TORCH_TAG_PREFIX = "torch"
[docs]TUNE_TAG_PREFIX = "tune"
[docs]def main(args: argparse.Namespace) -> None:
"""Execute command based on given config"""
if is_dev_mode():
print(cl.RA(ASCII_LOGO_DEV))
print(cl.BL(f"Location: {get_flambe_repo_location()}\n"))
else:
print(cl.RA(ASCII_LOGO))
print(cl.BL(f"VERSION: {flambe.__version__}\n"))
# Pass original module for ray / pickle
make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn')
# torch.optim.Optimizer exists, ignore mypy
make_component(torch.optim.Optimizer, TORCH_TAG_PREFIX, # type: ignore
only_module='torch.optim')
make_component(torch.optim.lr_scheduler._LRScheduler,
TORCH_TAG_PREFIX, only_module='torch.optim.lr_scheduler')
make_component(ray.tune.schedulers.TrialScheduler, TUNE_TAG_PREFIX)
make_component(ray.tune.suggest.SearchAlgorithm, TUNE_TAG_PREFIX)
# TODO check first if there is cluster as if there is there
# is no need to install extensions
check_system_reqs()
with SafeExecutionContext(args.config) as ex:
if args.cluster is not None:
with SafeExecutionContext(args.cluster) as ex_cluster:
cluster, _ = ex_cluster.preprocess(secrets=args.secrets,
install_ext=args.install_extensions)
runnable, extensions = ex.preprocess(import_ext=False,
secrets=args.secrets)
cluster.run(force=args.force)
if isinstance(runnable, ClusterRunnable):
cluster = cast(Cluster, cluster)
# This is independant to the type of ClusterRunnable
destiny = os.path.join(cluster.get_orch_home_path(), "extensions")
# Before sending the extensions, they need to be
# downloaded (locally).
t = os.path.join(FLAMBE_GLOBAL_FOLDER, "extensions")
extensions = download_extensions(extensions, t)
# At this point, all remote extensions
# (except pypi extensions)
# have local paths.
new_extensions = cluster.send_local_content(extensions,
destiny, all_hosts=True)
new_secrets = cluster.send_secrets(whitelist=["GITHUB", "PIP"])
runnable.setup_inject_env(cluster=cluster,
extensions=new_extensions,
force=args.force)
cluster.execute(runnable, new_extensions, new_secrets, args.force)
else:
raise ValueError("Only ClusterRunnables can be executed in a cluster.")
else:
runnable, _ = ex.preprocess(secrets=args.secrets, install_ext=args.install_extensions)
runnable.run(force=args.force, verbose=args.verbose)
if __name__ == '__main__':
# torch.multiprocessing exists, ignore mypy
torch.multiprocessing.set_start_method('fork', force=True) # type: ignore
[docs] parser = argparse.ArgumentParser(description='Run a flambé!')
parser.add_argument('config', type=str, help='Path to config file')
parser.add_argument('-i', '--install-extensions', action='store_true', default=False,
help='Install extensions automatically using pip. WARNING: ' +
'This could potentially override already installed packages.')
parser.add_argument('-c', '--cluster', type=str, default=None,
help='Specify the cluster that will run the experiment. This option ' +
'works if the main config is an Experiment')
parser.add_argument('-f', '--force', action='store_true', default=False,
help='Override existing runnables. Be careful ' +
'when using this flag as it could have undesired effects.')
parser.add_argument('--secrets',
type=str, default=os.path.join(FLAMBE_GLOBAL_FOLDER, "secrets.ini"))
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose console output')
args = parser.parse_args()
setup_global_logging(logging.INFO if not args.verbose else logging.DEBUG)
logger = logging.getLogger(__name__)
try:
main(args)
logger.info(cl.GR("------------------- Done -------------------"))
except KeyboardInterrupt:
logger.info(cl.RE("---- Exiting early (Keyboard Interrupt) ----"))