diff --git a/bigquery/google/cloud/bigquery/magics.py b/bigquery/google/cloud/bigquery/magics.py index 6bd1c45dfcd5..0acde4f21b5f 100644 --- a/bigquery/google/cloud/bigquery/magics.py +++ b/bigquery/google/cloud/bigquery/magics.py @@ -161,6 +161,7 @@ def __init__(self): self._project = None self._connection = None self._use_bqstorage_api = None + self._default_query_job_config = bigquery.QueryJobConfig() @property def credentials(self): @@ -237,6 +238,28 @@ def use_bqstorage_api(self): def use_bqstorage_api(self, value): self._use_bqstorage_api = value + @property + def default_query_job_config(self): + """google.cloud.bigquery.job.QueryJobConfig: Default job + configuration for queries. + + The context's :class:`~google.cloud.bigquery.job.QueryJobConfig` is + used for queries. Some properties can be overridden with arguments to + the magics. + + Example: + Manually setting the default value for ``maximum_bytes_billed`` + to 100 MB: + + >>> from google.cloud.bigquery import magics + >>> magics.context.default_query_job_config.maximum_bytes_billed = 100000000 + """ + return self._default_query_job_config + + @default_query_job_config.setter + def default_query_job_config(self, value): + self._default_query_job_config = value + context = Context() @@ -291,6 +314,14 @@ def _run_query(client, query, job_config=None): default=None, help=("Project to use for executing this query. Defaults to the context project."), ) +@magic_arguments.argument( + "--maximum_bytes_billed", + default=None, + help=( + "maximum_bytes_billed to use for executing this query. Defaults to " + "the context default_query_job_config.maximum_bytes_billed." + ), +) @magic_arguments.argument( "--use_legacy_sql", action="store_true", @@ -363,7 +394,11 @@ def _cell_magic(line, query): ) project = args.project or context.project - client = bigquery.Client(project=project, credentials=context.credentials) + client = bigquery.Client( + project=project, + credentials=context.credentials, + default_query_job_config=context.default_query_job_config, + ) if context._connection: client._connection = context._connection bqstorage_client = _make_bqstorage_client( @@ -372,6 +407,12 @@ def _cell_magic(line, query): job_config = bigquery.job.QueryJobConfig() job_config.query_parameters = params job_config.use_legacy_sql = args.use_legacy_sql + + if args.maximum_bytes_billed == "None": + job_config.maximum_bytes_billed = 0 + elif args.maximum_bytes_billed is not None: + value = int(args.maximum_bytes_billed) + job_config.maximum_bytes_billed = value query_job = _run_query(client, query, job_config) if not args.verbose: diff --git a/bigquery/tests/unit/test_magics.py b/bigquery/tests/unit/test_magics.py index 70848cbcae64..f3e64a46faca 100644 --- a/bigquery/tests/unit/test_magics.py +++ b/bigquery/tests/unit/test_magics.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import re -import mock -import six from concurrent import futures +import mock import pytest +import six try: import pandas @@ -37,6 +38,7 @@ from google.cloud import bigquery_storage_v1beta1 except ImportError: # pragma: NO COVER bigquery_storage_v1beta1 = None +from google.cloud.bigquery import job from google.cloud.bigquery import table from google.cloud.bigquery import magics from tests.unit.helpers import make_connection @@ -63,6 +65,26 @@ def ipython_interactive(request, ipython): yield ipython +JOB_REFERENCE_RESOURCE = {"projectId": "its-a-project-eh", "jobId": "some-random-id"} +TABLE_REFERENCE_RESOURCE = { + "projectId": "its-a-project-eh", + "datasetId": "ds", + "tableId": "persons", +} +QUERY_RESOURCE = { + "jobReference": JOB_REFERENCE_RESOURCE, + "configuration": { + "query": { + "destinationTable": TABLE_REFERENCE_RESOURCE, + "query": "SELECT 42 FROM `life.the_universe.and_everything`;", + "queryParameters": [], + "useLegacySql": False, + } + }, + "status": {"state": "DONE"}, +} + + def test_context_credentials_auto_set_w_application_default_credentials(): """When Application Default Credentials are set, the context credentials will be created the first time it is called @@ -117,22 +139,13 @@ def test_context_connection_can_be_overriden(): default_patch = mock.patch( "google.auth.default", return_value=(credentials_mock, project) ) + job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) + job_reference["projectId"] = project query = "select * from persons" - job_reference = {"projectId": project, "jobId": "some-random-id"} - table = {"projectId": project, "datasetId": "ds", "tableId": "persons"} - resource = { - "jobReference": job_reference, - "configuration": { - "query": { - "destinationTable": table, - "query": query, - "queryParameters": [], - "useLegacySql": False, - } - }, - "status": {"state": "DONE"}, - } + resource = copy.deepcopy(QUERY_RESOURCE) + resource["jobReference"] = job_reference + resource["configuration"]["query"]["query"] = query data = {"jobReference": job_reference, "totalRows": 0, "rows": []} conn = magics.context._connection = make_connection(resource, data) @@ -170,22 +183,13 @@ def test_context_no_connection(): default_patch = mock.patch( "google.auth.default", return_value=(credentials_mock, project) ) + job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) + job_reference["projectId"] = project query = "select * from persons" - job_reference = {"projectId": project, "jobId": "some-random-id"} - table = {"projectId": project, "datasetId": "ds", "tableId": "persons"} - resource = { - "jobReference": job_reference, - "configuration": { - "query": { - "destinationTable": table, - "query": query, - "queryParameters": [], - "useLegacySql": False, - } - }, - "status": {"state": "DONE"}, - } + resource = copy.deepcopy(QUERY_RESOURCE) + resource["jobReference"] = job_reference + resource["configuration"]["query"]["query"] = query data = {"jobReference": job_reference, "totalRows": 0, "rows": []} conn_mock = make_connection(resource, data, data, data) @@ -239,7 +243,8 @@ def test__run_query(): assert updates[0] == expected_first_line execution_updates = updates[1:-1] assert len(execution_updates) == 3 # one update per API response - assert all(re.match("Query executing: .*s", line) for line in execution_updates) + for line in execution_updates: + assert re.match("Query executing: .*s", line) assert re.match("Query complete after .*s", updates[-1]) @@ -548,6 +553,131 @@ def test_bigquery_magic_without_bqstorage(monkeypatch): assert isinstance(return_value, pandas.DataFrame) +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_maximum_bytes_billed_invalid(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + sql = "SELECT 17 AS num" + + with pytest.raises(ValueError): + ip.run_cell_magic("bigquery", "--maximum_bytes_billed=abc", sql) + + +@pytest.mark.parametrize( + "param_value,expected", [("987654321", "987654321"), ("None", "0")] +) +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_maximum_bytes_billed_overrides_context(param_value, expected): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + # Set the default maximum bytes billed, so we know it's overridable by the param. + magics.context.default_query_job_config.maximum_bytes_billed = 1234567 + + project = "test-project" + job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) + job_reference["projectId"] = project + query = "SELECT 17 AS num" + resource = copy.deepcopy(QUERY_RESOURCE) + resource["jobReference"] = job_reference + resource["configuration"]["query"]["query"] = query + data = {"jobReference": job_reference, "totalRows": 0, "rows": []} + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + conn = magics.context._connection = make_connection(resource, data) + list_rows_patch = mock.patch( + "google.cloud.bigquery.client.Client.list_rows", + return_value=google.cloud.bigquery.table._EmptyRowIterator(), + ) + with list_rows_patch, default_patch: + ip.run_cell_magic( + "bigquery", "--maximum_bytes_billed={}".format(param_value), query + ) + + _, req = conn.api_request.call_args_list[0] + sent_config = req["data"]["configuration"]["query"] + assert sent_config["maximumBytesBilled"] == expected + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_maximum_bytes_billed_w_context_inplace(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + magics.context.default_query_job_config.maximum_bytes_billed = 1337 + + project = "test-project" + job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) + job_reference["projectId"] = project + query = "SELECT 17 AS num" + resource = copy.deepcopy(QUERY_RESOURCE) + resource["jobReference"] = job_reference + resource["configuration"]["query"]["query"] = query + data = {"jobReference": job_reference, "totalRows": 0, "rows": []} + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + conn = magics.context._connection = make_connection(resource, data) + list_rows_patch = mock.patch( + "google.cloud.bigquery.client.Client.list_rows", + return_value=google.cloud.bigquery.table._EmptyRowIterator(), + ) + with list_rows_patch, default_patch: + ip.run_cell_magic("bigquery", "", query) + + _, req = conn.api_request.call_args_list[0] + sent_config = req["data"]["configuration"]["query"] + assert sent_config["maximumBytesBilled"] == "1337" + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_maximum_bytes_billed_w_context_setter(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + magics.context.default_query_job_config = job.QueryJobConfig( + maximum_bytes_billed=10203 + ) + + project = "test-project" + job_reference = copy.deepcopy(JOB_REFERENCE_RESOURCE) + job_reference["projectId"] = project + query = "SELECT 17 AS num" + resource = copy.deepcopy(QUERY_RESOURCE) + resource["jobReference"] = job_reference + resource["configuration"]["query"]["query"] = query + data = {"jobReference": job_reference, "totalRows": 0, "rows": []} + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + conn = magics.context._connection = make_connection(resource, data) + list_rows_patch = mock.patch( + "google.cloud.bigquery.client.Client.list_rows", + return_value=google.cloud.bigquery.table._EmptyRowIterator(), + ) + with list_rows_patch, default_patch: + ip.run_cell_magic("bigquery", "", query) + + _, req = conn.api_request.call_args_list[0] + sent_config = req["data"]["configuration"]["query"] + assert sent_config["maximumBytesBilled"] == "10203" + + @pytest.mark.usefixtures("ipython_interactive") def test_bigquery_magic_with_project(): ip = IPython.get_ipython()