Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-3059] Log how many rows are read from Postgres #3905

Merged
merged 1 commit into from
Sep 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions airflow/contrib/operators/postgres_to_gcs_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,38 @@ def _write_local_data_files(self, cursor):
contain the data for the GCS objects.
"""
schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
file_no = 0
tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}

for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats
row = map(self.convert_types, row)
row_dict = dict(zip(schema, row))

s = json.dumps(row_dict, sort_keys=True)
if PY3:
s = s.encode('utf-8')
tmp_file_handle.write(s)

# Append newline to make dumps BigQuery compatible.
tmp_file_handle.write(b'\n')

# Stop if the file exceeds the file size limit.
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
file_no += 1
tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
tmp_file_handles = {}
row_no = 0

def _create_new_file():
handle = NamedTemporaryFile(delete=True)
filename = self.filename.format(len(tmp_file_handles))
tmp_file_handles[filename] = handle
return handle

# Don't create a file if there is nothing to write
if cursor.rowcount > 0:
tmp_file_handle = _create_new_file()

for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats
row = map(self.convert_types, row)
row_dict = dict(zip(schema, row))

s = json.dumps(row_dict, sort_keys=True)
if PY3:
s = s.encode('utf-8')
tmp_file_handle.write(s)

# Append newline to make dumps BigQuery compatible.
tmp_file_handle.write(b'\n')

# Stop if the file exceeds the file size limit.
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
tmp_file_handle = _create_new_file()
row_no += 1

self.log.info('Received %s rows over %s files', row_no, len(tmp_file_handles))

return tmp_file_handles

Expand Down
100 changes: 62 additions & 38 deletions tests/contrib/operators/test_postgres_to_gcs_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -25,40 +25,66 @@
import sys
import unittest

from airflow.contrib.operators.postgres_to_gcs_operator import PostgresToGoogleCloudStorageOperator
from airflow.hooks.postgres_hook import PostgresHook
from airflow.contrib.operators.postgres_to_gcs_operator import \
PostgresToGoogleCloudStorageOperator

try:
from unittest import mock
from unittest.mock import patch
except ImportError:
try:
import mock
from mock import patch
except ImportError:
mock = None

PY3 = sys.version_info[0] == 3
TABLES = {'postgres_to_gcs_operator', 'postgres_to_gcs_operator_empty'}

TASK_ID = 'test-postgres-to-gcs'
POSTGRES_CONN_ID = 'postgres_conn_test'
SQL = 'select 1'
POSTGRES_CONN_ID = 'postgres_default'
SQL = 'SELECT * FROM postgres_to_gcs_operator'
BUCKET = 'gs://test'
FILENAME = 'test_{}.ndjson'
# we expect the psycopg cursor to return encoded strs in py2 and decoded in py3
if PY3:
ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)]
CURSOR_DESCRIPTION = (('some_str', 0), ('some_num', 1005))
else:
ROWS = [(b'mock_row_content_1', 42), (b'mock_row_content_2', 43), (b'mock_row_content_3', 44)]
CURSOR_DESCRIPTION = ((b'some_str', 0), (b'some_num', 1005))

NDJSON_LINES = [
b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
b'{"some_num": 44, "some_str": "mock_row_content_3"}\n'
]
SCHEMA_FILENAME = 'schema_test.json'
SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, {"mode": "REPEATED", "name": "some_num", "type": "INTEGER"}]'
SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]'


class PostgresToGoogleCloudStorageOperatorTest(unittest.TestCase):
def setUp(self):
postgres = PostgresHook()
with postgres.get_conn() as conn:
with conn.cursor() as cur:
for table in TABLES:
cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);"
.format(table))

cur.execute(
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
('mock_row_content_1', 42)
)
cur.execute(
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
('mock_row_content_2', 43)
)
cur.execute(
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
('mock_row_content_3', 44)
)

def tearDown(self):
postgres = PostgresHook()
with postgres.get_conn() as conn:
with conn.cursor() as cur:
for table in TABLES:
cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))

def test_init(self):
"""Test PostgresToGoogleCloudStorageOperator instance is properly initialized."""
op = PostgresToGoogleCloudStorageOperator(
Expand All @@ -68,9 +94,8 @@ def test_init(self):
self.assertEqual(op.bucket, BUCKET)
self.assertEqual(op.filename, FILENAME)

@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_exec_success(self, gcs_hook_mock_class):
"""Test the execute function in case where the run is successful."""
op = PostgresToGoogleCloudStorageOperator(
task_id=TASK_ID,
Expand All @@ -79,10 +104,6 @@ def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
bucket=BUCKET,
filename=FILENAME)

pg_hook_mock = pg_hook_mock_class.return_value
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, content_type):
Expand All @@ -96,16 +117,9 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):

op.execute(None)

pg_hook_mock_class.assert_called_once_with(postgres_conn_id=POSTGRES_CONN_ID)
pg_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL, None)

@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_file_splitting(self, gcs_hook_mock_class, pg_hook_mock_class):
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_file_splitting(self, gcs_hook_mock_class):
"""Test that ndjson is split by approx_max_file_size_bytes param."""
pg_hook_mock = pg_hook_mock_class.return_value
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION

gcs_hook_mock = gcs_hook_mock_class.return_value
expected_upload = {
Expand All @@ -129,13 +143,23 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):
approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]))
op.execute(None)

@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_schema_file(self, gcs_hook_mock_class, pg_hook_mock_class):
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_empty_query(self, gcs_hook_mock_class):
"""If the sql returns no rows, we should not upload any files"""
gcs_hook_mock = gcs_hook_mock_class.return_value

op = PostgresToGoogleCloudStorageOperator(
task_id=TASK_ID,
sql='SELECT * FROM postgres_to_gcs_operator_empty',
bucket=BUCKET,
filename=FILENAME)
op.execute(None)

assert not gcs_hook_mock.upload.called, 'No data means no files in the bucket'

@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_schema_file(self, gcs_hook_mock_class):
"""Test writing schema files."""
pg_hook_mock = pg_hook_mock_class.return_value
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION

gcs_hook_mock = gcs_hook_mock_class.return_value

Expand Down