From fe041b9eb08e03c5742d0a6794928dd50c55da9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C4=81rti=C5=86=C5=A1=20Kalv=C4=81ns?= Date: Mon, 23 Aug 2021 13:14:06 +0200 Subject: [PATCH] Add configure_job method to BigQuery job tasks This allows to configure job with any parameters, not only the ones exposed via properties. --- luigi/contrib/bigquery.py | 36 ++++ test/contrib/bigquery_test.py | 360 ++++++++++++---------------------- 2 files changed, 161 insertions(+), 235 deletions(-) diff --git a/luigi/contrib/bigquery.py b/luigi/contrib/bigquery.py index 55fc3b72c8..a1ab8486ef 100644 --- a/luigi/contrib/bigquery.py +++ b/luigi/contrib/bigquery.py @@ -527,6 +527,16 @@ def allow_quoted_new_lines(self): """ Indicates if BigQuery should allow quoted data sections that contain newline characters in a CSV file. The default value is false.""" return False + def configure_job(self, configuration): + """Set additional job configuration. + + This allows to specify job configuration parameters that are not exposed via Task properties. + + :param configuration: Current configuration. + :return: New or updated configuration. + """ + return configuration + def run(self): output = self.output() assert isinstance(output, BigQueryTarget), 'Output must be a BigQueryTarget, not %s' % (output) @@ -565,6 +575,8 @@ def run(self): else: job['configuration']['load']['autodetect'] = True + job['configuration'] = self.configure_job(job['configuration']) + bq_client.run_job(output.table.project_id, job, dataset=output.table.dataset) @@ -610,6 +622,16 @@ def use_legacy_sql(self): """ return True + def configure_job(self, configuration): + """Set additional job configuration. + + This allows to specify job configuration parameters that are not exposed via Task properties. + + :param configuration: Current configuration. + :return: New or updated configuration. + """ + return configuration + def run(self): output = self.output() assert isinstance(output, BigQueryTarget), 'Output must be a BigQueryTarget, not %s' % (output) @@ -643,6 +665,8 @@ def run(self): } } + job['configuration'] = self.configure_job(job['configuration']) + bq_client.run_job(output.table.project_id, job, dataset=output.table.dataset) @@ -739,6 +763,16 @@ def compression(self): """Whether to use compression.""" return Compression.NONE + def configure_job(self, configuration): + """Set additional job configuration. + + This allows to specify job configuration parameters that are not exposed via Task properties. + + :param configuration: Current configuration. + :return: New or updated configuration. + """ + return configuration + def run(self): input = luigi.task.flatten(self.input())[0] assert ( @@ -775,6 +809,8 @@ def run(self): job['configuration']['extract']['fieldDelimiter'] = \ self.field_delimiter + job['configuration'] = self.configure_job(job['configuration']) + bq_client.run_job( input.table.project_id, job, diff --git a/test/contrib/bigquery_test.py b/test/contrib/bigquery_test.py index c135c09911..61ee2d7cda 100644 --- a/test/contrib/bigquery_test.py +++ b/test/contrib/bigquery_test.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2015 Twitter Inc +# Copyright 2019 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,244 +16,134 @@ # """ -These are the unit tests for the BigQuery-luigi binding. +These are the unit tests for the BigQueryLoadAvro class. """ +import unittest -import luigi -from luigi.contrib import bigquery -from luigi.contrib.gcs import GCSTarget - -from helpers import unittest -from mock import MagicMock +import mock import pytest -PROJECT_ID = 'projectid' -DATASET_ID = 'dataset' - - -class TestRunQueryTask(bigquery.BigQueryRunQueryTask): - client = MagicMock() - query = ''' SELECT 'hello' as field1, 2 as field2 ''' - table = luigi.Parameter() - - def output(self): - return bigquery.BigQueryTarget(PROJECT_ID, DATASET_ID, self.table, client=self.client) - - -class TestRunQueryTaskDontFlattenResults(TestRunQueryTask): - - @property - def flatten_results(self): - return False - - -class TestRunQueryTaskWithRequires(bigquery.BigQueryRunQueryTask): - client = MagicMock() - table = luigi.Parameter() - - def requires(self): - return TestRunQueryTask(table='table1') - - @property - def query(self): - requires = self.requires().output().table - dataset = requires.dataset_id - table = requires.table_id - return 'SELECT * FROM [{dataset}.{table}]'.format(dataset=dataset, table=table) - - def output(self): - return bigquery.BigQueryTarget(PROJECT_ID, DATASET_ID, self.table, client=self.client) - - -class TestRunQueryTaskWithUdf(bigquery.BigqueryRunQueryTask): - client = MagicMock() - table = luigi.Parameter() - - @property - def udf_resource_uris(self): - return ["gs://test/file1.js", "gs://test/file2.js"] - - @property - def query(self): - return 'SELECT 1' - - def output(self): - return bigquery.BigqueryTarget(PROJECT_ID, DATASET_ID, self.table, client=self.client) - - -class TestRunQueryTaskWithoutLegacySql(bigquery.BigqueryRunQueryTask): - client = MagicMock() - table = luigi.Parameter() - - @property - def use_legacy_sql(self): - return False - - @property - def query(self): - return 'SELECT 1' - - def output(self): - return bigquery.BigqueryTarget(PROJECT_ID, DATASET_ID, self.table, client=self.client) - - -class TestExternalBigQueryTask(bigquery.ExternalBigQueryTask): - client = MagicMock() - - def output(self): - return bigquery.BigQueryTarget(PROJECT_ID, DATASET_ID, 'table1', client=self.client) - - -class TestCreateViewTask(bigquery.BigQueryCreateViewTask): - client = MagicMock() - view = '''SELECT * FROM table LIMIT 10''' - - def output(self): - return bigquery.BigQueryTarget(PROJECT_ID, DATASET_ID, 'view1', client=self.client) - - -class TestExtractTask(bigquery.BigQueryExtractTask): - client = MagicMock() - - def output(self): - return GCSTarget('gs://test/unload_file.csv', client=self.client) - - def requires(self): - return TestExternalBigQueryTask() - - -@pytest.mark.contrib -class BigQueryTest(unittest.TestCase): - - def test_bulk_complete(self): - parameters = ['table1', 'table2'] - - client = MagicMock() - client.dataset_exists.return_value = True - client.list_tables.return_value = ['table2', 'table3'] - TestRunQueryTask.client = client - - complete = list(TestRunQueryTask.bulk_complete(parameters)) - self.assertEqual(complete, ['table2']) - - # Test that bulk_complete accepts lazy sequences in addition to lists - def parameters_gen(): - yield 'table1' - yield 'table2' - - complete = list(TestRunQueryTask.bulk_complete(parameters_gen())) - self.assertEqual(complete, ['table2']) - - def test_dataset_doesnt_exist(self): - client = MagicMock() - client.dataset_exists.return_value = False - TestRunQueryTask.client = client - - complete = list(TestRunQueryTask.bulk_complete(['table1'])) - self.assertEqual(complete, []) - - def test_query_property(self): - task = TestRunQueryTask(table='table2') - task.client = MagicMock() - task.run() - - (_, job), _ = task.client.run_job.call_args - query = job['configuration']['query']['query'] - self.assertEqual(query, TestRunQueryTask.query) - - def test_override_query_property(self): - task = TestRunQueryTaskWithRequires(table='table2') - task.client = MagicMock() - task.run() - - (_, job), _ = task.client.run_job.call_args - query = job['configuration']['query']['query'] - - expected_table = '[' + DATASET_ID + '.' + task.requires().output().table.table_id + ']' - self.assertIn(expected_table, query) - self.assertEqual(query, task.query) - - def test_query_udf(self): - task = TestRunQueryTaskWithUdf(table='table2') - task.client = MagicMock() - task.run() - - (_, job), _ = task.client.run_job.call_args - - udfs = [ - {'resourceUri': 'gs://test/file1.js'}, - {'resourceUri': 'gs://test/file2.js'}, - ] - - self.assertEqual(job['configuration']['query']['userDefinedFunctionResources'], udfs) - - def test_query_with_legacy_sql(self): - task = TestRunQueryTask(table='table2') - task.client = MagicMock() - task.run() - - (_, job), _ = task.client.run_job.call_args - - self.assertEqual(job['configuration']['query']['useLegacySql'], True) - - def test_query_without_legacy_sql(self): - task = TestRunQueryTaskWithoutLegacySql(table='table2') - task.client = MagicMock() - task.run() - - (_, job), _ = task.client.run_job.call_args - - self.assertEqual(job['configuration']['query']['useLegacySql'], False) - - def test_external_task(self): - task = TestExternalBigQueryTask() - self.assertIsInstance(task, luigi.ExternalTask) - self.assertIsInstance(task, bigquery.MixinBigQueryBulkComplete) - - def test_create_view(self): - task = TestCreateViewTask() - - task.client.get_view.return_value = None - self.assertFalse(task.complete()) - - task.run() - (table, view), _ = task.client.update_view.call_args - self.assertEqual(task.output().table, table) - self.assertEqual(task.view, view) - - def test_update_view(self): - task = TestCreateViewTask() - - task.client.get_view.return_value = 'some other query' - self.assertFalse(task.complete()) - - task.run() - (table, view), _ = task.client.update_view.call_args - self.assertEqual(task.output().table, table) - self.assertEqual(task.view, view) - - def test_view_completed(self): - task = TestCreateViewTask() - - task.client.get_view.return_value = task.view - self.assertTrue(task.complete()) - - def test_flatten_results(self): - task = TestRunQueryTask(table='table3') - self.assertTrue(task.flatten_results) - - def test_dont_flatten_results(self): - task = TestRunQueryTaskDontFlattenResults(table='table3') - self.assertFalse(task.flatten_results) - - def test_extract_table(self): - task = TestExtractTask() - task.run() - - bq_client = luigi.task.flatten(task.input())[0].client - (_, job), _ = bq_client.run_job.call_args +from luigi.contrib.bigquery import BigQueryLoadTask, BigQueryTarget, BQDataset, \ + BigQueryRunQueryTask, BigQueryExtractTask +from luigi.contrib.gcs import GCSTarget - destination_uris = job['configuration']['extract']['destinationUris'] - self.assertEqual(destination_uris, task.destination_uris) +@pytest.mark.gcloud +class BigQueryLoadTaskTest(unittest.TestCase): + + @mock.patch('luigi.contrib.bigquery.BigQueryClient.run_job') + def test_configure_job(self, run_job): + class MyBigQueryLoadTask(BigQueryLoadTask): + def source_uris(self): + return ['gs://_'] + + def configure_job(self, configuration): + configuration['load']['destinationTableProperties'] = { + 'description': 'Nice table' + } + return configuration + + def output(self): + return BigQueryTarget(project_id='proj', dataset_id='ds', table_id='t') + + job = MyBigQueryLoadTask() + job.run() + + expected_body = { + 'configuration': { + 'load': { + 'destinationTable': {'projectId': 'proj', 'datasetId': 'ds', 'tableId': 't'}, + 'encoding': 'UTF-8', + 'sourceFormat': 'NEWLINE_DELIMITED_JSON', + 'writeDisposition': 'WRITE_EMPTY', + 'sourceUris': ['gs://_'], + 'maxBadRecords': 0, + 'ignoreUnknownValues': False, + 'autodetect': True, + 'destinationTableProperties': {'description': 'Nice table'} + } + } + } + run_job.assert_called_with('proj', expected_body, dataset=BQDataset('proj', 'ds', None)) + + +@pytest.mark.gcloud +class BigQueryRunQueryTaskTest(unittest.TestCase): + @mock.patch('luigi.contrib.bigquery.BigQueryClient.run_job') + def test_configure_job(self, run_job): + class MyBigQueryRunQuery(BigQueryRunQueryTask): + query = 'SELECT @thing' + use_legacy_sql = False + + def configure_job(self, configuration): + configuration['query']['parameterMode'] = 'NAMED' + configuration['query']['queryParameters'] = { + 'name': 'thing', + 'parameterType': {'type': 'STRING'}, + 'parameterValue': {'value': 'Nice Thing'} + } + return configuration + + def output(self): + return BigQueryTarget(project_id='proj', dataset_id='ds', table_id='t') + + job = MyBigQueryRunQuery() + job.run() + + expected_body = { + 'configuration': { + 'query': { + 'query': 'SELECT @thing', + 'priority': 'INTERACTIVE', + 'destinationTable': {'projectId': 'proj', 'datasetId': 'ds', 'tableId': 't'}, + 'allowLargeResults': True, + 'createDisposition': 'CREATE_IF_NEEDED', + 'writeDisposition': 'WRITE_TRUNCATE', + 'flattenResults': True, + 'userDefinedFunctionResources': [], + 'useLegacySql': False, + 'parameterMode': 'NAMED', + 'queryParameters': { + 'name': 'thing', + 'parameterType': {'type': 'STRING'}, + 'parameterValue': {'value': 'Nice Thing'} + } + } + } + } + run_job.assert_called_with('proj', expected_body, dataset=BQDataset('proj', 'ds', None)) + + +@pytest.mark.gcloud +class BigQueryExtractTaskTest(unittest.TestCase): + @mock.patch('luigi.contrib.bigquery.BigQueryClient.run_job') + def test_configure_job(self, run_job): + class MyBigQueryExtractTask(BigQueryExtractTask): + destination_format = 'AVRO' + + def configure_job(self, configuration): + configuration['extract']['useAvroLogicalTypes'] = True + return configuration + + def input(self): + return BigQueryTarget(project_id='proj', dataset_id='ds', table_id='t') + + def output(self): + return GCSTarget('gs://_') + + job = MyBigQueryExtractTask() + job.run() + + expected_body = { + 'configuration': { + 'extract': { + 'sourceTable': {'projectId': 'proj', 'datasetId': 'ds', 'tableId': 't'}, + 'destinationUris': ['gs://_'], + 'destinationFormat': 'AVRO', + 'compression': 'NONE', + 'useAvroLogicalTypes': True + } + } + } + run_job.assert_called_with('proj', expected_body, dataset=BQDataset('proj', 'ds', None))