Skip to content

Commit

Permalink
feat(bigquery): add --destination_table parameter to IPython magic (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vlasenkoalexey authored and plamut committed Nov 6, 2019
1 parent 41ae858 commit 9405db9
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
52 changes: 52 additions & 0 deletions bigquery/google/cloud/bigquery/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
this parameter is used. If an error occurs during the query execution,
the corresponding ``QueryJob`` instance (if available) is stored in
the variable instead.
* ``--destination_table`` (optional, line argument):
A dataset and table to store the query results. If table does not exists,
it will be created. If table already exists, its data will be overwritten.
Variable should be in a format <dataset_id>.<table_id>.
* ``--project <project>`` (optional, line argument):
Project to use for running the query. Defaults to the context
:attr:`~google.cloud.bigquery.magics.Context.project`.
Expand Down Expand Up @@ -145,6 +149,7 @@
raise ImportError("This module can only be loaded in IPython.")

from google.api_core import client_info
from google.api_core.exceptions import NotFound
import google.auth
from google.cloud import bigquery
from google.cloud.bigquery.dbapi import _helpers
Expand Down Expand Up @@ -336,12 +341,44 @@ def _run_query(client, query, job_config=None):
return query_job


def _create_dataset_if_necessary(client, dataset_id):
"""Create a dataset in the current project if it doesn't exist.
Args:
client (google.cloud.bigquery.client.Client):
Client to bundle configuration needed for API requests.
dataset_id (str):
Dataset id.
"""
dataset_reference = bigquery.dataset.DatasetReference(client.project, dataset_id)
try:
dataset = client.get_dataset(dataset_reference)
return
except NotFound:
pass
dataset = bigquery.Dataset(dataset_reference)
dataset.location = client.location
print("Creating dataset: {}".format(dataset_id))
dataset = client.create_dataset(dataset)


@magic_arguments.magic_arguments()
@magic_arguments.argument(
"destination_var",
nargs="?",
help=("If provided, save the output to this variable instead of displaying it."),
)
@magic_arguments.argument(
"--destination_table",
type=str,
default=None,
help=(
"If provided, save the output of the query to a new BigQuery table. "
"Variable should be in a format <dataset_id>.<table_id>. "
"If table does not exists, it will be created. "
"If table already exists, its data will be overwritten."
),
)
@magic_arguments.argument(
"--project",
type=str,
Expand Down Expand Up @@ -485,6 +522,21 @@ def _cell_magic(line, query):
job_config.use_legacy_sql = args.use_legacy_sql
job_config.dry_run = args.dry_run

if args.destination_table:
split = args.destination_table.split(".")
if len(split) != 2:
raise ValueError(
"--destination_table should be in a <dataset_id>.<table_id> format."
)
dataset_id, table_id = split
job_config.allow_large_results = True
dataset_ref = client.dataset(dataset_id)
destination_table_ref = dataset_ref.table(table_id)
job_config.destination = destination_table_ref
job_config.create_disposition = "CREATE_IF_NEEDED"
job_config.write_disposition = "WRITE_TRUNCATE"
_create_dataset_if_necessary(client, dataset_id)

if args.maximum_bytes_billed == "None":
job_config.maximum_bytes_billed = 0
elif args.maximum_bytes_billed is not None:
Expand Down
91 changes: 91 additions & 0 deletions bigquery/tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from google.cloud import bigquery_storage_v1beta1
except ImportError: # pragma: NO COVER
bigquery_storage_v1beta1 = None
from google.cloud import bigquery
from google.cloud.bigquery import job
from google.cloud.bigquery import table
from google.cloud.bigquery import magics
Expand Down Expand Up @@ -336,6 +337,37 @@ def test__make_bqstorage_client_true_missing_gapic(missing_grpcio_lib):
assert "grpcio" in str(exc_context.value)


def test__create_dataset_if_necessary_exists():
project = "project_id"
dataset_id = "dataset_id"
dataset_reference = bigquery.dataset.DatasetReference(project, dataset_id)
dataset = bigquery.Dataset(dataset_reference)
client_patch = mock.patch(
"google.cloud.bigquery.magics.bigquery.Client", autospec=True
)
with client_patch as client_mock:
client = client_mock()
client.project = project
client.get_dataset.result_value = dataset
magics._create_dataset_if_necessary(client, dataset_id)
client.create_dataset.assert_not_called()


def test__create_dataset_if_necessary_not_exist():
project = "project_id"
dataset_id = "dataset_id"
client_patch = mock.patch(
"google.cloud.bigquery.magics.bigquery.Client", autospec=True
)
with client_patch as client_mock:
client = client_mock()
client.location = "us"
client.project = project
client.get_dataset.side_effect = exceptions.NotFound("dataset not found")
magics._create_dataset_if_necessary(client, dataset_id)
client.create_dataset.assert_called_once()


@pytest.mark.usefixtures("ipython_interactive")
def test_extension_load():
ip = IPython.get_ipython()
Expand Down Expand Up @@ -1199,3 +1231,62 @@ def test_bigquery_magic_omits_tracebacks_from_error_message():
assert "400 Syntax error in SQL query" in output
assert "Traceback (most recent call last)" not in output
assert "Syntax error" not in captured_io.stdout


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_w_destination_table_invalid_format():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context._project = None

credentials_mock = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
default_patch = mock.patch(
"google.auth.default", return_value=(credentials_mock, "general-project")
)

client_patch = mock.patch(
"google.cloud.bigquery.magics.bigquery.Client", autospec=True
)

with client_patch, default_patch, pytest.raises(ValueError) as exc_context:
ip.run_cell_magic(
"bigquery", "--destination_table dataset", "SELECT foo FROM WHERE LIMIT bar"
)
error_msg = str(exc_context.value)
assert (
"--destination_table should be in a "
"<dataset_id>.<table_id> format." in error_msg
)


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_w_destination_table():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context.credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)

create_dataset_if_necessary_patch = mock.patch(
"google.cloud.bigquery.magics._create_dataset_if_necessary", autospec=True
)

run_query_patch = mock.patch(
"google.cloud.bigquery.magics._run_query", autospec=True
)

with create_dataset_if_necessary_patch, run_query_patch as run_query_mock:
ip.run_cell_magic(
"bigquery",
"--destination_table dataset_id.table_id",
"SELECT foo FROM WHERE LIMIT bar",
)

job_config_used = run_query_mock.call_args_list[0][1]["job_config"]
assert job_config_used.allow_large_results is True
assert job_config_used.create_disposition == "CREATE_IF_NEEDED"
assert job_config_used.write_disposition == "WRITE_TRUNCATE"
assert job_config_used.destination.dataset_id == "dataset_id"
assert job_config_used.destination.table_id == "table_id"

0 comments on commit 9405db9

Please sign in to comment.