diff --git a/airflow/contrib/operators/postgres_to_gcs_operator.py b/airflow/contrib/operators/postgres_to_gcs_operator.py index 88b4d00e39790..ddebc05319039 100644 --- a/airflow/contrib/operators/postgres_to_gcs_operator.py +++ b/airflow/contrib/operators/postgres_to_gcs_operator.py @@ -133,28 +133,38 @@ def _write_local_data_files(self, cursor): contain the data for the GCS objects. """ schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) - file_no = 0 - tmp_file_handle = NamedTemporaryFile(delete=True) - tmp_file_handles = {self.filename.format(file_no): tmp_file_handle} - - for row in cursor: - # Convert datetime objects to utc seconds, and decimals to floats - row = map(self.convert_types, row) - row_dict = dict(zip(schema, row)) - - s = json.dumps(row_dict, sort_keys=True) - if PY3: - s = s.encode('utf-8') - tmp_file_handle.write(s) - - # Append newline to make dumps BigQuery compatible. - tmp_file_handle.write(b'\n') - - # Stop if the file exceeds the file size limit. - if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: - file_no += 1 - tmp_file_handle = NamedTemporaryFile(delete=True) - tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle + tmp_file_handles = {} + row_no = 0 + + def _create_new_file(): + handle = NamedTemporaryFile(delete=True) + filename = self.filename.format(len(tmp_file_handles)) + tmp_file_handles[filename] = handle + return handle + + # Don't create a file if there is nothing to write + if cursor.rowcount > 0: + tmp_file_handle = _create_new_file() + + for row in cursor: + # Convert datetime objects to utc seconds, and decimals to floats + row = map(self.convert_types, row) + row_dict = dict(zip(schema, row)) + + s = json.dumps(row_dict, sort_keys=True) + if PY3: + s = s.encode('utf-8') + tmp_file_handle.write(s) + + # Append newline to make dumps BigQuery compatible. + tmp_file_handle.write(b'\n') + + # Stop if the file exceeds the file size limit. + if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: + tmp_file_handle = _create_new_file() + row_no += 1 + + self.log.info('Received %s rows over %s files', row_no, len(tmp_file_handles)) return tmp_file_handles diff --git a/tests/contrib/operators/test_postgres_to_gcs_operator.py b/tests/contrib/operators/test_postgres_to_gcs_operator.py index aaf2b9715c826..608c57d503941 100644 --- a/tests/contrib/operators/test_postgres_to_gcs_operator.py +++ b/tests/contrib/operators/test_postgres_to_gcs_operator.py @@ -24,40 +24,66 @@ import unittest -from airflow.contrib.operators.postgres_to_gcs_operator import PostgresToGoogleCloudStorageOperator +from airflow.hooks.postgres_hook import PostgresHook +from airflow.contrib.operators.postgres_to_gcs_operator import \ + PostgresToGoogleCloudStorageOperator try: - from unittest import mock + from unittest.mock import patch except ImportError: try: - import mock + from mock import patch except ImportError: mock = None -PY3 = sys.version_info[0] == 3 +TABLES = {'postgres_to_gcs_operator', 'postgres_to_gcs_operator_empty'} TASK_ID = 'test-postgres-to-gcs' -POSTGRES_CONN_ID = 'postgres_conn_test' -SQL = 'select 1' +POSTGRES_CONN_ID = 'postgres_default' +SQL = 'SELECT * FROM postgres_to_gcs_operator' BUCKET = 'gs://test' FILENAME = 'test_{}.ndjson' -# we expect the psycopg cursor to return encoded strs in py2 and decoded in py3 -if PY3: - ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)] - CURSOR_DESCRIPTION = (('some_str', 0), ('some_num', 1005)) -else: - ROWS = [(b'mock_row_content_1', 42), (b'mock_row_content_2', 43), (b'mock_row_content_3', 44)] - CURSOR_DESCRIPTION = ((b'some_str', 0), (b'some_num', 1005)) + NDJSON_LINES = [ b'{"some_num": 42, "some_str": "mock_row_content_1"}\n', b'{"some_num": 43, "some_str": "mock_row_content_2"}\n', b'{"some_num": 44, "some_str": "mock_row_content_3"}\n' ] SCHEMA_FILENAME = 'schema_test.json' -SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, {"mode": "REPEATED", "name": "some_num", "type": "INTEGER"}]' +SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \ + b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]' class PostgresToGoogleCloudStorageOperatorTest(unittest.TestCase): + def setUp(self): + postgres = PostgresHook() + with postgres.get_conn() as conn: + with conn.cursor() as cur: + for table in TABLES: + cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table)) + cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);" + .format(table)) + + cur.execute( + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", + ('mock_row_content_1', 42) + ) + cur.execute( + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", + ('mock_row_content_2', 43) + ) + cur.execute( + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", + ('mock_row_content_3', 44) + ) + + def tearDown(self): + postgres = PostgresHook() + with postgres.get_conn() as conn: + with conn.cursor() as cur: + for table in TABLES: + cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table)) + def test_init(self): """Test PostgresToGoogleCloudStorageOperator instance is properly initialized.""" op = PostgresToGoogleCloudStorageOperator( @@ -67,9 +93,8 @@ def test_init(self): self.assertEqual(op.bucket, BUCKET) self.assertEqual(op.filename, FILENAME) - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook') - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') - def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class): + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_exec_success(self, gcs_hook_mock_class): """Test the execute function in case where the run is successful.""" op = PostgresToGoogleCloudStorageOperator( task_id=TASK_ID, @@ -78,10 +103,6 @@ def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class): bucket=BUCKET, filename=FILENAME) - pg_hook_mock = pg_hook_mock_class.return_value - pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) - pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION - gcs_hook_mock = gcs_hook_mock_class.return_value def _assert_upload(bucket, obj, tmp_filename, content_type): @@ -95,16 +116,9 @@ def _assert_upload(bucket, obj, tmp_filename, content_type): op.execute(None) - pg_hook_mock_class.assert_called_once_with(postgres_conn_id=POSTGRES_CONN_ID) - pg_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL, None) - - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook') - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') - def test_file_splitting(self, gcs_hook_mock_class, pg_hook_mock_class): + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_file_splitting(self, gcs_hook_mock_class): """Test that ndjson is split by approx_max_file_size_bytes param.""" - pg_hook_mock = pg_hook_mock_class.return_value - pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) - pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION gcs_hook_mock = gcs_hook_mock_class.return_value expected_upload = { @@ -128,13 +142,23 @@ def _assert_upload(bucket, obj, tmp_filename, content_type): approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)])) op.execute(None) - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook') - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') - def test_schema_file(self, gcs_hook_mock_class, pg_hook_mock_class): + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_empty_query(self, gcs_hook_mock_class): + """If the sql returns no rows, we should not upload any files""" + gcs_hook_mock = gcs_hook_mock_class.return_value + + op = PostgresToGoogleCloudStorageOperator( + task_id=TASK_ID, + sql='SELECT * FROM postgres_to_gcs_operator_empty', + bucket=BUCKET, + filename=FILENAME) + op.execute(None) + + assert not gcs_hook_mock.upload.called, 'No data means no files in the bucket' + + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_schema_file(self, gcs_hook_mock_class): """Test writing schema files.""" - pg_hook_mock = pg_hook_mock_class.return_value - pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) - pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION gcs_hook_mock = gcs_hook_mock_class.return_value