Skip to content

Commit

Permalink
Add client methods for Models API. (googleapis#494)
Browse files Browse the repository at this point in the history
* Add client methods for Models API.

* Adds hack to workaround milliseconds format for
  model.trainingRun.startTime.
* Adds code samples for Models API, which double as system tests.
  • Loading branch information
tswast committed Mar 25, 2019
1 parent ecc0f18 commit 8d29dcd
Show file tree
Hide file tree
Showing 14 changed files with 725 additions and 2 deletions.
171 changes: 171 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from google.cloud.bigquery.dataset import DatasetListItem
from google.cloud.bigquery.dataset import DatasetReference
from google.cloud.bigquery import job
from google.cloud.bigquery.model import Model
from google.cloud.bigquery.model import ModelReference
from google.cloud.bigquery.query import _QueryResults
from google.cloud.bigquery.retry import DEFAULT_RETRY
from google.cloud.bigquery.table import Table
Expand Down Expand Up @@ -427,6 +429,33 @@ def get_dataset(self, dataset_ref, retry=DEFAULT_RETRY):
api_response = self._call_api(retry, method="GET", path=dataset_ref.path)
return Dataset.from_api_repr(api_response)

def get_model(self, model_ref, retry=DEFAULT_RETRY):
"""Fetch the model referenced by ``model_ref``.
Args:
model_ref (Union[ \
:class:`~google.cloud.bigquery.model.ModelReference`, \
str, \
]):
A reference to the model to fetch from the BigQuery API.
If a string is passed in, this method attempts to create a
model reference from a string using
:func:`google.cloud.bigquery.model.ModelReference.from_string`.
retry (:class:`google.api_core.retry.Retry`):
(Optional) How to retry the RPC.
Returns:
google.cloud.bigquery.model.Model:
A ``Model`` instance.
"""
if isinstance(model_ref, str):
model_ref = ModelReference.from_string(
model_ref, default_project=self.project
)

api_response = self._call_api(retry, method="GET", path=model_ref.path)
return Model.from_api_repr(api_response)

def get_table(self, table_ref, retry=DEFAULT_RETRY):
"""Fetch the table referenced by ``table_ref``.
Expand Down Expand Up @@ -490,6 +519,41 @@ def update_dataset(self, dataset, fields, retry=DEFAULT_RETRY):
)
return Dataset.from_api_repr(api_response)

def update_model(self, model, fields, retry=DEFAULT_RETRY):
"""Change some fields of a model.
Use ``fields`` to specify which fields to update. At least one field
must be provided. If a field is listed in ``fields`` and is ``None``
in ``model``, it will be deleted.
If ``model.etag`` is not ``None``, the update will only succeed if
the model on the server has the same ETag. Thus reading a model with
``get_model``, changing its fields, and then passing it to
``update_model`` will ensure that the changes will only be saved if
no modifications to the model occurred since the read.
Args:
model (google.cloud.bigquery.model.Model): The model to update.
fields (Sequence[str]):
The fields of ``model`` to change, spelled as the Model
properties (e.g. "friendly_name").
retry (google.api_core.retry.Retry):
(Optional) A description of how to retry the API call.
Returns:
google.cloud.bigquery.model.Model:
The model resource returned from the API call.
"""
partial = model._build_resource(fields)
if model.etag:
headers = {"If-Match": model.etag}
else:
headers = None
api_response = self._call_api(
retry, method="PATCH", path=model.path, data=partial, headers=headers
)
return Model.from_api_repr(api_response)

def update_table(self, table, fields, retry=DEFAULT_RETRY):
"""Change some fields of a table.
Expand Down Expand Up @@ -525,6 +589,64 @@ def update_table(self, table, fields, retry=DEFAULT_RETRY):
)
return Table.from_api_repr(api_response)

