Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support for capacity provider #407

Merged
merged 12 commits into from
Apr 11, 2024
6 changes: 5 additions & 1 deletion prefect_aws/workers/ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,11 @@ def _prepare_task_run_request(
task_run_request.setdefault("taskDefinition", task_definition_arn)
assert task_run_request["taskDefinition"] == task_definition_arn

if task_run_request.get("launchType") == "FARGATE_SPOT":
if "capacityProviderStrategy" in task_run_request:
# Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa
task_run_request.pop("launchType", None)
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved

elif task_run_request.get("launchType") == "FARGATE_SPOT":
# Should not be provided at all for FARGATE SPOT
task_run_request.pop("launchType", None)

Expand Down
46 changes: 44 additions & 2 deletions tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import anyio
import pytest
import yaml
from moto import mock_ec2, mock_ecs, mock_logs
from moto import mock_autoscaling, mock_ec2, mock_ecs, mock_logs
from moto.ec2.utils import generate_instance_identity_document
from prefect.server.schemas.core import FlowRun
from prefect.utilities.asyncutils import run_sync_in_worker_thread
Expand Down Expand Up @@ -275,7 +275,7 @@ def ecs_mocks(
aws_credentials: AwsCredentials, flow_run: FlowRun, container_status_code
):
with mock_ecs() as ecs:
with mock_ec2():
with mock_ec2(), mock_autoscaling():
with mock_logs():
session = aws_credentials.get_boto3_session()

Expand Down Expand Up @@ -2016,6 +2016,48 @@ async def test_user_defined_environment_variables_in_task_definition_template(
]


@pytest.mark.usefixtures("ecs_mocks")
async def test_user_defined_capacity_provider_strategy(
aws_credentials: AwsCredentials, flow_run: FlowRun
):
configuration = await construct_configuration_with_job_template(
template_overrides=dict(
task_run_request={
"capacityProviderStrategy": [
{"base": 0, "weight": 1, "capacityProvider": "r6i.large"},
]
},
),
aws_credentials=aws_credentials,
)

assert "launchType" not in configuration.task_run_request
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved

session = aws_credentials.get_boto3_session()
ecs_client = session.client("ecs")

async with ECSWorker(work_pool_name="test") as worker:
# Capture the task run call because moto does not track
# 'capacityProviderStrategy'
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

result = await run_then_stop_task(worker, configuration, flow_run)

assert result.status_code == 0
_, task_arn = parse_identifier(result.identifier)

task = describe_task(ecs_client, task_arn)
assert not task.get("launchType")
# Instead, it requires a capacity provider strategy but this is not supported
# by moto and is not present on the task even when provided so we assert on the
# mock call to ensure it is sent
assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [
{"base": 0, "weight": 1, "capacityProvider": "r6i.large"},
]


@pytest.mark.usefixtures("ecs_mocks")
async def test_user_defined_environment_variables_in_task_run_request_template(
aws_credentials: AwsCredentials, flow_run: FlowRun
Expand Down
Loading