diff --git a/autumn/core/runs/remote.py b/autumn/core/runs/remote.py index 23a8e00e5..54f4bd15d 100644 --- a/autumn/core/runs/remote.py +++ b/autumn/core/runs/remote.py @@ -9,8 +9,9 @@ from autumn import settings from autumn.core.utils import s3 + class RemoteRunData: - def __init__(self, run_id: str, client=None): + def __init__(self, run_id: str, client=None, local_path_base=None): """Remote (S3) wrapper for a given run_id Args: @@ -18,15 +19,18 @@ def __init__(self, run_id: str, client=None): client (optional): S3 client object (will be created if not supplied) """ self.run_id = run_id - + if client is None: client = s3.get_s3_client() self.client = client - - self.local_path_base = Path(settings.DATA_PATH) / 'outputs/runs' + + if local_path_base is None: + local_path_base = Path(settings.DATA_PATH) / "outputs/runs" + + self.local_path_base = local_path_base self.local_path_run = self.local_path_base / run_id - def list_contents(self, suffix:str =None) -> List[str]: + def list_contents(self, suffix: str = None) -> List[str]: """Return a list of all files for this run These can be passed directly into the download method @@ -37,7 +41,7 @@ def list_contents(self, suffix:str =None) -> List[str]: [List[str]]: List of files """ return s3.list_s3(self.client, self.run_id, suffix) - + def _get_full_metadata(self): """Complete S3 metadata for all objects in this run @@ -45,7 +49,7 @@ def _get_full_metadata(self): [dict]: Metadata """ return self.client.list_objects_v2(Bucket=settings.S3_BUCKET, Prefix=self.run_id) - + def download(self, remote_path: str): """Download a remote file and place it in the corresponding local path @@ -53,15 +57,15 @@ def download(self, remote_path: str): remote_path (str): Full string of remote file path """ # Strip the filename from the end of the path - split_path = remote_path.split('/') + split_path = remote_path.split("/") filename = split_path[-1] - dir_only = '/'.join(split_path[:-1]) - + dir_only = "/".join(split_path[:-1]) + local_path = self.local_path_base / dir_only local_path.mkdir(parents=True, exist_ok=True) - + full_local = local_path.joinpath(filename) s3.download_s3(self.client, remote_path, str(full_local)) - + def __repr__(self): return f"RemoteRunData: {self.run_id}" diff --git a/autumn/infrastructure/remote/aws/aws.py b/autumn/infrastructure/remote/aws/aws.py index 6d2237772..4ae898ab3 100644 --- a/autumn/infrastructure/remote/aws/aws.py +++ b/autumn/infrastructure/remote/aws/aws.py @@ -32,7 +32,7 @@ def get_instance_type( min_cores: int, min_ram: int, category: str = settings.EC2InstanceCategory.GENERAL -) -> dict: +) -> str: specs = settings.EC2_INSTANCE_SPECS[category] matching_specs = {k: v for k, v in specs.items() if v.cores >= min_cores and v.ram >= min_ram} @@ -135,7 +135,9 @@ def run_instance(job_id: str, instance_type: str, is_spot: bool, ami_name=None): return client.run_instances(**kwargs) -def run_multiple_instances(job_id: str, instance_type: str, n_instances: int, ami_name=None): +def run_multiple_instances( + job_id: str, instance_type: str, n_instances: int, run_group: str, ami_name=None +): logger.info(f"Creating EC2 instance {instance_type} for job {job_id}... ") ami_name = ami_name or settings.EC2_AMI["310conda"] kwargs = { @@ -148,7 +150,13 @@ def run_multiple_instances(job_id: str, instance_type: str, n_instances: int, am "KeyName": "autumn", "InstanceInitiatedShutdownBehavior": "terminate", "TagSpecifications": [ - {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": job_id}]} + { + "ResourceType": "instance", + "Tags": [ + {"Key": "Name", "Value": job_id}, + {"Key": "run_group", "Value": run_group}, + ], + } ], } diff --git a/autumn/infrastructure/remote/springboard/aws.py b/autumn/infrastructure/remote/springboard/aws.py index 132cd0e17..97a6da102 100644 --- a/autumn/infrastructure/remote/springboard/aws.py +++ b/autumn/infrastructure/remote/springboard/aws.py @@ -1,10 +1,11 @@ -from typing import Optional +from typing import Optional, List, Union from dataclasses import dataclass, asdict from enum import Enum from time import time import boto3 +from cloudpickle import instance import numpy as np # Keep everything we 'borrow' from autumn in this block @@ -74,7 +75,7 @@ def wait_instance(instance_id: str, timeout: Optional[float] = None) -> dict: raise InstanceStateError("Instance failed to launch with state", state) -def start_ec2_instance(mspec: EC2MachineSpec, name: str, ami: str = None) -> dict: +def start_ec2_instance(mspec: EC2MachineSpec, name: str, ami: Optional[str] = None) -> dict: """Request and launch an EC2 instance for the given machine specifications Args: @@ -99,8 +100,8 @@ def start_ec2_instance(mspec: EC2MachineSpec, name: str, ami: str = None) -> dic def start_ec2_multi_instance( - mspec: EC2MachineSpec, name: str, n_instances: int, ami: str = None -) -> dict: + mspec: EC2MachineSpec, name: str, n_instances: int, run_group: str, ami: Optional[str] = None +) -> list: """Request and launch an EC2 instance for the given machine specifications Args: @@ -116,7 +117,9 @@ def start_ec2_multi_instance( # +++: Borrow default from AuTuMN; move to springboard rcparams? ami = ami or aws_settings.EC2_AMI["springboardtest"] - inst_req = autumn_aws.run_multiple_instances(name, instance_type, n_instances, ami_name=ami) + inst_req = autumn_aws.run_multiple_instances( + name, instance_type, n_instances, run_group=run_group, ami_name=ami + ) iid = inst_req["Instances"][0]["InstanceId"] req_instances = inst_req["Instances"] @@ -130,7 +133,7 @@ def start_ec2_multi_instance( def set_cpu_termination_alarm( - instance_id: str, time_minutes: int = 30, min_cpu=1.0, region=aws_settings.AWS_REGION + instance_id: str, time_minutes: int = 15, min_cpu=5.0, region=aws_settings.AWS_REGION ): # Create alarm @@ -153,3 +156,43 @@ def set_cpu_termination_alarm( ], ) return alarm + + +def get_instances_by_tag(name, value, verbose=False) -> List[dict]: + import boto3 + + ec2client = boto3.client("ec2") + custom_filter = [{"Name": f"tag:{name}", "Values": [value]}] + + response = ec2client.describe_instances(Filters=custom_filter) + instances = [[i for i in r["Instances"]] for r in response["Reservations"]] + instances = [i for r in instances for i in r] + if verbose: + return instances + else: + concise_instances = [ + { + "InstanceId": inst["InstanceId"], + "InstanceType": inst["InstanceType"], + "LaunchTime": inst["LaunchTime"], + "State": inst["State"], + "ip": inst["PublicIpAddress"], + "tags": {tag["Key"]: tag["Value"] for tag in inst["Tags"]}, + } + for inst in instances + ] + for inst in concise_instances: + inst["name"] = inst["tags"]["Name"] + return concise_instances + + +def set_tag(instance_ids: Union[str, List[str]], key: str, value): + import boto3 + + ec2client = boto3.client("ec2") + + if isinstance(instance_ids, str): + instance_ids = [instance_ids] + + resp = ec2client.create_tags(Resources=instance_ids, Tags=[{"Key": key, "Value": value}]) + return resp diff --git a/autumn/infrastructure/remote/springboard/launch.py b/autumn/infrastructure/remote/springboard/launch.py index 73df101be..c6c5478e4 100644 --- a/autumn/infrastructure/remote/springboard/launch.py +++ b/autumn/infrastructure/remote/springboard/launch.py @@ -55,6 +55,9 @@ def launch_synced_autumn_task( rinst = aws.start_ec2_instance(mspec, job_id) + aws.set_tag(rinst["InstanceId"], "run_path", run_path) + rinst["tags"] = {"Name": run_path, "run_path": run_path} + try: s3t.set_instance(rinst) @@ -90,7 +93,11 @@ def launch_synced_autumn_task( def launch_synced_multiple_autumn_task( - task_dict, mspec, branch="master", job_id=None + task_dict: Dict[str, TaskSpec], + mspec: EC2MachineSpec, + run_group: str, + branch="master", + auto_shutdown_time: int = 4 * 60, # Time in minutes ) -> Dict[str, SpringboardTaskRunner]: for run_path in task_dict.keys(): s3t = task.S3TaskManager(run_path) @@ -98,19 +105,25 @@ def launch_synced_multiple_autumn_task( raise Exception("Task already exists", run_path) s3t.set_status(TaskStatus.LAUNCHING) - if job_id is None: - job_id = gen_run_name("autumntask") + job_id = gen_run_name("autumntask") - instances = aws.start_ec2_multi_instance(mspec, job_id, len(task_dict)) + instances = aws.start_ec2_multi_instance(mspec, job_id, len(task_dict), run_group) runners = {} for rinst, (run_path, task_spec) in zip(instances, task_dict.items()): + aws.set_tag(rinst["InstanceId"], "Name", run_path) + aws.set_tag(rinst["InstanceId"], "run_path", run_path) + rinst["name"] = run_path + rinst["tags"] = {"Name": run_path, "run_path": run_path, "run_group": run_group} + s3t = task.S3TaskManager(run_path) s3t.set_instance(rinst) aws.set_cpu_termination_alarm(rinst["InstanceId"]) srunner = task.SpringboardTaskRunner(rinst, run_path) + srunner.sshr.run(f"sudo shutdown -P +{auto_shutdown_time}") + script = scripting.gen_autumn_run_bash(run_path, branch) s3t._write_taskdata("taskscript.sh", script) diff --git a/autumn/infrastructure/remote/springboard/task.py b/autumn/infrastructure/remote/springboard/task.py index 3df54ebf8..f5a39568a 100644 --- a/autumn/infrastructure/remote/springboard/task.py +++ b/autumn/infrastructure/remote/springboard/task.py @@ -1,8 +1,11 @@ +from typing import Optional + from pathlib import PurePosixPath, Path from enum import Enum import json import sys from time import time, sleep +from functools import wraps import logging import logging.config @@ -17,6 +20,7 @@ from autumn.infrastructure.tasks.storage import S3Storage from autumn.core.utils.s3 import get_s3_client from autumn.core.utils.parallel import gather_exc_plus +from autumn.core.runs.remote import RemoteRunData # Multi-library SSH wrapper from .clients import CommandResult, SSHRunner @@ -36,7 +40,7 @@ class TaskStatus(str, Enum): class TaskSpec: - def __init__(self, run_func, func_kwargs: dict = None): + def __init__(self, run_func, func_kwargs: Optional[dict] = None): """Used to specify wrapped tasks for springboard runners Args: run_func: Any function taking a TaskBridge as its first argument @@ -191,10 +195,12 @@ def autumn_task_entry(run_path: str) -> int: bridge._storage.store(local_base / "log") logging.shutdown() + print("Exiting autumn task runner") + if success: - return 0 + sys.exit(0) else: - return 255 + sys.exit(255) class SpringboardTaskRunner: @@ -331,7 +337,7 @@ class S3TaskManager: def __init__( self, project_path: str, - fs: s3fs.S3FileSystem = None, + fs: Optional[s3fs.S3FileSystem] = None, bucket: PurePosixPath = PurePosixPath(aws_settings.S3_BUCKET), ): # s3fs seems to have intermittent trouble accessing files created remotely @@ -392,7 +398,7 @@ class ManagedTask(S3TaskManager): def __init__( self, project_path: str, - fs: s3fs.S3FileSystem = None, + fs: Optional[s3fs.S3FileSystem] = None, bucket: PurePosixPath = PurePosixPath(aws_settings.S3_BUCKET), ): """ @@ -404,6 +410,167 @@ def __init__( """ super().__init__(project_path, fs, bucket) + self.remote = RemoteTaskStore(project_path, fs, bucket) + from autumn.settings import DATA_PATH + + local_path_base = Path(DATA_PATH) / "managed" / str(bucket) / project_path + + self.local = LocalStore(local_path_base) + self._remotedata = RemoteRunData(project_path, local_path_base=local_path_base) + + def download(self, remote_path, recursive=False): + full_remote = self.remote._ensure_full_path(remote_path) + rel_path = full_remote.relative_to(self.remote_path) + return self.fs.get(str(full_remote), str(self.local.path / rel_path), recursive=recursive) + + def download_all(self): + return self.download(None, recursive=True) def get_runner(self): return SpringboardTaskRunner(self.get_instance(), self.project_path) + + +class LocalStore: + def __init__(self, base_path): + self.path = base_path + + def open(self, file, mode="r"): + file = self._ensure_full_path(file) + return open(file, mode) + + def _using_root(self, path=None): + if path is None: + return False + if isinstance(path, str): + path = Path(path) + if isinstance(path, Path): + if path.parts[0] == self.path.parts[0]: + return True + return False + + def _ensure_full_path(self, path=None): + if path is None: + return self.path + if isinstance(path, str): + path = Path(path) + if isinstance(path, Path): + if path.parts[0] == self.path.parts[0]: + return path + else: + return self.path / path + else: + raise TypeError("Path must be str or Path", path) + + def ls(self, path=None, full=False, recursive=False, **kwargs): + using_root = self._using_root(path) + + path = self._ensure_full_path(path) + if recursive: + results = path.rglob("*", **kwargs) + else: + results = path.glob("*", **kwargs) + + if using_root: + ref_path = path + else: + ref_path = self.path + + if not full: + results = [str(Path(res).relative_to(ref_path)) for res in results] + else: + results = [res for res in results] + + return results + + def __truediv__(self, divisor): + return self.path / divisor + + +class RemoteTaskStore: + def __init__( + self, base_path: Optional[str] = None, fs=None, bucket=PurePosixPath("autumn-data") + ): + self.fs = fs or s3fs.S3FileSystem(use_listings_cache=False) + self.bucket = bucket + self.cwd = bucket + if base_path is not None: + self._set_cwd(self.bucket / base_path) + + self.glob = self._wrap_ensure_path(self.fs.glob, True) + self.read_text = self._wrap_ensure_path(self.fs.read_text) + + def _set_cwd(self, path): + if self.fs.exists(path): + self.cwd = path + else: + raise FileNotFoundError(path) + + def _validate_path(self, path=None, as_str=False): + path = self._ensure_full_path(path) + if as_str: + return str(path) + else: + return path + + def _ensure_full_path(self, path=None): + if path is None: + return self.cwd + if isinstance(path, str): + path = PurePosixPath(path) + if isinstance(path, PurePosixPath): + if path.parts[0] == str(self.bucket): + return path + else: + return self.cwd / path + else: + raise TypeError("Path must be str or PurePosixPath", path) + + def _using_root(self, path=None): + if path is None: + return False + if isinstance(path, str): + path = PurePosixPath(path) + if isinstance(path, PurePosixPath): + if path.parts[0] == str(self.bucket): + return True + return False + + def _wrap_ensure_path(self, func, as_str=False): + @wraps(func) + def wrapper(path=None, *args, **kwargs): + path = self._validate_path(path, as_str) + return func(path, *args, **kwargs) + + return wrapper + + def cd(self, path): + path = self._ensure_full_path(path) + self._set_cwd(path) + + def ls(self, path=None, full=False, recursive=False, **kwargs): + using_root = self._using_root(path) + + path = self._ensure_full_path(path) + if recursive: + results = self.fs.glob(str(path / "**"), **kwargs) + else: + results = self.fs.ls(path, **kwargs) + + if using_root: + ref_path = path + else: + ref_path = self.cwd + + if not full: + results = [str(PurePosixPath(res).relative_to(ref_path)) for res in results] + + return results + + def get_managed_task(self, path=None): + path = self._ensure_full_path(path) + run_path = "/".join(path.parts[1:]) + mt = ManagedTask(run_path, self.fs, self.bucket) + if mt.exists(): + return mt + else: + raise FileNotFoundError("No task exists", run_path) diff --git a/requirements/requirements310.txt b/requirements/requirements310.txt index 98bb4a9d0..548452cfd 100644 --- a/requirements/requirements310.txt +++ b/requirements/requirements310.txt @@ -44,9 +44,11 @@ scipy==1.8.1 networkx numpyro==0.10.1 -# Database access +# Database and file access SQLAlchemy==1.4.46 pyarrow +tables==3.8.0 +h5py==3.8.0 # Utility pyyaml