From 213633dddacd175b0b757b6a4d92c43b3461d1f4 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Tue, 2 Apr 2024 15:41:33 -0500 Subject: [PATCH 1/8] checks for capacity provider --- prefect_aws/workers/ecs_worker.py | 6 +++- tests/workers/test_ecs_worker.py | 46 +++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 1a5c3d28..f862914a 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -1436,7 +1436,11 @@ def _prepare_task_run_request( task_run_request.setdefault("taskDefinition", task_definition_arn) assert task_run_request["taskDefinition"] == task_definition_arn - if task_run_request.get("launchType") == "FARGATE_SPOT": + if "capacityProviderStrategy" in task_run_request: + # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa + task_run_request.pop("launchType", None) + + elif task_run_request.get("launchType") == "FARGATE_SPOT": # Should not be provided at all for FARGATE SPOT task_run_request.pop("launchType", None) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index e4dab38c..a6b91820 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -8,7 +8,7 @@ import anyio import pytest import yaml -from moto import mock_ec2, mock_ecs, mock_logs +from moto import mock_autoscaling, mock_ec2, mock_ecs, mock_logs from moto.ec2.utils import generate_instance_identity_document from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread @@ -275,7 +275,7 @@ def ecs_mocks( aws_credentials: AwsCredentials, flow_run: FlowRun, container_status_code ): with mock_ecs() as ecs: - with mock_ec2(): + with mock_ec2(), mock_autoscaling(): with mock_logs(): session = aws_credentials.get_boto3_session() @@ -2003,6 +2003,48 @@ async def test_user_defined_environment_variables_in_task_definition_template( ] +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_capacity_provider_strategy( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_run_request={ + "capacityProviderStrategy": [ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"}, + ] + }, + ), + aws_credentials=aws_credentials, + ) + + assert "launchType" not in configuration.task_run_request + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track + # 'capacityProviderStrategy' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + assert not task.get("launchType") + # Instead, it requires a capacity provider strategy but this is not supported + # by moto and is not present on the task even when provided so we assert on the + # mock call to ensure it is sent + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"}, + ] + + @pytest.mark.usefixtures("ecs_mocks") async def test_user_defined_environment_variables_in_task_run_request_template( aws_credentials: AwsCredentials, flow_run: FlowRun From a6d37d93cdbbf875e12599bc37d5661dc3d0f43f Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Fri, 5 Apr 2024 13:44:10 -0500 Subject: [PATCH 2/8] logs warning when removing launch type --- prefect_aws/workers/ecs_worker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index f862914a..956fc4b6 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -1438,6 +1438,10 @@ def _prepare_task_run_request( if "capacityProviderStrategy" in task_run_request: # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa + self._logger.warning( + "Removing launchType from task run request. Due to finding" + " capacityProviderStrategy in the request." + ) task_run_request.pop("launchType", None) elif task_run_request.get("launchType") == "FARGATE_SPOT": From 0bea3e7aa17cb7d937f2d387e323d03c61aa8fce Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Mon, 8 Apr 2024 10:45:56 -0500 Subject: [PATCH 3/8] passes ecs test --- tests/workers/test_ecs_worker.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 39329d8d..dd96d600 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -522,15 +522,23 @@ async def test_launch_types( @pytest.mark.parametrize( "cpu,memory", [(None, None), (1024, None), (None, 2048), (2048, 4096)] ) +@pytest.mark.parametrize("container_cpu,container_memory", [(None, None), (1024, 2048)]) async def test_cpu_and_memory( aws_credentials: AwsCredentials, launch_type: str, flow_run: FlowRun, - cpu: int, - memory: int, + task_cpu: int, + task_memory: int, + container_cpu: int, + container_memory: int, ): configuration = await construct_configuration( - aws_credentials=aws_credentials, launch_type=launch_type, cpu=cpu, memory=memory + aws_credentials=aws_credentials, + launch_type=launch_type, + task_cpu=task_cpu, + task_memory=task_memory, + container_cpu=container_cpu, + container_memory=container_memory, ) session = aws_credentials.get_boto3_session() @@ -553,19 +561,19 @@ async def test_cpu_and_memory( if launch_type == "EC2": # EC2 requires CPU and memory to be defined at the container level - assert container_definition["cpu"] == cpu or ECS_DEFAULT_CPU - assert container_definition["memory"] == memory or ECS_DEFAULT_MEMORY + assert container_definition["cpu"] == container_cpu or ECS_DEFAULT_CPU + assert container_definition["memory"] == container_memory or ECS_DEFAULT_MEMORY else: # Fargate requires CPU and memory to be defined at the task definition level - assert task_definition["cpu"] == str(cpu or ECS_DEFAULT_CPU) - assert task_definition["memory"] == str(memory or ECS_DEFAULT_MEMORY) + assert task_definition["cpu"] == str(task_cpu or ECS_DEFAULT_CPU) + assert task_definition["memory"] == str(task_memory or ECS_DEFAULT_MEMORY) # We always provide non-null values as overrides on the task run - assert overrides.get("cpu") == (str(cpu) if cpu else None) - assert overrides.get("memory") == (str(memory) if memory else None) + assert overrides.get("cpu") == (str(task_cpu) if task_cpu else None) + assert overrides.get("memory") == (str(task_memory) if task_memory else None) # And as overrides for the Prefect container - assert container_overrides.get("cpu") == cpu - assert container_overrides.get("memory") == memory + assert container_overrides.get("cpu") == task_cpu + assert container_overrides.get("memory") == container_memory @pytest.mark.usefixtures("ecs_mocks") From c1042d35edb912586a9b4b0fd7a7e3d2d28201a5 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Mon, 8 Apr 2024 11:20:32 -0500 Subject: [PATCH 4/8] Revert "passes ecs test" This reverts commit 0bea3e7aa17cb7d937f2d387e323d03c61aa8fce. --- tests/workers/test_ecs_worker.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index dd96d600..39329d8d 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -522,23 +522,15 @@ async def test_launch_types( @pytest.mark.parametrize( "cpu,memory", [(None, None), (1024, None), (None, 2048), (2048, 4096)] ) -@pytest.mark.parametrize("container_cpu,container_memory", [(None, None), (1024, 2048)]) async def test_cpu_and_memory( aws_credentials: AwsCredentials, launch_type: str, flow_run: FlowRun, - task_cpu: int, - task_memory: int, - container_cpu: int, - container_memory: int, + cpu: int, + memory: int, ): configuration = await construct_configuration( - aws_credentials=aws_credentials, - launch_type=launch_type, - task_cpu=task_cpu, - task_memory=task_memory, - container_cpu=container_cpu, - container_memory=container_memory, + aws_credentials=aws_credentials, launch_type=launch_type, cpu=cpu, memory=memory ) session = aws_credentials.get_boto3_session() @@ -561,19 +553,19 @@ async def test_cpu_and_memory( if launch_type == "EC2": # EC2 requires CPU and memory to be defined at the container level - assert container_definition["cpu"] == container_cpu or ECS_DEFAULT_CPU - assert container_definition["memory"] == container_memory or ECS_DEFAULT_MEMORY + assert container_definition["cpu"] == cpu or ECS_DEFAULT_CPU + assert container_definition["memory"] == memory or ECS_DEFAULT_MEMORY else: # Fargate requires CPU and memory to be defined at the task definition level - assert task_definition["cpu"] == str(task_cpu or ECS_DEFAULT_CPU) - assert task_definition["memory"] == str(task_memory or ECS_DEFAULT_MEMORY) + assert task_definition["cpu"] == str(cpu or ECS_DEFAULT_CPU) + assert task_definition["memory"] == str(memory or ECS_DEFAULT_MEMORY) # We always provide non-null values as overrides on the task run - assert overrides.get("cpu") == (str(task_cpu) if task_cpu else None) - assert overrides.get("memory") == (str(task_memory) if task_memory else None) + assert overrides.get("cpu") == (str(cpu) if cpu else None) + assert overrides.get("memory") == (str(memory) if memory else None) # And as overrides for the Prefect container - assert container_overrides.get("cpu") == task_cpu - assert container_overrides.get("memory") == container_memory + assert container_overrides.get("cpu") == cpu + assert container_overrides.get("memory") == memory @pytest.mark.usefixtures("ecs_mocks") From daa9a244137d8404345a29ece7590722d8d72c62 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Tue, 9 Apr 2024 16:17:01 -0500 Subject: [PATCH 5/8] added capacity_provider_strategy to ecs variables --- prefect_aws/workers/ecs_worker.py | 21 +++++++++++++++++++-- tests/workers/test_ecs_worker.py | 12 ++++-------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 450cf185..b3b54495 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -70,9 +70,9 @@ from pydantic import VERSION as PYDANTIC_VERSION if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field, root_validator + from pydantic.v1 import BaseModel, Field, root_validator else: - from pydantic import Field, root_validator + from pydantic import Field, root_validator, BaseModel from slugify import slugify from tenacity import retry, stop_after_attempt, wait_fixed, wait_random @@ -367,6 +367,16 @@ def network_configuration_requires_vpc_id(cls, values: dict) -> dict: return values +class CapacityProvider(BaseModel): + """ + The capacity provider strategy to use when running the task. + """ + + capacity_provider: str + weight: int + base: int + + class ECSVariables(BaseVariables): """ Variables for templating an ECS job. @@ -425,6 +435,13 @@ class ECSVariables(BaseVariables): ), ) ) + capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( + default=None, + description=( + "The capacity provider strategy to use when running the task. This is only" + "If a capacityProviderStrategy is specified, we will omit the launchType" + ), + ) image: Optional[str] = Field( default=None, description=( diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 39329d8d..9db3882e 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -2020,15 +2020,11 @@ async def test_user_defined_environment_variables_in_task_definition_template( async def test_user_defined_capacity_provider_strategy( aws_credentials: AwsCredentials, flow_run: FlowRun ): - configuration = await construct_configuration_with_job_template( - template_overrides=dict( - task_run_request={ - "capacityProviderStrategy": [ - {"base": 0, "weight": 1, "capacityProvider": "r6i.large"}, - ] - }, - ), + configuration = await construct_configuration( aws_credentials=aws_credentials, + capacity_provider_strategy=[ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"} + ], ) assert "launchType" not in configuration.task_run_request From bbdae802a1859c39ca2ca9d1842c693df1fbba6c Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Wed, 10 Apr 2024 09:22:28 -0500 Subject: [PATCH 6/8] changed field formatting --- prefect_aws/workers/ecs_worker.py | 11 ++++++++--- tests/workers/test_ecs_worker.py | 3 --- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index b3b54495..ec0b47af 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -126,6 +126,7 @@ taskRoleArn: "{{ task_role_arn }}" tags: "{{ labels }}" taskDefinition: "{{ task_definition_arn }}" +capacityProviderStrategy: "{{ capacity_provider_strategy }}" """ # Create task run retry settings @@ -372,7 +373,7 @@ class CapacityProvider(BaseModel): The capacity provider strategy to use when running the task. """ - capacity_provider: str + capacityProvider: str weight: int base: int @@ -436,7 +437,7 @@ class ECSVariables(BaseVariables): ) ) capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( - default=None, + default_factory=List[CapacityProvider], description=( "The capacity provider strategy to use when running the task. This is only" "If a capacityProviderStrategy is specified, we will omit the launchType" @@ -1466,8 +1467,12 @@ def _prepare_task_run_request( task_run_request.setdefault("taskDefinition", task_definition_arn) assert task_run_request["taskDefinition"] == task_definition_arn + capacityProviderStrategy = ( + task_run_request.get("capacityProviderStrategy") + or configuration.capacity_provider_strategy + ) - if "capacityProviderStrategy" in task_run_request: + if capacityProviderStrategy: # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa self._logger.warning( "Removing launchType from task run request. Due to finding" diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 9db3882e..a6112773 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -2026,9 +2026,6 @@ async def test_user_defined_capacity_provider_strategy( {"base": 0, "weight": 1, "capacityProvider": "r6i.large"} ], ) - - assert "launchType" not in configuration.task_run_request - session = aws_credentials.get_boto3_session() ecs_client = session.client("ecs") From 9107e278b850d74cbf81847c07d4db97f18a820f Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Wed, 10 Apr 2024 11:48:37 -0500 Subject: [PATCH 7/8] sets fargate spot capacity provider --- prefect_aws/workers/ecs_worker.py | 36 ++++++++++++++----------------- tests/workers/test_ecs_worker.py | 5 +++-- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index ec0b47af..c65e594a 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -246,6 +246,16 @@ def mask_api_key(task_run_request): ) +class CapacityProvider(BaseModel): + """ + The capacity provider strategy to use when running the task. + """ + + capacityProvider: str + weight: int + base: int + + class ECSJobConfiguration(BaseJobConfiguration): """ Job configuration for an ECS worker. @@ -268,6 +278,7 @@ class ECSJobConfiguration(BaseJobConfiguration): auto_deregister_task_definition: bool = Field(default=False) vpc_id: Optional[str] = Field(default=None) container_name: Optional[str] = Field(default=None) + cluster: Optional[str] = Field(default=None) match_latest_revision_in_family: bool = Field(default=False) @@ -368,16 +379,6 @@ def network_configuration_requires_vpc_id(cls, values: dict) -> dict: return values -class CapacityProvider(BaseModel): - """ - The capacity provider strategy to use when running the task. - """ - - capacityProvider: str - weight: int - base: int - - class ECSVariables(BaseVariables): """ Variables for templating an ECS job. @@ -437,7 +438,7 @@ class ECSVariables(BaseVariables): ) ) capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( - default_factory=List[CapacityProvider], + default_factory=list, description=( "The capacity provider strategy to use when running the task. This is only" "If a capacityProviderStrategy is specified, we will omit the launchType" @@ -1467,10 +1468,7 @@ def _prepare_task_run_request( task_run_request.setdefault("taskDefinition", task_definition_arn) assert task_run_request["taskDefinition"] == task_definition_arn - capacityProviderStrategy = ( - task_run_request.get("capacityProviderStrategy") - or configuration.capacity_provider_strategy - ) + capacityProviderStrategy = task_run_request.get("capacityProviderStrategy") if capacityProviderStrategy: # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa @@ -1485,11 +1483,9 @@ def _prepare_task_run_request( task_run_request.pop("launchType", None) # A capacity provider strategy is required for FARGATE SPOT - task_run_request.setdefault( - "capacityProviderStrategy", - [{"capacityProvider": "FARGATE_SPOT", "weight": 1}], - ) - + task_run_request["capacityProviderStrategy"] = [ + {"capacityProvider": "FARGATE_SPOT", "weight": 1} + ] overrides = task_run_request.get("overrides", {}) container_overrides = overrides.get("containerOverrides", []) diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index a6112773..386ebf10 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -8,7 +8,7 @@ import anyio import pytest import yaml -from moto import mock_autoscaling, mock_ec2, mock_ecs, mock_logs +from moto import mock_ec2, mock_ecs, mock_logs from moto.ec2.utils import generate_instance_identity_document from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread @@ -275,7 +275,7 @@ def ecs_mocks( aws_credentials: AwsCredentials, flow_run: FlowRun, container_status_code ): with mock_ecs() as ecs: - with mock_ec2(), mock_autoscaling(): + with mock_ec2(): with mock_logs(): session = aws_credentials.get_boto3_session() @@ -506,6 +506,7 @@ async def test_launch_types( # Instead, it requires a capacity provider strategy but this is not supported # by moto and is not present on the task even when provided so we assert on the # mock call to ensure it is sent + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ {"capacityProvider": "FARGATE_SPOT", "weight": 1} ] From 4b912cd4abb7e5ffe3e4edb54ebcefb7e948ba0f Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Thu, 11 Apr 2024 11:20:51 -0500 Subject: [PATCH 8/8] description formatting --- prefect_aws/workers/ecs_worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index fec9fc41..d862fddd 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -278,7 +278,6 @@ class ECSJobConfiguration(BaseJobConfiguration): auto_deregister_task_definition: bool = Field(default=False) vpc_id: Optional[str] = Field(default=None) container_name: Optional[str] = Field(default=None) - cluster: Optional[str] = Field(default=None) match_latest_revision_in_family: bool = Field(default=False) @@ -440,8 +439,9 @@ class ECSVariables(BaseVariables): capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( default_factory=list, description=( - "The capacity provider strategy to use when running the task. This is only" - "If a capacityProviderStrategy is specified, we will omit the launchType" + "The capacity provider strategy to use when running the task. " + "If a capacity provider strategy is specified, the selected launch" + " type will be ignored." ), ) image: Optional[str] = Field( @@ -1473,8 +1473,8 @@ def _prepare_task_run_request( if capacityProviderStrategy: # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa self._logger.warning( - "Removing launchType from task run request. Due to finding" - " capacityProviderStrategy in the request." + "Found capacityProviderStrategy. " + "Removing launchType from task run request." ) task_run_request.pop("launchType", None)