Skip to content

Commit

Permalink
[AIRFLOW-3288] Add SNS integration (apache#4123)
Browse files Browse the repository at this point in the history
Provides a hook and an operator for publishing Amazon SNS messages.

Useful for integrating various Amazon services (SQS, Lambda) and
sending regular notifications (e-mail, SMS, ...).
  • Loading branch information
sjednac authored and Alice Berard committed Jan 3, 2019
1 parent 824f271 commit 051ed22
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 0 deletions.
60 changes: 60 additions & 0 deletions airflow/contrib/hooks/aws_sns_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import json

from airflow.contrib.hooks.aws_hook import AwsHook


class AwsSnsHook(AwsHook):
"""
Interact with Amazon Simple Notification Service.
"""

def __init__(self, *args, **kwargs):
super(AwsSnsHook, self).__init__(*args, **kwargs)

def get_conn(self):
"""
Get an SNS connection
"""
self.conn = self.get_client_type('sns')
return self.conn

def publish_to_target(self, target_arn, message):
"""
Publish a message to a topic or an endpoint.
:param target_arn: either a TopicArn or an EndpointArn
:type target_arn: str
:param message: the default message you want to send
:param message: str
"""

conn = self.get_conn()

messages = {
'default': message
}

return conn.publish(
TargetArn=target_arn,
Message=json.dumps(messages),
MessageStructure='json'
)
65 changes: 65 additions & 0 deletions airflow/contrib/operators/sns_publish_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.aws_sns_hook import AwsSnsHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults


class SnsPublishOperator(BaseOperator):
"""
Publish a message to Amazon SNS.
:param aws_conn_id: aws connection to use
:type aws_conn_id: str
:param target_arn: either a TopicArn or an EndpointArn
:type target_arn: str
:param message: the default message you want to send (templated)
:type message: str
"""
template_fields = ['message']
template_ext = ()

@apply_defaults
def __init__(
self,
target_arn,
message,
aws_conn_id='aws_default',
*args, **kwargs):
super(SnsPublishOperator, self).__init__(*args, **kwargs)
self.target_arn = target_arn
self.message = message
self.aws_conn_id = aws_conn_id

def execute(self, context):
sns = AwsSnsHook(aws_conn_id=self.aws_conn_id)

self.log.info(
'Sending SNS notification to {} using {}:\n{}'.format(
self.target_arn,
self.aws_conn_id,
self.message
)
)

return sns.publish_to_target(
target_arn=self.target_arn,
message=self.message
)
53 changes: 53 additions & 0 deletions tests/contrib/hooks/test_aws_sns_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import unittest

from airflow.contrib.hooks.aws_sns_hook import AwsSnsHook

try:
from moto import mock_sns
except ImportError:
mock_sns = None


@unittest.skipIf(mock_sns is None, 'moto package not present')
class TestAwsLambdaHook(unittest.TestCase):

@mock_sns
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsSnsHook(aws_conn_id='aws_default')
self.assertIsNotNone(hook.get_conn())

@mock_sns
def test_publish_to_target(self):
hook = AwsSnsHook(aws_conn_id='aws_default')

message = "Hello world"
topic_name = "test-topic"
target = hook.get_conn().create_topic(Name=topic_name).get('TopicArn')

response = hook.publish_to_target(target, message)

self.assertTrue('MessageId' in response)


if __name__ == '__main__':
unittest.main()
72 changes: 72 additions & 0 deletions tests/contrib/operators/test_sns_publish_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import mock
import unittest

from airflow.contrib.operators.sns_publish_operator import SnsPublishOperator

TASK_ID = "sns_publish_job"
AWS_CONN_ID = "custom_aws_conn"
TARGET_ARN = "arn:aws:sns:eu-central-1:1234567890:test-topic"
MESSAGE = "Message to send"


class TestSnsPublishOperator(unittest.TestCase):

def test_init(self):
# Given / When
operator = SnsPublishOperator(
task_id=TASK_ID,
aws_conn_id=AWS_CONN_ID,
target_arn=TARGET_ARN,
message=MESSAGE
)

# Then
self.assertEqual(TASK_ID, operator.task_id)
self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
self.assertEqual(TARGET_ARN, operator.target_arn)
self.assertEqual(MESSAGE, operator.message)

@mock.patch('airflow.contrib.operators.sns_publish_operator.AwsSnsHook')
def test_execute(self, mock_hook):
# Given
hook_response = {'MessageId': 'foobar'}

hook_instance = mock_hook.return_value
hook_instance.publish_to_target.return_value = hook_response

operator = SnsPublishOperator(
task_id=TASK_ID,
aws_conn_id=AWS_CONN_ID,
target_arn=TARGET_ARN,
message=MESSAGE
)

# When
result = operator.execute(None)

# Then
self.assertEqual(hook_response, result)


if __name__ == '__main__':
unittest.main()

0 comments on commit 051ed22

Please sign in to comment.