diff --git a/bigquery/google/cloud/bigquery/client.py b/bigquery/google/cloud/bigquery/client.py index 7fe412478bfa..6ad3227a226b 100644 --- a/bigquery/google/cloud/bigquery/client.py +++ b/bigquery/google/cloud/bigquery/client.py @@ -46,6 +46,8 @@ from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.dataset import DatasetReference from google.cloud.bigquery import job +from google.cloud.bigquery.model import Model +from google.cloud.bigquery.model import ModelReference from google.cloud.bigquery.query import _QueryResults from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.bigquery.table import Table @@ -427,6 +429,33 @@ def get_dataset(self, dataset_ref, retry=DEFAULT_RETRY): api_response = self._call_api(retry, method="GET", path=dataset_ref.path) return Dataset.from_api_repr(api_response) + def get_model(self, model_ref, retry=DEFAULT_RETRY): + """Fetch the model referenced by ``model_ref``. + + Args: + model_ref (Union[ \ + :class:`~google.cloud.bigquery.model.ModelReference`, \ + str, \ + ]): + A reference to the model to fetch from the BigQuery API. + If a string is passed in, this method attempts to create a + model reference from a string using + :func:`google.cloud.bigquery.model.ModelReference.from_string`. + retry (:class:`google.api_core.retry.Retry`): + (Optional) How to retry the RPC. + + Returns: + google.cloud.bigquery.model.Model: + A ``Model`` instance. + """ + if isinstance(model_ref, str): + model_ref = ModelReference.from_string( + model_ref, default_project=self.project + ) + + api_response = self._call_api(retry, method="GET", path=model_ref.path) + return Model.from_api_repr(api_response) + def get_table(self, table_ref, retry=DEFAULT_RETRY): """Fetch the table referenced by ``table_ref``. @@ -490,6 +519,41 @@ def update_dataset(self, dataset, fields, retry=DEFAULT_RETRY): ) return Dataset.from_api_repr(api_response) + def update_model(self, model, fields, retry=DEFAULT_RETRY): + """Change some fields of a model. + + Use ``fields`` to specify which fields to update. At least one field + must be provided. If a field is listed in ``fields`` and is ``None`` + in ``model``, it will be deleted. + + If ``model.etag`` is not ``None``, the update will only succeed if + the model on the server has the same ETag. Thus reading a model with + ``get_model``, changing its fields, and then passing it to + ``update_model`` will ensure that the changes will only be saved if + no modifications to the model occurred since the read. + + Args: + model (google.cloud.bigquery.model.Model): The model to update. + fields (Sequence[str]): + The fields of ``model`` to change, spelled as the Model + properties (e.g. "friendly_name"). + retry (google.api_core.retry.Retry): + (Optional) A description of how to retry the API call. + + Returns: + google.cloud.bigquery.model.Model: + The model resource returned from the API call. + """ + partial = model._build_resource(fields) + if model.etag: + headers = {"If-Match": model.etag} + else: + headers = None + api_response = self._call_api( + retry, method="PATCH", path=model.path, data=partial, headers=headers + ) + return Model.from_api_repr(api_response) + def update_table(self, table, fields, retry=DEFAULT_RETRY): """Change some fields of a table. @@ -525,6 +589,64 @@ def update_table(self, table, fields, retry=DEFAULT_RETRY): ) return Table.from_api_repr(api_response) + def list_models( + self, dataset, max_results=None, page_token=None, retry=DEFAULT_RETRY + ): + """List models in the dataset. + + See + https://cloud.google.com/bigquery/docs/reference/rest/v2/models/list + + Args: + dataset (Union[ \ + :class:`~google.cloud.bigquery.dataset.Dataset`, \ + :class:`~google.cloud.bigquery.dataset.DatasetReference`, \ + str, \ + ]): + A reference to the dataset whose models to list from the + BigQuery API. If a string is passed in, this method attempts + to create a dataset reference from a string using + :func:`google.cloud.bigquery.dataset.DatasetReference.from_string`. + max_results (int): + (Optional) Maximum number of models to return. If not passed, + defaults to a value set by the API. + page_token (str): + (Optional) Token representing a cursor into the models. If + not passed, the API will return the first page of models. The + token marks the beginning of the iterator to be returned and + the value of the ``page_token`` can be accessed at + ``next_page_token`` of the + :class:`~google.api_core.page_iterator.HTTPIterator`. + retry (:class:`google.api_core.retry.Retry`): + (Optional) How to retry the RPC. + + Returns: + google.api_core.page_iterator.Iterator: + Iterator of + :class:`~google.cloud.bigquery.model.Model` contained + within the requested dataset. + """ + if isinstance(dataset, str): + dataset = DatasetReference.from_string( + dataset, default_project=self.project + ) + + if not isinstance(dataset, (Dataset, DatasetReference)): + raise TypeError("dataset must be a Dataset, DatasetReference, or string") + + path = "%s/models" % dataset.path + result = page_iterator.HTTPIterator( + client=self, + api_request=functools.partial(self._call_api, retry), + path=path, + item_to_value=_item_to_model, + items_key="models", + page_token=page_token, + max_results=max_results, + ) + result.dataset = dataset + return result + def list_tables( self, dataset, max_results=None, page_token=None, retry=DEFAULT_RETRY ): @@ -631,6 +753,40 @@ def delete_dataset( if not not_found_ok: raise + def delete_model(self, model, retry=DEFAULT_RETRY, not_found_ok=False): + """Delete a model + + See + https://cloud.google.com/bigquery/docs/reference/rest/v2/models/delete + + Args: + model (Union[ \ + :class:`~google.cloud.bigquery.model.Model`, \ + :class:`~google.cloud.bigquery.model.ModelReference`, \ + str, \ + ]): + A reference to the model to delete. If a string is passed in, + this method attempts to create a model reference from a + string using + :func:`google.cloud.bigquery.model.ModelReference.from_string`. + retry (:class:`google.api_core.retry.Retry`): + (Optional) How to retry the RPC. + not_found_ok (bool): + Defaults to ``False``. If ``True``, ignore "not found" errors + when deleting the model. + """ + if isinstance(model, str): + model = ModelReference.from_string(model, default_project=self.project) + + if not isinstance(model, (Model, ModelReference)): + raise TypeError("model must be a Model or a ModelReference") + + try: + self._call_api(retry, method="DELETE", path=model.path) + except google.api_core.exceptions.NotFound: + if not not_found_ok: + raise + def delete_table(self, table, retry=DEFAULT_RETRY, not_found_ok=False): """Delete a table @@ -1810,6 +1966,21 @@ def _item_to_job(iterator, resource): return iterator.client.job_from_resource(resource) +def _item_to_model(iterator, resource): + """Convert a JSON model to the native object. + + Args: + iterator (google.api_core.page_iterator.Iterator): + The iterator that is currently in use. + resource (dict): + An item to be converted to a model. + + Returns: + google.cloud.bigquery.model.Model: The next model in the page. + """ + return Model.from_api_repr(resource) + + def _item_to_table(iterator, resource): """Convert a JSON table to the native object. diff --git a/bigquery/google/cloud/bigquery/model.py b/bigquery/google/cloud/bigquery/model.py index 5d9ec810b2d0..c89727cf1fb5 100644 --- a/bigquery/google/cloud/bigquery/model.py +++ b/bigquery/google/cloud/bigquery/model.py @@ -16,12 +16,13 @@ """Define resources for the BigQuery ML Models API.""" -import datetime +import copy from google.protobuf import json_format import six import google.cloud._helpers +from google.api_core import datetime_helpers from google.cloud.bigquery import _helpers from google.cloud.bigquery_v2 import types @@ -83,6 +84,26 @@ def reference(self): ref._proto = self._proto.model_reference return ref + @property + def project(self): + """str: Project bound to the model""" + return self.reference.project + + @property + def dataset_id(self): + """str: ID of dataset containing the model.""" + return self.reference.dataset_id + + @property + def model_id(self): + """str: The model ID.""" + return self.reference.model_id + + @property + def path(self): + """str: URL path for the model's APIs.""" + return self.reference.path + @property def location(self): """str: The geographic location where the model resides. This value @@ -192,7 +213,7 @@ def expires(self): @expires.setter def expires(self, value): if value is not None: - value = google.cloud._helpers._millis_from_datetime(value) + value = str(google.cloud._helpers._millis_from_datetime(value)) self._properties["expirationTime"] = value @property @@ -247,6 +268,17 @@ def from_api_repr(cls, resource): google.cloud.bigquery.model.Model: Model parsed from ``resource``. """ this = cls(None) + + # Convert from millis-from-epoch to timestamp well-known type. + # TODO: Remove this hack once CL 238585470 hits prod. + resource = copy.deepcopy(resource) + for training_run in resource.get("trainingRuns", ()): + start_time = training_run.get("startTime") + if not start_time or "-" in start_time: # Already right format? + continue + start_time = datetime_helpers.from_microseconds(1e3 * float(start_time)) + training_run["startTime"] = datetime_helpers.to_rfc3339(start_time) + this._proto = json_format.ParseDict(resource, types.Model()) for key in six.itervalues(cls._PROPERTY_TO_API_FIELD): # Leave missing keys unset. This allows us to use setdefault in the @@ -288,6 +320,15 @@ def model_id(self): """str: The model ID.""" return self._proto.model_id + @property + def path(self): + """str: URL path for the model's APIs.""" + return "/projects/%s/datasets/%s/models/%s" % ( + self._proto.project_id, + self._proto.dataset_id, + self._proto.model_id, + ) + @classmethod def from_api_repr(cls, resource): """Factory: construct a model reference given its API representation diff --git a/bigquery/noxfile.py b/bigquery/noxfile.py index 82846604306e..cb784eae9fa3 100644 --- a/bigquery/noxfile.py +++ b/bigquery/noxfile.py @@ -129,6 +129,8 @@ def snippets(session): # Run py.test against the snippets tests. session.run( 'py.test', os.path.join('docs', 'snippets.py'), *session.posargs) + session.run( + 'py.test', os.path.join('samples'), *session.posargs) @nox.session(python='3.6') @@ -178,6 +180,7 @@ def blacken(session): session.run( "black", "google", + "samples", "tests", "docs", ) diff --git a/bigquery/samples/__init__.py b/bigquery/samples/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/bigquery/samples/delete_model.py b/bigquery/samples/delete_model.py new file mode 100644 index 000000000000..dfe23cd7ef29 --- /dev/null +++ b/bigquery/samples/delete_model.py @@ -0,0 +1,31 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def delete_model(client, model_id): + """Sample ID: go/samples-tracker/1534""" + + # [START bigquery_delete_model] + from google.cloud import bigquery + + # TODO(developer): Construct a BigQuery client object. + # client = bigquery.Client() + + # TODO(developer): Set model_id to the ID of the model to fetch. + # model_id = 'your-project.your_dataset.your_model' + + client.delete_model(model_id) + # [END bigquery_delete_model] + + print("Deleted model '{}'.".format(model_id)) diff --git a/bigquery/samples/get_model.py b/bigquery/samples/get_model.py new file mode 100644 index 000000000000..8e43e53ec450 --- /dev/null +++ b/bigquery/samples/get_model.py @@ -0,0 +1,35 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_model(client, model_id): + """Sample ID: go/samples-tracker/1510""" + + # [START bigquery_get_model] + from google.cloud import bigquery + + # TODO(developer): Construct a BigQuery client object. + # client = bigquery.Client() + + # TODO(developer): Set model_id to the ID of the model to fetch. + # model_id = 'your-project.your_dataset.your_model' + + model = client.get_model(model_id) + + full_model_id = "{}.{}.{}".format(model.project, model.dataset_id, model.model_id) + friendly_name = model.friendly_name + print( + "Got model '{}' with friendly_name '{}'.".format(full_model_id, friendly_name) + ) + # [END bigquery_get_model] diff --git a/bigquery/samples/list_models.py b/bigquery/samples/list_models.py new file mode 100644 index 000000000000..cb6e4fb5569f --- /dev/null +++ b/bigquery/samples/list_models.py @@ -0,0 +1,38 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def list_models(client, dataset_id): + """Sample ID: go/samples-tracker/1512""" + + # [START bigquery_list_models] + from google.cloud import bigquery + + # TODO(developer): Construct a BigQuery client object. + # client = bigquery.Client() + + # TODO(developer): Set dataset_id to the ID of the dataset that contains + # the models you are listing. + # dataset_id = 'your-project.your_dataset' + + models = client.list_models(dataset_id) + + print("Models contained in '{}':".format(dataset_id)) + for model in models: + full_model_id = "{}.{}.{}".format( + model.project, model.dataset_id, model.model_id + ) + friendly_name = model.friendly_name + print("{}: friendly_name='{}'".format(full_model_id, friendly_name)) + # [END bigquery_list_models] diff --git a/bigquery/samples/tests/__init__.py b/bigquery/samples/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/bigquery/samples/tests/conftest.py b/bigquery/samples/tests/conftest.py new file mode 100644 index 000000000000..1543e1fdcd0a --- /dev/null +++ b/bigquery/samples/tests/conftest.py @@ -0,0 +1,62 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import uuid + +import pytest + +from google.cloud import bigquery + + +@pytest.fixture(scope="module") +def client(): + return bigquery.Client() + + +@pytest.fixture +def dataset_id(client): + now = datetime.datetime.now() + dataset_id = "python_samples_{}_{}".format( + now.strftime("%Y%m%d%H%M%S"), uuid.uuid4().hex[:8] + ) + dataset = client.create_dataset(dataset_id) + yield "{}.{}".format(dataset.project, dataset.dataset_id) + client.delete_dataset(dataset, delete_contents=True) + + +@pytest.fixture +def model_id(client, dataset_id): + model_id = "{}.{}".format(dataset_id, uuid.uuid4().hex) + + # The only way to create a model resource is via SQL. + # Use a very small dataset (2 points), to train a model quickly. + sql = """ + CREATE MODEL `{}` + OPTIONS ( + model_type='linear_reg', + max_iteration=1, + learn_rate=0.4, + learn_rate_strategy='constant' + ) AS ( + SELECT 'a' AS f1, 2.0 AS label + UNION ALL + SELECT 'b' AS f1, 3.8 AS label + ) + """.format( + model_id + ) + + client.query(sql).result() + return model_id diff --git a/bigquery/samples/tests/test_model_samples.py b/bigquery/samples/tests/test_model_samples.py new file mode 100644 index 000000000000..d7b06a92a3e1 --- /dev/null +++ b/bigquery/samples/tests/test_model_samples.py @@ -0,0 +1,39 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .. import delete_model +from .. import get_model +from .. import list_models +from .. import update_model + + +def test_model_samples(capsys, client, dataset_id, model_id): + """Since creating a model is a long operation, test all model samples in + the same test, following a typical end-to-end flow. + """ + get_model.get_model(client, model_id) + out, err = capsys.readouterr() + assert model_id in out + + list_models.list_models(client, dataset_id) + out, err = capsys.readouterr() + assert "Models contained in '{}':".format(dataset_id) in out + + update_model.update_model(client, model_id) + out, err = capsys.readouterr() + assert "This model was modified from a Python program." in out + + delete_model.delete_model(client, model_id) + out, err = capsys.readouterr() + assert "Deleted model '{}'.".format(model_id) in out diff --git a/bigquery/samples/update_model.py b/bigquery/samples/update_model.py new file mode 100644 index 000000000000..2440066ae1ec --- /dev/null +++ b/bigquery/samples/update_model.py @@ -0,0 +1,38 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def update_model(client, model_id): + """Sample ID: go/samples-tracker/1533""" + + # [START bigquery_update_model_description] + from google.cloud import bigquery + + # TODO(developer): Construct a BigQuery client object. + # client = bigquery.Client() + + # TODO(developer): Set model_id to the ID of the model to fetch. + # model_id = 'your-project.your_dataset.your_model' + + model = client.get_model(model_id) + model.description = "This model was modified from a Python program." + model = client.update_model(model, ["description"]) + + full_model_id = "{}.{}.{}".format(model.project, model.dataset_id, model.model_id) + print( + "Updated model '{}' with description '{}'.".format( + full_model_id, model.description + ) + ) + # [END bigquery_update_model_description] diff --git a/bigquery/tests/unit/model/test_model.py b/bigquery/tests/unit/model/test_model.py index 87da266a76bc..2086c333486d 100644 --- a/bigquery/tests/unit/model/test_model.py +++ b/bigquery/tests/unit/model/test_model.py @@ -91,12 +91,22 @@ def test_from_api_repr(target_class): google.cloud._helpers._datetime_to_rfc3339(modified_time) ), }, + { + "trainingOptions": {"initialLearnRate": 0.25}, + # Allow milliseconds since epoch format. + # TODO: Remove this hack once CL 238585470 hits prod. + "startTime": str(google.cloud._helpers._millis(expiration_time)), + }, ], "featureColumns": [], } got = target_class.from_api_repr(resource) + assert got.project == "my-project" + assert got.dataset_id == "my_dataset" + assert got.model_id == "my_model" assert got.reference == ModelReference.from_string("my-project.my_dataset.my_model") + assert got.path == "/projects/my-project/datasets/my_dataset/models/my_model" assert got.location == "US" assert got.etag == "abcdefg" assert got.created == creation_time @@ -120,6 +130,13 @@ def test_from_api_repr(target_class): .replace(tzinfo=google.cloud._helpers.UTC) == modified_time ) + assert got.training_runs[2].training_options.initial_learn_rate == 0.25 + assert ( + got.training_runs[2] + .start_time.ToDatetime() + .replace(tzinfo=google.cloud._helpers.UTC) + == expiration_time + ) def test_from_api_repr_w_minimal_resource(target_class): diff --git a/bigquery/tests/unit/model/test_model_reference.py b/bigquery/tests/unit/model/test_model_reference.py index ea8a76fb14b6..0145c76f6ad0 100644 --- a/bigquery/tests/unit/model/test_model_reference.py +++ b/bigquery/tests/unit/model/test_model_reference.py @@ -34,6 +34,7 @@ def test_from_api_repr(target_class): assert got.project == "my-project" assert got.dataset_id == "my_dataset" assert got.model_id == "my_model" + assert got.path == "/projects/my-project/datasets/my_dataset/models/my_model" def test_to_api_repr(target_class): @@ -51,6 +52,9 @@ def test_from_string(target_class): assert got.project == "string-project" assert got.dataset_id == "string_dataset" assert got.model_id == "string_model" + assert got.path == ( + "/projects/string-project/datasets/string_dataset/models/string_model" + ) def test_from_string_legacy_string(target_class): diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index 794b76a0a9f4..08648419393b 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -37,6 +37,7 @@ pyarrow = None import google.api_core.exceptions +import google.cloud._helpers from google.cloud.bigquery.dataset import DatasetReference @@ -81,6 +82,7 @@ class TestClient(unittest.TestCase): PROJECT = "PROJECT" DS_ID = "DATASET_ID" TABLE_ID = "TABLE_ID" + MODEL_ID = "MODEL_ID" TABLE_REF = DatasetReference(PROJECT, DS_ID).table(TABLE_ID) KMS_KEY_NAME = "projects/1/locations/global/keyRings/1/cryptoKeys/1" LOCATION = "us-central" @@ -1306,6 +1308,54 @@ def test_create_table_alreadyexists_w_exists_ok_true(self): ] ) + def test_get_model(self): + path = "projects/%s/datasets/%s/models/%s" % ( + self.PROJECT, + self.DS_ID, + self.MODEL_ID, + ) + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + resource = { + "modelReference": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": self.MODEL_ID, + } + } + conn = client._connection = _make_connection(resource) + + model_ref = client.dataset(self.DS_ID).model(self.MODEL_ID) + got = client.get_model(model_ref) + + conn.api_request.assert_called_once_with(method="GET", path="/%s" % path) + self.assertEqual(got.model_id, self.MODEL_ID) + + def test_get_model_w_string(self): + path = "projects/%s/datasets/%s/models/%s" % ( + self.PROJECT, + self.DS_ID, + self.MODEL_ID, + ) + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + resource = { + "modelReference": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": self.MODEL_ID, + } + } + conn = client._connection = _make_connection(resource) + + model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID) + got = client.get_model(model_id) + + conn.api_request.assert_called_once_with(method="GET", path="/%s" % path) + self.assertEqual(got.model_id, self.MODEL_ID) + def test_get_table(self): path = "projects/%s/datasets/%s/tables/%s" % ( self.PROJECT, @@ -1422,6 +1472,66 @@ def test_update_dataset_w_custom_property(self): self.assertEqual(dataset.project, self.PROJECT) self.assertEqual(dataset._properties["newAlphaProperty"], "unreleased property") + def test_update_model(self): + from google.cloud.bigquery.model import Model + + path = "projects/%s/datasets/%s/models/%s" % ( + self.PROJECT, + self.DS_ID, + self.MODEL_ID, + ) + description = "description" + title = "title" + expires = datetime.datetime( + 2012, 12, 21, 16, 0, 0, tzinfo=google.cloud._helpers.UTC + ) + resource = { + "modelReference": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": self.MODEL_ID, + }, + "description": description, + "etag": "etag", + "expirationTime": str(google.cloud._helpers._millis(expires)), + "friendlyName": title, + "labels": {"x": "y"}, + } + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + conn = client._connection = _make_connection(resource, resource) + model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID) + model = Model(model_id) + model.description = description + model.friendly_name = title + model.expires = expires + model.labels = {"x": "y"} + + updated_model = client.update_model( + model, ["description", "friendly_name", "labels", "expires"] + ) + + sent = { + "description": description, + "expirationTime": str(google.cloud._helpers._millis(expires)), + "friendlyName": title, + "labels": {"x": "y"}, + } + conn.api_request.assert_called_once_with( + method="PATCH", data=sent, path="/" + path, headers=None + ) + self.assertEqual(updated_model.model_id, model.model_id) + self.assertEqual(updated_model.description, model.description) + self.assertEqual(updated_model.friendly_name, model.friendly_name) + self.assertEqual(updated_model.labels, model.labels) + self.assertEqual(updated_model.expires, model.expires) + + # ETag becomes If-Match header. + model._proto.etag = "etag" + client.update_model(model, []) + req = conn.api_request.call_args + self.assertEqual(req[1]["headers"]["If-Match"], "etag") + def test_update_table(self): from google.cloud.bigquery.table import Table, SchemaField @@ -1773,6 +1883,78 @@ def test_list_tables_empty(self): method="GET", path=path, query_params={} ) + def test_list_models_empty(self): + path = "/projects/{}/datasets/{}/models".format(self.PROJECT, self.DS_ID) + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + conn = client._connection = _make_connection({}) + + dataset_id = "{}.{}".format(self.PROJECT, self.DS_ID) + iterator = client.list_models(dataset_id) + page = six.next(iterator.pages) + models = list(page) + token = iterator.next_page_token + + self.assertEqual(models, []) + self.assertIsNone(token) + conn.api_request.assert_called_once_with( + method="GET", path=path, query_params={} + ) + + def test_list_models_defaults(self): + from google.cloud.bigquery.model import Model + + MODEL_1 = "model_one" + MODEL_2 = "model_two" + PATH = "projects/%s/datasets/%s/models" % (self.PROJECT, self.DS_ID) + TOKEN = "TOKEN" + DATA = { + "nextPageToken": TOKEN, + "models": [ + { + "modelReference": { + "modelId": MODEL_1, + "datasetId": self.DS_ID, + "projectId": self.PROJECT, + } + }, + { + "modelReference": { + "modelId": MODEL_2, + "datasetId": self.DS_ID, + "projectId": self.PROJECT, + } + }, + ], + } + + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + conn = client._connection = _make_connection(DATA) + dataset = client.dataset(self.DS_ID) + + iterator = client.list_models(dataset) + self.assertIs(iterator.dataset, dataset) + page = six.next(iterator.pages) + models = list(page) + token = iterator.next_page_token + + self.assertEqual(len(models), len(DATA["models"])) + for found, expected in zip(models, DATA["models"]): + self.assertIsInstance(found, Model) + self.assertEqual(found.model_id, expected["modelReference"]["modelId"]) + self.assertEqual(token, TOKEN) + + conn.api_request.assert_called_once_with( + method="GET", path="/%s" % PATH, query_params={} + ) + + def test_list_models_wrong_type(self): + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + with self.assertRaises(TypeError): + client.list_models(client.dataset(self.DS_ID).model("foo")) + def test_list_tables_defaults(self): from google.cloud.bigquery.table import TableListItem @@ -1960,6 +2142,68 @@ def test_delete_dataset_w_not_found_ok_true(self): conn.api_request.assert_called_with(method="DELETE", path=path, query_params={}) + def test_delete_model(self): + from google.cloud.bigquery.model import Model + + path = "projects/%s/datasets/%s/models/%s" % ( + self.PROJECT, + self.DS_ID, + self.MODEL_ID, + ) + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID) + models = ( + model_id, + client.dataset(self.DS_ID).model(self.MODEL_ID), + Model(model_id), + ) + conn = client._connection = _make_connection(*([{}] * len(models))) + + for arg in models: + client.delete_model(arg) + conn.api_request.assert_called_with(method="DELETE", path="/%s" % path) + + def test_delete_model_w_wrong_type(self): + creds = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=creds) + with self.assertRaises(TypeError): + client.delete_model(client.dataset(self.DS_ID)) + + def test_delete_model_w_not_found_ok_false(self): + path = "/projects/{}/datasets/{}/models/{}".format( + self.PROJECT, self.DS_ID, self.MODEL_ID + ) + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = _make_connection( + google.api_core.exceptions.NotFound("model not found") + ) + + with self.assertRaises(google.api_core.exceptions.NotFound): + client.delete_model("{}.{}".format(self.DS_ID, self.MODEL_ID)) + + conn.api_request.assert_called_with(method="DELETE", path=path) + + def test_delete_model_w_not_found_ok_true(self): + path = "/projects/{}/datasets/{}/models/{}".format( + self.PROJECT, self.DS_ID, self.MODEL_ID + ) + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = _make_connection( + google.api_core.exceptions.NotFound("model not found") + ) + + client.delete_model( + "{}.{}".format(self.DS_ID, self.MODEL_ID), not_found_ok=True + ) + + conn.api_request.assert_called_with(method="DELETE", path=path) + def test_delete_table(self): from google.cloud.bigquery.table import Table