Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Adds ability to publish ECSTask block as a ecs work pool (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Dec 11, 2023
1 parent 6865af7 commit 6aca613
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 3 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

## 0.4.6

Released December 11th, 2023.

### Added

Ability to publish `ECSTask`` block as an ecs work pool - [#353](https://github.com/PrefectHQ/prefect-aws/pull/353)

## 0.4.5

Released November 30th, 2023.
Expand Down
74 changes: 73 additions & 1 deletion prefect_aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
import json
import logging
import pprint
import shlex
import sys
import time
import warnings
Expand All @@ -116,6 +117,8 @@
import boto3
import yaml
from anyio.abc import TaskStatus
from jsonpointer import JsonPointerException
from prefect.blocks.core import BlockNotSavedError
from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound
from prefect.infrastructure.base import Infrastructure, InfrastructureResult
from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible
Expand All @@ -132,7 +135,7 @@
from typing_extensions import Literal, Self

from prefect_aws import AwsCredentials
from prefect_aws.workers.ecs_worker import _TAG_REGEX
from prefect_aws.workers.ecs_worker import _TAG_REGEX, ECSWorker

# Internal type alias for ECS clients which are generated dynamically in botocore
_ECSClient = Any
Expand Down Expand Up @@ -681,6 +684,75 @@ async def kill(self, identifier: str, grace_seconds: int = 30) -> None:
cluster, task = parse_task_identifier(identifier)
await run_sync_in_worker_thread(self._stop_task, cluster, task)

@staticmethod
def get_corresponding_worker_type() -> str:
"""Return the corresponding worker type for this infrastructure block."""
return ECSWorker.type

async def generate_work_pool_base_job_template(self) -> dict:
"""
Generate a base job template for a cloud-run work pool with the same
configuration as this block.
Returns:
- dict: a base job template for a cloud-run work pool
"""
base_job_template = copy.deepcopy(ECSWorker.get_default_base_job_template())
for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items():
if key == "command":
base_job_template["variables"]["properties"]["command"]["default"] = (
shlex.join(value)
)
elif key in [
"type",
"block_type_slug",
"_block_document_id",
"_block_document_name",
"_is_anonymous",
"task_customizations",
]:
continue
elif key == "aws_credentials":
if not self.aws_credentials._block_document_id:
raise BlockNotSavedError(
"It looks like you are trying to use a block that"
" has not been saved. Please call `.save` on your block"
" before publishing it as a work pool."
)
base_job_template["variables"]["properties"]["aws_credentials"][
"default"
] = {
"$ref": {
"block_document_id": str(
self.aws_credentials._block_document_id
)
}
}
elif key == "task_definition":
base_job_template["job_configuration"]["task_definition"] = value
elif key in base_job_template["variables"]["properties"]:
base_job_template["variables"]["properties"][key]["default"] = value
else:
self.logger.warning(
f"Variable {key!r} is not supported by Cloud Run work pools."
" Skipping."
)

if self.task_customizations:
try:
base_job_template["job_configuration"]["task_run_request"] = (
self.task_customizations.apply(
base_job_template["job_configuration"]["task_run_request"]
)
)
except JsonPointerException:
self.logger.warning(
"Unable to apply task customizations to the base job template."
"You may need to update the template manually."
)

return base_job_template

def _stop_task(self, cluster: str, task: str) -> None:
"""
Stop a running ECS task.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ boto3>=1.24.53
botocore>=1.27.53
mypy_boto3_s3>=1.24.94
mypy_boto3_secretsmanager>=1.26.49
prefect>=2.13.5
prefect>=2.14.10
tenacity>=8.0.0
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ def prefect_db():

@pytest.fixture
def aws_credentials():
return AwsCredentials(
block = AwsCredentials(
aws_access_key_id="access_key_id",
aws_secret_access_key="secret_access_key",
region_name="us-east-1",
)
block.save("test-creds-block", overwrite=True)
return block


@pytest.fixture
Expand Down
189 changes: 189 additions & 0 deletions tests/test_ecs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import textwrap
from copy import deepcopy
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Optional
from unittest.mock import MagicMock
Expand All @@ -18,6 +19,8 @@
from prefect.utilities.dockerutils import get_prefect_image_name
from pydantic import VERSION as PYDANTIC_VERSION

from prefect_aws.workers.ecs_worker import ECSWorker

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import ValidationError
else:
Expand Down Expand Up @@ -2047,3 +2050,189 @@ async def test_kill_with_grace_period(aws_credentials, caplog):

# Logs warning
assert "grace period of 60s requested, but AWS does not support" in caplog.text


@pytest.fixture
def default_base_job_template():
return deepcopy(ECSWorker.get_default_base_job_template())


@pytest.fixture
def base_job_template_with_defaults(default_base_job_template, aws_credentials):
base_job_template_with_defaults = deepcopy(default_base_job_template)
base_job_template_with_defaults["variables"]["properties"]["command"][
"default"
] = "python my_script.py"
base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = {
"VAR1": "value1",
"VAR2": "value2",
}
base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = {
"label1": "value1",
"label2": "value2",
}
base_job_template_with_defaults["variables"]["properties"]["name"][
"default"
] = "prefect-job"
base_job_template_with_defaults["variables"]["properties"]["image"][
"default"
] = "docker.io/my_image:latest"
base_job_template_with_defaults["variables"]["properties"]["aws_credentials"][
"default"
] = {"$ref": {"block_document_id": str(aws_credentials._block_document_id)}}
base_job_template_with_defaults["variables"]["properties"]["launch_type"][
"default"
] = "FARGATE_SPOT"
base_job_template_with_defaults["variables"]["properties"]["vpc_id"][
"default"
] = "vpc-123456"
base_job_template_with_defaults["variables"]["properties"]["task_role_arn"][
"default"
] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole"
base_job_template_with_defaults["variables"]["properties"]["execution_role_arn"][
"default"
] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole"
base_job_template_with_defaults["variables"]["properties"]["cluster"][
"default"
] = "test-cluster"
base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2048
base_job_template_with_defaults["variables"]["properties"]["memory"][
"default"
] = 4096

base_job_template_with_defaults["variables"]["properties"]["family"][
"default"
] = "test-family"
base_job_template_with_defaults["variables"]["properties"]["task_definition_arn"][
"default"
] = "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1"
base_job_template_with_defaults["variables"]["properties"][
"cloudwatch_logs_options"
]["default"] = {
"awslogs-group": "prefect",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "prefect",
}
base_job_template_with_defaults["variables"]["properties"][
"configure_cloudwatch_logs"
]["default"] = True
base_job_template_with_defaults["variables"]["properties"]["stream_output"][
"default"
] = True
base_job_template_with_defaults["variables"]["properties"][
"task_watch_poll_interval"
]["default"] = 5.1
base_job_template_with_defaults["variables"]["properties"][
"task_start_timeout_seconds"
]["default"] = 60
base_job_template_with_defaults["variables"]["properties"][
"auto_deregister_task_definition"
]["default"] = False
return base_job_template_with_defaults


@pytest.fixture
def base_job_template_with_task_arn(default_base_job_template, aws_credentials):
base_job_template_with_task_arn = deepcopy(default_base_job_template)
base_job_template_with_task_arn["variables"]["properties"]["image"][
"default"
] = "docker.io/my_image:latest"

base_job_template_with_task_arn["job_configuration"]["task_definition"] = {
"containerDefinitions": [
{"image": "docker.io/my_image:latest", "name": "prefect-job"}
],
"cpu": "2048",
"family": "test-family",
"memory": "2024",
"executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
}
return base_job_template_with_task_arn


@pytest.mark.parametrize(
"job_config",
[
"default",
"custom",
"task_definition_arn",
],
)
async def test_generate_work_pool_base_job_template(
job_config,
base_job_template_with_defaults,
aws_credentials,
default_base_job_template,
base_job_template_with_task_arn,
caplog,
):
job = ECSTask()
expected_template = default_base_job_template
expected_template["variables"]["properties"]["image"][
"default"
] = get_prefect_image_name()
if job_config == "custom":
expected_template = base_job_template_with_defaults
job = ECSTask(
command=["python", "my_script.py"],
env={"VAR1": "value1", "VAR2": "value2"},
labels={"label1": "value1", "label2": "value2"},
name="prefect-job",
image="docker.io/my_image:latest",
aws_credentials=aws_credentials,
launch_type="FARGATE_SPOT",
vpc_id="vpc-123456",
task_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
execution_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
cluster="test-cluster",
cpu=2048,
memory=4096,
task_customizations=[
{
"op": "add",
"path": "/networkConfiguration/awsvpcConfiguration/securityGroups",
"value": ["sg-d72e9599956a084f5"],
},
],
family="test-family",
task_definition_arn=(
"arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1"
),
cloudwatch_logs_options={
"awslogs-group": "prefect",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "prefect",
},
configure_cloudwatch_logs=True,
stream_output=True,
task_watch_poll_interval=5.1,
task_start_timeout_seconds=60,
auto_deregister_task_definition=False,
)
elif job_config == "task_definition_arn":
expected_template = base_job_template_with_task_arn
job = ECSTask(
image="docker.io/my_image:latest",
task_definition={
"containerDefinitions": [
{"image": "docker.io/my_image:latest", "name": "prefect-job"}
],
"cpu": "2048",
"family": "test-family",
"memory": "2024",
"executionRoleArn": (
"arn:aws:iam::123456789012:role/ecsTaskExecutionRole"
),
},
)

template = await job.generate_work_pool_base_job_template()

assert template == expected_template

if job_config == "custom":
assert (
"Unable to apply task customizations to the base job template."
"You may need to update the template manually."
in caplog.text
)

0 comments on commit 6aca613

Please sign in to comment.