diff --git a/requirements-testing.txt b/requirements-testing.txt index a9d8d5b..4a881c2 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -7,4 +7,4 @@ black == 24.4.* moto[cloudformation,s3] == 4.2.* mypy == 1.10.* ruff == 0.4.* -twine == 5.0.* +twine == 5.1.* diff --git a/src/deadline_test_fixtures/deadline/resources.py b/src/deadline_test_fixtures/deadline/resources.py index 1a71dd7..bc8f916 100644 --- a/src/deadline_test_fixtures/deadline/resources.py +++ b/src/deadline_test_fixtures/deadline/resources.py @@ -103,6 +103,7 @@ def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> N class Fleet: id: str farm: Farm + autoscaling: bool = True @staticmethod def create( diff --git a/src/deadline_test_fixtures/deadline/worker.py b/src/deadline_test_fixtures/deadline/worker.py index 524f03f..b266aad 100644 --- a/src/deadline_test_fixtures/deadline/worker.py +++ b/src/deadline_test_fixtures/deadline/worker.py @@ -23,6 +23,7 @@ PosixSessionUser, OperatingSystem, ) +from .resources import Fleet from ..util import call_api, wait_for LOG = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def linux_worker_command(config: DeadlineWorkerConfiguration) -> str: # pragma: "install-deadline-worker " + "-y " + f"--farm-id {config.farm_id} " - + f"--fleet-id {config.fleet_id} " + + f"--fleet-id {config.fleet.id} " + f"--region {config.region} " + f"--user {config.user} " + f"--group {config.group} " @@ -80,7 +81,7 @@ def windows_worker_command(config: DeadlineWorkerConfiguration) -> str: # pragm "install-deadline-worker " + "-y " + f"--farm-id {config.farm_id} " - + f"--fleet-id {config.fleet_id} " + + f"--fleet-id {config.fleet.id} " + f"--region {config.region} " + f"--user {config.user} " + f"{'--allow-shutdown ' if config.allow_shutdown else ''}" @@ -130,10 +131,6 @@ def stop(self) -> None: def send_command(self, command: str) -> CommandResult: pass - @abc.abstractproperty - def worker_id(self) -> str: - pass - @dataclass(frozen=True) class CommandResult: # pragma: no cover @@ -173,7 +170,7 @@ def __str__(self) -> str: class DeadlineWorkerConfiguration: operating_system: OperatingSystem farm_id: str - fleet_id: str + fleet: Fleet region: str user: str group: str @@ -203,11 +200,14 @@ class EC2InstanceWorker(DeadlineWorker): s3_client: botocore.client.BaseClient ec2_client: botocore.client.BaseClient ssm_client: botocore.client.BaseClient + deadline_client: botocore.client.BaseClient configuration: DeadlineWorkerConfiguration instance_id: Optional[str] = field(init=False, default=None) override_ami_id: InitVar[Optional[str]] = None + worker_id: Optional[str] = None + """ Option to override the AMI ID for the EC2 instance. The latest AL2023 is used by default. Note that the scripting to configure the EC2 instance is only verified to work on AL2023. @@ -225,8 +225,66 @@ def start(self) -> None: def stop(self) -> None: LOG.info(f"Terminating EC2 instance {self.instance_id}") self.ec2_client.terminate_instances(InstanceIds=[self.instance_id]) + self.instance_id = None + if not self.configuration.fleet.autoscaling: + try: + self.wait_until_stopped() + except TimeoutError: + LOG.warning( + f"{self.worker_id} did not transition to a STOPPED status, forcibly stopping..." + ) + self.set_stopped_status() + + try: + self.delete() + except botocore.exceptions.ClientError as error: + LOG.exception(f"Failed to delete worker: {error}") + raise + + def delete(self): + try: + self.deadline_client.delete_worker( + farmId=self.configuration.farm_id, + fleetId=self.configuration.fleet.id, + workerId=self.worker_id, + ) + LOG.info(f"{self.worker_id} has been deleted from {self.configuration.fleet.id}") + except botocore.exceptions.ClientError as error: + LOG.exception(f"Failed to delete worker: {error}") + raise + + def wait_until_stopped( + self, *, max_checks: int = 25, seconds_between_checks: float = 5 + ) -> None: + for _ in range(max_checks): + response = self.deadline_client.get_worker( + farmId=self.configuration.farm_id, + fleetId=self.configuration.fleet.id, + workerId=self.worker_id, + ) + if response["status"] == "STOPPED": + LOG.info(f"{self.worker_id} is STOPPED") + break + time.sleep(seconds_between_checks) + LOG.info(f"Waiting for {self.worker_id} to transition to STOPPED status") + else: + raise TimeoutError + + def set_stopped_status(self): + LOG.info(f"Setting {self.worker_id} to STOPPED status") + try: + self.deadline_client.update_worker( + farmId=self.configuration.farm_id, + fleetId=self.configuration.fleet.id, + workerId=self.worker_id, + status="STOPPED", + ) + except botocore.exceptions.ClientError as error: + LOG.exception(f"Failed to update worker status: {error}") + raise + def send_command(self, command: str) -> CommandResult: """Send a command via SSM to a shell on a launched EC2 instance. Once the command has fully finished the result of the invocation is returned. @@ -240,7 +298,7 @@ def send_command(self, command: str) -> CommandResult: # # If we send an SSM command then we will get an InvalidInstanceId error # if the instance isn't in that state. - NUM_RETRIES = 20 + NUM_RETRIES = 30 SLEEP_INTERVAL_S = 10 for i in range(0, NUM_RETRIES): LOG.info(f"Sending SSM command to instance {self.instance_id}") @@ -491,9 +549,20 @@ def _start_worker_agent(self) -> None: # pragma: no cover else: self.start_windows_worker() - @property - def worker_id(self) -> str: - cmd_result = self.send_command("cat /var/lib/deadline/worker.json | jq -r '.worker_id'") + self.worker_id = self.get_worker_id() + + def get_worker_id(self) -> str: + if self.configuration.operating_system.name == "AL2023": + cmd_result = self.send_command("jq -r '.worker_id' /var/lib/deadline/worker.json") + else: + cmd_result = self.send_command( + " ; ".join( + [ + "$worker=Get-Content -Raw C:\ProgramData\Amazon\Deadline\Cache\worker.json | ConvertFrom-Json", + "echo $worker.worker_id", + ] + ) + ) assert cmd_result.exit_code == 0, f"Failed to get Worker ID: {cmd_result}" worker_id = cmd_result.stdout.rstrip("\n\r") @@ -553,7 +622,7 @@ def start(self) -> None: run_container_env = { **os.environ, "FARM_ID": self.configuration.farm_id, - "FLEET_ID": self.configuration.fleet_id, + "FLEET_ID": self.configuration.fleet.id, "AGENT_USER": self.configuration.user, "SHARED_GROUP": self.configuration.group, "JOB_USER": self.configuration.job_users[0].user, diff --git a/src/deadline_test_fixtures/fixtures.py b/src/deadline_test_fixtures/fixtures.py index 78337ff..1024121 100644 --- a/src/deadline_test_fixtures/fixtures.py +++ b/src/deadline_test_fixtures/fixtures.py @@ -454,7 +454,7 @@ def worker_config( yield DeadlineWorkerConfiguration( farm_id=deadline_resources.farm.id, - fleet_id=deadline_resources.fleet.id, + fleet=deadline_resources.fleet, region=region, user=os.getenv("WORKER_POSIX_USER", "deadline-worker"), group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"), @@ -514,10 +514,12 @@ def worker( ec2_client = boto3.client("ec2") s3_client = boto3.client("s3") ssm_client = boto3.client("ssm") + deadline_client = boto3.client("deadline") worker = EC2InstanceWorker( ec2_client=ec2_client, s3_client=s3_client, + deadline_client=deadline_client, bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name, ssm_client=ssm_client, override_ami_id=ami_id, diff --git a/test/unit/deadline/test_worker.py b/test/unit/deadline/test_worker.py index 5ab0a4c..939119a 100644 --- a/test/unit/deadline/test_worker.py +++ b/test/unit/deadline/test_worker.py @@ -14,14 +14,16 @@ from deadline_test_fixtures.deadline import worker as mod from deadline_test_fixtures import ( - CodeArtifactRepositoryInfo, CommandResult, DeadlineWorkerConfiguration, DockerContainerWorker, EC2InstanceWorker, PipInstall, + CodeArtifactRepositoryInfo, OperatingSystem, S3Object, + Fleet, + Farm, ) @@ -62,7 +64,7 @@ def region(boto_config: dict[str, str]) -> str: def worker_config(region: str) -> DeadlineWorkerConfiguration: return DeadlineWorkerConfiguration( farm_id="farm-123", - fleet_id="fleet-123", + fleet=Fleet(id="fleet_123", farm=Farm(id="farm-123")), region=region, user="test-user", group="test-group", @@ -157,7 +159,9 @@ def worker( s3_client=boto3.client("s3"), ec2_client=boto3.client("ec2"), ssm_client=boto3.client("ssm"), + deadline_client=boto3.client("deadline"), configuration=worker_config, + worker_id="worker-7c3377ec9eba444bb51cc7da18463081", ) @patch.object(mod, "open", mock_open(read_data="mock data".encode())) @@ -171,6 +175,13 @@ def test_start(self, worker: EC2InstanceWorker) -> None: patch.object(worker, "_stage_s3_bucket", return_value=s3_files) as mock_stage_s3_bucket, patch.object(worker, "_launch_instance") as mock_launch_instance, patch.object(worker, "_start_worker_agent") as mock_start_worker_agent, + patch.object( + worker, + "get_worker_id", + return_value=CommandResult( + exit_code=0, stdout="worker-7c3377ec9eba444bb51cc7da18463081" + ), + ), ): # WHEN worker.start() @@ -240,14 +251,17 @@ def test_start_worker_agent(self) -> None: def test_stop(self, worker: EC2InstanceWorker) -> None: # GIVEN - worker.start() + # WHEN + with patch.object( + worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081" + ): + worker.start() instance_id = worker.instance_id assert instance_id is not None instance = TestEC2InstanceWorker.describe_instance(instance_id) assert instance["State"]["Name"] == "running" - # WHEN worker.stop() # THEN @@ -259,7 +273,11 @@ class TestSendCommand: def test_sends_command(self, worker: EC2InstanceWorker) -> None: # GIVEN cmd = 'echo "Hello world"' - worker.start() + # WHEN + with patch.object( + worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081" + ): + worker.start() # WHEN with patch.object( @@ -277,7 +295,11 @@ def test_sends_command(self, worker: EC2InstanceWorker) -> None: def test_retries_when_instance_not_ready(self, worker: EC2InstanceWorker) -> None: # GIVEN cmd = 'echo "Hello world"' - worker.start() + # WHEN + with patch.object( + worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081" + ): + worker.start() real_send_command = worker.ssm_client.send_command call_count = 0 @@ -311,7 +333,11 @@ def side_effect(*args, **kwargs): def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None: # GIVEN cmd = 'echo "Hello world"' - worker.start() + # WHEN + with patch.object( + worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081" + ): + worker.start() err = ClientError({"Error": {"Code": "SomethingWentWrong"}}, "SendCommand") # WHEN @@ -337,7 +363,7 @@ def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None: "worker-7c3377ec9eba444bb51cc7da18463081\r\n", ], ) - def test_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None: + def test_get_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None: # GIVEN with patch.object( worker, "send_command", return_value=CommandResult(exit_code=0, stdout=worker_id)