From a0d0131a877c1cc2deae9d10cd879688944dad53 Mon Sep 17 00:00:00 2001 From: Tom Miller Date: Thu, 6 Dec 2018 10:18:44 -0800 Subject: [PATCH] [AIRFLOW-3406] Implement an Azure CosmosDB operator (#4265) Add an operator and hook to manipulate and use Azure CosmosDB documents, including creation, deletion, and updating documents and collections. Includes sensor to detect documents being added to a collection. --- .../example_dags/example_cosmosdb_sensor.py | 64 ++++ airflow/contrib/hooks/azure_cosmos_hook.py | 287 ++++++++++++++++++ .../azure_cosmos_insertdocument_operator.py | 69 +++++ .../contrib/sensors/azure_cosmos_sensor.py | 67 ++++ airflow/models.py | 4 + airflow/utils/db.py | 4 + docs/integration.rst | 29 ++ setup.py | 6 +- tests/contrib/hooks/test_azure_cosmos_hook.py | 202 ++++++++++++ ...st_azure_cosmos_insertdocument_operator.py | 84 +++++ 10 files changed, 815 insertions(+), 1 deletion(-) create mode 100644 airflow/contrib/example_dags/example_cosmosdb_sensor.py create mode 100644 airflow/contrib/hooks/azure_cosmos_hook.py create mode 100644 airflow/contrib/operators/azure_cosmos_insertdocument_operator.py create mode 100644 airflow/contrib/sensors/azure_cosmos_sensor.py create mode 100644 tests/contrib/hooks/test_azure_cosmos_hook.py create mode 100644 tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py diff --git a/airflow/contrib/example_dags/example_cosmosdb_sensor.py b/airflow/contrib/example_dags/example_cosmosdb_sensor.py new file mode 100644 index 0000000000000..a801d9f41ba23 --- /dev/null +++ b/airflow/contrib/example_dags/example_cosmosdb_sensor.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This is only an example DAG to highlight usage of AzureCosmosDocumentSensor to detect +if a document now exists. + +You can trigger this manually with `airflow trigger_dag example_cosmosdb_sensor`. + +*Note: Make sure that connection `azure_cosmos_default` is properly set before running +this example.* +""" + +from airflow import DAG +from airflow.contrib.sensors.azure_cosmos_sensor import AzureCosmosDocumentSensor +from airflow.contrib.operators.azure_cosmos_insertdocument_operator import AzureCosmosInsertDocumentOperator +from airflow.utils import dates + +default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': dates.days_ago(2), + 'email': ['airflow@example.com'], + 'email_on_failure': False, + 'email_on_retry': False +} + +dag = DAG('example_cosmosdb_sensor', default_args=default_args) + +dag.doc_md = __doc__ + +t1 = AzureCosmosDocumentSensor( + task_id='check_cosmos_file', + database_name='airflow_example_db', + collection_name='airflow_example_coll', + document_id='airflow_checkid', + azure_cosmos_conn_id='azure_cosmos_default', + dag=dag) + +t2 = AzureCosmosInsertDocumentOperator( + task_id='insert_cosmos_file', + dag=dag, + database_name='airflow_example_db', + collection_name='new-collection', + document={"id": "someuniqueid", "param1": "value1", "param2": "value2"}, + azure_cosmos_conn_id='azure_cosmos_default') + +t2.set_upstream(t1) diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py new file mode 100644 index 0000000000000..01b4007b0308f --- /dev/null +++ b/airflow/contrib/hooks/azure_cosmos_hook.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import azure.cosmos.cosmos_client as cosmos_client +from azure.cosmos.errors import HTTPFailure +import uuid + +from airflow.exceptions import AirflowBadRequest +from airflow.hooks.base_hook import BaseHook + + +class AzureCosmosDBHook(BaseHook): + """ + Interacts with Azure CosmosDB. + + login should be the endpoint uri, password should be the master key + optionally, you can use the following extras to default these values + {"database_name": "", "collection_name": "COLLECTION_NAME"}. + + :param azure_cosmos_conn_id: Reference to the Azure CosmosDB connection. + :type azure_cosmos_conn_id: str + """ + + def __init__(self, azure_cosmos_conn_id='azure_cosmos_default'): + self.conn_id = azure_cosmos_conn_id + self.connection = self.get_connection(self.conn_id) + self.extras = self.connection.extra_dejson + + self.endpoint_uri = self.connection.login + self.master_key = self.connection.password + self.default_database_name = self.extras.get('database_name') + self.default_collection_name = self.extras.get('collection_name') + self.cosmos_client = None + + def get_conn(self): + """ + Return a cosmos db client. + """ + if self.cosmos_client is not None: + return self.cosmos_client + + # Initialize the Python Azure Cosmos DB client + self.cosmos_client = cosmos_client.CosmosClient(self.endpoint_uri, {'masterKey': self.master_key}) + + return self.cosmos_client + + def __get_database_name(self, database_name=None): + db_name = database_name + if db_name is None: + db_name = self.default_database_name + + if db_name is None: + raise AirflowBadRequest("Database name must be specified") + + return db_name + + def __get_collection_name(self, collection_name=None): + coll_name = collection_name + if coll_name is None: + coll_name = self.default_collection_name + + if coll_name is None: + raise AirflowBadRequest("Collection name must be specified") + + return coll_name + + def does_collection_exist(self, collection_name, database_name=None): + """ + Checks if a collection exists in CosmosDB. + """ + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + existing_container = list(self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [ + {"name": "@id", "value": collection_name} + ] + })) + if len(existing_container) == 0: + return False + + return True + + def create_collection(self, collection_name, database_name=None): + """ + Creates a new collection in the CosmosDB database. + """ + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + # We need to check to see if this container already exists so we don't try + # to create it twice + existing_container = list(self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [ + {"name": "@id", "value": collection_name} + ] + })) + + # Only create if we did not find it already existing + if len(existing_container) == 0: + self.get_conn().CreateContainer( + get_database_link(self.__get_database_name(database_name)), + {"id": collection_name}) + + def does_database_exist(self, database_name): + """ + Checks if a database exists in CosmosDB. + """ + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + existing_database = list(self.get_conn().QueryDatabases({ + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [ + {"name": "@id", "value": database_name} + ] + })) + if len(existing_database) == 0: + return False + + return True + + def create_database(self, database_name): + """ + Creates a new database in CosmosDB. + """ + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + # We need to check to see if this database already exists so we don't try + # to create it twice + existing_database = list(self.get_conn().QueryDatabases({ + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [ + {"name": "@id", "value": database_name} + ] + })) + + # Only create if we did not find it already existing + if len(existing_database) == 0: + self.get_conn().CreateDatabase({"id": database_name}) + + def delete_database(self, database_name): + """ + Deletes an existing database in CosmosDB. + """ + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + self.get_conn().DeleteDatabase(get_database_link(database_name)) + + def delete_collection(self, collection_name, database_name=None): + """ + Deletes an existing collection in the CosmosDB database. + """ + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + self.get_conn().DeleteContainer( + get_collection_link(self.__get_database_name(database_name), collection_name)) + + def upsert_document(self, document, database_name=None, collection_name=None, document_id=None): + """ + Inserts a new document (or updates an existing one) into an existing + collection in the CosmosDB database. + """ + # Assign unique ID if one isn't provided + if document_id is None: + document_id = str(uuid.uuid4()) + + if document is None: + raise AirflowBadRequest("You cannot insert a None document") + + # Add document id if isn't found + if 'id' in document: + if document['id'] is None: + document['id'] = document_id + else: + document['id'] = document_id + + created_document = self.get_conn().CreateItem( + get_collection_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name)), + document) + + return created_document + + def insert_documents(self, documents, database_name=None, collection_name=None): + """ + Insert a list of new documents into an existing collection in the CosmosDB database. + """ + if documents is None: + raise AirflowBadRequest("You cannot insert empty documents") + + created_documents = [] + for single_document in documents: + created_documents.append( + self.get_conn().CreateItem( + get_collection_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name)), + single_document)) + + return created_documents + + def delete_document(self, document_id, database_name=None, collection_name=None): + """ + Delete an existing document out of a collection in the CosmosDB database. + """ + if document_id is None: + raise AirflowBadRequest("Cannot delete a document without an id") + + self.get_conn().DeleteItem( + get_document_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + document_id)) + + def get_document(self, document_id, database_name=None, collection_name=None): + """ + Get a document from an existing collection in the CosmosDB database. + """ + if document_id is None: + raise AirflowBadRequest("Cannot get a document without an id") + + try: + return self.get_conn().ReadItem( + get_document_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + document_id)) + except HTTPFailure: + return None + + def get_documents(self, sql_string, database_name=None, collection_name=None, partition_key=None): + """ + Get a list of documents from an existing collection in the CosmosDB database via SQL query. + """ + if sql_string is None: + raise AirflowBadRequest("SQL query string cannot be None") + + # Query them in SQL + query = {'query': sql_string} + + try: + result_iterable = self.get_conn().QueryItems( + get_collection_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name)), + query, + partition_key) + + return list(result_iterable) + except HTTPFailure: + return None + + +def get_database_link(database_id): + return "dbs/" + database_id + + +def get_collection_link(database_id, collection_id): + return get_database_link(database_id) + "/colls/" + collection_id + + +def get_document_link(database_id, collection_id, document_id): + return get_collection_link(database_id, collection_id) + "/docs/" + document_id diff --git a/airflow/contrib/operators/azure_cosmos_insertdocument_operator.py b/airflow/contrib/operators/azure_cosmos_insertdocument_operator.py new file mode 100644 index 0000000000000..930ff402d0507 --- /dev/null +++ b/airflow/contrib/operators/azure_cosmos_insertdocument_operator.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class AzureCosmosInsertDocumentOperator(BaseOperator): + """ + Inserts a new document into the specified Cosmos database and collection + It will create both the database and collection if they do not already exist + + :param database_name: The name of the database. (templated) + :type database_name: str + :param collection_name: The name of the collection. (templated) + :type collection_name: str + :param document: The document to insert + :type document: json + :param azure_cosmos_conn_id: reference to a CosmosDB connection. + :type azure_cosmos_conn_id: str + """ + template_fields = ('database_name', 'collection_name') + ui_color = '#e4f0e8' + + @apply_defaults + def __init__(self, + database_name, + collection_name, + document, + azure_cosmos_conn_id='azure_cosmos_default', + *args, + **kwargs): + super(AzureCosmosInsertDocumentOperator, self).__init__(*args, **kwargs) + self.database_name = database_name + self.collection_name = collection_name + self.document = document + self.azure_cosmos_conn_id = azure_cosmos_conn_id + + def execute(self, context): + # Create the hook + hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) + + # Create the DB if it doesn't already exist + if (not hook.does_database_exist(self.database_name)): + hook.create_database(self.database_name) + + # Create the collection as well + if (not hook.does_collection_exist(self.collection_name, self.database_name)): + hook.create_collection(self.collection_name, self.database_name) + + # finally insert the document + hook.upsert_document(self.document, self.database_name, self.collection_name) diff --git a/airflow/contrib/sensors/azure_cosmos_sensor.py b/airflow/contrib/sensors/azure_cosmos_sensor.py new file mode 100644 index 0000000000000..78b340d4efe53 --- /dev/null +++ b/airflow/contrib/sensors/azure_cosmos_sensor.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class AzureCosmosDocumentSensor(BaseSensorOperator): + """ + Checks for the existence of a document which + matches the given query in CosmosDB. Example: + + >>> azure_cosmos_sensor = AzureCosmosDocumentSensor(database_name="somedatabase_name", + ... collection_name="somecollection_name", + ... document_id="unique-doc-id", + ... azure_cosmos_conn_id="azure_cosmos_default", + ... task_id="azure_cosmos_sensor") + """ + template_fields = ('database_name', 'collection_name', 'document_id') + + @apply_defaults + def __init__( + self, + database_name, + collection_name, + document_id, + azure_cosmos_conn_id="azure_cosmos_default", + *args, + **kwargs): + """ + Create a new AzureCosmosDocumentSensor + + :param database_name: Target CosmosDB database_name. + :type database_name: str + :param collection_name: Target CosmosDB collection_name. + :type collection_name: str + :param document_id: The ID of the target document. + :type query: str + :param azure_cosmos_conn_id: Reference to the Azure CosmosDB connection. + :type azure_cosmos_conn_id: str + """ + super(AzureCosmosDocumentSensor, self).__init__(*args, **kwargs) + self.azure_cosmos_conn_id = azure_cosmos_conn_id + self.database_name = database_name + self.collection_name = collection_name + self.document_id = document_id + + def poke(self, context): + self.log.info("*** Intering poke") + hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) + return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None diff --git a/airflow/models.py b/airflow/models.py index bb30c9523d0ee..b4b4d83587882 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -652,6 +652,7 @@ class Connection(Base, LoggingMixin): ('snowflake', 'Snowflake',), ('segment', 'Segment',), ('azure_data_lake', 'Azure Data Lake'), + ('azure_cosmos', 'Azure CosmosDB'), ('cassandra', 'Cassandra',), ('qubole', 'Qubole'), ('mongo', 'MongoDB'), @@ -793,6 +794,9 @@ def get_hook(self): elif self.conn_type == 'azure_data_lake': from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) + elif self.conn_type == 'azure_cosmos': + from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook + return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.contrib.hooks.cassandra_hook import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 4697605815682..74317a59c7a1f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -280,6 +280,10 @@ def initdb(rbac=False): models.Connection( conn_id='azure_data_lake_default', conn_type='azure_data_lake', extra='{"tenant": "", "account_name": "" }')) + merge_conn( + models.Connection( + conn_id='azure_cosmos_default', conn_type='azure_cosmos', + extra='{"database_name": "", "collection_name": "" }')) merge_conn( models.Connection( conn_id='cassandra_default', conn_type='cassandra', diff --git a/docs/integration.rst b/docs/integration.rst index 5f5fc89bde4fe..e98d047a58f0c 100644 --- a/docs/integration.rst +++ b/docs/integration.rst @@ -161,6 +161,35 @@ Logging Airflow can be configured to read and write task logs in Azure Blob Storage. See :ref:`write-logs-azure`. +Azure CosmosDB +'''''''''''''''''' + +AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a +Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a +login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify the +default database and collection to use (see connection `azure_cosmos_default` for an example). + +- :ref:`AzureCosmosDBHook`: Interface with Azure CosmosDB. +- :ref:`AzureCosmosInsertDocumentOperator`: Simple operator to insert document into CosmosDB. +- :ref:`AzureCosmosDocumentSensor`: Simple sensor to detect document existence in CosmosDB. + +.. AzureCosmosDBHook: + +AzureCosmosDBHook +""""""""" + +.. autoclass:: airflow.contrib.hooks.azure_cosmos_hook.AzureCosmosDBHook + +AzureCosmosInsertDocumentOperator +""""""""" + +.. autoclass:: airflow.contrib.operators.azure_cosmos_insertdocument_operator.AzureCosmosInsertDocumentOperator + +AzureCosmosDocumentSensor +""""""""" + +.. autoclass:: airflow.contrib.sensors.azure_cosmos_sensor.AzureCosmosDocumentSensor + Azure Data Lake ''''''''''''''' diff --git a/setup.py b/setup.py index f34985c943c04..8a223ee6cb6a1 100644 --- a/setup.py +++ b/setup.py @@ -152,6 +152,7 @@ def write_version(filename=os.path.join(*['airflow', 'azure-mgmt-datalake-store==0.4.0', 'azure-datalake-store==0.0.19' ] +azure_cosmos = ['azure-cosmos>=3.0.1'] cassandra = ['cassandra-driver>=3.13.0'] celery = [ 'celery>=4.1.1, <4.2.0', @@ -260,10 +261,11 @@ def write_version(filename=os.path.join(*['airflow', ] devel_minreq = devel + kubernetes + mysql + doc + password + s3 + cgroups devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos +devel_azure = devel_minreq + azure_data_lake + azure_cosmos devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker + ssh + kubernetes + celery + azure_blob_storage + redis + gcp_api + datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + - druid + pinot + segment + snowflake + elasticsearch + azure_data_lake + + druid + pinot + segment + snowflake + elasticsearch + azure_data_lake + azure_cosmos + atlas) # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'( @@ -338,6 +340,7 @@ def do_setup(): 'async': async_packages, 'azure_blob_storage': azure_blob_storage, 'azure_data_lake': azure_data_lake, + 'azure_cosmos': azure_cosmos, 'cassandra': cassandra, 'celery': celery, 'cgroups': cgroups, @@ -348,6 +351,7 @@ def do_setup(): 'datadog': datadog, 'devel': devel_minreq, 'devel_hadoop': devel_hadoop, + 'devel_azure': devel_azure, 'doc': doc, 'docker': docker, 'druid': druid, diff --git a/tests/contrib/hooks/test_azure_cosmos_hook.py b/tests/contrib/hooks/test_azure_cosmos_hook.py new file mode 100644 index 0000000000000..653242a34b48f --- /dev/null +++ b/tests/contrib/hooks/test_azure_cosmos_hook.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import json +import unittest +import uuid + +from airflow.exceptions import AirflowException +from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook + +from airflow import configuration +from airflow import models +from airflow.utils import db + +import logging + +try: + from unittest import mock + +except ImportError: + try: + import mock + except ImportError: + mock = None + + +class TestAzureCosmosDbHook(unittest.TestCase): + + # Set up an environment to test with + def setUp(self): + # set up some test variables + self.test_end_point = 'https://test_endpoint:443' + self.test_master_key = 'magic_test_key' + self.test_database_name = 'test_database_name' + self.test_collection_name = 'test_collection_name' + self.test_database_default = 'test_database_default' + self.test_collection_default = 'test_collection_default' + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='azure_cosmos_test_key_id', + conn_type='azure_cosmos', + login=self.test_end_point, + password=self.test_master_key, + extra=json.dumps({'database_name': self.test_database_default, + 'collection_name': self.test_collection_default}) + ) + ) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_database(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.create_database(self.test_database_name) + expected_calls = [mock.call().CreateDatabase({'id': self.test_database_name})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_database_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.create_database, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_container_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.create_collection, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_container(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.create_collection(self.test_collection_name, self.test_database_name) + expected_calls = [mock.call().CreateContainer( + 'dbs/test_database_name', + {'id': self.test_collection_name})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_container_default(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.create_collection(self.test_collection_name) + expected_calls = [mock.call().CreateContainer( + 'dbs/test_database_default', + {'id': self.test_collection_name})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_upsert_document_default(self, cosmos_mock): + test_id = str(uuid.uuid4()) + cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + returned_item = self.cosmos.upsert_document({'id': test_id}) + expected_calls = [mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'id': test_id})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + logging.getLogger().info(returned_item) + self.assertEqual(returned_item['id'], test_id) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_upsert_document(self, cosmos_mock): + test_id = str(uuid.uuid4()) + cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + returned_item = self.cosmos.upsert_document( + {'data1': 'somedata'}, + database_name=self.test_database_name, + collection_name=self.test_collection_name, + document_id=test_id) + + expected_calls = [mock.call().CreateItem( + 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, + {'data1': 'somedata', 'id': test_id})] + + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + logging.getLogger().info(returned_item) + self.assertEqual(returned_item['id'], test_id) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_insert_documents(self, cosmos_mock): + test_id1 = str(uuid.uuid4()) + test_id2 = str(uuid.uuid4()) + test_id3 = str(uuid.uuid4()) + documents = [ + {'id': test_id1, 'data': 'data1'}, + {'id': test_id2, 'data': 'data2'}, + {'id': test_id3, 'data': 'data3'}] + + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + returned_item = self.cosmos.insert_documents(documents) + expected_calls = [ + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'data': 'data1', 'id': test_id1}), + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'data': 'data2', 'id': test_id2}), + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'data': 'data3', 'id': test_id3})] + logging.getLogger().info(returned_item) + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_database(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.delete_database(self.test_database_name) + expected_calls = [mock.call().DeleteDatabase('dbs/test_database_name')] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_database_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.delete_database, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_container_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.delete_collection, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_container(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.delete_collection(self.test_collection_name, self.test_database_name) + expected_calls = [mock.call().DeleteContainer('dbs/test_database_name/colls/test_collection_name')] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_container_default(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.delete_collection(self.test_collection_name) + expected_calls = [mock.call().DeleteContainer('dbs/test_database_default/colls/test_collection_name')] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py b/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py new file mode 100644 index 0000000000000..26099d0cb3b57 --- /dev/null +++ b/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import json +import unittest +import uuid + +from airflow.contrib.operators.azure_cosmos_insertdocument_operator import AzureCosmosInsertDocumentOperator + +from airflow import configuration +from airflow import models +from airflow.utils import db + +try: + from unittest import mock + +except ImportError: + try: + import mock + except ImportError: + mock = None + + +class TestAzureCosmosDbHook(unittest.TestCase): + + # Set up an environment to test with + def setUp(self): + # set up some test variables + self.test_end_point = 'https://test_endpoint:443' + self.test_master_key = 'magic_test_key' + self.test_database_name = 'test_database_name' + self.test_collection_name = 'test_collection_name' + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='azure_cosmos_test_key_id', + conn_type='azure_cosmos', + login=self.test_end_point, + password=self.test_master_key, + extra=json.dumps({'database_name': self.test_database_name, + 'collection_name': self.test_collection_name}) + ) + ) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_insert_document(self, cosmos_mock): + test_id = str(uuid.uuid4()) + cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} + self.cosmos = AzureCosmosInsertDocumentOperator( + database_name=self.test_database_name, + collection_name=self.test_collection_name, + document={'id': test_id, 'data': 'sometestdata'}, + azure_cosmos_conn_id='azure_cosmos_test_key_id', + task_id='azure_cosmos_sensor') + + expected_calls = [mock.call().CreateItem( + 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, + {'data': 'sometestdata', 'id': test_id})] + + self.cosmos.execute(None) + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + +if __name__ == '__main__': + unittest.main()