Skip to content

Commit

Permalink
fix!: delete workers from non-autoscaling fleets (#124)
Browse files Browse the repository at this point in the history
Signed-off-by: Charles Moore <122481442+moorec-aws@users.noreply.github.com>
  • Loading branch information
moorec-aws committed Jul 4, 2024
1 parent 8a046ad commit 1217b4b
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 22 deletions.
2 changes: 1 addition & 1 deletion requirements-testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ black == 24.4.*
moto[cloudformation,s3] == 4.2.*
mypy == 1.10.*
ruff == 0.4.*
twine == 5.0.*
twine == 5.1.*
1 change: 1 addition & 0 deletions src/deadline_test_fixtures/deadline/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> N
class Fleet:
id: str
farm: Farm
autoscaling: bool = True

@staticmethod
def create(
Expand Down
93 changes: 81 additions & 12 deletions src/deadline_test_fixtures/deadline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PosixSessionUser,
OperatingSystem,
)
from .resources import Fleet
from ..util import call_api, wait_for

LOG = logging.getLogger(__name__)
Expand All @@ -43,7 +44,7 @@ def linux_worker_command(config: DeadlineWorkerConfiguration) -> str: # pragma:
"install-deadline-worker "
+ "-y "
+ f"--farm-id {config.farm_id} "
+ f"--fleet-id {config.fleet_id} "
+ f"--fleet-id {config.fleet.id} "
+ f"--region {config.region} "
+ f"--user {config.user} "
+ f"--group {config.group} "
Expand Down Expand Up @@ -80,7 +81,7 @@ def windows_worker_command(config: DeadlineWorkerConfiguration) -> str: # pragm
"install-deadline-worker "
+ "-y "
+ f"--farm-id {config.farm_id} "
+ f"--fleet-id {config.fleet_id} "
+ f"--fleet-id {config.fleet.id} "
+ f"--region {config.region} "
+ f"--user {config.user} "
+ f"{'--allow-shutdown ' if config.allow_shutdown else ''}"
Expand Down Expand Up @@ -130,10 +131,6 @@ def stop(self) -> None:
def send_command(self, command: str) -> CommandResult:
pass

@abc.abstractproperty
def worker_id(self) -> str:
pass


@dataclass(frozen=True)
class CommandResult: # pragma: no cover
Expand Down Expand Up @@ -173,7 +170,7 @@ def __str__(self) -> str:
class DeadlineWorkerConfiguration:
operating_system: OperatingSystem
farm_id: str
fleet_id: str
fleet: Fleet
region: str
user: str
group: str
Expand Down Expand Up @@ -203,11 +200,14 @@ class EC2InstanceWorker(DeadlineWorker):
s3_client: botocore.client.BaseClient
ec2_client: botocore.client.BaseClient
ssm_client: botocore.client.BaseClient
deadline_client: botocore.client.BaseClient
configuration: DeadlineWorkerConfiguration

instance_id: Optional[str] = field(init=False, default=None)

override_ami_id: InitVar[Optional[str]] = None
worker_id: Optional[str] = None

"""
Option to override the AMI ID for the EC2 instance. The latest AL2023 is used by default.
Note that the scripting to configure the EC2 instance is only verified to work on AL2023.
Expand All @@ -225,8 +225,66 @@ def start(self) -> None:
def stop(self) -> None:
LOG.info(f"Terminating EC2 instance {self.instance_id}")
self.ec2_client.terminate_instances(InstanceIds=[self.instance_id])

self.instance_id = None

if not self.configuration.fleet.autoscaling:
try:
self.wait_until_stopped()
except TimeoutError:
LOG.warning(
f"{self.worker_id} did not transition to a STOPPED status, forcibly stopping..."
)
self.set_stopped_status()

try:
self.delete()
except botocore.exceptions.ClientError as error:
LOG.exception(f"Failed to delete worker: {error}")
raise

def delete(self):
try:
self.deadline_client.delete_worker(
farmId=self.configuration.farm_id,
fleetId=self.configuration.fleet.id,
workerId=self.worker_id,
)
LOG.info(f"{self.worker_id} has been deleted from {self.configuration.fleet.id}")
except botocore.exceptions.ClientError as error:
LOG.exception(f"Failed to delete worker: {error}")
raise

def wait_until_stopped(
self, *, max_checks: int = 25, seconds_between_checks: float = 5
) -> None:
for _ in range(max_checks):
response = self.deadline_client.get_worker(
farmId=self.configuration.farm_id,
fleetId=self.configuration.fleet.id,
workerId=self.worker_id,
)
if response["status"] == "STOPPED":
LOG.info(f"{self.worker_id} is STOPPED")
break
time.sleep(seconds_between_checks)
LOG.info(f"Waiting for {self.worker_id} to transition to STOPPED status")
else:
raise TimeoutError

def set_stopped_status(self):
LOG.info(f"Setting {self.worker_id} to STOPPED status")
try:
self.deadline_client.update_worker(
farmId=self.configuration.farm_id,
fleetId=self.configuration.fleet.id,
workerId=self.worker_id,
status="STOPPED",
)
except botocore.exceptions.ClientError as error:
LOG.exception(f"Failed to update worker status: {error}")
raise

def send_command(self, command: str) -> CommandResult:
"""Send a command via SSM to a shell on a launched EC2 instance. Once the command has fully
finished the result of the invocation is returned.
Expand All @@ -240,7 +298,7 @@ def send_command(self, command: str) -> CommandResult:
#
# If we send an SSM command then we will get an InvalidInstanceId error
# if the instance isn't in that state.
NUM_RETRIES = 20
NUM_RETRIES = 30
SLEEP_INTERVAL_S = 10
for i in range(0, NUM_RETRIES):
LOG.info(f"Sending SSM command to instance {self.instance_id}")
Expand Down Expand Up @@ -491,9 +549,20 @@ def _start_worker_agent(self) -> None: # pragma: no cover
else:
self.start_windows_worker()

@property
def worker_id(self) -> str:
cmd_result = self.send_command("cat /var/lib/deadline/worker.json | jq -r '.worker_id'")
self.worker_id = self.get_worker_id()

def get_worker_id(self) -> str:
if self.configuration.operating_system.name == "AL2023":
cmd_result = self.send_command("jq -r '.worker_id' /var/lib/deadline/worker.json")
else:
cmd_result = self.send_command(
" ; ".join(
[
"$worker=Get-Content -Raw C:\ProgramData\Amazon\Deadline\Cache\worker.json | ConvertFrom-Json",
"echo $worker.worker_id",
]
)
)
assert cmd_result.exit_code == 0, f"Failed to get Worker ID: {cmd_result}"

worker_id = cmd_result.stdout.rstrip("\n\r")
Expand Down Expand Up @@ -553,7 +622,7 @@ def start(self) -> None:
run_container_env = {
**os.environ,
"FARM_ID": self.configuration.farm_id,
"FLEET_ID": self.configuration.fleet_id,
"FLEET_ID": self.configuration.fleet.id,
"AGENT_USER": self.configuration.user,
"SHARED_GROUP": self.configuration.group,
"JOB_USER": self.configuration.job_users[0].user,
Expand Down
4 changes: 3 additions & 1 deletion src/deadline_test_fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def worker_config(

yield DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
fleet=deadline_resources.fleet,
region=region,
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
Expand Down Expand Up @@ -514,10 +514,12 @@ def worker(
ec2_client = boto3.client("ec2")
s3_client = boto3.client("s3")
ssm_client = boto3.client("ssm")
deadline_client = boto3.client("deadline")

worker = EC2InstanceWorker(
ec2_client=ec2_client,
s3_client=s3_client,
deadline_client=deadline_client,
bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name,
ssm_client=ssm_client,
override_ami_id=ami_id,
Expand Down
42 changes: 34 additions & 8 deletions test/unit/deadline/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from deadline_test_fixtures.deadline import worker as mod
from deadline_test_fixtures import (
CodeArtifactRepositoryInfo,
CommandResult,
DeadlineWorkerConfiguration,
DockerContainerWorker,
EC2InstanceWorker,
PipInstall,
CodeArtifactRepositoryInfo,
OperatingSystem,
S3Object,
Fleet,
Farm,
)


Expand Down Expand Up @@ -62,7 +64,7 @@ def region(boto_config: dict[str, str]) -> str:
def worker_config(region: str) -> DeadlineWorkerConfiguration:
return DeadlineWorkerConfiguration(
farm_id="farm-123",
fleet_id="fleet-123",
fleet=Fleet(id="fleet_123", farm=Farm(id="farm-123")),
region=region,
user="test-user",
group="test-group",
Expand Down Expand Up @@ -157,7 +159,9 @@ def worker(
s3_client=boto3.client("s3"),
ec2_client=boto3.client("ec2"),
ssm_client=boto3.client("ssm"),
deadline_client=boto3.client("deadline"),
configuration=worker_config,
worker_id="worker-7c3377ec9eba444bb51cc7da18463081",
)

@patch.object(mod, "open", mock_open(read_data="mock data".encode()))
Expand All @@ -171,6 +175,13 @@ def test_start(self, worker: EC2InstanceWorker) -> None:
patch.object(worker, "_stage_s3_bucket", return_value=s3_files) as mock_stage_s3_bucket,
patch.object(worker, "_launch_instance") as mock_launch_instance,
patch.object(worker, "_start_worker_agent") as mock_start_worker_agent,
patch.object(
worker,
"get_worker_id",
return_value=CommandResult(
exit_code=0, stdout="worker-7c3377ec9eba444bb51cc7da18463081"
),
),
):
# WHEN
worker.start()
Expand Down Expand Up @@ -240,14 +251,17 @@ def test_start_worker_agent(self) -> None:

def test_stop(self, worker: EC2InstanceWorker) -> None:
# GIVEN
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()
instance_id = worker.instance_id
assert instance_id is not None

instance = TestEC2InstanceWorker.describe_instance(instance_id)
assert instance["State"]["Name"] == "running"

# WHEN
worker.stop()

# THEN
Expand All @@ -259,7 +273,11 @@ class TestSendCommand:
def test_sends_command(self, worker: EC2InstanceWorker) -> None:
# GIVEN
cmd = 'echo "Hello world"'
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()

# WHEN
with patch.object(
Expand All @@ -277,7 +295,11 @@ def test_sends_command(self, worker: EC2InstanceWorker) -> None:
def test_retries_when_instance_not_ready(self, worker: EC2InstanceWorker) -> None:
# GIVEN
cmd = 'echo "Hello world"'
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()
real_send_command = worker.ssm_client.send_command

call_count = 0
Expand Down Expand Up @@ -311,7 +333,11 @@ def side_effect(*args, **kwargs):
def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None:
# GIVEN
cmd = 'echo "Hello world"'
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()
err = ClientError({"Error": {"Code": "SomethingWentWrong"}}, "SendCommand")

# WHEN
Expand All @@ -337,7 +363,7 @@ def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None:
"worker-7c3377ec9eba444bb51cc7da18463081\r\n",
],
)
def test_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None:
def test_get_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None:
# GIVEN
with patch.object(
worker, "send_command", return_value=CommandResult(exit_code=0, stdout=worker_id)
Expand Down

0 comments on commit 1217b4b

Please sign in to comment.