Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update automl_tables_predict.py with batch_predict_bq sample #4142

Merged
merged 10 commits into from
Jul 17, 2020
51 changes: 51 additions & 0 deletions tables/automl/automl_tables_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,49 @@ def predict(
# [END automl_tables_predict]


def batch_predict_bq(
project_id,
compute_region,
model_display_name,
bq_input_uri,
bq_output_uri,
):
"""Make a batch of predictions."""
# [START automl_tables_batch_predict_bq]
# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_display_name = 'MODEL_DISPLAY_NAME_HERE'
# bq_input_uri = 'bq://my-project.my-dataset.my-table'
# bq_output_uri = 'bq://my-project'

from google.cloud import automl_v1beta1 as automl

client = automl.TablesClient(project=project_id, region=compute_region)

# Query model
response = client.batch_predict(bigquery_input_uri=bq_input_uri,
bigquery_output_uri=bq_output_uri,
model_display_name=model_display_name)
print("Making batch prediction... ")
# `response` is a async operation descriptor,
# you can register a callback for the operation to complete via `add_done_callback`:
# def callback(operation_future):
# result = operation_future.result()
# response.add_done_callback(callback)
#
# or block the thread polling for the operation's results:
response.result()
# AutoML puts predictions in a newly generated dataset with a name by a mask "prediction_" + model_id + "_" + timestamp
# here's how to get the dataset name:
dataset_name = response.metadata.batch_predict_details.output_info.bigquery_output_dataset

print("Batch prediction complete.\nResults are in '{}' dataset.\n{}".format(
dataset_name, response.metadata))

# [END automl_tables_batch_predict_bq]


def batch_predict(
project_id,
compute_region,
Expand Down Expand Up @@ -108,7 +151,15 @@ def batch_predict(
model_display_name=model_display_name,
)
print("Making batch prediction... ")
# `response` is a async operation descriptor,
# you can register a callback for the operation to complete via `add_done_callback`:
# def callback(operation_future):
# result = operation_future.result()
# response.add_done_callback(callback)
#
# or block the thread polling for the operation's results:
response.result()

print("Batch prediction complete.\n{}".format(response.metadata))

# [END automl_tables_batch_predict]
Expand Down
12 changes: 12 additions & 0 deletions tables/automl/batch_predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
STATIC_MODEL = model_test.STATIC_MODEL
GCS_INPUT = "gs://{}-automl-tables-test/bank-marketing.csv".format(PROJECT)
GCS_OUTPUT = "gs://{}-automl-tables-test/TABLE_TEST_OUTPUT/".format(PROJECT)
BQ_INPUT = "bq://{}.automl_test.bank_marketing".format(PROJECT)
BQ_OUTPUT = "bq://{}".format(PROJECT)


@pytest.mark.slow
Expand All @@ -42,6 +44,16 @@ def test_batch_predict(capsys):
assert "Batch prediction complete" in out


@pytest.mark.slow
def test_batch_predict_bq(capsys):
ensure_model_online()
automl_tables_predict.batch_predict_bq(
PROJECT, REGION, STATIC_MODEL, BQ_INPUT, BQ_OUTPUT
)
out, _ = capsys.readouterr()
assert "Batch prediction complete" in out


def ensure_model_online():
model = model_test.ensure_model_ready()
if model.deployment_state != enums.Model.DeploymentState.DEPLOYED:
Expand Down