def list_models(
self, dataset, max_results=None, page_token=None, retry=DEFAULT_RETRY
):
"""List models in the dataset.
See
https://cloud.google.com/bigquery/docs/reference/rest/v2/models/list
Args:
dataset (Union[ \
:class:`~google.cloud.bigquery.dataset.Dataset`, \
:class:`~google.cloud.bigquery.dataset.DatasetReference`, \
str, \
]):
A reference to the dataset whose models to list from the
BigQuery API. If a string is passed in, this method attempts
to create a dataset reference from a string using
:func:`google.cloud.bigquery.dataset.DatasetReference.from_string`.
max_results (int):
(Optional) Maximum number of models to return. If not passed,
defaults to a value set by the API.
page_token (str):
(Optional) Token representing a cursor into the models. If
not passed, the API will return the first page of models. The
token marks the beginning of the iterator to be returned and
the value of the ``page_token`` can be accessed at
``next_page_token`` of the
:class:`~google.api_core.page_iterator.HTTPIterator`.
retry (:class:`google.api_core.retry.Retry`):
(Optional) How to retry the RPC.
Returns:
google.api_core.page_iterator.Iterator:
Iterator of
:class:`~google.cloud.bigquery.model.Model` contained
within the requested dataset.
"""
if isinstance(dataset, str):
dataset = DatasetReference.from_string(
dataset, default_project=self.project
)

if not isinstance(dataset, (Dataset, DatasetReference)):
raise TypeError("dataset must be a Dataset, DatasetReference, or string")

path = "%s/models" % dataset.path
result = page_iterator.HTTPIterator(
client=self,
api_request=functools.partial(self._call_api, retry),
path=path,
item_to_value=_item_to_model,
items_key="models",
page_token=page_token,
max_results=max_results,
)
result.dataset = dataset
return result

