From 63aa3db88f8824efe79622301efd9f8ba75b991c Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Sun, 2 Feb 2020 11:27:39 +0200 Subject: [PATCH] [AIRFLOW-6258] Add CloudFormation operators to AWS providers (#6824) --- .../amazon/aws/hooks/cloud_formation.py | 87 ++++++++++++ .../amazon/aws/operators/cloud_formation.py | 102 ++++++++++++++ .../amazon/aws/sensors/cloud_formation.py | 96 +++++++++++++ docs/operators-and-hooks-ref.rst | 6 + .../amazon/aws/hooks/test_cloud_formation.py | 105 +++++++++++++++ .../aws/operators/test_cloud_formation.py | 106 +++++++++++++++ .../aws/sensors/test_cloud_formation.py | 127 ++++++++++++++++++ 7 files changed, 629 insertions(+) create mode 100644 airflow/providers/amazon/aws/hooks/cloud_formation.py create mode 100644 airflow/providers/amazon/aws/operators/cloud_formation.py create mode 100644 airflow/providers/amazon/aws/sensors/cloud_formation.py create mode 100644 tests/providers/amazon/aws/hooks/test_cloud_formation.py create mode 100644 tests/providers/amazon/aws/operators/test_cloud_formation.py create mode 100644 tests/providers/amazon/aws/sensors/test_cloud_formation.py diff --git a/airflow/providers/amazon/aws/hooks/cloud_formation.py b/airflow/providers/amazon/aws/hooks/cloud_formation.py new file mode 100644 index 0000000000000..062fc184524ce --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/cloud_formation.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains AWS CloudFormation Hook +""" +from botocore.exceptions import ClientError + +from airflow.contrib.hooks.aws_hook import AwsHook + + +class AWSCloudFormationHook(AwsHook): + """ + Interact with AWS CloudFormation. + """ + + def __init__(self, region_name=None, *args, **kwargs): + self.region_name = region_name + self.conn = None + super().__init__(*args, **kwargs) + + def get_conn(self): + if not self.conn: + self.conn = self.get_client_type('cloudformation', self.region_name) + return self.conn + + def get_stack_status(self, stack_name): + """ + Get stack status from CloudFormation. + """ + cloudformation = self.get_conn() + + self.log.info('Poking for stack %s', stack_name) + + try: + stacks = cloudformation.describe_stacks(StackName=stack_name)['Stacks'] + return stacks[0]['StackStatus'] + except ClientError as e: + if 'does not exist' in str(e): + return None + else: + raise e + + def create_stack(self, stack_name, params): + """ + Create stack in CloudFormation. + + :param stack_name: stack_name. + :type stack_name: str + :param params: parameters to be passed to CloudFormation. + :type params: dict + """ + + if 'StackName' not in params: + params['StackName'] = stack_name + self.get_conn().create_stack(**params) + + def delete_stack(self, stack_name, params=None): + """ + Delete stack in CloudFormation. + + :param stack_name: stack_name. + :type stack_name: str + :param params: parameters to be passed to CloudFormation (optional). + :type params: dict + """ + + params = params or {} + if 'StackName' not in params: + params['StackName'] = stack_name + self.get_conn().delete_stack(**params) diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py new file mode 100644 index 0000000000000..a3bce702a6cf3 --- /dev/null +++ b/airflow/providers/amazon/aws/operators/cloud_formation.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains CloudFormation create/delete stack operators. +""" +from typing import List + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook +from airflow.utils.decorators import apply_defaults + + +class CloudFormationCreateStackOperator(BaseOperator): + """ + An operator that creates a CloudFormation stack. + + :param stack_name: stack name (templated) + :type stack_name: str + :param params: parameters to be passed to CloudFormation. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cloudformation.html#CloudFormation.Client.create_stack + :type params: dict + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + template_fields: List[str] = ['stack_name'] + template_ext = () + ui_color = '#6b9659' + + @apply_defaults + def __init__( + self, + stack_name, + params, + aws_conn_id='aws_default', + *args, **kwargs): + super().__init__(*args, **kwargs) + self.stack_name = stack_name + self.params = params + self.aws_conn_id = aws_conn_id + + def execute(self, context): + self.log.info('Parameters: %s', self.params) + + cloudformation_hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id) + cloudformation_hook.create_stack(self.stack_name, self.params) + + +class CloudFormationDeleteStackOperator(BaseOperator): + """ + An operator that deletes a CloudFormation stack. + + :param stack_name: stack name (templated) + :type stack_name: str + :param params: parameters to be passed to CloudFormation. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cloudformation.html#CloudFormation.Client.delete_stack + :type params: dict + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + template_fields: List[str] = ['stack_name'] + template_ext = () + ui_color = '#1d472b' + ui_fgcolor = '#FFF' + + @apply_defaults + def __init__( + self, + stack_name, + params=None, + aws_conn_id='aws_default', + *args, **kwargs): + super().__init__(*args, **kwargs) + self.params = params or {} + self.stack_name = stack_name + self.params = params + self.aws_conn_id = aws_conn_id + + def execute(self, context): + self.log.info('Parameters: %s', self.params) + + cloudformation_hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id) + cloudformation_hook.delete_stack(self.stack_name, self.params) diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py new file mode 100644 index 0000000000000..7f99b1b6e82d4 --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains sensors for AWS CloudFormation. +""" +from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class CloudFormationCreateStackSensor(BaseSensorOperator): + """ + Waits for a stack to be created successfully on AWS CloudFormation. + + :param stack_name: The name of the stack to wait for (templated) + :type stack_name: str + :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are + stored + :type aws_conn_id: str + :param poke_interval: Time in seconds that the job should wait between each try + :type poke_interval: int + """ + + template_fields = ['stack_name'] + ui_color = '#C5CAE9' + + @apply_defaults + def __init__(self, + stack_name, + aws_conn_id='aws_default', + region_name=None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.stack_name = stack_name + self.hook = AWSCloudFormationHook(aws_conn_id=aws_conn_id, region_name=region_name) + + def poke(self, context): + stack_status = self.hook.get_stack_status(self.stack_name) + if stack_status == 'CREATE_COMPLETE': + return True + if stack_status in ('CREATE_IN_PROGRESS', None): + return False + raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}') + + +class CloudFormationDeleteStackSensor(BaseSensorOperator): + """ + Waits for a stack to be deleted successfully on AWS CloudFormation. + + :param stack_name: The name of the stack to wait for (templated) + :type stack_name: str + :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are + stored + :type aws_conn_id: str + :param poke_interval: Time in seconds that the job should wait between each try + :type poke_interval: int + """ + + template_fields = ['stack_name'] + ui_color = '#C5CAE9' + + @apply_defaults + def __init__(self, + stack_name, + aws_conn_id='aws_default', + region_name=None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.stack_name = stack_name + self.hook = AWSCloudFormationHook(aws_conn_id=aws_conn_id, region_name=region_name) + + def poke(self, context): + stack_status = self.hook.get_stack_status(self.stack_name) + if stack_status in ('DELETE_COMPLETE', None): + return True + if stack_status == 'DELETE_IN_PROGRESS': + return False + raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}') diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 6b1cd1c87431c..3735f7af0a2a0 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -331,6 +331,12 @@ These integrations allow you to perform various operations within the Amazon Web - :mod:`airflow.providers.amazon.aws.operators.athena` - :mod:`airflow.providers.amazon.aws.sensors.athena` + * - `Amazon CloudFormation `__ + - + - :mod:`airflow.providers.amazon.aws.hooks.cloud_formation` + - :mod:`airflow.providers.amazon.aws.operators.cloud_formation` + - :mod:`airflow.providers.amazon.aws.sensors.cloud_formation` + * - `Amazon CloudWatch Logs `__ - - :mod:`airflow.providers.amazon.aws.hooks.logs` diff --git a/tests/providers/amazon/aws/hooks/test_cloud_formation.py b/tests/providers/amazon/aws/hooks/test_cloud_formation.py new file mode 100644 index 0000000000000..ae7c45aac8567 --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_cloud_formation.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import unittest + +from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook + +try: + from moto import mock_cloudformation +except ImportError: + mock_cloudformation = None + + +@unittest.skipIf(mock_cloudformation is None, 'moto package not present') +class TestAWSCloudFormationHook(unittest.TestCase): + + def setUp(self): + self.hook = AWSCloudFormationHook(aws_conn_id='aws_default') + + def create_stack(self, stack_name): + timeout = 15 + template_body = json.dumps({ + 'Resources': { + "myResource": { + "Type": "emr", + "Properties": { + "myProperty": "myPropertyValue" + } + } + } + }) + + self.hook.create_stack( + stack_name=stack_name, + params={ + 'TimeoutInMinutes': timeout, + 'TemplateBody': template_body, + 'Parameters': [{'ParameterKey': 'myParam', 'ParameterValue': 'myParamValue'}] + } + ) + + @mock_cloudformation + def test_get_conn_returns_a_boto3_connection(self): + self.assertIsNotNone(self.hook.get_conn().describe_stacks()) + + @mock_cloudformation + def test_get_stack_status(self): + stack_name = 'my_test_get_stack_status_stack' + + stack_status = self.hook.get_stack_status(stack_name=stack_name) + self.assertIsNone(stack_status) + + self.create_stack(stack_name) + stack_status = self.hook.get_stack_status(stack_name=stack_name) + self.assertEqual(stack_status, 'CREATE_COMPLETE', 'Incorrect stack status returned.') + + @mock_cloudformation + def test_create_stack(self): + stack_name = 'my_test_create_stack_stack' + self.create_stack(stack_name) + + stacks = self.hook.get_conn().describe_stacks()['Stacks'] + self.assertGreater(len(stacks), 0, 'CloudFormation should have stacks') + + matching_stacks = [x for x in stacks if x['StackName'] == stack_name] + self.assertEqual(len(matching_stacks), 1, f'stack with name {stack_name} should exist') + + stack = matching_stacks[0] + self.assertEqual( + stack['StackStatus'], + 'CREATE_COMPLETE', + 'Stack should be in status CREATE_COMPLETE' + ) + + @mock_cloudformation + def test_delete_stack(self): + stack_name = 'my_test_delete_stack_stack' + self.create_stack(stack_name) + + self.hook.delete_stack(stack_name=stack_name) + + stacks = self.hook.get_conn().describe_stacks()['Stacks'] + matching_stacks = [x for x in stacks if x['StackName'] == stack_name] + self.assertEqual(len(matching_stacks), 0, f'stack with name {stack_name} should not exist') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py new file mode 100644 index 0000000000000..abc8c333bf9a8 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest +from unittest.mock import MagicMock, patch + +from airflow import DAG +from airflow.providers.amazon.aws.operators.cloud_formation import ( + CloudFormationCreateStackOperator, CloudFormationDeleteStackOperator, +) +from airflow.utils import timezone + +DEFAULT_DATE = timezone.datetime(2019, 1, 1) + + +class TestCloudFormationCreateStackOperator(unittest.TestCase): + + def setUp(self): + self.args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + + # Mock out the cloudformation_client (moto fails with an exception). + self.cloudformation_client_mock = MagicMock() + + # Mock out the emr_client creator + cloudformation_session_mock = MagicMock() + cloudformation_session_mock.client.return_value = self.cloudformation_client_mock + self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock) + + self.mock_context = MagicMock() + + def test_create_stack(self): + stack_name = "myStack" + timeout = 15 + template_body = "My stack body" + + operator = CloudFormationCreateStackOperator( + task_id='test_task', + stack_name=stack_name, + params={ + 'TimeoutInMinutes': timeout, + 'TemplateBody': template_body + }, + dag=DAG('test_dag_id', default_args=self.args), + ) + + with patch('boto3.session.Session', self.boto3_session_mock): + operator.execute(self.mock_context) + + self.cloudformation_client_mock.create_stack.assert_any_call(StackName=stack_name, + TemplateBody=template_body, + TimeoutInMinutes=timeout) + + +class TestCloudFormationDeleteStackOperator(unittest.TestCase): + + def setUp(self): + self.args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + + # Mock out the cloudformation_client (moto fails with an exception). + self.cloudformation_client_mock = MagicMock() + + # Mock out the emr_client creator + cloudformation_session_mock = MagicMock() + cloudformation_session_mock.client.return_value = self.cloudformation_client_mock + self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock) + + self.mock_context = MagicMock() + + def test_delete_stack(self): + stack_name = "myStackToBeDeleted" + + operator = CloudFormationDeleteStackOperator( + task_id='test_task', + stack_name=stack_name, + dag=DAG('test_dag_id', default_args=self.args), + ) + + with patch('boto3.session.Session', self.boto3_session_mock): + operator.execute(self.mock_context) + + self.cloudformation_client_mock.delete_stack.assert_any_call(StackName=stack_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py new file mode 100644 index 0000000000000..7d93a601fd263 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest + +import boto3 +from mock import MagicMock, patch + +from airflow.providers.amazon.aws.sensors.cloud_formation import ( + CloudFormationCreateStackSensor, CloudFormationDeleteStackSensor, +) + +try: + from moto import mock_cloudformation +except ImportError: + mock_cloudformation = None + + +@unittest.skipIf(mock_cloudformation is None, + "Skipping test because moto.mock_cloudformation is not available") +class TestCloudFormationCreateStackSensor(unittest.TestCase): + task_id = 'test_cloudformation_cluster_create_sensor' + + @mock_cloudformation + def setUp(self): + self.client = boto3.client('cloudformation', region_name='us-east-1') + + self.cloudformation_client_mock = MagicMock() + + cloudformation_session_mock = MagicMock() + cloudformation_session_mock.client.return_value = self.cloudformation_client_mock + + self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock) + + @mock_cloudformation + def test_poke(self): + stack_name = 'foobar' + self.client.create_stack(StackName=stack_name, TemplateBody='{"Resources": {}}') + op = CloudFormationCreateStackSensor(task_id='task', stack_name='foobar') + self.assertTrue(op.poke({})) + + def test_poke_false(self): + with patch('boto3.session.Session', self.boto3_session_mock): + self.cloudformation_client_mock.describe_stacks.return_value = { + 'Stacks': [{'StackStatus': 'CREATE_IN_PROGRESS'}] + } + op = CloudFormationCreateStackSensor(task_id='task', stack_name='foo') + self.assertFalse(op.poke({})) + + def test_poke_stack_in_unsuccessful_state(self): + with patch('boto3.session.Session', self.boto3_session_mock): + self.cloudformation_client_mock.describe_stacks.return_value = { + 'Stacks': [{'StackStatus': 'bar'}] + } + with self.assertRaises(ValueError) as error: + op = CloudFormationCreateStackSensor(task_id='task', stack_name='foo') + op.poke({}) + + self.assertEqual('Stack foo in bad state: bar', str(error.exception)) + + +@unittest.skipIf(mock_cloudformation is None, + "Skipping test because moto.mock_cloudformation is not available") +class TestCloudFormationDeleteStackSensor(unittest.TestCase): + task_id = 'test_cloudformation_cluster_delete_sensor' + + @mock_cloudformation + def setUp(self): + self.client = boto3.client('cloudformation', region_name='us-east-1') + + self.cloudformation_client_mock = MagicMock() + + cloudformation_session_mock = MagicMock() + cloudformation_session_mock.client.return_value = self.cloudformation_client_mock + + self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock) + + @mock_cloudformation + def test_poke(self): + stack_name = 'foobar' + self.client.create_stack(StackName=stack_name, TemplateBody='{"Resources": {}}') + self.client.delete_stack(StackName=stack_name) + op = CloudFormationDeleteStackSensor(task_id='task', stack_name=stack_name) + self.assertTrue(op.poke({})) + + def test_poke_false(self): + with patch('boto3.session.Session', self.boto3_session_mock): + self.cloudformation_client_mock.describe_stacks.return_value = { + 'Stacks': [{'StackStatus': 'DELETE_IN_PROGRESS'}] + } + op = CloudFormationDeleteStackSensor(task_id='task', stack_name='foo') + self.assertFalse(op.poke({})) + + def test_poke_stack_in_unsuccessful_state(self): + with patch('boto3.session.Session', self.boto3_session_mock): + self.cloudformation_client_mock.describe_stacks.return_value = { + 'Stacks': [{'StackStatus': 'bar'}] + } + with self.assertRaises(ValueError) as error: + op = CloudFormationDeleteStackSensor(task_id='task', stack_name='foo') + op.poke({}) + + self.assertEqual('Stack foo in bad state: bar', str(error.exception)) + + @mock_cloudformation + def test_poke_stack_does_not_exist(self): + op = CloudFormationDeleteStackSensor(task_id='task', stack_name='foo') + self.assertTrue(op.poke({})) + + +if __name__ == '__main__': + unittest.main()