Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/monash-emu/AuTuMN
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 8, 2023
2 parents 95072fa + fd5ea8b commit 4c94037
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 31 deletions.
28 changes: 16 additions & 12 deletions autumn/core/runs/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,28 @@
from autumn import settings
from autumn.core.utils import s3


class RemoteRunData:
def __init__(self, run_id: str, client=None):
def __init__(self, run_id: str, client=None, local_path_base=None):
"""Remote (S3) wrapper for a given run_id
Args:
run_id (str): AuTuMN run_id string
client (optional): S3 client object (will be created if not supplied)
"""
self.run_id = run_id

if client is None:
client = s3.get_s3_client()
self.client = client

self.local_path_base = Path(settings.DATA_PATH) / 'outputs/runs'

if local_path_base is None:
local_path_base = Path(settings.DATA_PATH) / "outputs/runs"

self.local_path_base = local_path_base
self.local_path_run = self.local_path_base / run_id

def list_contents(self, suffix:str =None) -> List[str]:
def list_contents(self, suffix: str = None) -> List[str]:
"""Return a list of all files for this run
These can be passed directly into the download method
Expand All @@ -37,31 +41,31 @@ def list_contents(self, suffix:str =None) -> List[str]:
[List[str]]: List of files
"""
return s3.list_s3(self.client, self.run_id, suffix)

def _get_full_metadata(self):
"""Complete S3 metadata for all objects in this run
Returns:
[dict]: Metadata
"""
return self.client.list_objects_v2(Bucket=settings.S3_BUCKET, Prefix=self.run_id)

def download(self, remote_path: str):
"""Download a remote file and place it in the corresponding local path
Args:
remote_path (str): Full string of remote file path
"""
# Strip the filename from the end of the path
split_path = remote_path.split('/')
split_path = remote_path.split("/")
filename = split_path[-1]
dir_only = '/'.join(split_path[:-1])
dir_only = "/".join(split_path[:-1])

local_path = self.local_path_base / dir_only
local_path.mkdir(parents=True, exist_ok=True)

full_local = local_path.joinpath(filename)
s3.download_s3(self.client, remote_path, str(full_local))

def __repr__(self):
return f"RemoteRunData: {self.run_id}"
14 changes: 11 additions & 3 deletions autumn/infrastructure/remote/aws/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

def get_instance_type(
min_cores: int, min_ram: int, category: str = settings.EC2InstanceCategory.GENERAL
) -> dict:
) -> str:
specs = settings.EC2_INSTANCE_SPECS[category]

matching_specs = {k: v for k, v in specs.items() if v.cores >= min_cores and v.ram >= min_ram}
Expand Down Expand Up @@ -135,7 +135,9 @@ def run_instance(job_id: str, instance_type: str, is_spot: bool, ami_name=None):
return client.run_instances(**kwargs)


def run_multiple_instances(job_id: str, instance_type: str, n_instances: int, ami_name=None):
def run_multiple_instances(
job_id: str, instance_type: str, n_instances: int, run_group: str, ami_name=None
):
logger.info(f"Creating EC2 instance {instance_type} for job {job_id}... ")
ami_name = ami_name or settings.EC2_AMI["310conda"]
kwargs = {
Expand All @@ -148,7 +150,13 @@ def run_multiple_instances(job_id: str, instance_type: str, n_instances: int, am
"KeyName": "autumn",
"InstanceInitiatedShutdownBehavior": "terminate",
"TagSpecifications": [
{"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": job_id}]}
{
"ResourceType": "instance",
"Tags": [
{"Key": "Name", "Value": job_id},
{"Key": "run_group", "Value": run_group},
],
}
],
}

Expand Down
55 changes: 49 additions & 6 deletions autumn/infrastructure/remote/springboard/aws.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional
from typing import Optional, List, Union

from dataclasses import dataclass, asdict
from enum import Enum
from time import time

import boto3
from cloudpickle import instance
import numpy as np

# Keep everything we 'borrow' from autumn in this block
Expand Down Expand Up @@ -74,7 +75,7 @@ def wait_instance(instance_id: str, timeout: Optional[float] = None) -> dict:
raise InstanceStateError("Instance failed to launch with state", state)


