Skip to content

Commit

Permalink
Add reattach flag to ECSOperator (#10643)
Browse files Browse the repository at this point in the history
..so that whenever the Airflow server restarts, it does not leave rogue ECS Tasks. Instead the operator will seek for any running instance and attach to it.

GitOrigin-RevId: 0df60b773671ecf8d4e5f582ac2be200cf2a2edd
  • Loading branch information
darwinyip authored and Cloud Composer Team committed Sep 15, 2021
1 parent 897d2cf commit 2b13a2f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 11 deletions.
65 changes: 54 additions & 11 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
import sys
from datetime import datetime
from typing import Optional
from typing import Optional, Dict

from botocore.waiter import Waiter

Expand All @@ -42,22 +42,33 @@ class ECSProtocol(Protocol):
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html
"""

def run_task(self, **kwargs) -> dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501 # pylint: disable=line-too-long
# pylint: disable=C0103, line-too-long
def run_task(self, **kwargs) -> Dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501
...

def get_waiter(self, x: str) -> Waiter:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501 # pylint: disable=line-too-long
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501
...

def describe_tasks(self, cluster: str, tasks) -> dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501 # pylint: disable=line-too-long
def describe_tasks(self, cluster: str, tasks) -> Dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501
...

def stop_task(self, cluster, task, reason: str) -> dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501 # pylint: disable=line-too-long
def stop_task(self, cluster, task, reason: str) -> Dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501
...

def describe_task_definition(self, taskDefinition: str) -> Dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_task_definition""" # noqa: E501
...

def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family: str) -> Dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.list_tasks""" # noqa: E501
...

# pylint: enable=C0103, line-too-long


class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
"""
Expand Down Expand Up @@ -110,6 +121,9 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
Only required if you want logs to be shown in the Airflow UI after your job has
finished.
:type awslogs_stream_prefix: str
:param reattach: If set to True, will check if a task from the same family is already running.
If so, the operator will attach to it instead of starting a new task.
:type reattach: bool
"""

ui_color = '#f0ede4'
Expand All @@ -135,6 +149,7 @@ def __init__(
awslogs_region: Optional[str] = None,
awslogs_stream_prefix: Optional[str] = None,
propagate_tags: Optional[str] = None,
reattach: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -156,6 +171,7 @@ def __init__(
self.awslogs_stream_prefix = awslogs_stream_prefix
self.awslogs_region = awslogs_region
self.propagate_tags = propagate_tags
self.reattach = reattach

if self.awslogs_region is None:
self.awslogs_region = region_name
Expand All @@ -172,6 +188,18 @@ def execute(self, context):

self.client = self.get_hook().get_conn()

if self.reattach:
self._try_reattach_task()

if not self.arn:
self._start_task()

self._wait_for_task_ended()

self._check_success_task()
self.log.info('ECS Task has been successfully executed')

def _start_task(self):
run_opts = {
'cluster': self.cluster,
'taskDefinition': self.task_definition,
Expand Down Expand Up @@ -204,10 +232,25 @@ def execute(self, context):
self.log.info('ECS Task started: %s', response)

self.arn = response['tasks'][0]['taskArn']
self._wait_for_task_ended()

self._check_success_task()
self.log.info('ECS Task has been successfully executed: %s', response)
def _try_reattach_task(self):
task_def_resp = self.client.describe_task_definition(self.task_definition)
ecs_task_family = task_def_resp['taskDefinition']['family']

list_tasks_resp = self.client.list_tasks(
cluster=self.cluster, launchType=self.launch_type, desiredStatus='RUNNING', family=ecs_task_family
)
running_tasks = list_tasks_resp['taskArns']

running_tasks_count = len(running_tasks)
if running_tasks_count > 1:
self.arn = running_tasks[0]
self.log.warning('More than 1 ECS Task found. Reattaching to %s', self.arn)
elif running_tasks_count == 1:
self.arn = running_tasks[0]
self.log.info('Reattaching task: %s', self.arn)
else:
self.log.info('No active tasks found to reattach')

def _wait_for_task_ended(self) -> None:
if not self.client or not self.arn:
Expand Down
45 changes: 45 additions & 0 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,48 @@ def test_check_success_task_not_raises(self):
}
self.ecs._check_success_task()
client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])

@parameterized.expand(
[
['EC2', None],
['FARGATE', None],
['EC2', {'testTagKey': 'testTagValue'}],
['', {'testTagKey': 'testTagValue'}],
]
)
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(ECSOperator, '_start_task')
def test_reattach_successful(self, launch_type, tags, start_mock, check_mock, wait_mock):

self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter
client_mock = self.aws_hook_mock.return_value.get_conn.return_value
client_mock.describe_task_definition.return_value = {'taskDefinition': {'family': 'f'}}
client_mock.list_tasks.return_value = {
'taskArns': ['arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55']
}

self.ecs.reattach = True
self.ecs.execute(None)

self.aws_hook_mock.return_value.get_conn.assert_called_once()
extend_args = {}
if launch_type:
extend_args['launchType'] = launch_type
if launch_type == 'FARGATE':
extend_args['platformVersion'] = 'LATEST'
if tags:
extend_args['tags'] = [{'key': k, 'value': v} for (k, v) in tags.items()]

client_mock.describe_task_definition.assert_called_once_with('t')

client_mock.list_tasks.assert_called_once_with(
cluster='c', launchType=launch_type, desiredStatus='RUNNING', family='f'
)

start_mock.assert_not_called()
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
self.assertEqual(
self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
)

0 comments on commit 2b13a2f

Please sign in to comment.