diff --git a/bigquery/google/cloud/bigquery/magics.py b/bigquery/google/cloud/bigquery/magics.py index 2a174cefeea3..59265ed6b0c5 100644 --- a/bigquery/google/cloud/bigquery/magics.py +++ b/bigquery/google/cloud/bigquery/magics.py @@ -31,6 +31,10 @@ this parameter is used. If an error occurs during the query execution, the corresponding ``QueryJob`` instance (if available) is stored in the variable instead. + * ``--destination_table`` (optional, line argument): + A dataset and table to store the query results. If table does not exists, + it will be created. If table already exists, its data will be overwritten. + Variable should be in a format .. * ``--project `` (optional, line argument): Project to use for running the query. Defaults to the context :attr:`~google.cloud.bigquery.magics.Context.project`. @@ -145,6 +149,7 @@ raise ImportError("This module can only be loaded in IPython.") from google.api_core import client_info +from google.api_core.exceptions import NotFound import google.auth from google.cloud import bigquery from google.cloud.bigquery.dbapi import _helpers @@ -336,12 +341,44 @@ def _run_query(client, query, job_config=None): return query_job +def _create_dataset_if_necessary(client, dataset_id): + """Create a dataset in the current project if it doesn't exist. + + Args: + client (google.cloud.bigquery.client.Client): + Client to bundle configuration needed for API requests. + dataset_id (str): + Dataset id. + """ + dataset_reference = bigquery.dataset.DatasetReference(client.project, dataset_id) + try: + dataset = client.get_dataset(dataset_reference) + return + except NotFound: + pass + dataset = bigquery.Dataset(dataset_reference) + dataset.location = client.location + print("Creating dataset: {}".format(dataset_id)) + dataset = client.create_dataset(dataset) + + @magic_arguments.magic_arguments() @magic_arguments.argument( "destination_var", nargs="?", help=("If provided, save the output to this variable instead of displaying it."), ) +@magic_arguments.argument( + "--destination_table", + type=str, + default=None, + help=( + "If provided, save the output of the query to a new BigQuery table. " + "Variable should be in a format .. " + "If table does not exists, it will be created. " + "If table already exists, its data will be overwritten." + ), +) @magic_arguments.argument( "--project", type=str, @@ -485,6 +522,21 @@ def _cell_magic(line, query): job_config.use_legacy_sql = args.use_legacy_sql job_config.dry_run = args.dry_run + if args.destination_table: + split = args.destination_table.split(".") + if len(split) != 2: + raise ValueError( + "--destination_table should be in a . format." + ) + dataset_id, table_id = split + job_config.allow_large_results = True + dataset_ref = client.dataset(dataset_id) + destination_table_ref = dataset_ref.table(table_id) + job_config.destination = destination_table_ref + job_config.create_disposition = "CREATE_IF_NEEDED" + job_config.write_disposition = "WRITE_TRUNCATE" + _create_dataset_if_necessary(client, dataset_id) + if args.maximum_bytes_billed == "None": job_config.maximum_bytes_billed = 0 elif args.maximum_bytes_billed is not None: diff --git a/bigquery/tests/unit/test_magics.py b/bigquery/tests/unit/test_magics.py index ed253636c468..6ff9819854a8 100644 --- a/bigquery/tests/unit/test_magics.py +++ b/bigquery/tests/unit/test_magics.py @@ -39,6 +39,7 @@ from google.cloud import bigquery_storage_v1beta1 except ImportError: # pragma: NO COVER bigquery_storage_v1beta1 = None +from google.cloud import bigquery from google.cloud.bigquery import job from google.cloud.bigquery import table from google.cloud.bigquery import magics @@ -336,6 +337,37 @@ def test__make_bqstorage_client_true_missing_gapic(missing_grpcio_lib): assert "grpcio" in str(exc_context.value) +def test__create_dataset_if_necessary_exists(): + project = "project_id" + dataset_id = "dataset_id" + dataset_reference = bigquery.dataset.DatasetReference(project, dataset_id) + dataset = bigquery.Dataset(dataset_reference) + client_patch = mock.patch( + "google.cloud.bigquery.magics.bigquery.Client", autospec=True + ) + with client_patch as client_mock: + client = client_mock() + client.project = project + client.get_dataset.result_value = dataset + magics._create_dataset_if_necessary(client, dataset_id) + client.create_dataset.assert_not_called() + + +def test__create_dataset_if_necessary_not_exist(): + project = "project_id" + dataset_id = "dataset_id" + client_patch = mock.patch( + "google.cloud.bigquery.magics.bigquery.Client", autospec=True + ) + with client_patch as client_mock: + client = client_mock() + client.location = "us" + client.project = project + client.get_dataset.side_effect = exceptions.NotFound("dataset not found") + magics._create_dataset_if_necessary(client, dataset_id) + client.create_dataset.assert_called_once() + + @pytest.mark.usefixtures("ipython_interactive") def test_extension_load(): ip = IPython.get_ipython() @@ -1199,3 +1231,62 @@ def test_bigquery_magic_omits_tracebacks_from_error_message(): assert "400 Syntax error in SQL query" in output assert "Traceback (most recent call last)" not in output assert "Syntax error" not in captured_io.stdout + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_destination_table_invalid_format(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + + client_patch = mock.patch( + "google.cloud.bigquery.magics.bigquery.Client", autospec=True + ) + + with client_patch, default_patch, pytest.raises(ValueError) as exc_context: + ip.run_cell_magic( + "bigquery", "--destination_table dataset", "SELECT foo FROM WHERE LIMIT bar" + ) + error_msg = str(exc_context.value) + assert ( + "--destination_table should be in a " + ". format." in error_msg + ) + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_destination_table(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + create_dataset_if_necessary_patch = mock.patch( + "google.cloud.bigquery.magics._create_dataset_if_necessary", autospec=True + ) + + run_query_patch = mock.patch( + "google.cloud.bigquery.magics._run_query", autospec=True + ) + + with create_dataset_if_necessary_patch, run_query_patch as run_query_mock: + ip.run_cell_magic( + "bigquery", + "--destination_table dataset_id.table_id", + "SELECT foo FROM WHERE LIMIT bar", + ) + + job_config_used = run_query_mock.call_args_list[0][1]["job_config"] + assert job_config_used.allow_large_results is True + assert job_config_used.create_disposition == "CREATE_IF_NEEDED" + assert job_config_used.write_disposition == "WRITE_TRUNCATE" + assert job_config_used.destination.dataset_id == "dataset_id" + assert job_config_used.destination.table_id == "table_id"