def list_tables(
self, dataset, max_results=None, page_token=None, retry=DEFAULT_RETRY
):
Expand Down Expand Up @@ -631,6 +753,40 @@ def delete_dataset(
if not not_found_ok:
raise

def delete_model(self, model, retry=DEFAULT_RETRY, not_found_ok=False):
"""Delete a model
See
https://cloud.google.com/bigquery/docs/reference/rest/v2/models/delete
Args:
model (Union[ \
:class:`~google.cloud.bigquery.model.Model`, \
:class:`~google.cloud.bigquery.model.ModelReference`, \
str, \
]):
A reference to the model to delete. If a string is passed in,
this method attempts to create a model reference from a
string using
:func:`google.cloud.bigquery.model.ModelReference.from_string`.
retry (:class:`google.api_core.retry.Retry`):
(Optional) How to retry the RPC.
not_found_ok (bool):
Defaults to ``False``. If ``True``, ignore "not found" errors
when deleting the model.
"""
if isinstance(model, str):
model = ModelReference.from_string(model, default_project=self.project)

if not isinstance(model, (Model, ModelReference)):
raise TypeError("model must be a Model or a ModelReference")

try:
self._call_api(retry, method="DELETE", path=model.path)
except google.api_core.exceptions.NotFound:
if not not_found_ok:
raise

def delete_table(self, table, retry=DEFAULT_RETRY, not_found_ok=False):
"""Delete a table
Expand Down Expand Up @@ -1810,6 +1966,21 @@ def _item_to_job(iterator, resource):
return iterator.client.job_from_resource(resource)


def _item_to_model(iterator, resource):
"""Convert a JSON model to the native object.
Args:
iterator (google.api_core.page_iterator.Iterator):
The iterator that is currently in use.
resource (dict):
An item to be converted to a model.
Returns:
google.cloud.bigquery.model.Model: The next model in the page.
"""
return Model.from_api_repr(resource)


def _item_to_table(iterator, resource):
"""Convert a JSON table to the native object.
Expand Down
45 changes: 43 additions & 2 deletions bigquery/google/cloud/bigquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

"""Define resources for the BigQuery ML Models API."""

import datetime
import copy

from google.protobuf import json_format
import six

import google.cloud._helpers
from google.api_core import datetime_helpers
from google.cloud.bigquery import _helpers
from google.cloud.bigquery_v2 import types

Expand Down Expand Up @@ -83,6 +84,26 @@ def reference(self):
ref._proto = self._proto.model_reference
return ref

@property
def project(self):
"""str: Project bound to the model"""
return self.reference.project

@property
def dataset_id(self):
"""str: ID of dataset containing the model."""
return self.reference.dataset_id

@property
def model_id(self):
"""str: The model ID."""
return self.reference.model_id

@property
def path(self):
"""str: URL path for the model's APIs."""
return self.reference.path

@property
def location(self):
"""str: The geographic location where the model resides. This value
Expand Down Expand Up @@ -192,7 +213,7 @@ def expires(self):
@expires.setter
def expires(self, value):
if value is not None:
value = google.cloud._helpers._millis_from_datetime(value)
value = str(google.cloud._helpers._millis_from_datetime(value))
self._properties["expirationTime"] = value

@property
Expand Down Expand Up @@ -247,6 +268,17 @@ def from_api_repr(cls, resource):
google.cloud.bigquery.model.Model: Model parsed from ``resource``.
"""
this = cls(None)

# Convert from millis-from-epoch to timestamp well-known type.
# TODO: Remove this hack once CL 238585470 hits prod.
resource = copy.deepcopy(resource)
for training_run in resource.get("trainingRuns", ()):
start_time = training_run.get("startTime")
if not start_time or "-" in start_time: # Already right format?
continue
start_time = datetime_helpers.from_microseconds(1e3 * float(start_time))
training_run["startTime"] = datetime_helpers.to_rfc3339(start_time)

this._proto = json_format.ParseDict(resource, types.Model())
for key in six.itervalues(cls._PROPERTY_TO_API_FIELD):
# Leave missing keys unset. This allows us to use setdefault in the
Expand Down Expand Up @@ -288,6 +320,15 @@ def model_id(self):
"""str: The model ID."""
return self._proto.model_id

@property
def path(self):
"""str: URL path for the model's APIs."""
return "/projects/%s/datasets/%s/models/%s" % (
self._proto.project_id,
self._proto.dataset_id,
self._proto.model_id,
)

@classmethod
def from_api_repr(cls, resource):
"""Factory: construct a model reference given its API representation
Expand Down
3 changes: 3 additions & 0 deletions bigquery/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def snippets(session):
# Run py.test against the snippets tests.
session.run(
'py.test', os.path.join('docs', 'snippets.py'), *session.posargs)
session.run(
'py.test', os.path.join('samples'), *session.posargs)


@nox.session(python='3.6')
Expand Down Expand Up @@ -178,6 +180,7 @@ def blacken(session):
session.run(
"black",
"google",
"samples",
"tests",
"docs",
)
Expand Down
Empty file added bigquery/samples/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions bigquery/samples/delete_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def delete_model(client, model_id):
"""Sample ID: go/samples-tracker/1534"""

# [START bigquery_delete_model]
from google.cloud import bigquery

# TODO(developer): Construct a BigQuery client object.
# client = bigquery.Client()

# TODO(developer): Set model_id to the ID of the model to fetch.
# model_id = 'your-project.your_dataset.your_model'

client.delete_model(model_id)
# [END bigquery_delete_model]

print("Deleted model '{}'.".format(model_id))
35 changes: 35 additions & 0 deletions bigquery/samples/get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_model(client, model_id):
"""Sample ID: go/samples-tracker/1510"""

# [START bigquery_get_model]
from google.cloud import bigquery

# TODO(developer): Construct a BigQuery client object.
# client = bigquery.Client()

# TODO(developer): Set model_id to the ID of the model to fetch.
# model_id = 'your-project.your_dataset.your_model'

model = client.get_model(model_id)

full_model_id = "{}.{}.{}".format(model.project, model.dataset_id, model.model_id)
friendly_name = model.friendly_name
print(
"Got model '{}' with friendly_name '{}'.".format(full_model_id, friendly_name)
)
# [END bigquery_get_model]
Loading

0 comments on commit 8d29dcd

Please sign in to comment.