Source code for flambe.cluster.ssh

"""Implementation of the Manager for SSH hosts"""

import logging

from typing import List, TypeVar, Union, Optional

from flambe.cluster import instance
from flambe.cluster.cluster import Cluster, FactoryInsT

import os


[docs]logger = logging.getLogger(__name__)
[docs]FactoryT = TypeVar("FactoryT", instance.CPUFactoryInstance, instance.GPUFactoryInstance)
[docs]class SSHCluster(Cluster): """The SSH Manager needs to be used when having running instances. For example when having on-prem hardware or just a couple of AWS EC2 instances running. When using this cluster, the user needs to specify the IPs of the machines to use, both the public one and private one. """ def __init__(self, name: str, orchestrator_ip: Union[str, List[str]], factories_ips: Union[List[str], List[List[str]]], key: str, username: str, remote_context=None, use_public: bool = True, setup_cmds: Optional[List[str]] = None) -> None: """Initialize the SSHCluster.""" super().__init__(name, len(factories_ips), key, username, setup_cmds) self.orchestrator_ip = orchestrator_ip self.factories_ips = factories_ips self.remote_context = remote_context self.use_public = use_public if remote_context: self.cluster_id = self.remote_context.cluster_id
[docs] def load_all_instances(self, exp_name: str = None, force: bool = False) -> None: """This manager assumed that instances are running. This method loads the Python objects to the manager's variables. Parameters ---------- exp_name: str The name of the experiment force: bool Whether to override the current experiment of the same name """ if isinstance(self.orchestrator_ip, list): self.orchestrator = self.get_orchestrator(self.orchestrator_ip[0], self.orchestrator_ip[1], use_public=self.use_public) else: self.orchestrator = self.get_orchestrator(self.orchestrator_ip, use_public=self.use_public) aux: FactoryInsT for each in self.factories_ips: if isinstance(each, list): factory = self.get_factory(each[0], each[1], use_public=self.use_public) if factory.contains_gpu(): factory = self.get_gpu_factory(each[0], each[1], use_public=self.use_public) else: factory = self.get_factory(each, use_public=self.use_public) if factory.contains_gpu(): factory = self.get_gpu_factory(each, use_public=self.use_public) self.factories.append(factory)
[docs] def rollback_env(self) -> None: pass
[docs] def rsync_hosts(self): """Rsyncs the host's result folders. First, it rsyncs all worker folders to the orchestrator main folder. After that, so that every worker gets the last changes, the orchestrator rsync with all of them. """ if not self.remote_context: logger.error("Can't rsyn without a remote context") return exclude = ["state.pkl"] orch = self.orchestrator orch_save_path = os.path.join(f"{orch.get_home_path()}", self.remote_context.save_folder) orch_loc = f"{orch_save_path}" for f in self.factories: f_save_path = os.path.join(f"{orch.get_home_path()}", self.remote_context.save_folder) f_loc = f"{f.username}@{f.private_host}:{f_save_path}" orch.rsync_folder(f_loc, orch_loc, exclude) for f in self.factories: f_save_path = os.path.join(f"{f.get_home_path()}", self.remote_context.save_folder) f_loc = f"{f.username}@{f.private_host}:{f_save_path}" orch.rsync_folder(orch_loc, f_loc, exclude)