def start_ec2_instance(mspec: EC2MachineSpec, name: str, ami: str = None) -> dict:
def start_ec2_instance(mspec: EC2MachineSpec, name: str, ami: Optional[str] = None) -> dict:
"""Request and launch an EC2 instance for the given machine specifications
Args:
Expand All @@ -99,8 +100,8 @@ def start_ec2_instance(mspec: EC2MachineSpec, name: str, ami: str = None) -> dic


def start_ec2_multi_instance(
mspec: EC2MachineSpec, name: str, n_instances: int, ami: str = None
) -> dict:
mspec: EC2MachineSpec, name: str, n_instances: int, run_group: str, ami: Optional[str] = None
) -> list:
"""Request and launch an EC2 instance for the given machine specifications
Args:
Expand All @@ -116,7 +117,9 @@ def start_ec2_multi_instance(
# +++: Borrow default from AuTuMN; move to springboard rcparams?
ami = ami or aws_settings.EC2_AMI["springboardtest"]

inst_req = autumn_aws.run_multiple_instances(name, instance_type, n_instances, ami_name=ami)
inst_req = autumn_aws.run_multiple_instances(
name, instance_type, n_instances, run_group=run_group, ami_name=ami
)
iid = inst_req["Instances"][0]["InstanceId"]

req_instances = inst_req["Instances"]
Expand All @@ -130,7 +133,7 @@ def start_ec2_multi_instance(


def set_cpu_termination_alarm(
instance_id: str, time_minutes: int = 30, min_cpu=1.0, region=aws_settings.AWS_REGION
instance_id: str, time_minutes: int = 15, min_cpu=5.0, region=aws_settings.AWS_REGION
):
# Create alarm

Expand All @@ -153,3 +156,43 @@ def set_cpu_termination_alarm(
],
)
return alarm


def get_instances_by_tag(name, value, verbose=False) -> List[dict]:
import boto3

ec2client = boto3.client("ec2")
custom_filter = [{"Name": f"tag:{name}", "Values": [value]}]

response = ec2client.describe_instances(Filters=custom_filter)
instances = [[i for i in r["Instances"]] for r in response["Reservations"]]
instances = [i for r in instances for i in r]
if verbose:
return instances
else:
concise_instances = [
{
"InstanceId": inst["InstanceId"],
"InstanceType": inst["InstanceType"],
"LaunchTime": inst["LaunchTime"],
"State": inst["State"],
"ip": inst["PublicIpAddress"],
"tags": {tag["Key"]: tag["Value"] for tag in inst["Tags"]},
}
for inst in instances
]
for inst in concise_instances:
inst["name"] = inst["tags"]["Name"]
return concise_instances


def set_tag(instance_ids: Union[str, List[str]], key: str, value):
import boto3

ec2client = boto3.client("ec2")

if isinstance(instance_ids, str):
instance_ids = [instance_ids]

resp = ec2client.create_tags(Resources=instance_ids, Tags=[{"Key": key, "Value": value}])
return resp
21 changes: 17 additions & 4 deletions autumn/infrastructure/remote/springboard/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def launch_synced_autumn_task(

rinst = aws.start_ec2_instance(mspec, job_id)

aws.set_tag(rinst["InstanceId"], "run_path", run_path)
rinst["tags"] = {"Name": run_path, "run_path": run_path}

try:
s3t.set_instance(rinst)

Expand Down Expand Up @@ -90,27 +93,37 @@ def launch_synced_autumn_task(


def launch_synced_multiple_autumn_task(
task_dict, mspec, branch="master", job_id=None
task_dict: Dict[str, TaskSpec],
mspec: EC2MachineSpec,
run_group: str,
branch="master",
auto_shutdown_time: int = 4 * 60, # Time in minutes
) -> Dict[str, SpringboardTaskRunner]:
for run_path in task_dict.keys():
s3t = task.S3TaskManager(run_path)
if s3t.exists():
raise Exception("Task already exists", run_path)
s3t.set_status(TaskStatus.LAUNCHING)

if job_id is None:
job_id = gen_run_name("autumntask")
job_id = gen_run_name("autumntask")

instances = aws.start_ec2_multi_instance(mspec, job_id, len(task_dict))
instances = aws.start_ec2_multi_instance(mspec, job_id, len(task_dict), run_group)

runners = {}

for rinst, (run_path, task_spec) in zip(instances, task_dict.items()):
aws.set_tag(rinst["InstanceId"], "Name", run_path)
aws.set_tag(rinst["InstanceId"], "run_path", run_path)
rinst["name"] = run_path
rinst["tags"] = {"Name": run_path, "run_path": run_path, "run_group": run_group}

s3t = task.S3TaskManager(run_path)
s3t.set_instance(rinst)
aws.set_cpu_termination_alarm(rinst["InstanceId"])
srunner = task.SpringboardTaskRunner(rinst, run_path)

srunner.sshr.run(f"sudo shutdown -P +{auto_shutdown_time}")

script = scripting.gen_autumn_run_bash(run_path, branch)

s3t._write_taskdata("taskscript.sh", script)
Expand Down
Loading

0 comments on commit 4c94037

Please sign in to comment.