"""
This module provides methods to orchestrate all extensions
"""
import os
from shutil import which
import re
from urllib.parse import urlparse
from git import Repo, NoSuchPathError
import subprocess
import importlib
import importlib.util
from typing import Dict, Optional, Iterable, Union
from flambe.logging import coloredlogs as cl
from flambe.compile.utils import _is_url
import logging
[docs]logger = logging.getLogger(__name__)
[docs]def download_extensions(extensions: Dict[str, str],
container_folder: str) -> Dict[str, str]:
"""Iterate through the extensions and download the remote urls.
Parameters
----------
extensions: Dict[str, str]
The extensions that may contain both local or remote locations.
container_folder: str
The auxiliary folder where to download the remote repo
Returns
-------
Dict[str, str]
A new extensions dict with the local paths instead of remote
urls. The local paths contain the downloaded remote resources.
"""
ret = {}
for key, inc in extensions.items():
if _is_url(inc):
loc = os.path.join(container_folder, key)
new_inc = _download_remote_extension(inc, loc)
ret[key] = new_inc
else:
expanded_inc = os.path.abspath(os.path.expanduser(inc)) # Could be path with ~, or rel
if os.path.exists(expanded_inc):
ret[key] = expanded_inc
else:
ret[key] = inc
return ret
[docs]def _download_remote_extension(extension_url: str,
location: str) -> str:
"""Download a remote hosted extension.
It fully supports github urls only (for now).
Parameters
----------
extension_url: str
The github url pointing to an extension. For example:
https://github.com/user/folder/tree/branch/path/to/ext
location: str
The location to download the repo
Returns
-------
str
The location of the installed package (which it could not
match the location passed as parameter)
"""
url = urlparse(extension_url)
https_ext = url.scheme == 'https'
desc = list(filter(lambda x: len(x) > 0, url.path.split('/')))
if https_ext and 'github' not in url.netloc:
raise ImportError("We only support Github hosted extensions for now through https.")
if https_ext and len(desc) > 4 and _has_svn():
# Special case: folder inside github repo
# In this case we download with SVN (if available) as it
# downloads only the folder instead of full repo
# Ex: https://github.com/user/some_repo/tree/branch/path/to/ext
user, repository, branch = desc[0], desc[1], desc[3]
content = desc[4:]
svn_url = (
f"{url.scheme}://{url.hostname}/{user}/{repository}/"
f"branches/{branch}/{'/'.join(content)}"
)
_download_svn(svn_url, location)
logger.debug(f"Downloaded {extension_url} using svn")
else:
# Entire git repo (could be github or other)
original_location = location
# Add support for branch URLs in github.
# github URL's path follow this structure:
# {username}/{repo}/tree/{branch}
if https_ext and len(desc) >= 4:
user, repository, branch = desc[0], desc[1], desc[3]
new_url = f"{url.scheme}://{url.hostname}/{user}/{repository}"
location = f"{location}/{'/'.join(desc[4:])}"
url_path = f"{user}/{repository}"
else:
# In case of ssh url, then remove the 'ssh://',
# if not GitPython fails.
new_url = extension_url if https_ext else extension_url[6:]
url_path = url.path
branch = "master"
try:
repo = Repo(original_location)
logger.debug(f"{extension_url} already exists in {original_location}")
remote_url = list(repo.remotes[0].urls)[0] # Pick origin url
# Previous extensions does not match this one
if not remote_url.endswith(url_path):
subprocess.check_call(f"rm -rf {original_location}".split(),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
repo = Repo.clone_from(new_url, original_location)
logger.debug(f"{extension_url} git cloned as it had a different origin")
except NoSuchPathError:
# Repo was not downloaded before
repo = Repo.clone_from(new_url, original_location)
logger.debug(f"Downloaded {extension_url} using git clone")
repo.remotes.origin.fetch()
repo.git.checkout(branch)
repo.remotes.origin.pull()
logger.debug(f"Pulled latest changes from {extension_url}")
logger.info(cl.YE(f"Downloaded extension {extension_url}"))
return location
[docs]def _has_svn() -> bool:
"""Return if the host has svn installed"""
return which('svn') is not None
[docs]def _download_svn(svn_url: str, location: str,
username: Optional[str] = None, password: Optional[str] = None) -> None:
"""Use svn to download a specific folder inside a git repo.
This works only with remote Github repositories.
Parameters
----------
svn_url: str
The github URL adapted to use the SVN protocol
location: str
The location to download the folder
username: str
The username
password: str
The password
"""
cmd = ['svn', 'export']
if username:
cmd.extend(['--username', username])
if password:
cmd.extend(['--password', password])
cmd.extend(['--force', svn_url, location])
ret = subprocess.check_call(cmd, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
if ret != 0:
raise ImportError(f"Could not download folder through svn {svn_url}")
[docs]def install_extensions(extensions: Dict[str, str],
user_flag: bool = False) -> None:
"""Install extensions.
At this point, all extensions must be either local paths or
valid pypi packages.
Remote extensions hosted in Github must have been download first.
Parameters
----------
extensions: Dict[str, str]
Dictionary of extensions
user_flag: bool
Use --user flag when running pip install
"""
cmd = ['python3', '-m', 'pip', 'install', '-U']
if user_flag:
cmd.append('--user')
for ext, resource in extensions.items():
curr_cmd = cmd[:]
try:
if os.path.exists(resource):
# Package is local
if os.sep not in resource:
resource = f"./{resource}"
else:
# Package follows pypi notation: "torch>=0.4.1,<1.1"
resource = f"{resource}"
curr_cmd.append(resource)
output: Union[bytes, str]
output = subprocess.check_output(
curr_cmd,
stderr=subprocess.DEVNULL
)
output = output.decode("utf-8")
for l in output.splitlines():
logger.debug(l)
r = re.search(r'Successfully uninstalled (?P<pkg_name>\D*)-(?P<version>.*)', l)
if r and 'pkg_name' in r.groupdict():
logger.info(cl.RE(f"WARNING: While installing {ext}, " +
f"existing {r.groupdict()['pkg_name']}-" +
f"{r.groupdict()['version']} was uninstalled."))
except subprocess.CalledProcessError:
raise ImportError(f"Could not install package in {resource}")
logger.info(cl.GR(f"Successfully installed {ext}"))
[docs]def is_installed_module(module_name: str) -> bool:
"""Whether the module is installed.
Parameters
----------
module_name: str
The name of the module to check for
Returns
-------
bool
True if the module is installed locally, False otherwise.
"""
return importlib.util.find_spec(module_name) is not None
[docs]def import_modules(modules: Iterable[str]) -> None:
"""Dinamically import modules
Parameters
----------
modules: Iterable[str]
An iterable of strings containing the modules
to import
"""
for mod_name in modules:
try:
# Importing modules adds undesired handlers to
# the root logger.
# We will backup the handlers and updates them
# after importing
backup_handlers = logging.root.handlers[:]
importlib.import_module(mod_name)
# Remove all existing root handlers and
# re-apply the backed up root handlers
for x in logging.root.handlers[:]:
logging.root.removeHandler(x)
for x in backup_handlers:
logging.root.addHandler(x)
logger.info(cl.YE(f"Imported extensions {mod_name}"))
except ModuleNotFoundError as e:
raise ImportError(
f"Error importing {mod_name}: {e}. Please 'pip install' " +
"the package manually or use '-i' flag (only applies when running " +
"flambe as cmd line program)"
)
[docs]def setup_default_modules():
from flambe.compile.utils import make_component
from flambe.optim import LRScheduler
import torch
import ray
exclude = ['torch.nn.quantized', 'torch.nn.qat']
make_component(torch.nn.Module, only_module='torch.nn', exclude=exclude)
make_component(torch.optim.Optimizer, only_module='torch.optim')
make_component(torch.optim.lr_scheduler._LRScheduler,
only_module='torch.optim.lr_scheduler', parent_component_class=LRScheduler)
make_component(ray.tune.schedulers.TrialScheduler)
make_component(ray.tune.suggest.SearchAlgorithm)