diff --git a/bigquery/google/cloud/bigquery/magics.py b/bigquery/google/cloud/bigquery/magics.py index 4c93d1307a42..c238bb50317a 100644 --- a/bigquery/google/cloud/bigquery/magics.py +++ b/bigquery/google/cloud/bigquery/magics.py @@ -129,6 +129,7 @@ from __future__ import print_function +import re import ast import sys import time @@ -266,6 +267,15 @@ def default_query_job_config(self, value): context = Context() +def _print_error(error, destination_var=None): + if destination_var: + print( + "Could not save output to variable '{}'.".format(destination_var), + file=sys.stderr, + ) + print("\nERROR:\n", error, file=sys.stderr) + + def _run_query(client, query, job_config=None): """Runs a query while printing status updates @@ -434,6 +444,24 @@ def _cell_magic(line, query): else: max_results = None + query = query.strip() + + # Any query that does not contain whitespace (aside from leading and trailing whitespace) + # is assumed to be a table id + if not re.search(r"\s", query): + try: + rows = client.list_rows(query, max_results=max_results) + except Exception as ex: + _print_error(str(ex), args.destination_var) + return + + result = rows.to_dataframe(bqstorage_client=bqstorage_client) + if args.destination_var: + IPython.get_ipython().push({args.destination_var: result}) + return + else: + return result + job_config = bigquery.job.QueryJobConfig() job_config.query_parameters = params job_config.use_legacy_sql = args.use_legacy_sql @@ -445,24 +473,15 @@ def _cell_magic(line, query): value = int(args.maximum_bytes_billed) job_config.maximum_bytes_billed = value - error = None try: query_job = _run_query(client, query, job_config=job_config) except Exception as ex: - error = str(ex) + _print_error(str(ex), args.destination_var) + return if not args.verbose: display.clear_output() - if error: - if args.destination_var: - print( - "Could not save output to variable '{}'.".format(args.destination_var), - file=sys.stderr, - ) - print("\nERROR:\n", error, file=sys.stderr) - return - if args.dry_run and args.destination_var: IPython.get_ipython().push({args.destination_var: query_job}) return diff --git a/bigquery/tests/unit/test_magics.py b/bigquery/tests/unit/test_magics.py index ed748d2dd5e3..ec642ff384e1 100644 --- a/bigquery/tests/unit/test_magics.py +++ b/bigquery/tests/unit/test_magics.py @@ -696,6 +696,114 @@ def test_bigquery_magic_w_max_results_valid_calls_queryjob_result(): query_job_mock.result.assert_called_with(max_results=5) +def test_bigquery_magic_w_table_id_invalid(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + + list_rows_patch = mock.patch( + "google.cloud.bigquery.magics.bigquery.Client.list_rows", + autospec=True, + side_effect=exceptions.BadRequest("Not a valid table ID"), + ) + + table_id = "not-a-real-table" + + with list_rows_patch, default_patch, io.capture_output() as captured_io: + ip.run_cell_magic("bigquery", "df", table_id) + + output = captured_io.stderr + assert "Could not save output to variable" in output + assert "400 Not a valid table ID" in output + assert "Traceback (most recent call last)" not in output + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_table_id_and_destination_var(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + + row_iterator_mock = mock.create_autospec( + google.cloud.bigquery.table.RowIterator, instance=True + ) + + client_patch = mock.patch( + "google.cloud.bigquery.magics.bigquery.Client", autospec=True + ) + + table_id = "bigquery-public-data.samples.shakespeare" + result = pandas.DataFrame([17], columns=["num"]) + + with client_patch as client_mock, default_patch: + client_mock().list_rows.return_value = row_iterator_mock + row_iterator_mock.to_dataframe.return_value = result + + ip.run_cell_magic("bigquery", "df", table_id) + + assert "df" in ip.user_ns + df = ip.user_ns["df"] + + assert isinstance(df, pandas.DataFrame) + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_table_id_and_bqstorage_client(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + + row_iterator_mock = mock.create_autospec( + google.cloud.bigquery.table.RowIterator, instance=True + ) + + client_patch = mock.patch( + "google.cloud.bigquery.magics.bigquery.Client", autospec=True + ) + + bqstorage_mock = mock.create_autospec( + bigquery_storage_v1beta1.BigQueryStorageClient + ) + bqstorage_instance_mock = mock.create_autospec( + bigquery_storage_v1beta1.BigQueryStorageClient, instance=True + ) + bqstorage_mock.return_value = bqstorage_instance_mock + bqstorage_client_patch = mock.patch( + "google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock + ) + + table_id = "bigquery-public-data.samples.shakespeare" + + with default_patch, client_patch as client_mock, bqstorage_client_patch: + client_mock().list_rows.return_value = row_iterator_mock + + ip.run_cell_magic("bigquery", "--use_bqstorage_api --max_results=5", table_id) + row_iterator_mock.to_dataframe.assert_called_once_with( + bqstorage_client=bqstorage_instance_mock + ) + + @pytest.mark.usefixtures("ipython_interactive") def test_bigquery_magic_dryrun_option_sets_job_config(): ip = IPython.get_ipython()