diff --git a/tables/automl/automl_tables_predict.py b/tables/automl/automl_tables_predict.py index 786f80fcb856..e965427258af 100644 --- a/tables/automl/automl_tables_predict.py +++ b/tables/automl/automl_tables_predict.py @@ -81,6 +81,49 @@ def predict( # [END automl_tables_predict] +def batch_predict_bq( + project_id, + compute_region, + model_display_name, + bq_input_uri, + bq_output_uri, +): + """Make a batch of predictions.""" + # [START automl_tables_batch_predict_bq] + # TODO(developer): Uncomment and set the following variables + # project_id = 'PROJECT_ID_HERE' + # compute_region = 'COMPUTE_REGION_HERE' + # model_display_name = 'MODEL_DISPLAY_NAME_HERE' + # bq_input_uri = 'bq://my-project.my-dataset.my-table' + # bq_output_uri = 'bq://my-project' + + from google.cloud import automl_v1beta1 as automl + + client = automl.TablesClient(project=project_id, region=compute_region) + + # Query model + response = client.batch_predict(bigquery_input_uri=bq_input_uri, + bigquery_output_uri=bq_output_uri, + model_display_name=model_display_name) + print("Making batch prediction... ") + # `response` is a async operation descriptor, + # you can register a callback for the operation to complete via `add_done_callback`: + # def callback(operation_future): + # result = operation_future.result() + # response.add_done_callback(callback) + # + # or block the thread polling for the operation's results: + response.result() + # AutoML puts predictions in a newly generated dataset with a name by a mask "prediction_" + model_id + "_" + timestamp + # here's how to get the dataset name: + dataset_name = response.metadata.batch_predict_details.output_info.bigquery_output_dataset + + print("Batch prediction complete.\nResults are in '{}' dataset.\n{}".format( + dataset_name, response.metadata)) + + # [END automl_tables_batch_predict_bq] + + def batch_predict( project_id, compute_region, @@ -108,7 +151,15 @@ def batch_predict( model_display_name=model_display_name, ) print("Making batch prediction... ") + # `response` is a async operation descriptor, + # you can register a callback for the operation to complete via `add_done_callback`: + # def callback(operation_future): + # result = operation_future.result() + # response.add_done_callback(callback) + # + # or block the thread polling for the operation's results: response.result() + print("Batch prediction complete.\n{}".format(response.metadata)) # [END automl_tables_batch_predict] diff --git a/tables/automl/batch_predict_test.py b/tables/automl/batch_predict_test.py index 37b5f0e09c34..f77404deefd2 100644 --- a/tables/automl/batch_predict_test.py +++ b/tables/automl/batch_predict_test.py @@ -30,6 +30,8 @@ STATIC_MODEL = model_test.STATIC_MODEL GCS_INPUT = "gs://{}-automl-tables-test/bank-marketing.csv".format(PROJECT) GCS_OUTPUT = "gs://{}-automl-tables-test/TABLE_TEST_OUTPUT/".format(PROJECT) +BQ_INPUT = "bq://{}.automl_test.bank_marketing".format(PROJECT) +BQ_OUTPUT = "bq://{}".format(PROJECT) @pytest.mark.slow @@ -42,6 +44,16 @@ def test_batch_predict(capsys): assert "Batch prediction complete" in out +@pytest.mark.slow +def test_batch_predict_bq(capsys): + ensure_model_online() + automl_tables_predict.batch_predict_bq( + PROJECT, REGION, STATIC_MODEL, BQ_INPUT, BQ_OUTPUT + ) + out, _ = capsys.readouterr() + assert "Batch prediction complete" in out + + def ensure_model_online(): model = model_test.ensure_model_ready() if model.deployment_state != enums.Model.DeploymentState.DEPLOYED: