From 575966d2410f157b8eb9e985e6b782f5b6327968 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 13 Sep 2021 16:09:35 -0500 Subject: [PATCH 01/24] feat: add `api_method` parameter to `Client.query` to select `insert` or `query` API Work in Progress. This commit only refactors to allow jobs.insert to be selected. Supporting jobs.query will require more transformations to QueryJobConfig, QueryJob, and RowIterator. --- google/cloud/bigquery/_job_helpers.py | 109 ++++++++++++++++++++++++++ google/cloud/bigquery/client.py | 107 ++++++++++--------------- 2 files changed, 152 insertions(+), 64 deletions(-) create mode 100644 google/cloud/bigquery/_job_helpers.py diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py new file mode 100644 index 000000000..198205da5 --- /dev/null +++ b/google/cloud/bigquery/_job_helpers.py @@ -0,0 +1,109 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for interacting with the job REST APIs from the client.""" + +import copy +import uuid +from typing import TYPE_CHECKING + +import google.api_core.exceptions as core_exceptions +from google.api_core import retry as retries + +from google.cloud.bigquery import job + +# Avoid circular imports +if TYPE_CHECKING: + from google.cloud.bigquery.client import Client +else: + Client = None + + +def make_job_id(job_id, prefix=None): + """Construct an ID for a new job. + + Args: + job_id (Optional[str]): the user-provided job ID. + + prefix (Optional[str]): the user-provided prefix for a job ID. + + Returns: + str: A job ID + """ + if job_id is not None: + return job_id + elif prefix is not None: + return str(prefix) + str(uuid.uuid4()) + else: + return str(uuid.uuid4()) + + +def query_jobs_insert( + client: Client, + query: str, + job_config: job.QueryJobConfig, + job_id: str, + job_id_prefix: str, + location: str, + project: str, + retry: retries.Retry, + timeout: float, + job_retry: retries.Retry, +): + job_id_given = job_id is not None + job_id_save = job_id + job_config_save = job_config + + def do_query(): + # Make a copy now, so that original doesn't get changed by the process + # below and to facilitate retry + job_config = copy.deepcopy(job_config_save) + + job_id = make_job_id(job_id_save, job_id_prefix) + job_ref = job._JobReference(job_id, project=project, location=location) + query_job = job.QueryJob(job_ref, query, client=client, job_config=job_config) + + try: + query_job._begin(retry=retry, timeout=timeout) + except core_exceptions.Conflict as create_exc: + # The thought is if someone is providing their own job IDs and they get + # their job ID generation wrong, this could end up returning results for + # the wrong query. We thus only try to recover if job ID was not given. + if job_id_given: + raise create_exc + + try: + query_job = client.get_job( + job_id, + project=project, + location=location, + retry=retry, + timeout=timeout, + ) + except core_exceptions.GoogleAPIError: # (includes RetryError) + raise create_exc + else: + return query_job + else: + return query_job + + future = do_query() + # The future might be in a failed state now, but if it's + # unrecoverable, we'll find out when we ask for it's result, at which + # point, we may retry. + if not job_id_given: + future._retry_do_query = do_query # in case we have to retry later + future._job_retry = job_retry + + return future diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index a738dd0f3..ca635cbf8 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -49,6 +49,9 @@ DEFAULT_CLIENT_INFO as DEFAULT_BQSTORAGE_CLIENT_INFO, ) +from google.cloud.bigquery import _job_helpers +from google.cloud.bigquery._job_helpers import make_job_id as _make_job_id +from google.cloud.bigquery._helpers import _del_sub_prop from google.cloud.bigquery._helpers import _get_sub_prop from google.cloud.bigquery._helpers import _record_field_to_json from google.cloud.bigquery._helpers import _str_or_none @@ -3121,6 +3124,7 @@ def query( retry: retries.Retry = DEFAULT_RETRY, timeout: float = DEFAULT_TIMEOUT, job_retry: retries.Retry = DEFAULT_JOB_RETRY, + api_method: str = "insert", ) -> job.QueryJob: """Run a SQL query. @@ -3172,6 +3176,20 @@ def query( called on the job returned. The ``job_retry`` specified here becomes the default ``job_retry`` for ``result()``, where it can also be specified. + api_method: + One of ``'insert'`` or ``'query'``. Defaults to ``'insert'``. + + When set to ``'insert'``, submit a query job by using the + `jobs.insert REST API method + _`. + This supports all job configuration options. + + When set to ``'query'``, submit a query job by using the + `jobs.query REST API method + `_. + This API waits up to the specified timeout for the query to + finish. The ``job_id`` and ``job_id_prefix`` parameters cannot + be used with this API method. Returns: google.cloud.bigquery.job.QueryJob: A new query job instance. @@ -3195,7 +3213,16 @@ def query( " provided." ) - job_id_save = job_id + if api_method not in {"insert", "query"}: + raise ValueError( + f"Got unexpected value for api_method: {repr(api_method)}" + " Expected one of {'insert', 'query'}." + ) + + if job_id_given and api_method == "query": + raise TypeError( + "`job_id` was provided, but the 'query' `api_method` was requested." + ) if project is None: project = self.project @@ -3226,50 +3253,21 @@ def query( # Note that we haven't modified the original job_config (or # _default_query_job_config) up to this point. - job_config_save = job_config - - def do_query(): - # Make a copy now, so that original doesn't get changed by the process - # below and to facilitate retry - job_config = copy.deepcopy(job_config_save) - - job_id = _make_job_id(job_id_save, job_id_prefix) - job_ref = job._JobReference(job_id, project=project, location=location) - query_job = job.QueryJob(job_ref, query, client=self, job_config=job_config) - - try: - query_job._begin(retry=retry, timeout=timeout) - except core_exceptions.Conflict as create_exc: - # The thought is if someone is providing their own job IDs and they get - # their job ID generation wrong, this could end up returning results for - # the wrong query. We thus only try to recover if job ID was not given. - if job_id_given: - raise create_exc - - try: - query_job = self.get_job( - job_id, - project=project, - location=location, - retry=retry, - timeout=timeout, - ) - except core_exceptions.GoogleAPIError: # (includes RetryError) - raise create_exc - else: - return query_job - else: - return query_job - - future = do_query() - # The future might be in a failed state now, but if it's - # unrecoverable, we'll find out when we ask for it's result, at which - # point, we may retry. - if not job_id_given: - future._retry_do_query = do_query # in case we have to retry later - future._job_retry = job_retry - - return future + if api_method == "query": + return None # TODO + else: + return _job_helpers.query_jobs_insert( + self, + query, + job_config, + job_id, + job_id_prefix, + location, + project, + retry, + timeout, + job_retry, + ) def insert_rows( self, @@ -3940,25 +3938,6 @@ def _extract_job_reference(job, project=None, location=None): return (project, location, job_id) -def _make_job_id(job_id, prefix=None): - """Construct an ID for a new job. - - Args: - job_id (Optional[str]): the user-provided job ID. - - prefix (Optional[str]): the user-provided prefix for a job ID. - - Returns: - str: A job ID - """ - if job_id is not None: - return job_id - elif prefix is not None: - return str(prefix) + str(uuid.uuid4()) - else: - return str(uuid.uuid4()) - - def _check_mode(stream): """Check that a stream was opened in read-binary mode. From 094e3cb9b0186e784906750903630eec774cb160 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 13 Sep 2021 17:03:12 -0500 Subject: [PATCH 02/24] WIP: begin implementation of jobs.query usage --- google/cloud/bigquery/_job_helpers.py | 64 +++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index 198205da5..f0cba9531 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -30,6 +30,9 @@ Client = None +_TIMEOUT_BUFFER_SECS = 0.1 + + def make_job_id(job_id, prefix=None): """Construct an ID for a new job. @@ -107,3 +110,64 @@ def do_query(): future._job_retry = job_retry return future + + +def query_jobs_query( + client: Client, + query: str, + job_config: job.QueryJobConfig, + location: str, + project: str, + retry: retries.Retry, + timeout: float, + job_retry: retries.Retry, +): + # TODO: Validate that destination is not set. + + request_body = {} + job_config_resource = job_config.to_api_repr() + + # Transform from Job resource to QueryRequest resource. + # Most of the keys in job.configuration.query are in common + request_body.update(job_config_resource["configuration"]["query"]) + request_body["location"] = location + request_body["labels"] = job_config.labels + request_body["dryRun"] = job_config.dry_run + + # Subtract a buffer for context switching, network latency, etc. + request_body["timeoutMs"] = max(0, int(1000 * (timeout - _TIMEOUT_BUFFER_SECS))) + + def do_query(): + request_body["requestId"] = make_job_id(None) + # job_ref = job._JobReference(job_id, project=project, location=location) + # query_job = job.QueryJob(job_ref, query, client=client, job_config=job_config) + + # query_job._begin(retry=retry, timeout=timeout) + client._call_api(retry) + + path = f"/projects/{project}/queries" + + # jobs.insert is idempotent because we ensure that every new + # job has an ID. + span_attributes = {"path": path} + api_response = client._call_api( + retry, + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=timeout, + ) + # TODO: make query job out of api_response + return api_response + + future = do_query() + + # The future might be in a failed state now, but if it's + # unrecoverable, we'll find out when we ask for it's result, at which + # point, we may retry. + future._retry_do_query = do_query # in case we have to retry later + future._job_retry = job_retry + + return future From 17994f525e0aad953134afc2ee9b0e4c0fc7e7aa Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 5 Oct 2021 14:14:47 -0500 Subject: [PATCH 03/24] remove extra files --- google/cloud/bigquery_v2/types/model.py | 1507 ----------------- .../bigquery_v2/types/table_reference.py | 58 - 2 files changed, 1565 deletions(-) delete mode 100644 google/cloud/bigquery_v2/types/model.py delete mode 100644 google/cloud/bigquery_v2/types/table_reference.py diff --git a/google/cloud/bigquery_v2/types/model.py b/google/cloud/bigquery_v2/types/model.py deleted file mode 100644 index 706418401..000000000 --- a/google/cloud/bigquery_v2/types/model.py +++ /dev/null @@ -1,1507 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import proto # type: ignore - -from google.cloud.bigquery_v2.types import encryption_config -from google.cloud.bigquery_v2.types import model_reference as gcb_model_reference -from google.cloud.bigquery_v2.types import standard_sql -from google.cloud.bigquery_v2.types import table_reference -from google.protobuf import timestamp_pb2 # type: ignore -from google.protobuf import wrappers_pb2 # type: ignore - - -__protobuf__ = proto.module( - package="google.cloud.bigquery.v2", - manifest={ - "Model", - "GetModelRequest", - "PatchModelRequest", - "DeleteModelRequest", - "ListModelsRequest", - "ListModelsResponse", - }, -) - - -class Model(proto.Message): - r""" - Attributes: - etag (str): - Output only. A hash of this resource. - model_reference (google.cloud.bigquery_v2.types.ModelReference): - Required. Unique identifier for this model. - creation_time (int): - Output only. The time when this model was - created, in millisecs since the epoch. - last_modified_time (int): - Output only. The time when this model was - last modified, in millisecs since the epoch. - description (str): - Optional. A user-friendly description of this - model. - friendly_name (str): - Optional. A descriptive name for this model. - labels (Sequence[google.cloud.bigquery_v2.types.Model.LabelsEntry]): - The labels associated with this model. You - can use these to organize and group your models. - Label keys and values can be no longer than 63 - characters, can only contain lowercase letters, - numeric characters, underscores and dashes. - International characters are allowed. Label - values are optional. Label keys must start with - a letter and each label in the list must have a - different key. - expiration_time (int): - Optional. The time when this model expires, - in milliseconds since the epoch. If not present, - the model will persist indefinitely. Expired - models will be deleted and their storage - reclaimed. The defaultTableExpirationMs - property of the encapsulating dataset can be - used to set a default expirationTime on newly - created models. - location (str): - Output only. The geographic location where - the model resides. This value is inherited from - the dataset. - encryption_configuration (google.cloud.bigquery_v2.types.EncryptionConfiguration): - Custom encryption configuration (e.g., Cloud - KMS keys). This shows the encryption - configuration of the model data while stored in - BigQuery storage. This field can be used with - PatchModel to update encryption key for an - already encrypted model. - model_type (google.cloud.bigquery_v2.types.Model.ModelType): - Output only. Type of the model resource. - training_runs (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun]): - Output only. Information for all training runs in increasing - order of start_time. - feature_columns (Sequence[google.cloud.bigquery_v2.types.StandardSqlField]): - Output only. Input feature columns that were - used to train this model. - label_columns (Sequence[google.cloud.bigquery_v2.types.StandardSqlField]): - Output only. Label columns that were used to train this - model. The output of the model will have a `predicted_` - prefix to these columns. - best_trial_id (int): - The best trial_id across all training runs. - """ - - class ModelType(proto.Enum): - r"""Indicates the type of the Model.""" - MODEL_TYPE_UNSPECIFIED = 0 - LINEAR_REGRESSION = 1 - LOGISTIC_REGRESSION = 2 - KMEANS = 3 - MATRIX_FACTORIZATION = 4 - DNN_CLASSIFIER = 5 - TENSORFLOW = 6 - DNN_REGRESSOR = 7 - BOOSTED_TREE_REGRESSOR = 9 - BOOSTED_TREE_CLASSIFIER = 10 - ARIMA = 11 - AUTOML_REGRESSOR = 12 - AUTOML_CLASSIFIER = 13 - ARIMA_PLUS = 19 - - class LossType(proto.Enum): - r"""Loss metric to evaluate model training performance.""" - LOSS_TYPE_UNSPECIFIED = 0 - MEAN_SQUARED_LOSS = 1 - MEAN_LOG_LOSS = 2 - - class DistanceType(proto.Enum): - r"""Distance metric used to compute the distance between two - points. - """ - DISTANCE_TYPE_UNSPECIFIED = 0 - EUCLIDEAN = 1 - COSINE = 2 - - class DataSplitMethod(proto.Enum): - r"""Indicates the method to split input data into multiple - tables. - """ - DATA_SPLIT_METHOD_UNSPECIFIED = 0 - RANDOM = 1 - CUSTOM = 2 - SEQUENTIAL = 3 - NO_SPLIT = 4 - AUTO_SPLIT = 5 - - class DataFrequency(proto.Enum): - r"""Type of supported data frequency for time series forecasting - models. - """ - DATA_FREQUENCY_UNSPECIFIED = 0 - AUTO_FREQUENCY = 1 - YEARLY = 2 - QUARTERLY = 3 - MONTHLY = 4 - WEEKLY = 5 - DAILY = 6 - HOURLY = 7 - PER_MINUTE = 8 - - class HolidayRegion(proto.Enum): - r"""Type of supported holiday regions for time series forecasting - models. - """ - HOLIDAY_REGION_UNSPECIFIED = 0 - GLOBAL = 1 - NA = 2 - JAPAC = 3 - EMEA = 4 - LAC = 5 - AE = 6 - AR = 7 - AT = 8 - AU = 9 - BE = 10 - BR = 11 - CA = 12 - CH = 13 - CL = 14 - CN = 15 - CO = 16 - CS = 17 - CZ = 18 - DE = 19 - DK = 20 - DZ = 21 - EC = 22 - EE = 23 - EG = 24 - ES = 25 - FI = 26 - FR = 27 - GB = 28 - GR = 29 - HK = 30 - HU = 31 - ID = 32 - IE = 33 - IL = 34 - IN = 35 - IR = 36 - IT = 37 - JP = 38 - KR = 39 - LV = 40 - MA = 41 - MX = 42 - MY = 43 - NG = 44 - NL = 45 - NO = 46 - NZ = 47 - PE = 48 - PH = 49 - PK = 50 - PL = 51 - PT = 52 - RO = 53 - RS = 54 - RU = 55 - SA = 56 - SE = 57 - SG = 58 - SI = 59 - SK = 60 - TH = 61 - TR = 62 - TW = 63 - UA = 64 - US = 65 - VE = 66 - VN = 67 - ZA = 68 - - class LearnRateStrategy(proto.Enum): - r"""Indicates the learning rate optimization strategy to use.""" - LEARN_RATE_STRATEGY_UNSPECIFIED = 0 - LINE_SEARCH = 1 - CONSTANT = 2 - - class OptimizationStrategy(proto.Enum): - r"""Indicates the optimization strategy used for training.""" - OPTIMIZATION_STRATEGY_UNSPECIFIED = 0 - BATCH_GRADIENT_DESCENT = 1 - NORMAL_EQUATION = 2 - - class FeedbackType(proto.Enum): - r"""Indicates the training algorithm to use for matrix - factorization models. - """ - FEEDBACK_TYPE_UNSPECIFIED = 0 - IMPLICIT = 1 - EXPLICIT = 2 - - class SeasonalPeriod(proto.Message): - r""" """ - - class SeasonalPeriodType(proto.Enum): - r"""""" - SEASONAL_PERIOD_TYPE_UNSPECIFIED = 0 - NO_SEASONALITY = 1 - DAILY = 2 - WEEKLY = 3 - MONTHLY = 4 - QUARTERLY = 5 - YEARLY = 6 - - class KmeansEnums(proto.Message): - r""" """ - - class KmeansInitializationMethod(proto.Enum): - r"""Indicates the method used to initialize the centroids for - KMeans clustering algorithm. - """ - KMEANS_INITIALIZATION_METHOD_UNSPECIFIED = 0 - RANDOM = 1 - CUSTOM = 2 - KMEANS_PLUS_PLUS = 3 - - class RegressionMetrics(proto.Message): - r"""Evaluation metrics for regression and explicit feedback type - matrix factorization models. - - Attributes: - mean_absolute_error (google.protobuf.wrappers_pb2.DoubleValue): - Mean absolute error. - mean_squared_error (google.protobuf.wrappers_pb2.DoubleValue): - Mean squared error. - mean_squared_log_error (google.protobuf.wrappers_pb2.DoubleValue): - Mean squared log error. - median_absolute_error (google.protobuf.wrappers_pb2.DoubleValue): - Median absolute error. - r_squared (google.protobuf.wrappers_pb2.DoubleValue): - R^2 score. This corresponds to r2_score in ML.EVALUATE. - """ - - mean_absolute_error = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, - ) - mean_squared_error = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, - ) - mean_squared_log_error = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.DoubleValue, - ) - median_absolute_error = proto.Field( - proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, - ) - r_squared = proto.Field( - proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, - ) - - class AggregateClassificationMetrics(proto.Message): - r"""Aggregate metrics for classification/classifier models. For - multi-class models, the metrics are either macro-averaged or - micro-averaged. When macro-averaged, the metrics are calculated - for each label and then an unweighted average is taken of those - values. When micro-averaged, the metric is calculated globally - by counting the total number of correctly predicted rows. - - Attributes: - precision (google.protobuf.wrappers_pb2.DoubleValue): - Precision is the fraction of actual positive - predictions that had positive actual labels. For - multiclass this is a macro-averaged metric - treating each class as a binary classifier. - recall (google.protobuf.wrappers_pb2.DoubleValue): - Recall is the fraction of actual positive - labels that were given a positive prediction. - For multiclass this is a macro-averaged metric. - accuracy (google.protobuf.wrappers_pb2.DoubleValue): - Accuracy is the fraction of predictions given - the correct label. For multiclass this is a - micro-averaged metric. - threshold (google.protobuf.wrappers_pb2.DoubleValue): - Threshold at which the metrics are computed. - For binary classification models this is the - positive class threshold. For multi-class - classfication models this is the confidence - threshold. - f1_score (google.protobuf.wrappers_pb2.DoubleValue): - The F1 score is an average of recall and - precision. For multiclass this is a macro- - averaged metric. - log_loss (google.protobuf.wrappers_pb2.DoubleValue): - Logarithmic Loss. For multiclass this is a - macro-averaged metric. - roc_auc (google.protobuf.wrappers_pb2.DoubleValue): - Area Under a ROC Curve. For multiclass this - is a macro-averaged metric. - """ - - precision = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, - ) - recall = proto.Field(proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue,) - accuracy = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.DoubleValue, - ) - threshold = proto.Field( - proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, - ) - f1_score = proto.Field( - proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, - ) - log_loss = proto.Field( - proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, - ) - roc_auc = proto.Field( - proto.MESSAGE, number=7, message=wrappers_pb2.DoubleValue, - ) - - class BinaryClassificationMetrics(proto.Message): - r"""Evaluation metrics for binary classification/classifier - models. - - Attributes: - aggregate_classification_metrics (google.cloud.bigquery_v2.types.Model.AggregateClassificationMetrics): - Aggregate classification metrics. - binary_confusion_matrix_list (Sequence[google.cloud.bigquery_v2.types.Model.BinaryClassificationMetrics.BinaryConfusionMatrix]): - Binary confusion matrix at multiple - thresholds. - positive_label (str): - Label representing the positive class. - negative_label (str): - Label representing the negative class. - """ - - class BinaryConfusionMatrix(proto.Message): - r"""Confusion matrix for binary classification models. - Attributes: - positive_class_threshold (google.protobuf.wrappers_pb2.DoubleValue): - Threshold value used when computing each of - the following metric. - true_positives (google.protobuf.wrappers_pb2.Int64Value): - Number of true samples predicted as true. - false_positives (google.protobuf.wrappers_pb2.Int64Value): - Number of false samples predicted as true. - true_negatives (google.protobuf.wrappers_pb2.Int64Value): - Number of true samples predicted as false. - false_negatives (google.protobuf.wrappers_pb2.Int64Value): - Number of false samples predicted as false. - precision (google.protobuf.wrappers_pb2.DoubleValue): - The fraction of actual positive predictions - that had positive actual labels. - recall (google.protobuf.wrappers_pb2.DoubleValue): - The fraction of actual positive labels that - were given a positive prediction. - f1_score (google.protobuf.wrappers_pb2.DoubleValue): - The equally weighted average of recall and - precision. - accuracy (google.protobuf.wrappers_pb2.DoubleValue): - The fraction of predictions given the correct - label. - """ - - positive_class_threshold = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, - ) - true_positives = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.Int64Value, - ) - false_positives = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.Int64Value, - ) - true_negatives = proto.Field( - proto.MESSAGE, number=4, message=wrappers_pb2.Int64Value, - ) - false_negatives = proto.Field( - proto.MESSAGE, number=5, message=wrappers_pb2.Int64Value, - ) - precision = proto.Field( - proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, - ) - recall = proto.Field( - proto.MESSAGE, number=7, message=wrappers_pb2.DoubleValue, - ) - f1_score = proto.Field( - proto.MESSAGE, number=8, message=wrappers_pb2.DoubleValue, - ) - accuracy = proto.Field( - proto.MESSAGE, number=9, message=wrappers_pb2.DoubleValue, - ) - - aggregate_classification_metrics = proto.Field( - proto.MESSAGE, number=1, message="Model.AggregateClassificationMetrics", - ) - binary_confusion_matrix_list = proto.RepeatedField( - proto.MESSAGE, - number=2, - message="Model.BinaryClassificationMetrics.BinaryConfusionMatrix", - ) - positive_label = proto.Field(proto.STRING, number=3,) - negative_label = proto.Field(proto.STRING, number=4,) - - class MultiClassClassificationMetrics(proto.Message): - r"""Evaluation metrics for multi-class classification/classifier - models. - - Attributes: - aggregate_classification_metrics (google.cloud.bigquery_v2.types.Model.AggregateClassificationMetrics): - Aggregate classification metrics. - confusion_matrix_list (Sequence[google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics.ConfusionMatrix]): - Confusion matrix at different thresholds. - """ - - class ConfusionMatrix(proto.Message): - r"""Confusion matrix for multi-class classification models. - Attributes: - confidence_threshold (google.protobuf.wrappers_pb2.DoubleValue): - Confidence threshold used when computing the - entries of the confusion matrix. - rows (Sequence[google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics.ConfusionMatrix.Row]): - One row per actual label. - """ - - class Entry(proto.Message): - r"""A single entry in the confusion matrix. - Attributes: - predicted_label (str): - The predicted label. For confidence_threshold > 0, we will - also add an entry indicating the number of items under the - confidence threshold. - item_count (google.protobuf.wrappers_pb2.Int64Value): - Number of items being predicted as this - label. - """ - - predicted_label = proto.Field(proto.STRING, number=1,) - item_count = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.Int64Value, - ) - - class Row(proto.Message): - r"""A single row in the confusion matrix. - Attributes: - actual_label (str): - The original label of this row. - entries (Sequence[google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics.ConfusionMatrix.Entry]): - Info describing predicted label distribution. - """ - - actual_label = proto.Field(proto.STRING, number=1,) - entries = proto.RepeatedField( - proto.MESSAGE, - number=2, - message="Model.MultiClassClassificationMetrics.ConfusionMatrix.Entry", - ) - - confidence_threshold = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, - ) - rows = proto.RepeatedField( - proto.MESSAGE, - number=2, - message="Model.MultiClassClassificationMetrics.ConfusionMatrix.Row", - ) - - aggregate_classification_metrics = proto.Field( - proto.MESSAGE, number=1, message="Model.AggregateClassificationMetrics", - ) - confusion_matrix_list = proto.RepeatedField( - proto.MESSAGE, - number=2, - message="Model.MultiClassClassificationMetrics.ConfusionMatrix", - ) - - class ClusteringMetrics(proto.Message): - r"""Evaluation metrics for clustering models. - Attributes: - davies_bouldin_index (google.protobuf.wrappers_pb2.DoubleValue): - Davies-Bouldin index. - mean_squared_distance (google.protobuf.wrappers_pb2.DoubleValue): - Mean of squared distances between each sample - to its cluster centroid. - clusters (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster]): - Information for all clusters. - """ - - class Cluster(proto.Message): - r"""Message containing the information about one cluster. - Attributes: - centroid_id (int): - Centroid id. - feature_values (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster.FeatureValue]): - Values of highly variant features for this - cluster. - count (google.protobuf.wrappers_pb2.Int64Value): - Count of training data rows that were - assigned to this cluster. - """ - - class FeatureValue(proto.Message): - r"""Representative value of a single feature within the cluster. - Attributes: - feature_column (str): - The feature column name. - numerical_value (google.protobuf.wrappers_pb2.DoubleValue): - The numerical feature value. This is the - centroid value for this feature. - categorical_value (google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue): - The categorical feature value. - """ - - class CategoricalValue(proto.Message): - r"""Representative value of a categorical feature. - Attributes: - category_counts (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue.CategoryCount]): - Counts of all categories for the categorical feature. If - there are more than ten categories, we return top ten (by - count) and return one more CategoryCount with category - "*OTHER*" and count as aggregate counts of remaining - categories. - """ - - class CategoryCount(proto.Message): - r"""Represents the count of a single category within the cluster. - Attributes: - category (str): - The name of category. - count (google.protobuf.wrappers_pb2.Int64Value): - The count of training samples matching the - category within the cluster. - """ - - category = proto.Field(proto.STRING, number=1,) - count = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.Int64Value, - ) - - category_counts = proto.RepeatedField( - proto.MESSAGE, - number=1, - message="Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue.CategoryCount", - ) - - feature_column = proto.Field(proto.STRING, number=1,) - numerical_value = proto.Field( - proto.MESSAGE, - number=2, - oneof="value", - message=wrappers_pb2.DoubleValue, - ) - categorical_value = proto.Field( - proto.MESSAGE, - number=3, - oneof="value", - message="Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue", - ) - - centroid_id = proto.Field(proto.INT64, number=1,) - feature_values = proto.RepeatedField( - proto.MESSAGE, - number=2, - message="Model.ClusteringMetrics.Cluster.FeatureValue", - ) - count = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.Int64Value, - ) - - davies_bouldin_index = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, - ) - mean_squared_distance = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, - ) - clusters = proto.RepeatedField( - proto.MESSAGE, number=3, message="Model.ClusteringMetrics.Cluster", - ) - - class RankingMetrics(proto.Message): - r"""Evaluation metrics used by weighted-ALS models specified by - feedback_type=implicit. - - Attributes: - mean_average_precision (google.protobuf.wrappers_pb2.DoubleValue): - Calculates a precision per user for all the - items by ranking them and then averages all the - precisions across all the users. - mean_squared_error (google.protobuf.wrappers_pb2.DoubleValue): - Similar to the mean squared error computed in - regression and explicit recommendation models - except instead of computing the rating directly, - the output from evaluate is computed against a - preference which is 1 or 0 depending on if the - rating exists or not. - normalized_discounted_cumulative_gain (google.protobuf.wrappers_pb2.DoubleValue): - A metric to determine the goodness of a - ranking calculated from the predicted confidence - by comparing it to an ideal rank measured by the - original ratings. - average_rank (google.protobuf.wrappers_pb2.DoubleValue): - Determines the goodness of a ranking by - computing the percentile rank from the predicted - confidence and dividing it by the original rank. - """ - - mean_average_precision = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, - ) - mean_squared_error = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, - ) - normalized_discounted_cumulative_gain = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.DoubleValue, - ) - average_rank = proto.Field( - proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, - ) - - class ArimaForecastingMetrics(proto.Message): - r"""Model evaluation metrics for ARIMA forecasting models. - Attributes: - non_seasonal_order (Sequence[google.cloud.bigquery_v2.types.Model.ArimaOrder]): - Non-seasonal order. - arima_fitting_metrics (Sequence[google.cloud.bigquery_v2.types.Model.ArimaFittingMetrics]): - Arima model fitting metrics. - seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): - Seasonal periods. Repeated because multiple - periods are supported for one time series. - has_drift (Sequence[bool]): - Whether Arima model fitted with drift or not. - It is always false when d is not 1. - time_series_id (Sequence[str]): - Id to differentiate different time series for - the large-scale case. - arima_single_model_forecasting_metrics (Sequence[google.cloud.bigquery_v2.types.Model.ArimaForecastingMetrics.ArimaSingleModelForecastingMetrics]): - Repeated as there can be many metric sets - (one for each model) in auto-arima and the - large-scale case. - """ - - class ArimaSingleModelForecastingMetrics(proto.Message): - r"""Model evaluation metrics for a single ARIMA forecasting - model. - - Attributes: - non_seasonal_order (google.cloud.bigquery_v2.types.Model.ArimaOrder): - Non-seasonal order. - arima_fitting_metrics (google.cloud.bigquery_v2.types.Model.ArimaFittingMetrics): - Arima fitting metrics. - has_drift (bool): - Is arima model fitted with drift or not. It - is always false when d is not 1. - time_series_id (str): - The time_series_id value for this time series. It will be - one of the unique values from the time_series_id_column - specified during ARIMA model training. Only present when - time_series_id_column training option was used. - time_series_ids (Sequence[str]): - The tuple of time_series_ids identifying this time series. - It will be one of the unique tuples of values present in the - time_series_id_columns specified during ARIMA model - training. Only present when time_series_id_columns training - option was used and the order of values here are same as the - order of time_series_id_columns. - seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): - Seasonal periods. Repeated because multiple - periods are supported for one time series. - has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue): - If true, holiday_effect is a part of time series - decomposition result. - has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue): - If true, spikes_and_dips is a part of time series - decomposition result. - has_step_changes (google.protobuf.wrappers_pb2.BoolValue): - If true, step_changes is a part of time series decomposition - result. - """ - - non_seasonal_order = proto.Field( - proto.MESSAGE, number=1, message="Model.ArimaOrder", - ) - arima_fitting_metrics = proto.Field( - proto.MESSAGE, number=2, message="Model.ArimaFittingMetrics", - ) - has_drift = proto.Field(proto.BOOL, number=3,) - time_series_id = proto.Field(proto.STRING, number=4,) - time_series_ids = proto.RepeatedField(proto.STRING, number=9,) - seasonal_periods = proto.RepeatedField( - proto.ENUM, number=5, enum="Model.SeasonalPeriod.SeasonalPeriodType", - ) - has_holiday_effect = proto.Field( - proto.MESSAGE, number=6, message=wrappers_pb2.BoolValue, - ) - has_spikes_and_dips = proto.Field( - proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue, - ) - has_step_changes = proto.Field( - proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue, - ) - - non_seasonal_order = proto.RepeatedField( - proto.MESSAGE, number=1, message="Model.ArimaOrder", - ) - arima_fitting_metrics = proto.RepeatedField( - proto.MESSAGE, number=2, message="Model.ArimaFittingMetrics", - ) - seasonal_periods = proto.RepeatedField( - proto.ENUM, number=3, enum="Model.SeasonalPeriod.SeasonalPeriodType", - ) - has_drift = proto.RepeatedField(proto.BOOL, number=4,) - time_series_id = proto.RepeatedField(proto.STRING, number=5,) - arima_single_model_forecasting_metrics = proto.RepeatedField( - proto.MESSAGE, - number=6, - message="Model.ArimaForecastingMetrics.ArimaSingleModelForecastingMetrics", - ) - - class EvaluationMetrics(proto.Message): - r"""Evaluation metrics of a model. These are either computed on - all training data or just the eval data based on whether eval - data was used during training. These are not present for - imported models. - - Attributes: - regression_metrics (google.cloud.bigquery_v2.types.Model.RegressionMetrics): - Populated for regression models and explicit - feedback type matrix factorization models. - binary_classification_metrics (google.cloud.bigquery_v2.types.Model.BinaryClassificationMetrics): - Populated for binary - classification/classifier models. - multi_class_classification_metrics (google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics): - Populated for multi-class - classification/classifier models. - clustering_metrics (google.cloud.bigquery_v2.types.Model.ClusteringMetrics): - Populated for clustering models. - ranking_metrics (google.cloud.bigquery_v2.types.Model.RankingMetrics): - Populated for implicit feedback type matrix - factorization models. - arima_forecasting_metrics (google.cloud.bigquery_v2.types.Model.ArimaForecastingMetrics): - Populated for ARIMA models. - """ - - regression_metrics = proto.Field( - proto.MESSAGE, number=1, oneof="metrics", message="Model.RegressionMetrics", - ) - binary_classification_metrics = proto.Field( - proto.MESSAGE, - number=2, - oneof="metrics", - message="Model.BinaryClassificationMetrics", - ) - multi_class_classification_metrics = proto.Field( - proto.MESSAGE, - number=3, - oneof="metrics", - message="Model.MultiClassClassificationMetrics", - ) - clustering_metrics = proto.Field( - proto.MESSAGE, number=4, oneof="metrics", message="Model.ClusteringMetrics", - ) - ranking_metrics = proto.Field( - proto.MESSAGE, number=5, oneof="metrics", message="Model.RankingMetrics", - ) - arima_forecasting_metrics = proto.Field( - proto.MESSAGE, - number=6, - oneof="metrics", - message="Model.ArimaForecastingMetrics", - ) - - class DataSplitResult(proto.Message): - r"""Data split result. This contains references to the training - and evaluation data tables that were used to train the model. - - Attributes: - training_table (google.cloud.bigquery_v2.types.TableReference): - Table reference of the training data after - split. - evaluation_table (google.cloud.bigquery_v2.types.TableReference): - Table reference of the evaluation data after - split. - """ - - training_table = proto.Field( - proto.MESSAGE, number=1, message=table_reference.TableReference, - ) - evaluation_table = proto.Field( - proto.MESSAGE, number=2, message=table_reference.TableReference, - ) - - class ArimaOrder(proto.Message): - r"""Arima order, can be used for both non-seasonal and seasonal - parts. - - Attributes: - p (int): - Order of the autoregressive part. - d (int): - Order of the differencing part. - q (int): - Order of the moving-average part. - """ - - p = proto.Field(proto.INT64, number=1,) - d = proto.Field(proto.INT64, number=2,) - q = proto.Field(proto.INT64, number=3,) - - class ArimaFittingMetrics(proto.Message): - r"""ARIMA model fitting metrics. - Attributes: - log_likelihood (float): - Log-likelihood. - aic (float): - AIC. - variance (float): - Variance. - """ - - log_likelihood = proto.Field(proto.DOUBLE, number=1,) - aic = proto.Field(proto.DOUBLE, number=2,) - variance = proto.Field(proto.DOUBLE, number=3,) - - class GlobalExplanation(proto.Message): - r"""Global explanations containing the top most important - features after training. - - Attributes: - explanations (Sequence[google.cloud.bigquery_v2.types.Model.GlobalExplanation.Explanation]): - A list of the top global explanations. Sorted - by absolute value of attribution in descending - order. - class_label (str): - Class label for this set of global - explanations. Will be empty/null for binary - logistic and linear regression models. Sorted - alphabetically in descending order. - """ - - class Explanation(proto.Message): - r"""Explanation for a single feature. - Attributes: - feature_name (str): - Full name of the feature. For non-numerical features, will - be formatted like .. - Overall size of feature name will always be truncated to - first 120 characters. - attribution (google.protobuf.wrappers_pb2.DoubleValue): - Attribution of feature. - """ - - feature_name = proto.Field(proto.STRING, number=1,) - attribution = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, - ) - - explanations = proto.RepeatedField( - proto.MESSAGE, number=1, message="Model.GlobalExplanation.Explanation", - ) - class_label = proto.Field(proto.STRING, number=2,) - - class TrainingRun(proto.Message): - r"""Information about a single training query run for the model. - Attributes: - training_options (google.cloud.bigquery_v2.types.Model.TrainingRun.TrainingOptions): - Options that were used for this training run, - includes user specified and default options that - were used. - start_time (google.protobuf.timestamp_pb2.Timestamp): - The start time of this training run. - results (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult]): - Output of each iteration run, results.size() <= - max_iterations. - evaluation_metrics (google.cloud.bigquery_v2.types.Model.EvaluationMetrics): - The evaluation metrics over training/eval - data that were computed at the end of training. - data_split_result (google.cloud.bigquery_v2.types.Model.DataSplitResult): - Data split result of the training run. Only - set when the input data is actually split. - global_explanations (Sequence[google.cloud.bigquery_v2.types.Model.GlobalExplanation]): - Global explanations for important features of - the model. For multi-class models, there is one - entry for each label class. For other models, - there is only one entry in the list. - """ - - class TrainingOptions(proto.Message): - r"""Options used in model training. - Attributes: - max_iterations (int): - The maximum number of iterations in training. - Used only for iterative training algorithms. - loss_type (google.cloud.bigquery_v2.types.Model.LossType): - Type of loss function used during training - run. - learn_rate (float): - Learning rate in training. Used only for - iterative training algorithms. - l1_regularization (google.protobuf.wrappers_pb2.DoubleValue): - L1 regularization coefficient. - l2_regularization (google.protobuf.wrappers_pb2.DoubleValue): - L2 regularization coefficient. - min_relative_progress (google.protobuf.wrappers_pb2.DoubleValue): - When early_stop is true, stops training when accuracy - improvement is less than 'min_relative_progress'. Used only - for iterative training algorithms. - warm_start (google.protobuf.wrappers_pb2.BoolValue): - Whether to train a model from the last - checkpoint. - early_stop (google.protobuf.wrappers_pb2.BoolValue): - Whether to stop early when the loss doesn't improve - significantly any more (compared to min_relative_progress). - Used only for iterative training algorithms. - input_label_columns (Sequence[str]): - Name of input label columns in training data. - data_split_method (google.cloud.bigquery_v2.types.Model.DataSplitMethod): - The data split type for training and - evaluation, e.g. RANDOM. - data_split_eval_fraction (float): - The fraction of evaluation data over the - whole input data. The rest of data will be used - as training data. The format should be double. - Accurate to two decimal places. - Default value is 0.2. - data_split_column (str): - The column to split data with. This column won't be used as - a feature. - - 1. When data_split_method is CUSTOM, the corresponding - column should be boolean. The rows with true value tag - are eval data, and the false are training data. - 2. When data_split_method is SEQ, the first - DATA_SPLIT_EVAL_FRACTION rows (from smallest to largest) - in the corresponding column are used as training data, - and the rest are eval data. It respects the order in - Orderable data types: - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data-type-properties - learn_rate_strategy (google.cloud.bigquery_v2.types.Model.LearnRateStrategy): - The strategy to determine learn rate for the - current iteration. - initial_learn_rate (float): - Specifies the initial learning rate for the - line search learn rate strategy. - label_class_weights (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.TrainingOptions.LabelClassWeightsEntry]): - Weights associated with each label class, for - rebalancing the training data. Only applicable - for classification models. - user_column (str): - User column specified for matrix - factorization models. - item_column (str): - Item column specified for matrix - factorization models. - distance_type (google.cloud.bigquery_v2.types.Model.DistanceType): - Distance type for clustering models. - num_clusters (int): - Number of clusters for clustering models. - model_uri (str): - Google Cloud Storage URI from which the model - was imported. Only applicable for imported - models. - optimization_strategy (google.cloud.bigquery_v2.types.Model.OptimizationStrategy): - Optimization strategy for training linear - regression models. - hidden_units (Sequence[int]): - Hidden units for dnn models. - batch_size (int): - Batch size for dnn models. - dropout (google.protobuf.wrappers_pb2.DoubleValue): - Dropout probability for dnn models. - max_tree_depth (int): - Maximum depth of a tree for boosted tree - models. - subsample (float): - Subsample fraction of the training data to - grow tree to prevent overfitting for boosted - tree models. - min_split_loss (google.protobuf.wrappers_pb2.DoubleValue): - Minimum split loss for boosted tree models. - num_factors (int): - Num factors specified for matrix - factorization models. - feedback_type (google.cloud.bigquery_v2.types.Model.FeedbackType): - Feedback type that specifies which algorithm - to run for matrix factorization. - wals_alpha (google.protobuf.wrappers_pb2.DoubleValue): - Hyperparameter for matrix factoration when - implicit feedback type is specified. - kmeans_initialization_method (google.cloud.bigquery_v2.types.Model.KmeansEnums.KmeansInitializationMethod): - The method used to initialize the centroids - for kmeans algorithm. - kmeans_initialization_column (str): - The column used to provide the initial centroids for kmeans - algorithm when kmeans_initialization_method is CUSTOM. - time_series_timestamp_column (str): - Column to be designated as time series - timestamp for ARIMA model. - time_series_data_column (str): - Column to be designated as time series data - for ARIMA model. - auto_arima (bool): - Whether to enable auto ARIMA or not. - non_seasonal_order (google.cloud.bigquery_v2.types.Model.ArimaOrder): - A specification of the non-seasonal part of - the ARIMA model: the three components (p, d, q) - are the AR order, the degree of differencing, - and the MA order. - data_frequency (google.cloud.bigquery_v2.types.Model.DataFrequency): - The data frequency of a time series. - include_drift (bool): - Include drift when fitting an ARIMA model. - holiday_region (google.cloud.bigquery_v2.types.Model.HolidayRegion): - The geographical region based on which the - holidays are considered in time series modeling. - If a valid value is specified, then holiday - effects modeling is enabled. - time_series_id_column (str): - The time series id column that was used - during ARIMA model training. - time_series_id_columns (Sequence[str]): - The time series id columns that were used - during ARIMA model training. - horizon (int): - The number of periods ahead that need to be - forecasted. - preserve_input_structs (bool): - Whether to preserve the input structs in output feature - names. Suppose there is a struct A with field b. When false - (default), the output feature name is A_b. When true, the - output feature name is A.b. - auto_arima_max_order (int): - The max value of non-seasonal p and q. - decompose_time_series (google.protobuf.wrappers_pb2.BoolValue): - If true, perform decompose time series and - save the results. - clean_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue): - If true, clean spikes and dips in the input - time series. - adjust_step_changes (google.protobuf.wrappers_pb2.BoolValue): - If true, detect step changes and make data - adjustment in the input time series. - """ - - max_iterations = proto.Field(proto.INT64, number=1,) - loss_type = proto.Field(proto.ENUM, number=2, enum="Model.LossType",) - learn_rate = proto.Field(proto.DOUBLE, number=3,) - l1_regularization = proto.Field( - proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, - ) - l2_regularization = proto.Field( - proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, - ) - min_relative_progress = proto.Field( - proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, - ) - warm_start = proto.Field( - proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue, - ) - early_stop = proto.Field( - proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue, - ) - input_label_columns = proto.RepeatedField(proto.STRING, number=9,) - data_split_method = proto.Field( - proto.ENUM, number=10, enum="Model.DataSplitMethod", - ) - data_split_eval_fraction = proto.Field(proto.DOUBLE, number=11,) - data_split_column = proto.Field(proto.STRING, number=12,) - learn_rate_strategy = proto.Field( - proto.ENUM, number=13, enum="Model.LearnRateStrategy", - ) - initial_learn_rate = proto.Field(proto.DOUBLE, number=16,) - label_class_weights = proto.MapField(proto.STRING, proto.DOUBLE, number=17,) - user_column = proto.Field(proto.STRING, number=18,) - item_column = proto.Field(proto.STRING, number=19,) - distance_type = proto.Field( - proto.ENUM, number=20, enum="Model.DistanceType", - ) - num_clusters = proto.Field(proto.INT64, number=21,) - model_uri = proto.Field(proto.STRING, number=22,) - optimization_strategy = proto.Field( - proto.ENUM, number=23, enum="Model.OptimizationStrategy", - ) - hidden_units = proto.RepeatedField(proto.INT64, number=24,) - batch_size = proto.Field(proto.INT64, number=25,) - dropout = proto.Field( - proto.MESSAGE, number=26, message=wrappers_pb2.DoubleValue, - ) - max_tree_depth = proto.Field(proto.INT64, number=27,) - subsample = proto.Field(proto.DOUBLE, number=28,) - min_split_loss = proto.Field( - proto.MESSAGE, number=29, message=wrappers_pb2.DoubleValue, - ) - num_factors = proto.Field(proto.INT64, number=30,) - feedback_type = proto.Field( - proto.ENUM, number=31, enum="Model.FeedbackType", - ) - wals_alpha = proto.Field( - proto.MESSAGE, number=32, message=wrappers_pb2.DoubleValue, - ) - kmeans_initialization_method = proto.Field( - proto.ENUM, - number=33, - enum="Model.KmeansEnums.KmeansInitializationMethod", - ) - kmeans_initialization_column = proto.Field(proto.STRING, number=34,) - time_series_timestamp_column = proto.Field(proto.STRING, number=35,) - time_series_data_column = proto.Field(proto.STRING, number=36,) - auto_arima = proto.Field(proto.BOOL, number=37,) - non_seasonal_order = proto.Field( - proto.MESSAGE, number=38, message="Model.ArimaOrder", - ) - data_frequency = proto.Field( - proto.ENUM, number=39, enum="Model.DataFrequency", - ) - include_drift = proto.Field(proto.BOOL, number=41,) - holiday_region = proto.Field( - proto.ENUM, number=42, enum="Model.HolidayRegion", - ) - time_series_id_column = proto.Field(proto.STRING, number=43,) - time_series_id_columns = proto.RepeatedField(proto.STRING, number=51,) - horizon = proto.Field(proto.INT64, number=44,) - preserve_input_structs = proto.Field(proto.BOOL, number=45,) - auto_arima_max_order = proto.Field(proto.INT64, number=46,) - decompose_time_series = proto.Field( - proto.MESSAGE, number=50, message=wrappers_pb2.BoolValue, - ) - clean_spikes_and_dips = proto.Field( - proto.MESSAGE, number=52, message=wrappers_pb2.BoolValue, - ) - adjust_step_changes = proto.Field( - proto.MESSAGE, number=53, message=wrappers_pb2.BoolValue, - ) - - class IterationResult(proto.Message): - r"""Information about a single iteration of the training run. - Attributes: - index (google.protobuf.wrappers_pb2.Int32Value): - Index of the iteration, 0 based. - duration_ms (google.protobuf.wrappers_pb2.Int64Value): - Time taken to run the iteration in - milliseconds. - training_loss (google.protobuf.wrappers_pb2.DoubleValue): - Loss computed on the training data at the end - of iteration. - eval_loss (google.protobuf.wrappers_pb2.DoubleValue): - Loss computed on the eval data at the end of - iteration. - learn_rate (float): - Learn rate used for this iteration. - cluster_infos (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ClusterInfo]): - Information about top clusters for clustering - models. - arima_result (google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ArimaResult): - - """ - - class ClusterInfo(proto.Message): - r"""Information about a single cluster for clustering model. - Attributes: - centroid_id (int): - Centroid id. - cluster_radius (google.protobuf.wrappers_pb2.DoubleValue): - Cluster radius, the average distance from - centroid to each point assigned to the cluster. - cluster_size (google.protobuf.wrappers_pb2.Int64Value): - Cluster size, the total number of points - assigned to the cluster. - """ - - centroid_id = proto.Field(proto.INT64, number=1,) - cluster_radius = proto.Field( - proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, - ) - cluster_size = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.Int64Value, - ) - - class ArimaResult(proto.Message): - r"""(Auto-)arima fitting result. Wrap everything in ArimaResult - for easier refactoring if we want to use model-specific - iteration results. - - Attributes: - arima_model_info (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ArimaResult.ArimaModelInfo]): - This message is repeated because there are - multiple arima models fitted in auto-arima. For - non-auto-arima model, its size is one. - seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): - Seasonal periods. Repeated because multiple - periods are supported for one time series. - """ - - class ArimaCoefficients(proto.Message): - r"""Arima coefficients. - Attributes: - auto_regressive_coefficients (Sequence[float]): - Auto-regressive coefficients, an array of - double. - moving_average_coefficients (Sequence[float]): - Moving-average coefficients, an array of - double. - intercept_coefficient (float): - Intercept coefficient, just a double not an - array. - """ - - auto_regressive_coefficients = proto.RepeatedField( - proto.DOUBLE, number=1, - ) - moving_average_coefficients = proto.RepeatedField( - proto.DOUBLE, number=2, - ) - intercept_coefficient = proto.Field(proto.DOUBLE, number=3,) - - class ArimaModelInfo(proto.Message): - r"""Arima model information. - Attributes: - non_seasonal_order (google.cloud.bigquery_v2.types.Model.ArimaOrder): - Non-seasonal order. - arima_coefficients (google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ArimaResult.ArimaCoefficients): - Arima coefficients. - arima_fitting_metrics (google.cloud.bigquery_v2.types.Model.ArimaFittingMetrics): - Arima fitting metrics. - has_drift (bool): - Whether Arima model fitted with drift or not. - It is always false when d is not 1. - time_series_id (str): - The time_series_id value for this time series. It will be - one of the unique values from the time_series_id_column - specified during ARIMA model training. Only present when - time_series_id_column training option was used. - time_series_ids (Sequence[str]): - The tuple of time_series_ids identifying this time series. - It will be one of the unique tuples of values present in the - time_series_id_columns specified during ARIMA model - training. Only present when time_series_id_columns training - option was used and the order of values here are same as the - order of time_series_id_columns. - seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): - Seasonal periods. Repeated because multiple - periods are supported for one time series. - has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue): - If true, holiday_effect is a part of time series - decomposition result. - has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue): - If true, spikes_and_dips is a part of time series - decomposition result. - has_step_changes (google.protobuf.wrappers_pb2.BoolValue): - If true, step_changes is a part of time series decomposition - result. - """ - - non_seasonal_order = proto.Field( - proto.MESSAGE, number=1, message="Model.ArimaOrder", - ) - arima_coefficients = proto.Field( - proto.MESSAGE, - number=2, - message="Model.TrainingRun.IterationResult.ArimaResult.ArimaCoefficients", - ) - arima_fitting_metrics = proto.Field( - proto.MESSAGE, number=3, message="Model.ArimaFittingMetrics", - ) - has_drift = proto.Field(proto.BOOL, number=4,) - time_series_id = proto.Field(proto.STRING, number=5,) - time_series_ids = proto.RepeatedField(proto.STRING, number=10,) - seasonal_periods = proto.RepeatedField( - proto.ENUM, - number=6, - enum="Model.SeasonalPeriod.SeasonalPeriodType", - ) - has_holiday_effect = proto.Field( - proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue, - ) - has_spikes_and_dips = proto.Field( - proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue, - ) - has_step_changes = proto.Field( - proto.MESSAGE, number=9, message=wrappers_pb2.BoolValue, - ) - - arima_model_info = proto.RepeatedField( - proto.MESSAGE, - number=1, - message="Model.TrainingRun.IterationResult.ArimaResult.ArimaModelInfo", - ) - seasonal_periods = proto.RepeatedField( - proto.ENUM, - number=2, - enum="Model.SeasonalPeriod.SeasonalPeriodType", - ) - - index = proto.Field( - proto.MESSAGE, number=1, message=wrappers_pb2.Int32Value, - ) - duration_ms = proto.Field( - proto.MESSAGE, number=4, message=wrappers_pb2.Int64Value, - ) - training_loss = proto.Field( - proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, - ) - eval_loss = proto.Field( - proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, - ) - learn_rate = proto.Field(proto.DOUBLE, number=7,) - cluster_infos = proto.RepeatedField( - proto.MESSAGE, - number=8, - message="Model.TrainingRun.IterationResult.ClusterInfo", - ) - arima_result = proto.Field( - proto.MESSAGE, - number=9, - message="Model.TrainingRun.IterationResult.ArimaResult", - ) - - training_options = proto.Field( - proto.MESSAGE, number=1, message="Model.TrainingRun.TrainingOptions", - ) - start_time = proto.Field( - proto.MESSAGE, number=8, message=timestamp_pb2.Timestamp, - ) - results = proto.RepeatedField( - proto.MESSAGE, number=6, message="Model.TrainingRun.IterationResult", - ) - evaluation_metrics = proto.Field( - proto.MESSAGE, number=7, message="Model.EvaluationMetrics", - ) - data_split_result = proto.Field( - proto.MESSAGE, number=9, message="Model.DataSplitResult", - ) - global_explanations = proto.RepeatedField( - proto.MESSAGE, number=10, message="Model.GlobalExplanation", - ) - - etag = proto.Field(proto.STRING, number=1,) - model_reference = proto.Field( - proto.MESSAGE, number=2, message=gcb_model_reference.ModelReference, - ) - creation_time = proto.Field(proto.INT64, number=5,) - last_modified_time = proto.Field(proto.INT64, number=6,) - description = proto.Field(proto.STRING, number=12,) - friendly_name = proto.Field(proto.STRING, number=14,) - labels = proto.MapField(proto.STRING, proto.STRING, number=15,) - expiration_time = proto.Field(proto.INT64, number=16,) - location = proto.Field(proto.STRING, number=13,) - encryption_configuration = proto.Field( - proto.MESSAGE, number=17, message=encryption_config.EncryptionConfiguration, - ) - model_type = proto.Field(proto.ENUM, number=7, enum=ModelType,) - training_runs = proto.RepeatedField(proto.MESSAGE, number=9, message=TrainingRun,) - feature_columns = proto.RepeatedField( - proto.MESSAGE, number=10, message=standard_sql.StandardSqlField, - ) - label_columns = proto.RepeatedField( - proto.MESSAGE, number=11, message=standard_sql.StandardSqlField, - ) - best_trial_id = proto.Field(proto.INT64, number=19,) - - -class GetModelRequest(proto.Message): - r""" - Attributes: - project_id (str): - Required. Project ID of the requested model. - dataset_id (str): - Required. Dataset ID of the requested model. - model_id (str): - Required. Model ID of the requested model. - """ - - project_id = proto.Field(proto.STRING, number=1,) - dataset_id = proto.Field(proto.STRING, number=2,) - model_id = proto.Field(proto.STRING, number=3,) - - -class PatchModelRequest(proto.Message): - r""" - Attributes: - project_id (str): - Required. Project ID of the model to patch. - dataset_id (str): - Required. Dataset ID of the model to patch. - model_id (str): - Required. Model ID of the model to patch. - model (google.cloud.bigquery_v2.types.Model): - Required. Patched model. - Follows RFC5789 patch semantics. Missing fields - are not updated. To clear a field, explicitly - set to default value. - """ - - project_id = proto.Field(proto.STRING, number=1,) - dataset_id = proto.Field(proto.STRING, number=2,) - model_id = proto.Field(proto.STRING, number=3,) - model = proto.Field(proto.MESSAGE, number=4, message="Model",) - - -class DeleteModelRequest(proto.Message): - r""" - Attributes: - project_id (str): - Required. Project ID of the model to delete. - dataset_id (str): - Required. Dataset ID of the model to delete. - model_id (str): - Required. Model ID of the model to delete. - """ - - project_id = proto.Field(proto.STRING, number=1,) - dataset_id = proto.Field(proto.STRING, number=2,) - model_id = proto.Field(proto.STRING, number=3,) - - -class ListModelsRequest(proto.Message): - r""" - Attributes: - project_id (str): - Required. Project ID of the models to list. - dataset_id (str): - Required. Dataset ID of the models to list. - max_results (google.protobuf.wrappers_pb2.UInt32Value): - The maximum number of results to return in a - single response page. Leverage the page tokens - to iterate through the entire collection. - page_token (str): - Page token, returned by a previous call to - request the next page of results - """ - - project_id = proto.Field(proto.STRING, number=1,) - dataset_id = proto.Field(proto.STRING, number=2,) - max_results = proto.Field( - proto.MESSAGE, number=3, message=wrappers_pb2.UInt32Value, - ) - page_token = proto.Field(proto.STRING, number=4,) - - -class ListModelsResponse(proto.Message): - r""" - Attributes: - models (Sequence[google.cloud.bigquery_v2.types.Model]): - Models in the requested dataset. Only the following fields - are populated: model_reference, model_type, creation_time, - last_modified_time and labels. - next_page_token (str): - A token to request the next page of results. - """ - - @property - def raw_page(self): - return self - - models = proto.RepeatedField(proto.MESSAGE, number=1, message="Model",) - next_page_token = proto.Field(proto.STRING, number=2,) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/bigquery_v2/types/table_reference.py b/google/cloud/bigquery_v2/types/table_reference.py deleted file mode 100644 index d56e5b09f..000000000 --- a/google/cloud/bigquery_v2/types/table_reference.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import proto # type: ignore - - -__protobuf__ = proto.module( - package="google.cloud.bigquery.v2", manifest={"TableReference",}, -) - - -class TableReference(proto.Message): - r""" - Attributes: - project_id (str): - Required. The ID of the project containing - this table. - dataset_id (str): - Required. The ID of the dataset containing - this table. - table_id (str): - Required. The ID of the table. The ID must contain only - letters (a-z, A-Z), numbers (0-9), or underscores (_). The - maximum length is 1,024 characters. Certain operations allow - suffixing of the table ID with a partition decorator, such - as ``sample_table$20190123``. - project_id_alternative (Sequence[str]): - The alternative field that will be used when ESF is not able - to translate the received data to the project_id field. - dataset_id_alternative (Sequence[str]): - The alternative field that will be used when ESF is not able - to translate the received data to the project_id field. - table_id_alternative (Sequence[str]): - The alternative field that will be used when ESF is not able - to translate the received data to the project_id field. - """ - - project_id = proto.Field(proto.STRING, number=1,) - dataset_id = proto.Field(proto.STRING, number=2,) - table_id = proto.Field(proto.STRING, number=3,) - project_id_alternative = proto.RepeatedField(proto.STRING, number=4,) - dataset_id_alternative = proto.RepeatedField(proto.STRING, number=5,) - table_id_alternative = proto.RepeatedField(proto.STRING, number=6,) - - -__all__ = tuple(sorted(__protobuf__.manifest)) From c963e25f7d4d0abb4d424ab9ddfb0a7eb3293297 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 6 Oct 2021 14:11:24 -0500 Subject: [PATCH 04/24] insert query with jobs.query --- google/cloud/bigquery/_job_helpers.py | 109 ++++++++++++++++++-------- google/cloud/bigquery/client.py | 5 +- tests/system/test_client.py | 58 -------------- tests/system/test_query.py | 75 ++++++++++++++++++ 4 files changed, 154 insertions(+), 93 deletions(-) create mode 100644 tests/system/test_query.py diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index f0cba9531..b02d054e1 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -16,7 +16,7 @@ import copy import uuid -from typing import TYPE_CHECKING +from typing import Any, Dict, TYPE_CHECKING, Optional import google.api_core.exceptions as core_exceptions from google.api_core import retry as retries @@ -24,16 +24,14 @@ from google.cloud.bigquery import job # Avoid circular imports -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER from google.cloud.bigquery.client import Client -else: - Client = None _TIMEOUT_BUFFER_SECS = 0.1 -def make_job_id(job_id, prefix=None): +def make_job_id(job_id: Optional[str] = None, prefix: Optional[str] = None) -> str: """Construct an ID for a new job. Args: @@ -53,7 +51,7 @@ def make_job_id(job_id, prefix=None): def query_jobs_insert( - client: Client, + client: "Client", query: str, job_config: job.QueryJobConfig, job_id: str, @@ -61,9 +59,13 @@ def query_jobs_insert( location: str, project: str, retry: retries.Retry, - timeout: float, + timeout: Optional[float], job_retry: retries.Retry, ): + """Initiate a query using jobs.insert. + + See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert + """ job_id_given = job_id is not None job_id_save = job_id job_config_save = job_config @@ -112,43 +114,83 @@ def do_query(): return future +def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any]: + """Transform from Job resource to QueryRequest resource. + + Most of the keys in job.configuration.query are in common with + QueryRequest. If any configuration property is set that is not available in + jobs.query, it will result in a server-side error. + """ + request_body = {} + job_config_resource = job_config.to_api_repr() if job_config else {} + query_config_resource = job_config_resource.get("configuration", {}).get( + "query", {} + ) + + request_body.update(query_config_resource) + + # These keys are top level in job resource and query resource. + if "labels" in job_config_resource: + request_body["labels"] = job_config_resource["labels"] + if "dryRun" in job_config_resource: + request_body["dryRun"] = job_config_resource["dryRun"] + + # Default to standard SQL. + request_body.setdefault("useLegacySql", False) + + return request_body + + +def _to_query_job( + client: "Client", query: str, query_response: Dict[str, Any] +) -> job.QueryJob: + # TODO: check for errors? + job_ref_resource = query_response["jobReference"] + job_ref = job._JobReference._from_api_repr(job_ref_resource) + query_job = job.QueryJob(job_ref, query, client=client) + + # Set errors if any were encountered. + query_job._properties.setdefault("status", {}) + if "errors" in query_response: + query_job._properties["status"]["errors"] = query_response["errors"] + query_job._properties["status"]["errorResult"] = query_response["errors"][0] + + # Transform job state so that QueryJob doesn't try to restart the query. + job_complete = query_response.get("jobComplete") + if job_complete: + query_job._properties["status"]["state"] = "DONE" + # TODO: set first page of results if job is "complete" + else: + query_job._properties["status"]["state"] = "PENDING" + + return query_job + + def query_jobs_query( - client: Client, + client: "Client", query: str, - job_config: job.QueryJobConfig, + job_config: Optional[job.QueryJobConfig], location: str, project: str, retry: retries.Retry, - timeout: float, + timeout: Optional[float], job_retry: retries.Retry, ): - # TODO: Validate that destination is not set. + """Initiate a query using jobs.query. - request_body = {} - job_config_resource = job_config.to_api_repr() + See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query + """ + path = f"/projects/{project}/queries" + request_body = _to_query_request(job_config) - # Transform from Job resource to QueryRequest resource. - # Most of the keys in job.configuration.query are in common - request_body.update(job_config_resource["configuration"]["query"]) + if timeout is not None: + # Subtract a buffer for context switching, network latency, etc. + request_body["timeoutMs"] = max(0, int(1000 * (timeout - _TIMEOUT_BUFFER_SECS))) request_body["location"] = location - request_body["labels"] = job_config.labels - request_body["dryRun"] = job_config.dry_run - - # Subtract a buffer for context switching, network latency, etc. - request_body["timeoutMs"] = max(0, int(1000 * (timeout - _TIMEOUT_BUFFER_SECS))) + request_body["query"] = query def do_query(): - request_body["requestId"] = make_job_id(None) - # job_ref = job._JobReference(job_id, project=project, location=location) - # query_job = job.QueryJob(job_ref, query, client=client, job_config=job_config) - - # query_job._begin(retry=retry, timeout=timeout) - client._call_api(retry) - - path = f"/projects/{project}/queries" - - # jobs.insert is idempotent because we ensure that every new - # job has an ID. + request_body["requestId"] = make_job_id() span_attributes = {"path": path} api_response = client._call_api( retry, @@ -159,8 +201,7 @@ def do_query(): data=request_body, timeout=timeout, ) - # TODO: make query job out of api_response - return api_response + return _to_query_job(client, query, api_response) future = do_query() diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index ca635cbf8..ed2527654 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -3254,7 +3254,10 @@ def query( # Note that we haven't modified the original job_config (or # _default_query_job_config) up to this point. if api_method == "query": - return None # TODO + # TODO: error if job_id or job_id_prefix set + return _job_helpers.query_jobs_query( + self, query, job_config, location, project, retry, timeout, job_retry, + ) else: return _job_helpers.query_jobs_insert( self, diff --git a/tests/system/test_client.py b/tests/system/test_client.py index 4884112ac..a38926f5f 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -700,64 +700,6 @@ def _fetch_single_page(table, selected_fields=None): page = next(iterator.pages) return list(page) - def _create_table_many_columns(self, rowcount): - # Generate a table of maximum width via CREATE TABLE AS SELECT. - # first column is named 'rowval', and has a value from 1..rowcount - # Subsequent column is named col_ and contains the value N*rowval, - # where N is between 1 and 9999 inclusive. - dsname = _make_dataset_id("wide_schema") - dataset = self.temp_dataset(dsname) - table_id = "many_columns" - table_ref = dataset.table(table_id) - self.to_delete.insert(0, table_ref) - colprojections = ",".join( - ["r * {} as col_{}".format(n, n) for n in range(1, 10000)] - ) - sql = """ - CREATE TABLE {}.{} - AS - SELECT - r as rowval, - {} - FROM - UNNEST(GENERATE_ARRAY(1,{},1)) as r - """.format( - dsname, table_id, colprojections, rowcount - ) - query_job = Config.CLIENT.query(sql) - query_job.result() - self.assertEqual(query_job.statement_type, "CREATE_TABLE_AS_SELECT") - self.assertEqual(query_job.ddl_operation_performed, "CREATE") - self.assertEqual(query_job.ddl_target_table, table_ref) - - return table_ref - - def test_query_many_columns(self): - # Test working with the widest schema BigQuery supports, 10k columns. - row_count = 2 - table_ref = self._create_table_many_columns(row_count) - rows = list( - Config.CLIENT.query( - "SELECT * FROM `{}.{}`".format(table_ref.dataset_id, table_ref.table_id) - ) - ) - - self.assertEqual(len(rows), row_count) - - # check field representations adhere to expected values. - correctwidth = 0 - badvals = 0 - for r in rows: - vals = r._xxx_values - rowval = vals[0] - if len(vals) == 10000: - correctwidth = correctwidth + 1 - for n in range(1, 10000): - if vals[n] != rowval * (n): - badvals = badvals + 1 - self.assertEqual(correctwidth, row_count) - self.assertEqual(badvals, 0) - def test_insert_rows_then_dump_table(self): NOW_SECONDS = 1448911495.484366 NOW = datetime.datetime.utcfromtimestamp(NOW_SECONDS).replace(tzinfo=UTC) diff --git a/tests/system/test_query.py b/tests/system/test_query.py new file mode 100644 index 000000000..91a91b7e3 --- /dev/null +++ b/tests/system/test_query.py @@ -0,0 +1,75 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture( + params=[ + # None, + # "insert", + "query", + ] +) +def query_api_method(request): + return request.param + + +@pytest.fixture(scope="session") +def table_with_9999_columns_10_rows(bigquery_client, project_id, dataset_id): + """Generate a table of maximum width via CREATE TABLE AS SELECT. + + The first column is named 'rowval', and has a value from 1..rowcount + Subsequent columns are named col_ and contain the value N*rowval, where + N is between 1 and 9999 inclusive. + """ + table_id = "many_columns" + row_count = 10 + col_projections = ",".join([f"r * {n} as col_{n}" for n in range(1, 10000)]) + sql = f""" + CREATE TABLE `{project_id}.{dataset_id}.{table_id}` + AS + SELECT + r as rowval, + {col_projections} + FROM + UNNEST(GENERATE_ARRAY(1,{row_count},1)) as r + """ + query_job = bigquery_client.query(sql) + query_job.result() + + return f"{project_id}.{dataset_id}.{table_id}" + + +def test_query_many_columns( + bigquery_client, table_with_9999_columns_10_rows, query_api_method +): + # Test working with the widest schema BigQuery supports, 10k columns. + if query_api_method is not None: + query_job = bigquery_client.query( + f"SELECT * FROM `{table_with_9999_columns_10_rows}`", + api_method=query_api_method, + ) + else: + query_job = bigquery_client.query( + f"SELECT * FROM `{table_with_9999_columns_10_rows}`" + ) + rows = list(query_job) + assert len(rows) == 10 + + # check field representations adhere to expected values. + for row in rows: + rowval = row["rowval"] + for column in range(1, 10000): + assert row[f"col_{column}"] == rowval * column From 50933786417ce42b9b48b73b3bf22ed833cc02a3 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 6 Oct 2021 15:46:35 -0500 Subject: [PATCH 05/24] fix merge between job config and query request --- google/cloud/bigquery/_job_helpers.py | 5 +- google/cloud/bigquery/client.py | 34 +-- google/cloud/bigquery/enums.py | 28 ++ tests/system/test_client.py | 303 ---------------------- tests/system/test_query.py | 358 ++++++++++++++++++++++++-- tests/unit/test__job_helpers.py | 45 ++++ 6 files changed, 425 insertions(+), 348 deletions(-) create mode 100644 tests/unit/test__job_helpers.py diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index b02d054e1..daad01529 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -123,9 +123,7 @@ def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any """ request_body = {} job_config_resource = job_config.to_api_repr() if job_config else {} - query_config_resource = job_config_resource.get("configuration", {}).get( - "query", {} - ) + query_config_resource = job_config_resource.get("query", {}) request_body.update(query_config_resource) @@ -144,7 +142,6 @@ def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any def _to_query_job( client: "Client", query: str, query_response: Dict[str, Any] ) -> job.QueryJob: - # TODO: check for errors? job_ref_resource = query_response["jobReference"] job_ref = job._JobReference._from_api_repr(job_ref_resource) query_job = job.QueryJob(job_ref, query, client=client) diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index ed2527654..e5ee2e742 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -61,6 +61,7 @@ from google.cloud.bigquery.dataset import Dataset from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.dataset import DatasetReference +from google.cloud.bigquery import enums from google.cloud.bigquery.enums import AutoRowIDs from google.cloud.bigquery.opentelemetry_tracing import create_span from google.cloud.bigquery import job @@ -3124,7 +3125,7 @@ def query( retry: retries.Retry = DEFAULT_RETRY, timeout: float = DEFAULT_TIMEOUT, job_retry: retries.Retry = DEFAULT_JOB_RETRY, - api_method: str = "insert", + api_method: enums.QueryApiMethod = enums.QueryApiMethod.INSERT, ) -> job.QueryJob: """Run a SQL query. @@ -3177,19 +3178,7 @@ def query( specified here becomes the default ``job_retry`` for ``result()``, where it can also be specified. api_method: - One of ``'insert'`` or ``'query'``. Defaults to ``'insert'``. - - When set to ``'insert'``, submit a query job by using the - `jobs.insert REST API method - _`. - This supports all job configuration options. - - When set to ``'query'``, submit a query job by using the - `jobs.query REST API method - `_. - This API waits up to the specified timeout for the query to - finish. The ``job_id`` and ``job_id_prefix`` parameters cannot - be used with this API method. + Method with which to start the query job. Returns: google.cloud.bigquery.job.QueryJob: A new query job instance. @@ -3213,15 +3202,9 @@ def query( " provided." ) - if api_method not in {"insert", "query"}: - raise ValueError( - f"Got unexpected value for api_method: {repr(api_method)}" - " Expected one of {'insert', 'query'}." - ) - - if job_id_given and api_method == "query": + if job_id_given and api_method == enums.QueryApiMethod.QUERY: raise TypeError( - "`job_id` was provided, but the 'query' `api_method` was requested." + "`job_id` was provided, but the 'QUERY' `api_method` was requested." ) if project is None: @@ -3253,12 +3236,11 @@ def query( # Note that we haven't modified the original job_config (or # _default_query_job_config) up to this point. - if api_method == "query": - # TODO: error if job_id or job_id_prefix set + if api_method == enums.QueryApiMethod.QUERY: return _job_helpers.query_jobs_query( self, query, job_config, location, project, retry, timeout, job_retry, ) - else: + elif api_method == enums.QueryApiMethod.INSERT: return _job_helpers.query_jobs_insert( self, query, @@ -3271,6 +3253,8 @@ def query( timeout, job_retry, ) + else: + raise ValueError(f"Got unexpected value for api_method: {repr(api_method)}") def insert_rows( self, diff --git a/google/cloud/bigquery/enums.py b/google/cloud/bigquery/enums.py index cecdaa503..39d880c2b 100644 --- a/google/cloud/bigquery/enums.py +++ b/google/cloud/bigquery/enums.py @@ -122,6 +122,34 @@ class QueryPriority(object): """Specifies batch priority.""" +class QueryApiMethod(str, enum.Enum): + """API method used to start the query. The default value is + :attr:`INSERT`. + """ + + INSERT = "INSERT" + """Submit a query job by using the `jobs.insert REST API method + _`. + + This supports all job configuration options. + """ + + QUERY = "QUERY" + """Submit a query job by using the `jobs.query REST API method + `_. + + This API blocks for up to a specified timeout for the query to finish. The + full job resource (including job statistics) may not be available if the + query finishes within the timeout. Call + :meth:`~google.cloud.bigquery.job.QueryJob.reload` or + :meth:`~google.cloud.bigquery.client.Client.get_job` to get full job + statistics. + + Many parameters, including destination table and job ID cannot be used with + this API method. + """ + + class SchemaUpdateOption(object): """Specifies an update to the destination table schema as a side effect of a load job. diff --git a/tests/system/test_client.py b/tests/system/test_client.py index a38926f5f..79db6cff5 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -13,7 +13,6 @@ # limitations under the License. import base64 -import concurrent.futures import csv import datetime import decimal @@ -1318,25 +1317,6 @@ def test_query_w_wrong_config(self): with self.assertRaises(Exception): Config.CLIENT.query(good_query, job_config=bad_config).result() - def test_query_w_timeout(self): - job_config = bigquery.QueryJobConfig() - job_config.use_query_cache = False - - query_job = Config.CLIENT.query( - "SELECT * FROM `bigquery-public-data.github_repos.commits`;", - job_id_prefix="test_query_w_timeout_", - location="US", - job_config=job_config, - ) - - with self.assertRaises(concurrent.futures.TimeoutError): - query_job.result(timeout=1) - - # Even though the query takes >1 second, the call to getQueryResults - # should succeed. - self.assertFalse(query_job.done(timeout=1)) - self.assertIsNotNone(Config.CLIENT.cancel_job(query_job)) - def test_query_w_page_size(self): page_size = 45 query_job = Config.CLIENT.query( @@ -1358,83 +1338,6 @@ def test_query_w_start_index(self): self.assertEqual(result1.extra_params["startIndex"], start_index) self.assertEqual(len(list(result1)), total_rows - start_index) - def test_query_statistics(self): - """ - A system test to exercise some of the extended query statistics. - - Note: We construct a query that should need at least three stages by - specifying a JOIN query. Exact plan and stats are effectively - non-deterministic, so we're largely interested in confirming values - are present. - """ - - job_config = bigquery.QueryJobConfig() - job_config.use_query_cache = False - - query_job = Config.CLIENT.query( - """ - SELECT - COUNT(1) - FROM - ( - SELECT - year, - wban_number - FROM `bigquery-public-data.samples.gsod` - LIMIT 1000 - ) lside - INNER JOIN - ( - SELECT - year, - state - FROM `bigquery-public-data.samples.natality` - LIMIT 1000 - ) rside - ON - lside.year = rside.year - """, - location="US", - job_config=job_config, - ) - - # run the job to completion - query_job.result() - - # Assert top-level stats - self.assertFalse(query_job.cache_hit) - self.assertIsNotNone(query_job.destination) - self.assertTrue(query_job.done) - self.assertFalse(query_job.dry_run) - self.assertIsNone(query_job.num_dml_affected_rows) - self.assertEqual(query_job.priority, "INTERACTIVE") - self.assertGreater(query_job.total_bytes_billed, 1) - self.assertGreater(query_job.total_bytes_processed, 1) - self.assertEqual(query_job.statement_type, "SELECT") - self.assertGreater(query_job.slot_millis, 1) - - # Make assertions on the shape of the query plan. - plan = query_job.query_plan - self.assertGreaterEqual(len(plan), 3) - first_stage = plan[0] - self.assertIsNotNone(first_stage.start) - self.assertIsNotNone(first_stage.end) - self.assertIsNotNone(first_stage.entry_id) - self.assertIsNotNone(first_stage.name) - self.assertGreater(first_stage.parallel_inputs, 0) - self.assertGreater(first_stage.completed_parallel_inputs, 0) - self.assertGreater(first_stage.shuffle_output_bytes, 0) - self.assertEqual(first_stage.status, "COMPLETE") - - # Query plan is a digraph. Ensure it has inter-stage links, - # but not every stage has inputs. - stages_with_inputs = 0 - for entry in plan: - if len(entry.input_stages) > 0: - stages_with_inputs = stages_with_inputs + 1 - self.assertGreater(stages_with_inputs, 0) - self.assertGreater(len(plan), stages_with_inputs) - def test_dml_statistics(self): table_schema = ( bigquery.SchemaField("foo", "STRING"), @@ -1724,212 +1627,6 @@ def test_dbapi_w_dml(self): ) self.assertEqual(Config.CURSOR.rowcount, 1) - def test_query_w_query_params(self): - from google.cloud.bigquery.job import QueryJobConfig - from google.cloud.bigquery.query import ArrayQueryParameter - from google.cloud.bigquery.query import ScalarQueryParameter - from google.cloud.bigquery.query import ScalarQueryParameterType - from google.cloud.bigquery.query import StructQueryParameter - from google.cloud.bigquery.query import StructQueryParameterType - - question = "What is the answer to life, the universe, and everything?" - question_param = ScalarQueryParameter( - name="question", type_="STRING", value=question - ) - answer = 42 - answer_param = ScalarQueryParameter(name="answer", type_="INT64", value=answer) - pi = 3.1415926 - pi_param = ScalarQueryParameter(name="pi", type_="FLOAT64", value=pi) - pi_numeric = decimal.Decimal("3.141592654") - pi_numeric_param = ScalarQueryParameter( - name="pi_numeric_param", type_="NUMERIC", value=pi_numeric - ) - bignum = decimal.Decimal("-{d38}.{d38}".format(d38="9" * 38)) - bignum_param = ScalarQueryParameter( - name="bignum_param", type_="BIGNUMERIC", value=bignum - ) - truthy = True - truthy_param = ScalarQueryParameter(name="truthy", type_="BOOL", value=truthy) - beef = b"DEADBEEF" - beef_param = ScalarQueryParameter(name="beef", type_="BYTES", value=beef) - naive = datetime.datetime(2016, 12, 5, 12, 41, 9) - naive_param = ScalarQueryParameter(name="naive", type_="DATETIME", value=naive) - naive_date_param = ScalarQueryParameter( - name="naive_date", type_="DATE", value=naive.date() - ) - naive_time_param = ScalarQueryParameter( - name="naive_time", type_="TIME", value=naive.time() - ) - zoned = naive.replace(tzinfo=UTC) - zoned_param = ScalarQueryParameter(name="zoned", type_="TIMESTAMP", value=zoned) - array_param = ArrayQueryParameter( - name="array_param", array_type="INT64", values=[1, 2] - ) - struct_param = StructQueryParameter("hitchhiker", question_param, answer_param) - phred_name = "Phred Phlyntstone" - phred_name_param = ScalarQueryParameter( - name="name", type_="STRING", value=phred_name - ) - phred_age = 32 - phred_age_param = ScalarQueryParameter( - name="age", type_="INT64", value=phred_age - ) - phred_param = StructQueryParameter(None, phred_name_param, phred_age_param) - bharney_name = "Bharney Rhubbyl" - bharney_name_param = ScalarQueryParameter( - name="name", type_="STRING", value=bharney_name - ) - bharney_age = 31 - bharney_age_param = ScalarQueryParameter( - name="age", type_="INT64", value=bharney_age - ) - bharney_param = StructQueryParameter( - None, bharney_name_param, bharney_age_param - ) - characters_param = ArrayQueryParameter( - name=None, array_type="RECORD", values=[phred_param, bharney_param] - ) - empty_struct_array_param = ArrayQueryParameter( - name="empty_array_param", - values=[], - array_type=StructQueryParameterType( - ScalarQueryParameterType(name="foo", type_="INT64"), - ScalarQueryParameterType(name="bar", type_="STRING"), - ), - ) - hero_param = StructQueryParameter("hero", phred_name_param, phred_age_param) - sidekick_param = StructQueryParameter( - "sidekick", bharney_name_param, bharney_age_param - ) - roles_param = StructQueryParameter("roles", hero_param, sidekick_param) - friends_param = ArrayQueryParameter( - name="friends", array_type="STRING", values=[phred_name, bharney_name] - ) - with_friends_param = StructQueryParameter(None, friends_param) - top_left_param = StructQueryParameter( - "top_left", - ScalarQueryParameter("x", "INT64", 12), - ScalarQueryParameter("y", "INT64", 102), - ) - bottom_right_param = StructQueryParameter( - "bottom_right", - ScalarQueryParameter("x", "INT64", 22), - ScalarQueryParameter("y", "INT64", 92), - ) - rectangle_param = StructQueryParameter( - "rectangle", top_left_param, bottom_right_param - ) - examples = [ - { - "sql": "SELECT @question", - "expected": question, - "query_parameters": [question_param], - }, - { - "sql": "SELECT @answer", - "expected": answer, - "query_parameters": [answer_param], - }, - {"sql": "SELECT @pi", "expected": pi, "query_parameters": [pi_param]}, - { - "sql": "SELECT @pi_numeric_param", - "expected": pi_numeric, - "query_parameters": [pi_numeric_param], - }, - { - "sql": "SELECT @bignum_param", - "expected": bignum, - "query_parameters": [bignum_param], - }, - { - "sql": "SELECT @truthy", - "expected": truthy, - "query_parameters": [truthy_param], - }, - {"sql": "SELECT @beef", "expected": beef, "query_parameters": [beef_param]}, - { - "sql": "SELECT @naive", - "expected": naive, - "query_parameters": [naive_param], - }, - { - "sql": "SELECT @naive_date", - "expected": naive.date(), - "query_parameters": [naive_date_param], - }, - { - "sql": "SELECT @naive_time", - "expected": naive.time(), - "query_parameters": [naive_time_param], - }, - { - "sql": "SELECT @zoned", - "expected": zoned, - "query_parameters": [zoned_param], - }, - { - "sql": "SELECT @array_param", - "expected": [1, 2], - "query_parameters": [array_param], - }, - { - "sql": "SELECT (@hitchhiker.question, @hitchhiker.answer)", - "expected": ({"_field_1": question, "_field_2": answer}), - "query_parameters": [struct_param], - }, - { - "sql": "SELECT " - "((@rectangle.bottom_right.x - @rectangle.top_left.x) " - "* (@rectangle.top_left.y - @rectangle.bottom_right.y))", - "expected": 100, - "query_parameters": [rectangle_param], - }, - { - "sql": "SELECT ?", - "expected": [ - {"name": phred_name, "age": phred_age}, - {"name": bharney_name, "age": bharney_age}, - ], - "query_parameters": [characters_param], - }, - { - "sql": "SELECT @empty_array_param", - "expected": [], - "query_parameters": [empty_struct_array_param], - }, - { - "sql": "SELECT @roles", - "expected": { - "hero": {"name": phred_name, "age": phred_age}, - "sidekick": {"name": bharney_name, "age": bharney_age}, - }, - "query_parameters": [roles_param], - }, - { - "sql": "SELECT ?", - "expected": {"friends": [phred_name, bharney_name]}, - "query_parameters": [with_friends_param], - }, - { - "sql": "SELECT @bignum_param", - "expected": bignum, - "query_parameters": [bignum_param], - }, - ] - - for example in examples: - jconfig = QueryJobConfig() - jconfig.query_parameters = example["query_parameters"] - query_job = Config.CLIENT.query( - example["sql"], - job_config=jconfig, - job_id_prefix="test_query_w_query_params", - ) - rows = list(query_job.result()) - self.assertEqual(len(rows), 1) - self.assertEqual(len(rows[0]), 1) - self.assertEqual(rows[0][0], example["expected"]) - def test_dbapi_w_query_parameters(self): examples = [ { diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 91a91b7e3..07b7676fc 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures +import datetime +import decimal + import pytest +from google.cloud import bigquery +from google.cloud.bigquery.query import ScalarQueryParameter -@pytest.fixture( - params=[ - # None, - # "insert", - "query", - ] -) +# from google.cloud.bigquery.query import ArrayQueryParameter +# from google.cloud.bigquery.query import ScalarQueryParameterType +# from google.cloud.bigquery.query import StructQueryParameter +# from google.cloud.bigquery.query import StructQueryParameterType + + +@pytest.fixture(params=["INSERT", "QUERY"]) def query_api_method(request): return request.param @@ -56,15 +62,10 @@ def test_query_many_columns( bigquery_client, table_with_9999_columns_10_rows, query_api_method ): # Test working with the widest schema BigQuery supports, 10k columns. - if query_api_method is not None: - query_job = bigquery_client.query( - f"SELECT * FROM `{table_with_9999_columns_10_rows}`", - api_method=query_api_method, - ) - else: - query_job = bigquery_client.query( - f"SELECT * FROM `{table_with_9999_columns_10_rows}`" - ) + query_job = bigquery_client.query( + f"SELECT * FROM `{table_with_9999_columns_10_rows}`", + api_method=query_api_method, + ) rows = list(query_job) assert len(rows) == 10 @@ -73,3 +74,328 @@ def test_query_many_columns( rowval = row["rowval"] for column in range(1, 10000): assert row[f"col_{column}"] == rowval * column + + +def test_query_w_timeout(bigquery_client, query_api_method): + job_config = bigquery.QueryJobConfig() + job_config.use_query_cache = False + + query_job = bigquery_client.query( + "SELECT * FROM `bigquery-public-data.github_repos.commits`;", + location="US", + job_config=job_config, + api_method=query_api_method, + ) + + with pytest.raises(concurrent.futures.TimeoutError): + query_job.result(timeout=1) + + # Even though the query takes >1 second, the call to getQueryResults + # should succeed. + assert not query_job.done(timeout=1) + assert bigquery_client.cancel_job(query_job) is not None + + +def test_query_statistics(bigquery_client, query_api_method): + """ + A system test to exercise some of the extended query statistics. + + Note: We construct a query that should need at least three stages by + specifying a JOIN query. Exact plan and stats are effectively + non-deterministic, so we're largely interested in confirming values + are present. + """ + + job_config = bigquery.QueryJobConfig() + job_config.use_query_cache = False + + query_job = bigquery_client.query( + """ + SELECT + COUNT(1) + FROM + ( + SELECT + year, + wban_number + FROM `bigquery-public-data.samples.gsod` + LIMIT 1000 + ) lside + INNER JOIN + ( + SELECT + year, + state + FROM `bigquery-public-data.samples.natality` + LIMIT 1000 + ) rside + ON + lside.year = rside.year + """, + location="US", + job_config=job_config, + api_method=query_api_method, + ) + + # run the job to completion + query_job.result() + + # Must reload job to get stats if jobs.query was used. + if query_api_method == "QUERY": + query_job.reload() + + # Assert top-level stats + assert not query_job.cache_hit + assert query_job.destination is not None + assert query_job.done + assert not query_job.dry_run + assert query_job.num_dml_affected_rows is None + assert query_job.priority == "INTERACTIVE" + assert query_job.total_bytes_billed > 1 + assert query_job.total_bytes_processed > 1 + assert query_job.statement_type == "SELECT" + assert query_job.slot_millis > 1 + + # Make assertions on the shape of the query plan. + plan = query_job.query_plan + assert len(plan) >= 3 + first_stage = plan[0] + assert first_stage.start is not None + assert first_stage.end is not None + assert first_stage.entry_id is not None + assert first_stage.name is not None + assert first_stage.parallel_inputs > 0 + assert first_stage.completed_parallel_inputs > 0 + assert first_stage.shuffle_output_bytes > 0 + assert first_stage.status == "COMPLETE" + + # Query plan is a digraph. Ensure it has inter-stage links, + # but not every stage has inputs. + stages_with_inputs = 0 + for entry in plan: + if len(entry.input_stages) > 0: + stages_with_inputs = stages_with_inputs + 1 + assert stages_with_inputs > 0 + assert len(plan) > stages_with_inputs + + +@pytest.mark.parametrize( + ("sql", "expected", "query_parameters"), + ( + ( + "SELECT @question", + "What is the answer to life, the universe, and everything?", + [ + ScalarQueryParameter( + name="question", + type_="STRING", + value="What is the answer to life, the universe, and everything?", + ) + ], + ), + ( + "SELECT @answer", + 42, + [ScalarQueryParameter(name="answer", type_="INT64", value=42)], + ), + ( + "SELECT @pi", + 3.1415926, + [ScalarQueryParameter(name="pi", type_="FLOAT64", value=3.1415926)], + ), + ( + "SELECT @pi_numeric_param", + decimal.Decimal("3.141592654"), + [ + ScalarQueryParameter( + name="pi_numeric_param", + type_="NUMERIC", + value=decimal.Decimal("3.141592654"), + ) + ], + ), + ( + "SELECT @bignum_param", + decimal.Decimal("-{d38}.{d38}".format(d38="9" * 38)), + [ + ScalarQueryParameter( + name="bignum_param", + type_="BIGNUMERIC", + value=decimal.Decimal("-{d38}.{d38}".format(d38="9" * 38)), + ) + ], + ), + ( + "SELECT @truthy", + True, + [ScalarQueryParameter(name="truthy", type_="BOOL", value=True)], + ), + ( + "SELECT @beef", + b"DEADBEEF", + [ScalarQueryParameter(name="beef", type_="BYTES", value=b"DEADBEEF")], + ), + ( + "SELECT @naive", + datetime.datetime(2016, 12, 5, 12, 41, 9), + [ + ScalarQueryParameter( + name="naive", + type_="DATETIME", + value=datetime.datetime(2016, 12, 5, 12, 41, 9), + ) + ], + ), + ( + "SELECT @naive_date", + datetime.date(2016, 12, 5), + [ + ScalarQueryParameter( + name="naive_date", type_="DATE", value=datetime.date(2016, 12, 5) + ) + ], + ), + ( + "SELECT @naive_time", + datetime.time(12, 41, 9, 62500), + [ + ScalarQueryParameter( + name="naive_time", + type_="TIME", + value=datetime.time(12, 41, 9, 62500), + ) + ], + ), + ( + "SELECT @zoned", + datetime.datetime(2016, 12, 5, 12, 41, 9, tzinfo=datetime.timezone.utc), + [ + ScalarQueryParameter( + name="zoned", + type_="TIMESTAMP", + value=datetime.datetime( + 2016, 12, 5, 12, 41, 9, tzinfo=datetime.timezone.utc + ), + ) + ], + ), + # ( + # "SELECT @array_param", + # [1, 2], + # [array_param], + # ), + # ( + # "SELECT (@hitchhiker.question, @hitchhiker.answer)", + # ({"_field_1": question, "_field_2": answer}), + # [struct_param], + # ), + # ( + # "SELECT " + # "((@rectangle.bottom_right.x - @rectangle.top_left.x) " + # "* (@rectangle.top_left.y - @rectangle.bottom_right.y))", + # 100, + # [rectangle_param], + # ), + # ( + # "SELECT ?", + # [ + # {"name": phred_name, "age": phred_age}, + # {"name": bharney_name, "age": bharney_age}, + # ], + # [characters_param], + # ), + # ( + # "SELECT @empty_array_param", + # [], + # [empty_struct_array_param], + # ), + # ( + # "SELECT @roles", + # ( + # "hero": {"name": phred_name, "age": phred_age}, + # "sidekick": {"name": bharney_name, "age": bharney_age}, + # ), + # [roles_param], + # ), + # ( + # "SELECT ?", + # {"friends": [phred_name, bharney_name]}, + # [with_friends_param], + # ), + # ( + # "SELECT @bignum_param", + # bignum, + # [bignum_param], + # ), + ), +) +def test_query_parameters( + bigquery_client, query_api_method, sql, expected, query_parameters +): + # array_param = ArrayQueryParameter( + # name="array_param", array_type="INT64", values=[1, 2] + # ) + # struct_param = StructQueryParameter("hitchhiker", question_param, answer_param) + # phred_name = "Phred Phlyntstone" + # phred_name_param = ScalarQueryParameter( + # name="name", type_="STRING", value=phred_name + # ) + # phred_age = 32 + # phred_age_param = ScalarQueryParameter( + # name="age", type_="INT64", value=phred_age + # ) + # phred_param = StructQueryParameter(None, phred_name_param, phred_age_param) + # bharney_name = "Bharney Rhubbyl" + # bharney_name_param = ScalarQueryParameter( + # name="name", type_="STRING", value=bharney_name + # ) + # bharney_age = 31 + # bharney_age_param = ScalarQueryParameter( + # name="age", type_="INT64", value=bharney_age + # ) + # bharney_param = StructQueryParameter( + # None, bharney_name_param, bharney_age_param + # ) + # characters_param = ArrayQueryParameter( + # name=None, array_type="RECORD", values=[phred_param, bharney_param] + # ) + # empty_struct_array_param = ArrayQueryParameter( + # name="empty_array_param", + # values=[], + # array_type=StructQueryParameterType( + # ScalarQueryParameterType(name="foo", type_="INT64"), + # ScalarQueryParameterType(name="bar", type_="STRING"), + # ), + # ) + # hero_param = StructQueryParameter("hero", phred_name_param, phred_age_param) + # sidekick_param = StructQueryParameter( + # "sidekick", bharney_name_param, bharney_age_param + # ) + # roles_param = StructQueryParameter("roles", hero_param, sidekick_param) + # friends_param = ArrayQueryParameter( + # name="friends", array_type="STRING", values=[phred_name, bharney_name] + # ) + # with_friends_param = StructQueryParameter(None, friends_param) + # top_left_param = StructQueryParameter( + # "top_left", + # ScalarQueryParameter("x", "INT64", 12), + # ScalarQueryParameter("y", "INT64", 102), + # ) + # bottom_right_param = StructQueryParameter( + # "bottom_right", + # ScalarQueryParameter("x", "INT64", 22), + # ScalarQueryParameter("y", "INT64", 92), + # ) + # rectangle_param = StructQueryParameter( + # "rectangle", top_left_param, bottom_right_param + # ) + + jconfig = bigquery.QueryJobConfig() + jconfig.query_parameters = query_parameters + query_job = bigquery_client.query( + sql, job_config=jconfig, api_method=query_api_method, + ) + rows = list(query_job.result()) + assert len(rows) == 1 + assert len(rows[0]) == 1 + assert rows[0][0] == expected diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py new file mode 100644 index 000000000..35756fa90 --- /dev/null +++ b/tests/unit/test__job_helpers.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.cloud.bigquery.job.query import QueryJobConfig + + +@pytest.fixture +def module_under_test(): + from google.cloud.bigquery import _job_helpers + + return _job_helpers + + +@pytest.mark.parametrize( + ("job_config", "expected"), + ( + (None, {"useLegacySql": False}), + (QueryJobConfig(), {"useLegacySql": False}), + (QueryJobConfig(dry_run=True), {"useLegacySql": False, "dryRun": True}), + ( + QueryJobConfig(labels={"abc": "def"}), + {"useLegacySql": False, "labels": {"abc": "def"}}, + ), + ( + QueryJobConfig(use_query_cache=False), + {"useLegacySql": False, "useQueryCache": False}, + ), + ), +) +def test__to_query_request(module_under_test, job_config, expected): + result = module_under_test._to_query_request(job_config) + assert result == expected From e2e1c7ddb62fd29317fa1da4aaf17416fa00fe78 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 6 Oct 2021 16:23:13 -0500 Subject: [PATCH 06/24] add tests --- tests/system/test_query.py | 238 ++++++++++++++++++++----------------- 1 file changed, 128 insertions(+), 110 deletions(-) diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 07b7676fc..ccd583cbf 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -19,12 +19,11 @@ import pytest from google.cloud import bigquery +from google.cloud.bigquery.query import ArrayQueryParameter from google.cloud.bigquery.query import ScalarQueryParameter - -# from google.cloud.bigquery.query import ArrayQueryParameter -# from google.cloud.bigquery.query import ScalarQueryParameterType -# from google.cloud.bigquery.query import StructQueryParameter -# from google.cloud.bigquery.query import StructQueryParameterType +from google.cloud.bigquery.query import ScalarQueryParameterType +from google.cloud.bigquery.query import StructQueryParameter +from google.cloud.bigquery.query import StructQueryParameterType @pytest.fixture(params=["INSERT", "QUERY"]) @@ -279,116 +278,135 @@ def test_query_statistics(bigquery_client, query_api_method): ) ], ), - # ( - # "SELECT @array_param", - # [1, 2], - # [array_param], - # ), - # ( - # "SELECT (@hitchhiker.question, @hitchhiker.answer)", - # ({"_field_1": question, "_field_2": answer}), - # [struct_param], - # ), - # ( - # "SELECT " - # "((@rectangle.bottom_right.x - @rectangle.top_left.x) " - # "* (@rectangle.top_left.y - @rectangle.bottom_right.y))", - # 100, - # [rectangle_param], - # ), - # ( - # "SELECT ?", - # [ - # {"name": phred_name, "age": phred_age}, - # {"name": bharney_name, "age": bharney_age}, - # ], - # [characters_param], - # ), - # ( - # "SELECT @empty_array_param", - # [], - # [empty_struct_array_param], - # ), - # ( - # "SELECT @roles", - # ( - # "hero": {"name": phred_name, "age": phred_age}, - # "sidekick": {"name": bharney_name, "age": bharney_age}, - # ), - # [roles_param], - # ), - # ( - # "SELECT ?", - # {"friends": [phred_name, bharney_name]}, - # [with_friends_param], - # ), - # ( - # "SELECT @bignum_param", - # bignum, - # [bignum_param], - # ), + ( + "SELECT @array_param", + [1, 2], + [ + ArrayQueryParameter( + name="array_param", array_type="INT64", values=[1, 2] + ) + ], + ), + ( + "SELECT (@hitchhiker.question, @hitchhiker.answer)", + ({"_field_1": "What is the answer?", "_field_2": 42}), + [ + StructQueryParameter( + "hitchhiker", + ScalarQueryParameter( + name="question", type_="STRING", value="What is the answer?", + ), + ScalarQueryParameter(name="answer", type_="INT64", value=42,), + ), + ], + ), + ( + "SELECT " + "((@rectangle.bottom_right.x - @rectangle.top_left.x) " + "* (@rectangle.top_left.y - @rectangle.bottom_right.y))", + 100, + [ + StructQueryParameter( + "rectangle", + StructQueryParameter( + "top_left", + ScalarQueryParameter("x", "INT64", 12), + ScalarQueryParameter("y", "INT64", 102), + ), + StructQueryParameter( + "bottom_right", + ScalarQueryParameter("x", "INT64", 22), + ScalarQueryParameter("y", "INT64", 92), + ), + ) + ], + ), + ( + "SELECT ?", + [ + {"name": "Phred Phlyntstone", "age": 32}, + {"name": "Bharney Rhubbyl", "age": 31}, + ], + [ + ArrayQueryParameter( + name=None, + array_type="RECORD", + values=[ + StructQueryParameter( + None, + ScalarQueryParameter( + name="name", type_="STRING", value="Phred Phlyntstone" + ), + ScalarQueryParameter(name="age", type_="INT64", value=32), + ), + StructQueryParameter( + None, + ScalarQueryParameter( + name="name", type_="STRING", value="Bharney Rhubbyl" + ), + ScalarQueryParameter(name="age", type_="INT64", value=31), + ), + ], + ) + ], + ), + ( + "SELECT @empty_array_param", + [], + [ + ArrayQueryParameter( + name="empty_array_param", + values=[], + array_type=StructQueryParameterType( + ScalarQueryParameterType(name="foo", type_="INT64"), + ScalarQueryParameterType(name="bar", type_="STRING"), + ), + ) + ], + ), + ( + "SELECT @roles", + { + "hero": {"name": "Phred Phlyntstone", "age": 32}, + "sidekick": {"name": "Bharney Rhubbyl", "age": 31}, + }, + [ + StructQueryParameter( + "roles", + StructQueryParameter( + "hero", + ScalarQueryParameter( + name="name", type_="STRING", value="Phred Phlyntstone" + ), + ScalarQueryParameter(name="age", type_="INT64", value=32), + ), + StructQueryParameter( + "sidekick", + ScalarQueryParameter( + name="name", type_="STRING", value="Bharney Rhubbyl" + ), + ScalarQueryParameter(name="age", type_="INT64", value=31), + ), + ), + ], + ), + ( + "SELECT ?", + {"friends": ["Jack", "Jill"]}, + [ + StructQueryParameter( + None, + ArrayQueryParameter( + name="friends", array_type="STRING", values=["Jack", "Jill"] + ), + ) + ], + ), ), ) def test_query_parameters( bigquery_client, query_api_method, sql, expected, query_parameters ): - # array_param = ArrayQueryParameter( - # name="array_param", array_type="INT64", values=[1, 2] - # ) - # struct_param = StructQueryParameter("hitchhiker", question_param, answer_param) - # phred_name = "Phred Phlyntstone" - # phred_name_param = ScalarQueryParameter( - # name="name", type_="STRING", value=phred_name - # ) - # phred_age = 32 - # phred_age_param = ScalarQueryParameter( - # name="age", type_="INT64", value=phred_age - # ) - # phred_param = StructQueryParameter(None, phred_name_param, phred_age_param) - # bharney_name = "Bharney Rhubbyl" - # bharney_name_param = ScalarQueryParameter( - # name="name", type_="STRING", value=bharney_name - # ) - # bharney_age = 31 - # bharney_age_param = ScalarQueryParameter( - # name="age", type_="INT64", value=bharney_age - # ) - # bharney_param = StructQueryParameter( - # None, bharney_name_param, bharney_age_param - # ) - # characters_param = ArrayQueryParameter( - # name=None, array_type="RECORD", values=[phred_param, bharney_param] - # ) - # empty_struct_array_param = ArrayQueryParameter( - # name="empty_array_param", - # values=[], - # array_type=StructQueryParameterType( - # ScalarQueryParameterType(name="foo", type_="INT64"), - # ScalarQueryParameterType(name="bar", type_="STRING"), - # ), - # ) - # hero_param = StructQueryParameter("hero", phred_name_param, phred_age_param) - # sidekick_param = StructQueryParameter( - # "sidekick", bharney_name_param, bharney_age_param - # ) - # roles_param = StructQueryParameter("roles", hero_param, sidekick_param) - # friends_param = ArrayQueryParameter( - # name="friends", array_type="STRING", values=[phred_name, bharney_name] - # ) - # with_friends_param = StructQueryParameter(None, friends_param) - # top_left_param = StructQueryParameter( - # "top_left", - # ScalarQueryParameter("x", "INT64", 12), - # ScalarQueryParameter("y", "INT64", 102), - # ) - # bottom_right_param = StructQueryParameter( - # "bottom_right", - # ScalarQueryParameter("x", "INT64", 22), - # ScalarQueryParameter("y", "INT64", 92), - # ) - # rectangle_param = StructQueryParameter( - # "rectangle", top_left_param, bottom_right_param - # ) jconfig = bigquery.QueryJobConfig() jconfig.query_parameters = query_parameters From 2e4af92c0a98b6518905172544e71a56dd6ccd12 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 7 Oct 2021 16:42:28 -0500 Subject: [PATCH 07/24] update todo with thoughts on future perf update --- google/cloud/bigquery/_job_helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index daad01529..187175fd8 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -156,7 +156,9 @@ def _to_query_job( job_complete = query_response.get("jobComplete") if job_complete: query_job._properties["status"]["state"] = "DONE" - # TODO: set first page of results if job is "complete" + # TODO: set first page of results if job is "complete" (and there is + # only 1 page of results? otherwise, need some awkward logic + # for DB API and to_dataframe to get destination table) else: query_job._properties["status"]["state"] = "PENDING" From d700df51ccee4ab5fd2e5be9ce983cd0aeed949e Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 7 Oct 2021 16:49:24 -0500 Subject: [PATCH 08/24] clarify TODO comment --- google/cloud/bigquery/_job_helpers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index 187175fd8..ba4d88522 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -156,9 +156,14 @@ def _to_query_job( job_complete = query_response.get("jobComplete") if job_complete: query_job._properties["status"]["state"] = "DONE" - # TODO: set first page of results if job is "complete" (and there is - # only 1 page of results? otherwise, need some awkward logic - # for DB API and to_dataframe to get destination table) + # TODO: https://github.com/googleapis/python-bigquery/issues/589 + # Set the first page of results if job is "complete" and there is + # only 1 page of results. Otherwise, use the existing logic that + # refreshes the job stats. + # + # This also requires updates to `to_dataframe` and the DB API connector + # so that they don't try to read from a destination table if all the + # results are present. else: query_job._properties["status"]["state"] = "PENDING" From 1e73a56072a74e1824fb96b78528d2c8c574cfae Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 8 Oct 2021 15:52:19 -0500 Subject: [PATCH 09/24] add placeholders for needed tests --- google/cloud/bigquery/_job_helpers.py | 6 ++- tests/system/conftest.py | 63 ++++++++++++++++++++------ tests/system/test_query.py | 24 ++++++++++ tests/unit/test__job_helpers.py | 64 +++++++++++++++++++++++---- 4 files changed, 134 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index ba4d88522..08cbce191 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -149,8 +149,10 @@ def _to_query_job( # Set errors if any were encountered. query_job._properties.setdefault("status", {}) if "errors" in query_response: - query_job._properties["status"]["errors"] = query_response["errors"] - query_job._properties["status"]["errorResult"] = query_response["errors"][0] + errors = query_response["errors"] + query_job._properties["status"]["errors"] = errors + if len(errors) > 0: + query_job._properties["status"]["errorResult"] = errors[0] # Transform job state so that QueryJob doesn't try to restart the query. job_complete = query_response.get("jobComplete") diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 7eec76a32..e3c67b537 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -14,6 +14,7 @@ import pathlib import re +from typing import Tuple import pytest import test_utils.prefixer @@ -26,6 +27,7 @@ prefixer = test_utils.prefixer.Prefixer("python-bigquery", "tests/system") DATA_DIR = pathlib.Path(__file__).parent.parent / "data" +TOKYO_LOCATION = "asia-northeast1" @pytest.fixture(scope="session", autouse=True) @@ -62,6 +64,16 @@ def dataset_id(bigquery_client): bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) +@pytest.fixture(scope="session") +def dataset_id_tokyo(bigquery_client: bigquery.Client, project_id: str): + dataset_id = prefixer.create_prefix() + "_tokyo" + dataset = bigquery.Dataset(f"{project_id}.{dataset_id}") + dataset.location = TOKYO_LOCATION + bigquery_client.create_dataset(dataset) + yield dataset_id + bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) + + @pytest.fixture() def dataset_client(bigquery_client, dataset_id): import google.cloud.bigquery.job @@ -78,18 +90,37 @@ def table_id(dataset_id): return f"{dataset_id}.table_{helpers.temp_suffix()}" -@pytest.fixture(scope="session") -def scalars_table(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): +def load_scalars_table( + bigquery_client: bigquery.Client, + project_id: str, + dataset_id: str, + data_path: str = "scalars.jsonl", +) -> str: schema = bigquery_client.schema_from_json(DATA_DIR / "scalars_schema.json") job_config = bigquery.LoadJobConfig() job_config.schema = schema job_config.source_format = enums.SourceFormat.NEWLINE_DELIMITED_JSON full_table_id = f"{project_id}.{dataset_id}.scalars" - with open(DATA_DIR / "scalars.jsonl", "rb") as data_file: + with open(DATA_DIR / data_path, "rb") as data_file: job = bigquery_client.load_table_from_file( data_file, full_table_id, job_config=job_config ) job.result() + return full_table_id + + +@pytest.fixture(scope="session") +def scalars_table(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): + full_table_id = load_scalars_table(bigquery_client, project_id, dataset_id) + yield full_table_id + bigquery_client.delete_table(full_table_id) + + +@pytest.fixture(scope="session") +def scalars_table_tokyo( + bigquery_client: bigquery.Client, project_id: str, dataset_id_tokyo: str +): + full_table_id = load_scalars_table(bigquery_client, project_id, dataset_id_tokyo) yield full_table_id bigquery_client.delete_table(full_table_id) @@ -98,20 +129,26 @@ def scalars_table(bigquery_client: bigquery.Client, project_id: str, dataset_id: def scalars_extreme_table( bigquery_client: bigquery.Client, project_id: str, dataset_id: str ): - schema = bigquery_client.schema_from_json(DATA_DIR / "scalars_schema.json") - job_config = bigquery.LoadJobConfig() - job_config.schema = schema - job_config.source_format = enums.SourceFormat.NEWLINE_DELIMITED_JSON - full_table_id = f"{project_id}.{dataset_id}.scalars_extreme" - with open(DATA_DIR / "scalars_extreme.jsonl", "rb") as data_file: - job = bigquery_client.load_table_from_file( - data_file, full_table_id, job_config=job_config - ) - job.result() + full_table_id = load_scalars_table( + bigquery_client, project_id, dataset_id, data_path="scalars_extreme.jsonl" + ) yield full_table_id bigquery_client.delete_table(full_table_id) +@pytest.fixture(scope="session", params=["US", TOKYO_LOCATION]) +def scalars_table_multi_location( + request, scalars_table: str, scalars_table_tokyo: str +) -> Tuple[str, str]: + if request.param == "US": + full_table_id = scalars_table + elif request.param == TOKYO_LOCATION: + full_table_id = scalars_table_tokyo + else: + raise ValueError(f"got unexpected location: {request.param}") + return request.param, full_table_id + + @pytest.fixture def test_table_name(request, replace_non_anum=re.compile(r"[^a-zA-Z0-9_]").sub): return replace_non_anum("_", request.node.name) diff --git a/tests/system/test_query.py b/tests/system/test_query.py index ccd583cbf..a0bffbc6c 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -15,6 +15,7 @@ import concurrent.futures import datetime import decimal +from typing import Tuple import pytest @@ -417,3 +418,26 @@ def test_query_parameters( assert len(rows) == 1 assert len(rows[0]) == 1 assert rows[0][0] == expected + + +def test_dry_run( + bigquery_client: bigquery.Client, + query_api_method: str, + scalars_table_multi_location: Tuple[str, str], +): + location, full_table_id = scalars_table_multi_location + query_config = bigquery.QueryJobConfig() + query_config.dry_run = True + + query_string = f"SELECT * FROM {full_table_id}" + query_job = bigquery_client.query( + query_string, + location=location, + job_config=query_config, + api_method=query_api_method, + ) + + # Note: `query_job.result()` is not necessary on a dry run query. All + # necessary information is returned in the initial response. + assert query_job.dry_run is True + # TODO: check more properties, such as estimated bytes processed, schema diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py index 35756fa90..13e6bdf14 100644 --- a/tests/unit/test__job_helpers.py +++ b/tests/unit/test__job_helpers.py @@ -12,16 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest +from typing import Dict, Any +from unittest import mock -from google.cloud.bigquery.job.query import QueryJobConfig +import pytest +from google.cloud.bigquery.client import Client +from google.cloud.bigquery import _job_helpers +from google.cloud.bigquery.job.query import QueryJob, QueryJobConfig -@pytest.fixture -def module_under_test(): - from google.cloud.bigquery import _job_helpers - return _job_helpers +def make_query_response( + completed: bool = False, + job_id: str = "abcd-efg-hijk-lmnop", + location="US", + project_id="test-project", +) -> Dict[str, Any]: + response = { + "jobReference": { + "projectId": project_id, + "jobId": job_id, + "location": location, + }, + "jobComplete": completed, + } + return response @pytest.mark.parametrize( @@ -40,6 +55,39 @@ def module_under_test(): ), ), ) -def test__to_query_request(module_under_test, job_config, expected): - result = module_under_test._to_query_request(job_config) +def test__to_query_request(job_config, expected): + result = _job_helpers._to_query_request(job_config) assert result == expected + + +def test__to_query_job_defaults(): + mock_client = mock.create_autospec(Client) + response = make_query_response( + job_id="test-job", project_id="some-project", location="asia-northeast1" + ) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", response) + assert job.query == "query-str" + assert job._client is mock_client + assert job.job_id == "test-job" + assert job.project == "some-project" + assert job.location == "asia-northeast1" + assert job.error_result is None + assert job.errors is None + + +def test__to_query_job_dry_run(): + assert False + + +@pytest.mark.parametrize( + ("completed", "expected_state"), ((True, "DONE"), (False, "PENDING"),), +) +def test__to_query_job_sets_state(completed, expected_state): + mock_client = mock.create_autospec(Client) + response = make_query_response(completed=completed) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", response) + assert job.state == expected_state + + +def test__to_query_job_sets_errors(): + assert False From 7435c8d932f00522f02183c1d2c6a89273612eee Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 11 Oct 2021 14:56:43 -0500 Subject: [PATCH 10/24] add schema property --- google/cloud/bigquery/_job_helpers.py | 9 +++++++-- google/cloud/bigquery/job/query.py | 15 ++++++++++++++- tests/system/test_query.py | 5 +++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index 08cbce191..f022cc055 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -140,12 +140,17 @@ def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any def _to_query_job( - client: "Client", query: str, query_response: Dict[str, Any] + client: "Client", + query: str, + request_config: job.QueryJobConfig, + query_response: Dict[str, Any], ) -> job.QueryJob: job_ref_resource = query_response["jobReference"] job_ref = job._JobReference._from_api_repr(job_ref_resource) query_job = job.QueryJob(job_ref, query, client=client) + query_job._properties["configuration"] = request_config.to_api_repr() + # Set errors if any were encountered. query_job._properties.setdefault("status", {}) if "errors" in query_response: @@ -207,7 +212,7 @@ def do_query(): data=request_body, timeout=timeout, ) - return _to_query_job(client, query, api_response) + return _to_query_job(client, query, job_config, api_response) future = do_query() diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index c07daec99..e4e064f68 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -18,7 +18,7 @@ import copy import re import typing -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from google.api_core import exceptions from google.api_core.future import polling as polling_future @@ -38,6 +38,7 @@ from google.cloud.bigquery.query import UDFResource from google.cloud.bigquery.retry import DEFAULT_RETRY, DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import RoutineReference +from google.cloud.bigquery.schema import SchemaField from google.cloud.bigquery.table import _EmptyRowIterator from google.cloud.bigquery.table import RangePartitioning from google.cloud.bigquery.table import _table_arg_to_table_ref @@ -887,6 +888,18 @@ def query_plan(self): plan_entries = self._job_statistics().get("queryPlan", ()) return [QueryPlanEntry.from_api_repr(entry) for entry in plan_entries] + @property + def schema(self) -> Optional[List[SchemaField]]: + """The schema of the results. + + Present only for successful dry run of non-legacy SQL queries. + """ + resource = self._job_statistics().get("schema") + if resource is None: + return None + fields = resource.get("fields", []) + return [SchemaField.from_api_repr(field) for field in fields] + @property def timeline(self): """List(TimelineEntry): Return the query execution timeline diff --git a/tests/system/test_query.py b/tests/system/test_query.py index a0bffbc6c..f0d6cf651 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -439,5 +439,6 @@ def test_dry_run( # Note: `query_job.result()` is not necessary on a dry run query. All # necessary information is returned in the initial response. - assert query_job.dry_run is True - # TODO: check more properties, such as estimated bytes processed, schema + assert query_job.dry_run is False + assert query_job.total_bytes_processed > 0 + assert len(query_job.schema) > 0 From d3838471b289b67965ef1369a359743ddfbd6e68 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 11 Oct 2021 16:36:41 -0500 Subject: [PATCH 11/24] feat: add `QueryJob.schema` property for dry run queries --- google/cloud/bigquery/job/base.py | 4 ++- google/cloud/bigquery/job/query.py | 9 +++-- tests/system/test_query.py | 2 +- tests/unit/job/test_query.py | 56 ++++++++++++++++++++---------- 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/google/cloud/bigquery/job/base.py b/google/cloud/bigquery/job/base.py index 698181092..23c5aa8db 100644 --- a/google/cloud/bigquery/job/base.py +++ b/google/cloud/bigquery/job/base.py @@ -1005,7 +1005,9 @@ def from_api_repr(cls, resource: dict, client) -> "UnknownJob": Returns: UnknownJob: Job corresponding to the resource. """ - job_ref_properties = resource.get("jobReference", {"projectId": client.project}) + job_ref_properties = resource.get( + "jobReference", {"projectId": client.project, "jobId": None} + ) job_ref = _JobReference._from_api_repr(job_ref_properties) job = cls(job_ref, client) # Populate the job reference with the project, even if it has been diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index e4e064f68..1e6a60c8a 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -58,6 +58,7 @@ import pyarrow from google.api_core import retry as retries from google.cloud import bigquery_storage + from google.cloud.bigquery.client import Client from google.cloud.bigquery.table import RowIterator @@ -854,7 +855,7 @@ def to_api_repr(self): } @classmethod - def from_api_repr(cls, resource: dict, client) -> "QueryJob": + def from_api_repr(cls, resource: dict, client: "Client") -> "QueryJob": """Factory: construct a job given its API representation Args: @@ -867,8 +868,10 @@ def from_api_repr(cls, resource: dict, client) -> "QueryJob": Returns: google.cloud.bigquery.job.QueryJob: Job parsed from ``resource``. """ - cls._check_resource_config(resource) - job_ref = _JobReference._from_api_repr(resource["jobReference"]) + job_ref_properties = resource.setdefault( + "jobReference", {"projectId": client.project, "jobId": None} + ) + job_ref = _JobReference._from_api_repr(job_ref_properties) job = cls(job_ref, None, client=client) job._set_properties(resource) return job diff --git a/tests/system/test_query.py b/tests/system/test_query.py index f0d6cf651..6d0bda0f0 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -439,6 +439,6 @@ def test_dry_run( # Note: `query_job.result()` is not necessary on a dry run query. All # necessary information is returned in the initial response. - assert query_job.dry_run is False + assert query_job.dry_run is True assert query_job.total_bytes_processed > 0 assert len(query_job.schema) > 0 diff --git a/tests/unit/job/test_query.py b/tests/unit/job/test_query.py index 4c598d797..57a1dec32 100644 --- a/tests/unit/job/test_query.py +++ b/tests/unit/job/test_query.py @@ -268,25 +268,6 @@ def test_ctor_w_query_parameters(self): job = self._make_one(self.JOB_ID, self.QUERY, client, job_config=config) self.assertEqual(job.query_parameters, query_parameters) - def test_from_api_repr_missing_identity(self): - self._setUpConstants() - client = _make_client(project=self.PROJECT) - RESOURCE = {} - klass = self._get_target_class() - with self.assertRaises(KeyError): - klass.from_api_repr(RESOURCE, client=client) - - def test_from_api_repr_missing_config(self): - self._setUpConstants() - client = _make_client(project=self.PROJECT) - RESOURCE = { - "id": "%s:%s" % (self.PROJECT, self.DS_ID), - "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, - } - klass = self._get_target_class() - with self.assertRaises(KeyError): - klass.from_api_repr(RESOURCE, client=client) - def test_from_api_repr_bare(self): self._setUpConstants() client = _make_client(project=self.PROJECT) @@ -1391,6 +1372,43 @@ def test_result_transport_timeout_error(self): with call_api_patch, self.assertRaises(concurrent.futures.TimeoutError): job.result(timeout=1) + def test_no_schema(self): + client = _make_client(project=self.PROJECT) + resource = {} + klass = self._get_target_class() + job = klass.from_api_repr(resource, client=client) + assert job.schema is None + + def test_schema(self): + client = _make_client(project=self.PROJECT) + resource = { + "statistics": { + "query": { + "schema": { + "fields": [ + {"mode": "NULLABLE", "name": "bool_col", "type": "BOOLEAN"}, + { + "mode": "NULLABLE", + "name": "string_col", + "type": "STRING", + }, + { + "mode": "NULLABLE", + "name": "timestamp_col", + "type": "TIMESTAMP", + }, + ] + }, + }, + }, + } + klass = self._get_target_class() + job = klass.from_api_repr(resource, client=client) + assert len(job.schema) == 3 + assert job.schema[0].field_type == "BOOLEAN" + assert job.schema[1].field_type == "STRING" + assert job.schema[2].field_type == "TIMESTAMP" + def test__begin_error(self): from google.cloud import exceptions From 8bc2458e492d9c9f1c5f0fff380bf5723e8c49db Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 13 Oct 2021 17:07:06 -0500 Subject: [PATCH 12/24] add more job properties --- google/cloud/bigquery/_job_helpers.py | 22 ++++++++++++++++++++-- tests/unit/test__job_helpers.py | 15 ++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index f022cc055..3020308a0 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -142,14 +142,32 @@ def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any def _to_query_job( client: "Client", query: str, - request_config: job.QueryJobConfig, + request_config: Optional[job.QueryJobConfig], query_response: Dict[str, Any], ) -> job.QueryJob: job_ref_resource = query_response["jobReference"] job_ref = job._JobReference._from_api_repr(job_ref_resource) query_job = job.QueryJob(job_ref, query, client=client) - query_job._properties["configuration"] = request_config.to_api_repr() + # Not all relevant properties are in the jobs.query response, so + if request_config is not None: + query_job._properties["configuration"].update(request_config.to_api_repr()) + query_job._properties["configuration"]["query"]["query"] = query + query_job._properties["configuration"]["query"].setdefault( + "useLegacySql", False + ) + + query_job._properties.setdefault("statistics", {}) + query_job._properties["statistics"].setdefault("query", {}) + query_job._properties["statistics"]["query"]["cacheHit"] = query_response.get( + "cacheHit" + ) + query_job._properties["statistics"]["query"]["schema"] = query_response.get( + "schema" + ) + query_job._properties["statistics"]["query"][ + "totalBytesProcessed" + ] = query_response.get("totalBytesProcessed") # Set errors if any were encountered. query_job._properties.setdefault("status", {}) diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py index 13e6bdf14..7e0b69d91 100644 --- a/tests/unit/test__job_helpers.py +++ b/tests/unit/test__job_helpers.py @@ -65,7 +65,7 @@ def test__to_query_job_defaults(): response = make_query_response( job_id="test-job", project_id="some-project", location="asia-northeast1" ) - job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", response) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) assert job.query == "query-str" assert job._client is mock_client assert job.job_id == "test-job" @@ -76,7 +76,16 @@ def test__to_query_job_defaults(): def test__to_query_job_dry_run(): - assert False + mock_client = mock.create_autospec(Client) + response = make_query_response( + job_id="test-job", project_id="some-project", location="asia-northeast1" + ) + job_config: QueryJobConfig = QueryJobConfig() + job_config.dry_run = True + job: QueryJob = _job_helpers._to_query_job( + mock_client, "query-str", job_config, response + ) + assert job.dry_run is True @pytest.mark.parametrize( @@ -85,7 +94,7 @@ def test__to_query_job_dry_run(): def test__to_query_job_sets_state(completed, expected_state): mock_client = mock.create_autospec(Client) response = make_query_response(completed=completed) - job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", response) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) assert job.state == expected_state From a2b4c2bc7e47642abe46b2d7021babf076271610 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 14 Oct 2021 15:22:14 -0500 Subject: [PATCH 13/24] add tests for differences in API error behavior between jobs.query and jobs.insert --- google/cloud/bigquery/_job_helpers.py | 5 +++-- tests/system/test_query.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index 3020308a0..c0ef8d6f5 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -172,10 +172,11 @@ def _to_query_job( # Set errors if any were encountered. query_job._properties.setdefault("status", {}) if "errors" in query_response: + # Set errors but not errorResult. If there was an error that failed + # the job, jobs.query behaves like jobs.getQueryResults and returns a + # non-success HTTP status code. errors = query_response["errors"] query_job._properties["status"]["errors"] = errors - if len(errors) > 0: - query_job._properties["status"]["errorResult"] = errors[0] # Transform job state so that QueryJob doesn't try to restart the query. job_complete = query_response.get("jobComplete") diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 6d0bda0f0..9e2565b2c 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -17,6 +17,7 @@ import decimal from typing import Tuple +from google.api_core import exceptions import pytest from google.cloud import bigquery @@ -408,7 +409,6 @@ def test_query_statistics(bigquery_client, query_api_method): def test_query_parameters( bigquery_client, query_api_method, sql, expected, query_parameters ): - jconfig = bigquery.QueryJobConfig() jconfig.query_parameters = query_parameters query_job = bigquery_client.query( @@ -442,3 +442,21 @@ def test_dry_run( assert query_job.dry_run is True assert query_job.total_bytes_processed > 0 assert len(query_job.schema) > 0 + + +def test_query_error_w_api_method_query(bigquery_client: bigquery.Client): + """No job is returned from jobs.query if the query fails.""" + + with pytest.raises(exceptions.NotFound, match="not_a_real_dataset"): + bigquery_client.query( + "SELECT * FROM not_a_real_dataset.doesnt_exist", api_method="QUERY" + ) + + +def test_query_error_w_api_method_insert(bigquery_client: bigquery.Client): + """With jobs.insert, an exception is thrown when fetching the results..""" + + query_job = bigquery_client.query("SELECT * FROM not_a_real_dataset.doesnt_exist") + + with pytest.raises(exceptions.NotFound, match="not_a_real_dataset"): + query_job.result() From 8b970f2c9d4facf45420d51835bd226233513d8a Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 14 Oct 2021 16:33:41 -0500 Subject: [PATCH 14/24] update docs to show differences --- google/cloud/bigquery/client.py | 4 +++- google/cloud/bigquery/enums.py | 26 +++++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index e5ee2e742..ccaae47b5 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -51,7 +51,6 @@ from google.cloud.bigquery import _job_helpers from google.cloud.bigquery._job_helpers import make_job_id as _make_job_id -from google.cloud.bigquery._helpers import _del_sub_prop from google.cloud.bigquery._helpers import _get_sub_prop from google.cloud.bigquery._helpers import _record_field_to_json from google.cloud.bigquery._helpers import _str_or_none @@ -3180,6 +3179,9 @@ def query( api_method: Method with which to start the query job. + See :class:`google.cloud.bigquery.enums.QueryApiMethod` for + details on the difference between the query start methods. + Returns: google.cloud.bigquery.job.QueryJob: A new query job instance. diff --git a/google/cloud/bigquery/enums.py b/google/cloud/bigquery/enums.py index 39d880c2b..03f347427 100644 --- a/google/cloud/bigquery/enums.py +++ b/google/cloud/bigquery/enums.py @@ -138,15 +138,23 @@ class QueryApiMethod(str, enum.Enum): """Submit a query job by using the `jobs.query REST API method `_. - This API blocks for up to a specified timeout for the query to finish. The - full job resource (including job statistics) may not be available if the - query finishes within the timeout. Call - :meth:`~google.cloud.bigquery.job.QueryJob.reload` or - :meth:`~google.cloud.bigquery.client.Client.get_job` to get full job - statistics. - - Many parameters, including destination table and job ID cannot be used with - this API method. + Differences from ``INSERT``: + + * Many parameters, including destination table and job ID, cannot be used + with this API method. + + * API blocks for up to a specified timeout, waiting for the query to + finish. + + * The full job resource (including job statistics) may not be available. + Call :meth:`~google.cloud.bigquery.job.QueryJob.reload` or + :meth:`~google.cloud.bigquery.client.Client.get_job` to get full job + statistics. + + * :meth:`~google.cloud.bigquery.Client.query` can raise API exceptions if + the query fails, whereas often the same errors don't appear until calling + :meth:`~google.cloud.bigquery.job.QueryJob.reload` when the ``INSERT`` + API method is used. """ From e7e5e179e7e2f911fd11ec6df5b481f494fc7a93 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 14 Oct 2021 16:39:38 -0500 Subject: [PATCH 15/24] cover error conversion --- tests/unit/test__job_helpers.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py index 7e0b69d91..a3211433f 100644 --- a/tests/unit/test__job_helpers.py +++ b/tests/unit/test__job_helpers.py @@ -27,6 +27,7 @@ def make_query_response( job_id: str = "abcd-efg-hijk-lmnop", location="US", project_id="test-project", + errors=None, ) -> Dict[str, Any]: response = { "jobReference": { @@ -36,6 +37,8 @@ def make_query_response( }, "jobComplete": completed, } + if errors is not None: + response["errors"] = errors return response @@ -99,4 +102,13 @@ def test__to_query_job_sets_state(completed, expected_state): def test__to_query_job_sets_errors(): - assert False + mock_client = mock.create_autospec(Client) + response = make_query_response( + errors=[ + # https://cloud.google.com/bigquery/docs/reference/rest/v2/ErrorProto + {"reason": "backendError", "message": "something went wrong"}, + {"message": "something else went wrong"}, + ] + ) + job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) + assert len(job.errors) == 2 From b572188172e1b87bb840eefe05838acf3a8f27a4 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 14 Oct 2021 16:51:06 -0500 Subject: [PATCH 16/24] restore missing modules --- google/cloud/bigquery/enums.py | 15 +- google/cloud/bigquery_v2/types/model.py | 1507 +++++++++++++++++ .../bigquery_v2/types/table_reference.py | 58 + 3 files changed, 1574 insertions(+), 6 deletions(-) create mode 100644 google/cloud/bigquery_v2/types/model.py create mode 100644 google/cloud/bigquery_v2/types/table_reference.py diff --git a/google/cloud/bigquery/enums.py b/google/cloud/bigquery/enums.py index 03f347427..450923700 100644 --- a/google/cloud/bigquery/enums.py +++ b/google/cloud/bigquery/enums.py @@ -140,20 +140,23 @@ class QueryApiMethod(str, enum.Enum): Differences from ``INSERT``: - * Many parameters, including destination table and job ID, cannot be used - with this API method. + * Many parameters and job configuration options, including job ID and + destination table, cannot be used + with this API method. See the `jobs.query REST API documentation + `_ for + the complete list of supported configuration options. - * API blocks for up to a specified timeout, waiting for the query to + * API blocks up to a specified timeout, waiting for the query to finish. * The full job resource (including job statistics) may not be available. Call :meth:`~google.cloud.bigquery.job.QueryJob.reload` or :meth:`~google.cloud.bigquery.client.Client.get_job` to get full job - statistics. + statistics and configuration. * :meth:`~google.cloud.bigquery.Client.query` can raise API exceptions if - the query fails, whereas often the same errors don't appear until calling - :meth:`~google.cloud.bigquery.job.QueryJob.reload` when the ``INSERT`` + the query fails, whereas the same errors don't appear until calling + :meth:`~google.cloud.bigquery.job.QueryJob.result` when the ``INSERT`` API method is used. """ diff --git a/google/cloud/bigquery_v2/types/model.py b/google/cloud/bigquery_v2/types/model.py new file mode 100644 index 000000000..706418401 --- /dev/null +++ b/google/cloud/bigquery_v2/types/model.py @@ -0,0 +1,1507 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto # type: ignore + +from google.cloud.bigquery_v2.types import encryption_config +from google.cloud.bigquery_v2.types import model_reference as gcb_model_reference +from google.cloud.bigquery_v2.types import standard_sql +from google.cloud.bigquery_v2.types import table_reference +from google.protobuf import timestamp_pb2 # type: ignore +from google.protobuf import wrappers_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.bigquery.v2", + manifest={ + "Model", + "GetModelRequest", + "PatchModelRequest", + "DeleteModelRequest", + "ListModelsRequest", + "ListModelsResponse", + }, +) + + +class Model(proto.Message): + r""" + Attributes: + etag (str): + Output only. A hash of this resource. + model_reference (google.cloud.bigquery_v2.types.ModelReference): + Required. Unique identifier for this model. + creation_time (int): + Output only. The time when this model was + created, in millisecs since the epoch. + last_modified_time (int): + Output only. The time when this model was + last modified, in millisecs since the epoch. + description (str): + Optional. A user-friendly description of this + model. + friendly_name (str): + Optional. A descriptive name for this model. + labels (Sequence[google.cloud.bigquery_v2.types.Model.LabelsEntry]): + The labels associated with this model. You + can use these to organize and group your models. + Label keys and values can be no longer than 63 + characters, can only contain lowercase letters, + numeric characters, underscores and dashes. + International characters are allowed. Label + values are optional. Label keys must start with + a letter and each label in the list must have a + different key. + expiration_time (int): + Optional. The time when this model expires, + in milliseconds since the epoch. If not present, + the model will persist indefinitely. Expired + models will be deleted and their storage + reclaimed. The defaultTableExpirationMs + property of the encapsulating dataset can be + used to set a default expirationTime on newly + created models. + location (str): + Output only. The geographic location where + the model resides. This value is inherited from + the dataset. + encryption_configuration (google.cloud.bigquery_v2.types.EncryptionConfiguration): + Custom encryption configuration (e.g., Cloud + KMS keys). This shows the encryption + configuration of the model data while stored in + BigQuery storage. This field can be used with + PatchModel to update encryption key for an + already encrypted model. + model_type (google.cloud.bigquery_v2.types.Model.ModelType): + Output only. Type of the model resource. + training_runs (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun]): + Output only. Information for all training runs in increasing + order of start_time. + feature_columns (Sequence[google.cloud.bigquery_v2.types.StandardSqlField]): + Output only. Input feature columns that were + used to train this model. + label_columns (Sequence[google.cloud.bigquery_v2.types.StandardSqlField]): + Output only. Label columns that were used to train this + model. The output of the model will have a `predicted_` + prefix to these columns. + best_trial_id (int): + The best trial_id across all training runs. + """ + + class ModelType(proto.Enum): + r"""Indicates the type of the Model.""" + MODEL_TYPE_UNSPECIFIED = 0 + LINEAR_REGRESSION = 1 + LOGISTIC_REGRESSION = 2 + KMEANS = 3 + MATRIX_FACTORIZATION = 4 + DNN_CLASSIFIER = 5 + TENSORFLOW = 6 + DNN_REGRESSOR = 7 + BOOSTED_TREE_REGRESSOR = 9 + BOOSTED_TREE_CLASSIFIER = 10 + ARIMA = 11 + AUTOML_REGRESSOR = 12 + AUTOML_CLASSIFIER = 13 + ARIMA_PLUS = 19 + + class LossType(proto.Enum): + r"""Loss metric to evaluate model training performance.""" + LOSS_TYPE_UNSPECIFIED = 0 + MEAN_SQUARED_LOSS = 1 + MEAN_LOG_LOSS = 2 + + class DistanceType(proto.Enum): + r"""Distance metric used to compute the distance between two + points. + """ + DISTANCE_TYPE_UNSPECIFIED = 0 + EUCLIDEAN = 1 + COSINE = 2 + + class DataSplitMethod(proto.Enum): + r"""Indicates the method to split input data into multiple + tables. + """ + DATA_SPLIT_METHOD_UNSPECIFIED = 0 + RANDOM = 1 + CUSTOM = 2 + SEQUENTIAL = 3 + NO_SPLIT = 4 + AUTO_SPLIT = 5 + + class DataFrequency(proto.Enum): + r"""Type of supported data frequency for time series forecasting + models. + """ + DATA_FREQUENCY_UNSPECIFIED = 0 + AUTO_FREQUENCY = 1 + YEARLY = 2 + QUARTERLY = 3 + MONTHLY = 4 + WEEKLY = 5 + DAILY = 6 + HOURLY = 7 + PER_MINUTE = 8 + + class HolidayRegion(proto.Enum): + r"""Type of supported holiday regions for time series forecasting + models. + """ + HOLIDAY_REGION_UNSPECIFIED = 0 + GLOBAL = 1 + NA = 2 + JAPAC = 3 + EMEA = 4 + LAC = 5 + AE = 6 + AR = 7 + AT = 8 + AU = 9 + BE = 10 + BR = 11 + CA = 12 + CH = 13 + CL = 14 + CN = 15 + CO = 16 + CS = 17 + CZ = 18 + DE = 19 + DK = 20 + DZ = 21 + EC = 22 + EE = 23 + EG = 24 + ES = 25 + FI = 26 + FR = 27 + GB = 28 + GR = 29 + HK = 30 + HU = 31 + ID = 32 + IE = 33 + IL = 34 + IN = 35 + IR = 36 + IT = 37 + JP = 38 + KR = 39 + LV = 40 + MA = 41 + MX = 42 + MY = 43 + NG = 44 + NL = 45 + NO = 46 + NZ = 47 + PE = 48 + PH = 49 + PK = 50 + PL = 51 + PT = 52 + RO = 53 + RS = 54 + RU = 55 + SA = 56 + SE = 57 + SG = 58 + SI = 59 + SK = 60 + TH = 61 + TR = 62 + TW = 63 + UA = 64 + US = 65 + VE = 66 + VN = 67 + ZA = 68 + + class LearnRateStrategy(proto.Enum): + r"""Indicates the learning rate optimization strategy to use.""" + LEARN_RATE_STRATEGY_UNSPECIFIED = 0 + LINE_SEARCH = 1 + CONSTANT = 2 + + class OptimizationStrategy(proto.Enum): + r"""Indicates the optimization strategy used for training.""" + OPTIMIZATION_STRATEGY_UNSPECIFIED = 0 + BATCH_GRADIENT_DESCENT = 1 + NORMAL_EQUATION = 2 + + class FeedbackType(proto.Enum): + r"""Indicates the training algorithm to use for matrix + factorization models. + """ + FEEDBACK_TYPE_UNSPECIFIED = 0 + IMPLICIT = 1 + EXPLICIT = 2 + + class SeasonalPeriod(proto.Message): + r""" """ + + class SeasonalPeriodType(proto.Enum): + r"""""" + SEASONAL_PERIOD_TYPE_UNSPECIFIED = 0 + NO_SEASONALITY = 1 + DAILY = 2 + WEEKLY = 3 + MONTHLY = 4 + QUARTERLY = 5 + YEARLY = 6 + + class KmeansEnums(proto.Message): + r""" """ + + class KmeansInitializationMethod(proto.Enum): + r"""Indicates the method used to initialize the centroids for + KMeans clustering algorithm. + """ + KMEANS_INITIALIZATION_METHOD_UNSPECIFIED = 0 + RANDOM = 1 + CUSTOM = 2 + KMEANS_PLUS_PLUS = 3 + + class RegressionMetrics(proto.Message): + r"""Evaluation metrics for regression and explicit feedback type + matrix factorization models. + + Attributes: + mean_absolute_error (google.protobuf.wrappers_pb2.DoubleValue): + Mean absolute error. + mean_squared_error (google.protobuf.wrappers_pb2.DoubleValue): + Mean squared error. + mean_squared_log_error (google.protobuf.wrappers_pb2.DoubleValue): + Mean squared log error. + median_absolute_error (google.protobuf.wrappers_pb2.DoubleValue): + Median absolute error. + r_squared (google.protobuf.wrappers_pb2.DoubleValue): + R^2 score. This corresponds to r2_score in ML.EVALUATE. + """ + + mean_absolute_error = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, + ) + mean_squared_error = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, + ) + mean_squared_log_error = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.DoubleValue, + ) + median_absolute_error = proto.Field( + proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, + ) + r_squared = proto.Field( + proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, + ) + + class AggregateClassificationMetrics(proto.Message): + r"""Aggregate metrics for classification/classifier models. For + multi-class models, the metrics are either macro-averaged or + micro-averaged. When macro-averaged, the metrics are calculated + for each label and then an unweighted average is taken of those + values. When micro-averaged, the metric is calculated globally + by counting the total number of correctly predicted rows. + + Attributes: + precision (google.protobuf.wrappers_pb2.DoubleValue): + Precision is the fraction of actual positive + predictions that had positive actual labels. For + multiclass this is a macro-averaged metric + treating each class as a binary classifier. + recall (google.protobuf.wrappers_pb2.DoubleValue): + Recall is the fraction of actual positive + labels that were given a positive prediction. + For multiclass this is a macro-averaged metric. + accuracy (google.protobuf.wrappers_pb2.DoubleValue): + Accuracy is the fraction of predictions given + the correct label. For multiclass this is a + micro-averaged metric. + threshold (google.protobuf.wrappers_pb2.DoubleValue): + Threshold at which the metrics are computed. + For binary classification models this is the + positive class threshold. For multi-class + classfication models this is the confidence + threshold. + f1_score (google.protobuf.wrappers_pb2.DoubleValue): + The F1 score is an average of recall and + precision. For multiclass this is a macro- + averaged metric. + log_loss (google.protobuf.wrappers_pb2.DoubleValue): + Logarithmic Loss. For multiclass this is a + macro-averaged metric. + roc_auc (google.protobuf.wrappers_pb2.DoubleValue): + Area Under a ROC Curve. For multiclass this + is a macro-averaged metric. + """ + + precision = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, + ) + recall = proto.Field(proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue,) + accuracy = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.DoubleValue, + ) + threshold = proto.Field( + proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, + ) + f1_score = proto.Field( + proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, + ) + log_loss = proto.Field( + proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, + ) + roc_auc = proto.Field( + proto.MESSAGE, number=7, message=wrappers_pb2.DoubleValue, + ) + + class BinaryClassificationMetrics(proto.Message): + r"""Evaluation metrics for binary classification/classifier + models. + + Attributes: + aggregate_classification_metrics (google.cloud.bigquery_v2.types.Model.AggregateClassificationMetrics): + Aggregate classification metrics. + binary_confusion_matrix_list (Sequence[google.cloud.bigquery_v2.types.Model.BinaryClassificationMetrics.BinaryConfusionMatrix]): + Binary confusion matrix at multiple + thresholds. + positive_label (str): + Label representing the positive class. + negative_label (str): + Label representing the negative class. + """ + + class BinaryConfusionMatrix(proto.Message): + r"""Confusion matrix for binary classification models. + Attributes: + positive_class_threshold (google.protobuf.wrappers_pb2.DoubleValue): + Threshold value used when computing each of + the following metric. + true_positives (google.protobuf.wrappers_pb2.Int64Value): + Number of true samples predicted as true. + false_positives (google.protobuf.wrappers_pb2.Int64Value): + Number of false samples predicted as true. + true_negatives (google.protobuf.wrappers_pb2.Int64Value): + Number of true samples predicted as false. + false_negatives (google.protobuf.wrappers_pb2.Int64Value): + Number of false samples predicted as false. + precision (google.protobuf.wrappers_pb2.DoubleValue): + The fraction of actual positive predictions + that had positive actual labels. + recall (google.protobuf.wrappers_pb2.DoubleValue): + The fraction of actual positive labels that + were given a positive prediction. + f1_score (google.protobuf.wrappers_pb2.DoubleValue): + The equally weighted average of recall and + precision. + accuracy (google.protobuf.wrappers_pb2.DoubleValue): + The fraction of predictions given the correct + label. + """ + + positive_class_threshold = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, + ) + true_positives = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.Int64Value, + ) + false_positives = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.Int64Value, + ) + true_negatives = proto.Field( + proto.MESSAGE, number=4, message=wrappers_pb2.Int64Value, + ) + false_negatives = proto.Field( + proto.MESSAGE, number=5, message=wrappers_pb2.Int64Value, + ) + precision = proto.Field( + proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, + ) + recall = proto.Field( + proto.MESSAGE, number=7, message=wrappers_pb2.DoubleValue, + ) + f1_score = proto.Field( + proto.MESSAGE, number=8, message=wrappers_pb2.DoubleValue, + ) + accuracy = proto.Field( + proto.MESSAGE, number=9, message=wrappers_pb2.DoubleValue, + ) + + aggregate_classification_metrics = proto.Field( + proto.MESSAGE, number=1, message="Model.AggregateClassificationMetrics", + ) + binary_confusion_matrix_list = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Model.BinaryClassificationMetrics.BinaryConfusionMatrix", + ) + positive_label = proto.Field(proto.STRING, number=3,) + negative_label = proto.Field(proto.STRING, number=4,) + + class MultiClassClassificationMetrics(proto.Message): + r"""Evaluation metrics for multi-class classification/classifier + models. + + Attributes: + aggregate_classification_metrics (google.cloud.bigquery_v2.types.Model.AggregateClassificationMetrics): + Aggregate classification metrics. + confusion_matrix_list (Sequence[google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics.ConfusionMatrix]): + Confusion matrix at different thresholds. + """ + + class ConfusionMatrix(proto.Message): + r"""Confusion matrix for multi-class classification models. + Attributes: + confidence_threshold (google.protobuf.wrappers_pb2.DoubleValue): + Confidence threshold used when computing the + entries of the confusion matrix. + rows (Sequence[google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics.ConfusionMatrix.Row]): + One row per actual label. + """ + + class Entry(proto.Message): + r"""A single entry in the confusion matrix. + Attributes: + predicted_label (str): + The predicted label. For confidence_threshold > 0, we will + also add an entry indicating the number of items under the + confidence threshold. + item_count (google.protobuf.wrappers_pb2.Int64Value): + Number of items being predicted as this + label. + """ + + predicted_label = proto.Field(proto.STRING, number=1,) + item_count = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.Int64Value, + ) + + class Row(proto.Message): + r"""A single row in the confusion matrix. + Attributes: + actual_label (str): + The original label of this row. + entries (Sequence[google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics.ConfusionMatrix.Entry]): + Info describing predicted label distribution. + """ + + actual_label = proto.Field(proto.STRING, number=1,) + entries = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Model.MultiClassClassificationMetrics.ConfusionMatrix.Entry", + ) + + confidence_threshold = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, + ) + rows = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Model.MultiClassClassificationMetrics.ConfusionMatrix.Row", + ) + + aggregate_classification_metrics = proto.Field( + proto.MESSAGE, number=1, message="Model.AggregateClassificationMetrics", + ) + confusion_matrix_list = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Model.MultiClassClassificationMetrics.ConfusionMatrix", + ) + + class ClusteringMetrics(proto.Message): + r"""Evaluation metrics for clustering models. + Attributes: + davies_bouldin_index (google.protobuf.wrappers_pb2.DoubleValue): + Davies-Bouldin index. + mean_squared_distance (google.protobuf.wrappers_pb2.DoubleValue): + Mean of squared distances between each sample + to its cluster centroid. + clusters (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster]): + Information for all clusters. + """ + + class Cluster(proto.Message): + r"""Message containing the information about one cluster. + Attributes: + centroid_id (int): + Centroid id. + feature_values (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster.FeatureValue]): + Values of highly variant features for this + cluster. + count (google.protobuf.wrappers_pb2.Int64Value): + Count of training data rows that were + assigned to this cluster. + """ + + class FeatureValue(proto.Message): + r"""Representative value of a single feature within the cluster. + Attributes: + feature_column (str): + The feature column name. + numerical_value (google.protobuf.wrappers_pb2.DoubleValue): + The numerical feature value. This is the + centroid value for this feature. + categorical_value (google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue): + The categorical feature value. + """ + + class CategoricalValue(proto.Message): + r"""Representative value of a categorical feature. + Attributes: + category_counts (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue.CategoryCount]): + Counts of all categories for the categorical feature. If + there are more than ten categories, we return top ten (by + count) and return one more CategoryCount with category + "*OTHER*" and count as aggregate counts of remaining + categories. + """ + + class CategoryCount(proto.Message): + r"""Represents the count of a single category within the cluster. + Attributes: + category (str): + The name of category. + count (google.protobuf.wrappers_pb2.Int64Value): + The count of training samples matching the + category within the cluster. + """ + + category = proto.Field(proto.STRING, number=1,) + count = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.Int64Value, + ) + + category_counts = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue.CategoryCount", + ) + + feature_column = proto.Field(proto.STRING, number=1,) + numerical_value = proto.Field( + proto.MESSAGE, + number=2, + oneof="value", + message=wrappers_pb2.DoubleValue, + ) + categorical_value = proto.Field( + proto.MESSAGE, + number=3, + oneof="value", + message="Model.ClusteringMetrics.Cluster.FeatureValue.CategoricalValue", + ) + + centroid_id = proto.Field(proto.INT64, number=1,) + feature_values = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Model.ClusteringMetrics.Cluster.FeatureValue", + ) + count = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.Int64Value, + ) + + davies_bouldin_index = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, + ) + mean_squared_distance = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, + ) + clusters = proto.RepeatedField( + proto.MESSAGE, number=3, message="Model.ClusteringMetrics.Cluster", + ) + + class RankingMetrics(proto.Message): + r"""Evaluation metrics used by weighted-ALS models specified by + feedback_type=implicit. + + Attributes: + mean_average_precision (google.protobuf.wrappers_pb2.DoubleValue): + Calculates a precision per user for all the + items by ranking them and then averages all the + precisions across all the users. + mean_squared_error (google.protobuf.wrappers_pb2.DoubleValue): + Similar to the mean squared error computed in + regression and explicit recommendation models + except instead of computing the rating directly, + the output from evaluate is computed against a + preference which is 1 or 0 depending on if the + rating exists or not. + normalized_discounted_cumulative_gain (google.protobuf.wrappers_pb2.DoubleValue): + A metric to determine the goodness of a + ranking calculated from the predicted confidence + by comparing it to an ideal rank measured by the + original ratings. + average_rank (google.protobuf.wrappers_pb2.DoubleValue): + Determines the goodness of a ranking by + computing the percentile rank from the predicted + confidence and dividing it by the original rank. + """ + + mean_average_precision = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.DoubleValue, + ) + mean_squared_error = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, + ) + normalized_discounted_cumulative_gain = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.DoubleValue, + ) + average_rank = proto.Field( + proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, + ) + + class ArimaForecastingMetrics(proto.Message): + r"""Model evaluation metrics for ARIMA forecasting models. + Attributes: + non_seasonal_order (Sequence[google.cloud.bigquery_v2.types.Model.ArimaOrder]): + Non-seasonal order. + arima_fitting_metrics (Sequence[google.cloud.bigquery_v2.types.Model.ArimaFittingMetrics]): + Arima model fitting metrics. + seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): + Seasonal periods. Repeated because multiple + periods are supported for one time series. + has_drift (Sequence[bool]): + Whether Arima model fitted with drift or not. + It is always false when d is not 1. + time_series_id (Sequence[str]): + Id to differentiate different time series for + the large-scale case. + arima_single_model_forecasting_metrics (Sequence[google.cloud.bigquery_v2.types.Model.ArimaForecastingMetrics.ArimaSingleModelForecastingMetrics]): + Repeated as there can be many metric sets + (one for each model) in auto-arima and the + large-scale case. + """ + + class ArimaSingleModelForecastingMetrics(proto.Message): + r"""Model evaluation metrics for a single ARIMA forecasting + model. + + Attributes: + non_seasonal_order (google.cloud.bigquery_v2.types.Model.ArimaOrder): + Non-seasonal order. + arima_fitting_metrics (google.cloud.bigquery_v2.types.Model.ArimaFittingMetrics): + Arima fitting metrics. + has_drift (bool): + Is arima model fitted with drift or not. It + is always false when d is not 1. + time_series_id (str): + The time_series_id value for this time series. It will be + one of the unique values from the time_series_id_column + specified during ARIMA model training. Only present when + time_series_id_column training option was used. + time_series_ids (Sequence[str]): + The tuple of time_series_ids identifying this time series. + It will be one of the unique tuples of values present in the + time_series_id_columns specified during ARIMA model + training. Only present when time_series_id_columns training + option was used and the order of values here are same as the + order of time_series_id_columns. + seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): + Seasonal periods. Repeated because multiple + periods are supported for one time series. + has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue): + If true, holiday_effect is a part of time series + decomposition result. + has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue): + If true, spikes_and_dips is a part of time series + decomposition result. + has_step_changes (google.protobuf.wrappers_pb2.BoolValue): + If true, step_changes is a part of time series decomposition + result. + """ + + non_seasonal_order = proto.Field( + proto.MESSAGE, number=1, message="Model.ArimaOrder", + ) + arima_fitting_metrics = proto.Field( + proto.MESSAGE, number=2, message="Model.ArimaFittingMetrics", + ) + has_drift = proto.Field(proto.BOOL, number=3,) + time_series_id = proto.Field(proto.STRING, number=4,) + time_series_ids = proto.RepeatedField(proto.STRING, number=9,) + seasonal_periods = proto.RepeatedField( + proto.ENUM, number=5, enum="Model.SeasonalPeriod.SeasonalPeriodType", + ) + has_holiday_effect = proto.Field( + proto.MESSAGE, number=6, message=wrappers_pb2.BoolValue, + ) + has_spikes_and_dips = proto.Field( + proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue, + ) + has_step_changes = proto.Field( + proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue, + ) + + non_seasonal_order = proto.RepeatedField( + proto.MESSAGE, number=1, message="Model.ArimaOrder", + ) + arima_fitting_metrics = proto.RepeatedField( + proto.MESSAGE, number=2, message="Model.ArimaFittingMetrics", + ) + seasonal_periods = proto.RepeatedField( + proto.ENUM, number=3, enum="Model.SeasonalPeriod.SeasonalPeriodType", + ) + has_drift = proto.RepeatedField(proto.BOOL, number=4,) + time_series_id = proto.RepeatedField(proto.STRING, number=5,) + arima_single_model_forecasting_metrics = proto.RepeatedField( + proto.MESSAGE, + number=6, + message="Model.ArimaForecastingMetrics.ArimaSingleModelForecastingMetrics", + ) + + class EvaluationMetrics(proto.Message): + r"""Evaluation metrics of a model. These are either computed on + all training data or just the eval data based on whether eval + data was used during training. These are not present for + imported models. + + Attributes: + regression_metrics (google.cloud.bigquery_v2.types.Model.RegressionMetrics): + Populated for regression models and explicit + feedback type matrix factorization models. + binary_classification_metrics (google.cloud.bigquery_v2.types.Model.BinaryClassificationMetrics): + Populated for binary + classification/classifier models. + multi_class_classification_metrics (google.cloud.bigquery_v2.types.Model.MultiClassClassificationMetrics): + Populated for multi-class + classification/classifier models. + clustering_metrics (google.cloud.bigquery_v2.types.Model.ClusteringMetrics): + Populated for clustering models. + ranking_metrics (google.cloud.bigquery_v2.types.Model.RankingMetrics): + Populated for implicit feedback type matrix + factorization models. + arima_forecasting_metrics (google.cloud.bigquery_v2.types.Model.ArimaForecastingMetrics): + Populated for ARIMA models. + """ + + regression_metrics = proto.Field( + proto.MESSAGE, number=1, oneof="metrics", message="Model.RegressionMetrics", + ) + binary_classification_metrics = proto.Field( + proto.MESSAGE, + number=2, + oneof="metrics", + message="Model.BinaryClassificationMetrics", + ) + multi_class_classification_metrics = proto.Field( + proto.MESSAGE, + number=3, + oneof="metrics", + message="Model.MultiClassClassificationMetrics", + ) + clustering_metrics = proto.Field( + proto.MESSAGE, number=4, oneof="metrics", message="Model.ClusteringMetrics", + ) + ranking_metrics = proto.Field( + proto.MESSAGE, number=5, oneof="metrics", message="Model.RankingMetrics", + ) + arima_forecasting_metrics = proto.Field( + proto.MESSAGE, + number=6, + oneof="metrics", + message="Model.ArimaForecastingMetrics", + ) + + class DataSplitResult(proto.Message): + r"""Data split result. This contains references to the training + and evaluation data tables that were used to train the model. + + Attributes: + training_table (google.cloud.bigquery_v2.types.TableReference): + Table reference of the training data after + split. + evaluation_table (google.cloud.bigquery_v2.types.TableReference): + Table reference of the evaluation data after + split. + """ + + training_table = proto.Field( + proto.MESSAGE, number=1, message=table_reference.TableReference, + ) + evaluation_table = proto.Field( + proto.MESSAGE, number=2, message=table_reference.TableReference, + ) + + class ArimaOrder(proto.Message): + r"""Arima order, can be used for both non-seasonal and seasonal + parts. + + Attributes: + p (int): + Order of the autoregressive part. + d (int): + Order of the differencing part. + q (int): + Order of the moving-average part. + """ + + p = proto.Field(proto.INT64, number=1,) + d = proto.Field(proto.INT64, number=2,) + q = proto.Field(proto.INT64, number=3,) + + class ArimaFittingMetrics(proto.Message): + r"""ARIMA model fitting metrics. + Attributes: + log_likelihood (float): + Log-likelihood. + aic (float): + AIC. + variance (float): + Variance. + """ + + log_likelihood = proto.Field(proto.DOUBLE, number=1,) + aic = proto.Field(proto.DOUBLE, number=2,) + variance = proto.Field(proto.DOUBLE, number=3,) + + class GlobalExplanation(proto.Message): + r"""Global explanations containing the top most important + features after training. + + Attributes: + explanations (Sequence[google.cloud.bigquery_v2.types.Model.GlobalExplanation.Explanation]): + A list of the top global explanations. Sorted + by absolute value of attribution in descending + order. + class_label (str): + Class label for this set of global + explanations. Will be empty/null for binary + logistic and linear regression models. Sorted + alphabetically in descending order. + """ + + class Explanation(proto.Message): + r"""Explanation for a single feature. + Attributes: + feature_name (str): + Full name of the feature. For non-numerical features, will + be formatted like .. + Overall size of feature name will always be truncated to + first 120 characters. + attribution (google.protobuf.wrappers_pb2.DoubleValue): + Attribution of feature. + """ + + feature_name = proto.Field(proto.STRING, number=1,) + attribution = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, + ) + + explanations = proto.RepeatedField( + proto.MESSAGE, number=1, message="Model.GlobalExplanation.Explanation", + ) + class_label = proto.Field(proto.STRING, number=2,) + + class TrainingRun(proto.Message): + r"""Information about a single training query run for the model. + Attributes: + training_options (google.cloud.bigquery_v2.types.Model.TrainingRun.TrainingOptions): + Options that were used for this training run, + includes user specified and default options that + were used. + start_time (google.protobuf.timestamp_pb2.Timestamp): + The start time of this training run. + results (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult]): + Output of each iteration run, results.size() <= + max_iterations. + evaluation_metrics (google.cloud.bigquery_v2.types.Model.EvaluationMetrics): + The evaluation metrics over training/eval + data that were computed at the end of training. + data_split_result (google.cloud.bigquery_v2.types.Model.DataSplitResult): + Data split result of the training run. Only + set when the input data is actually split. + global_explanations (Sequence[google.cloud.bigquery_v2.types.Model.GlobalExplanation]): + Global explanations for important features of + the model. For multi-class models, there is one + entry for each label class. For other models, + there is only one entry in the list. + """ + + class TrainingOptions(proto.Message): + r"""Options used in model training. + Attributes: + max_iterations (int): + The maximum number of iterations in training. + Used only for iterative training algorithms. + loss_type (google.cloud.bigquery_v2.types.Model.LossType): + Type of loss function used during training + run. + learn_rate (float): + Learning rate in training. Used only for + iterative training algorithms. + l1_regularization (google.protobuf.wrappers_pb2.DoubleValue): + L1 regularization coefficient. + l2_regularization (google.protobuf.wrappers_pb2.DoubleValue): + L2 regularization coefficient. + min_relative_progress (google.protobuf.wrappers_pb2.DoubleValue): + When early_stop is true, stops training when accuracy + improvement is less than 'min_relative_progress'. Used only + for iterative training algorithms. + warm_start (google.protobuf.wrappers_pb2.BoolValue): + Whether to train a model from the last + checkpoint. + early_stop (google.protobuf.wrappers_pb2.BoolValue): + Whether to stop early when the loss doesn't improve + significantly any more (compared to min_relative_progress). + Used only for iterative training algorithms. + input_label_columns (Sequence[str]): + Name of input label columns in training data. + data_split_method (google.cloud.bigquery_v2.types.Model.DataSplitMethod): + The data split type for training and + evaluation, e.g. RANDOM. + data_split_eval_fraction (float): + The fraction of evaluation data over the + whole input data. The rest of data will be used + as training data. The format should be double. + Accurate to two decimal places. + Default value is 0.2. + data_split_column (str): + The column to split data with. This column won't be used as + a feature. + + 1. When data_split_method is CUSTOM, the corresponding + column should be boolean. The rows with true value tag + are eval data, and the false are training data. + 2. When data_split_method is SEQ, the first + DATA_SPLIT_EVAL_FRACTION rows (from smallest to largest) + in the corresponding column are used as training data, + and the rest are eval data. It respects the order in + Orderable data types: + https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data-type-properties + learn_rate_strategy (google.cloud.bigquery_v2.types.Model.LearnRateStrategy): + The strategy to determine learn rate for the + current iteration. + initial_learn_rate (float): + Specifies the initial learning rate for the + line search learn rate strategy. + label_class_weights (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.TrainingOptions.LabelClassWeightsEntry]): + Weights associated with each label class, for + rebalancing the training data. Only applicable + for classification models. + user_column (str): + User column specified for matrix + factorization models. + item_column (str): + Item column specified for matrix + factorization models. + distance_type (google.cloud.bigquery_v2.types.Model.DistanceType): + Distance type for clustering models. + num_clusters (int): + Number of clusters for clustering models. + model_uri (str): + Google Cloud Storage URI from which the model + was imported. Only applicable for imported + models. + optimization_strategy (google.cloud.bigquery_v2.types.Model.OptimizationStrategy): + Optimization strategy for training linear + regression models. + hidden_units (Sequence[int]): + Hidden units for dnn models. + batch_size (int): + Batch size for dnn models. + dropout (google.protobuf.wrappers_pb2.DoubleValue): + Dropout probability for dnn models. + max_tree_depth (int): + Maximum depth of a tree for boosted tree + models. + subsample (float): + Subsample fraction of the training data to + grow tree to prevent overfitting for boosted + tree models. + min_split_loss (google.protobuf.wrappers_pb2.DoubleValue): + Minimum split loss for boosted tree models. + num_factors (int): + Num factors specified for matrix + factorization models. + feedback_type (google.cloud.bigquery_v2.types.Model.FeedbackType): + Feedback type that specifies which algorithm + to run for matrix factorization. + wals_alpha (google.protobuf.wrappers_pb2.DoubleValue): + Hyperparameter for matrix factoration when + implicit feedback type is specified. + kmeans_initialization_method (google.cloud.bigquery_v2.types.Model.KmeansEnums.KmeansInitializationMethod): + The method used to initialize the centroids + for kmeans algorithm. + kmeans_initialization_column (str): + The column used to provide the initial centroids for kmeans + algorithm when kmeans_initialization_method is CUSTOM. + time_series_timestamp_column (str): + Column to be designated as time series + timestamp for ARIMA model. + time_series_data_column (str): + Column to be designated as time series data + for ARIMA model. + auto_arima (bool): + Whether to enable auto ARIMA or not. + non_seasonal_order (google.cloud.bigquery_v2.types.Model.ArimaOrder): + A specification of the non-seasonal part of + the ARIMA model: the three components (p, d, q) + are the AR order, the degree of differencing, + and the MA order. + data_frequency (google.cloud.bigquery_v2.types.Model.DataFrequency): + The data frequency of a time series. + include_drift (bool): + Include drift when fitting an ARIMA model. + holiday_region (google.cloud.bigquery_v2.types.Model.HolidayRegion): + The geographical region based on which the + holidays are considered in time series modeling. + If a valid value is specified, then holiday + effects modeling is enabled. + time_series_id_column (str): + The time series id column that was used + during ARIMA model training. + time_series_id_columns (Sequence[str]): + The time series id columns that were used + during ARIMA model training. + horizon (int): + The number of periods ahead that need to be + forecasted. + preserve_input_structs (bool): + Whether to preserve the input structs in output feature + names. Suppose there is a struct A with field b. When false + (default), the output feature name is A_b. When true, the + output feature name is A.b. + auto_arima_max_order (int): + The max value of non-seasonal p and q. + decompose_time_series (google.protobuf.wrappers_pb2.BoolValue): + If true, perform decompose time series and + save the results. + clean_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue): + If true, clean spikes and dips in the input + time series. + adjust_step_changes (google.protobuf.wrappers_pb2.BoolValue): + If true, detect step changes and make data + adjustment in the input time series. + """ + + max_iterations = proto.Field(proto.INT64, number=1,) + loss_type = proto.Field(proto.ENUM, number=2, enum="Model.LossType",) + learn_rate = proto.Field(proto.DOUBLE, number=3,) + l1_regularization = proto.Field( + proto.MESSAGE, number=4, message=wrappers_pb2.DoubleValue, + ) + l2_regularization = proto.Field( + proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, + ) + min_relative_progress = proto.Field( + proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, + ) + warm_start = proto.Field( + proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue, + ) + early_stop = proto.Field( + proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue, + ) + input_label_columns = proto.RepeatedField(proto.STRING, number=9,) + data_split_method = proto.Field( + proto.ENUM, number=10, enum="Model.DataSplitMethod", + ) + data_split_eval_fraction = proto.Field(proto.DOUBLE, number=11,) + data_split_column = proto.Field(proto.STRING, number=12,) + learn_rate_strategy = proto.Field( + proto.ENUM, number=13, enum="Model.LearnRateStrategy", + ) + initial_learn_rate = proto.Field(proto.DOUBLE, number=16,) + label_class_weights = proto.MapField(proto.STRING, proto.DOUBLE, number=17,) + user_column = proto.Field(proto.STRING, number=18,) + item_column = proto.Field(proto.STRING, number=19,) + distance_type = proto.Field( + proto.ENUM, number=20, enum="Model.DistanceType", + ) + num_clusters = proto.Field(proto.INT64, number=21,) + model_uri = proto.Field(proto.STRING, number=22,) + optimization_strategy = proto.Field( + proto.ENUM, number=23, enum="Model.OptimizationStrategy", + ) + hidden_units = proto.RepeatedField(proto.INT64, number=24,) + batch_size = proto.Field(proto.INT64, number=25,) + dropout = proto.Field( + proto.MESSAGE, number=26, message=wrappers_pb2.DoubleValue, + ) + max_tree_depth = proto.Field(proto.INT64, number=27,) + subsample = proto.Field(proto.DOUBLE, number=28,) + min_split_loss = proto.Field( + proto.MESSAGE, number=29, message=wrappers_pb2.DoubleValue, + ) + num_factors = proto.Field(proto.INT64, number=30,) + feedback_type = proto.Field( + proto.ENUM, number=31, enum="Model.FeedbackType", + ) + wals_alpha = proto.Field( + proto.MESSAGE, number=32, message=wrappers_pb2.DoubleValue, + ) + kmeans_initialization_method = proto.Field( + proto.ENUM, + number=33, + enum="Model.KmeansEnums.KmeansInitializationMethod", + ) + kmeans_initialization_column = proto.Field(proto.STRING, number=34,) + time_series_timestamp_column = proto.Field(proto.STRING, number=35,) + time_series_data_column = proto.Field(proto.STRING, number=36,) + auto_arima = proto.Field(proto.BOOL, number=37,) + non_seasonal_order = proto.Field( + proto.MESSAGE, number=38, message="Model.ArimaOrder", + ) + data_frequency = proto.Field( + proto.ENUM, number=39, enum="Model.DataFrequency", + ) + include_drift = proto.Field(proto.BOOL, number=41,) + holiday_region = proto.Field( + proto.ENUM, number=42, enum="Model.HolidayRegion", + ) + time_series_id_column = proto.Field(proto.STRING, number=43,) + time_series_id_columns = proto.RepeatedField(proto.STRING, number=51,) + horizon = proto.Field(proto.INT64, number=44,) + preserve_input_structs = proto.Field(proto.BOOL, number=45,) + auto_arima_max_order = proto.Field(proto.INT64, number=46,) + decompose_time_series = proto.Field( + proto.MESSAGE, number=50, message=wrappers_pb2.BoolValue, + ) + clean_spikes_and_dips = proto.Field( + proto.MESSAGE, number=52, message=wrappers_pb2.BoolValue, + ) + adjust_step_changes = proto.Field( + proto.MESSAGE, number=53, message=wrappers_pb2.BoolValue, + ) + + class IterationResult(proto.Message): + r"""Information about a single iteration of the training run. + Attributes: + index (google.protobuf.wrappers_pb2.Int32Value): + Index of the iteration, 0 based. + duration_ms (google.protobuf.wrappers_pb2.Int64Value): + Time taken to run the iteration in + milliseconds. + training_loss (google.protobuf.wrappers_pb2.DoubleValue): + Loss computed on the training data at the end + of iteration. + eval_loss (google.protobuf.wrappers_pb2.DoubleValue): + Loss computed on the eval data at the end of + iteration. + learn_rate (float): + Learn rate used for this iteration. + cluster_infos (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ClusterInfo]): + Information about top clusters for clustering + models. + arima_result (google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ArimaResult): + + """ + + class ClusterInfo(proto.Message): + r"""Information about a single cluster for clustering model. + Attributes: + centroid_id (int): + Centroid id. + cluster_radius (google.protobuf.wrappers_pb2.DoubleValue): + Cluster radius, the average distance from + centroid to each point assigned to the cluster. + cluster_size (google.protobuf.wrappers_pb2.Int64Value): + Cluster size, the total number of points + assigned to the cluster. + """ + + centroid_id = proto.Field(proto.INT64, number=1,) + cluster_radius = proto.Field( + proto.MESSAGE, number=2, message=wrappers_pb2.DoubleValue, + ) + cluster_size = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.Int64Value, + ) + + class ArimaResult(proto.Message): + r"""(Auto-)arima fitting result. Wrap everything in ArimaResult + for easier refactoring if we want to use model-specific + iteration results. + + Attributes: + arima_model_info (Sequence[google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ArimaResult.ArimaModelInfo]): + This message is repeated because there are + multiple arima models fitted in auto-arima. For + non-auto-arima model, its size is one. + seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): + Seasonal periods. Repeated because multiple + periods are supported for one time series. + """ + + class ArimaCoefficients(proto.Message): + r"""Arima coefficients. + Attributes: + auto_regressive_coefficients (Sequence[float]): + Auto-regressive coefficients, an array of + double. + moving_average_coefficients (Sequence[float]): + Moving-average coefficients, an array of + double. + intercept_coefficient (float): + Intercept coefficient, just a double not an + array. + """ + + auto_regressive_coefficients = proto.RepeatedField( + proto.DOUBLE, number=1, + ) + moving_average_coefficients = proto.RepeatedField( + proto.DOUBLE, number=2, + ) + intercept_coefficient = proto.Field(proto.DOUBLE, number=3,) + + class ArimaModelInfo(proto.Message): + r"""Arima model information. + Attributes: + non_seasonal_order (google.cloud.bigquery_v2.types.Model.ArimaOrder): + Non-seasonal order. + arima_coefficients (google.cloud.bigquery_v2.types.Model.TrainingRun.IterationResult.ArimaResult.ArimaCoefficients): + Arima coefficients. + arima_fitting_metrics (google.cloud.bigquery_v2.types.Model.ArimaFittingMetrics): + Arima fitting metrics. + has_drift (bool): + Whether Arima model fitted with drift or not. + It is always false when d is not 1. + time_series_id (str): + The time_series_id value for this time series. It will be + one of the unique values from the time_series_id_column + specified during ARIMA model training. Only present when + time_series_id_column training option was used. + time_series_ids (Sequence[str]): + The tuple of time_series_ids identifying this time series. + It will be one of the unique tuples of values present in the + time_series_id_columns specified during ARIMA model + training. Only present when time_series_id_columns training + option was used and the order of values here are same as the + order of time_series_id_columns. + seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]): + Seasonal periods. Repeated because multiple + periods are supported for one time series. + has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue): + If true, holiday_effect is a part of time series + decomposition result. + has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue): + If true, spikes_and_dips is a part of time series + decomposition result. + has_step_changes (google.protobuf.wrappers_pb2.BoolValue): + If true, step_changes is a part of time series decomposition + result. + """ + + non_seasonal_order = proto.Field( + proto.MESSAGE, number=1, message="Model.ArimaOrder", + ) + arima_coefficients = proto.Field( + proto.MESSAGE, + number=2, + message="Model.TrainingRun.IterationResult.ArimaResult.ArimaCoefficients", + ) + arima_fitting_metrics = proto.Field( + proto.MESSAGE, number=3, message="Model.ArimaFittingMetrics", + ) + has_drift = proto.Field(proto.BOOL, number=4,) + time_series_id = proto.Field(proto.STRING, number=5,) + time_series_ids = proto.RepeatedField(proto.STRING, number=10,) + seasonal_periods = proto.RepeatedField( + proto.ENUM, + number=6, + enum="Model.SeasonalPeriod.SeasonalPeriodType", + ) + has_holiday_effect = proto.Field( + proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue, + ) + has_spikes_and_dips = proto.Field( + proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue, + ) + has_step_changes = proto.Field( + proto.MESSAGE, number=9, message=wrappers_pb2.BoolValue, + ) + + arima_model_info = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="Model.TrainingRun.IterationResult.ArimaResult.ArimaModelInfo", + ) + seasonal_periods = proto.RepeatedField( + proto.ENUM, + number=2, + enum="Model.SeasonalPeriod.SeasonalPeriodType", + ) + + index = proto.Field( + proto.MESSAGE, number=1, message=wrappers_pb2.Int32Value, + ) + duration_ms = proto.Field( + proto.MESSAGE, number=4, message=wrappers_pb2.Int64Value, + ) + training_loss = proto.Field( + proto.MESSAGE, number=5, message=wrappers_pb2.DoubleValue, + ) + eval_loss = proto.Field( + proto.MESSAGE, number=6, message=wrappers_pb2.DoubleValue, + ) + learn_rate = proto.Field(proto.DOUBLE, number=7,) + cluster_infos = proto.RepeatedField( + proto.MESSAGE, + number=8, + message="Model.TrainingRun.IterationResult.ClusterInfo", + ) + arima_result = proto.Field( + proto.MESSAGE, + number=9, + message="Model.TrainingRun.IterationResult.ArimaResult", + ) + + training_options = proto.Field( + proto.MESSAGE, number=1, message="Model.TrainingRun.TrainingOptions", + ) + start_time = proto.Field( + proto.MESSAGE, number=8, message=timestamp_pb2.Timestamp, + ) + results = proto.RepeatedField( + proto.MESSAGE, number=6, message="Model.TrainingRun.IterationResult", + ) + evaluation_metrics = proto.Field( + proto.MESSAGE, number=7, message="Model.EvaluationMetrics", + ) + data_split_result = proto.Field( + proto.MESSAGE, number=9, message="Model.DataSplitResult", + ) + global_explanations = proto.RepeatedField( + proto.MESSAGE, number=10, message="Model.GlobalExplanation", + ) + + etag = proto.Field(proto.STRING, number=1,) + model_reference = proto.Field( + proto.MESSAGE, number=2, message=gcb_model_reference.ModelReference, + ) + creation_time = proto.Field(proto.INT64, number=5,) + last_modified_time = proto.Field(proto.INT64, number=6,) + description = proto.Field(proto.STRING, number=12,) + friendly_name = proto.Field(proto.STRING, number=14,) + labels = proto.MapField(proto.STRING, proto.STRING, number=15,) + expiration_time = proto.Field(proto.INT64, number=16,) + location = proto.Field(proto.STRING, number=13,) + encryption_configuration = proto.Field( + proto.MESSAGE, number=17, message=encryption_config.EncryptionConfiguration, + ) + model_type = proto.Field(proto.ENUM, number=7, enum=ModelType,) + training_runs = proto.RepeatedField(proto.MESSAGE, number=9, message=TrainingRun,) + feature_columns = proto.RepeatedField( + proto.MESSAGE, number=10, message=standard_sql.StandardSqlField, + ) + label_columns = proto.RepeatedField( + proto.MESSAGE, number=11, message=standard_sql.StandardSqlField, + ) + best_trial_id = proto.Field(proto.INT64, number=19,) + + +class GetModelRequest(proto.Message): + r""" + Attributes: + project_id (str): + Required. Project ID of the requested model. + dataset_id (str): + Required. Dataset ID of the requested model. + model_id (str): + Required. Model ID of the requested model. + """ + + project_id = proto.Field(proto.STRING, number=1,) + dataset_id = proto.Field(proto.STRING, number=2,) + model_id = proto.Field(proto.STRING, number=3,) + + +class PatchModelRequest(proto.Message): + r""" + Attributes: + project_id (str): + Required. Project ID of the model to patch. + dataset_id (str): + Required. Dataset ID of the model to patch. + model_id (str): + Required. Model ID of the model to patch. + model (google.cloud.bigquery_v2.types.Model): + Required. Patched model. + Follows RFC5789 patch semantics. Missing fields + are not updated. To clear a field, explicitly + set to default value. + """ + + project_id = proto.Field(proto.STRING, number=1,) + dataset_id = proto.Field(proto.STRING, number=2,) + model_id = proto.Field(proto.STRING, number=3,) + model = proto.Field(proto.MESSAGE, number=4, message="Model",) + + +class DeleteModelRequest(proto.Message): + r""" + Attributes: + project_id (str): + Required. Project ID of the model to delete. + dataset_id (str): + Required. Dataset ID of the model to delete. + model_id (str): + Required. Model ID of the model to delete. + """ + + project_id = proto.Field(proto.STRING, number=1,) + dataset_id = proto.Field(proto.STRING, number=2,) + model_id = proto.Field(proto.STRING, number=3,) + + +class ListModelsRequest(proto.Message): + r""" + Attributes: + project_id (str): + Required. Project ID of the models to list. + dataset_id (str): + Required. Dataset ID of the models to list. + max_results (google.protobuf.wrappers_pb2.UInt32Value): + The maximum number of results to return in a + single response page. Leverage the page tokens + to iterate through the entire collection. + page_token (str): + Page token, returned by a previous call to + request the next page of results + """ + + project_id = proto.Field(proto.STRING, number=1,) + dataset_id = proto.Field(proto.STRING, number=2,) + max_results = proto.Field( + proto.MESSAGE, number=3, message=wrappers_pb2.UInt32Value, + ) + page_token = proto.Field(proto.STRING, number=4,) + + +class ListModelsResponse(proto.Message): + r""" + Attributes: + models (Sequence[google.cloud.bigquery_v2.types.Model]): + Models in the requested dataset. Only the following fields + are populated: model_reference, model_type, creation_time, + last_modified_time and labels. + next_page_token (str): + A token to request the next page of results. + """ + + @property + def raw_page(self): + return self + + models = proto.RepeatedField(proto.MESSAGE, number=1, message="Model",) + next_page_token = proto.Field(proto.STRING, number=2,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/bigquery_v2/types/table_reference.py b/google/cloud/bigquery_v2/types/table_reference.py new file mode 100644 index 000000000..d56e5b09f --- /dev/null +++ b/google/cloud/bigquery_v2/types/table_reference.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.bigquery.v2", manifest={"TableReference",}, +) + + +class TableReference(proto.Message): + r""" + Attributes: + project_id (str): + Required. The ID of the project containing + this table. + dataset_id (str): + Required. The ID of the dataset containing + this table. + table_id (str): + Required. The ID of the table. The ID must contain only + letters (a-z, A-Z), numbers (0-9), or underscores (_). The + maximum length is 1,024 characters. Certain operations allow + suffixing of the table ID with a partition decorator, such + as ``sample_table$20190123``. + project_id_alternative (Sequence[str]): + The alternative field that will be used when ESF is not able + to translate the received data to the project_id field. + dataset_id_alternative (Sequence[str]): + The alternative field that will be used when ESF is not able + to translate the received data to the project_id field. + table_id_alternative (Sequence[str]): + The alternative field that will be used when ESF is not able + to translate the received data to the project_id field. + """ + + project_id = proto.Field(proto.STRING, number=1,) + dataset_id = proto.Field(proto.STRING, number=2,) + table_id = proto.Field(proto.STRING, number=3,) + project_id_alternative = proto.RepeatedField(proto.STRING, number=4,) + dataset_id_alternative = proto.RepeatedField(proto.STRING, number=5,) + table_id_alternative = proto.RepeatedField(proto.STRING, number=6,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) From 7bb120061ae0cc83680335c57d571b461c902a82 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 15 Oct 2021 11:24:57 -0500 Subject: [PATCH 17/24] add unit tests --- google/cloud/bigquery/_job_helpers.py | 9 +- tests/unit/test__job_helpers.py | 177 +++++++++++++++++++++++++- 2 files changed, 177 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index c0ef8d6f5..38cad247c 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -28,7 +28,7 @@ from google.cloud.bigquery.client import Client -_TIMEOUT_BUFFER_SECS = 0.1 +_TIMEOUT_BUFFER_MILLIS = 100 def make_job_id(job_id: Optional[str] = None, prefix: Optional[str] = None) -> str: @@ -136,6 +136,11 @@ def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any # Default to standard SQL. request_body.setdefault("useLegacySql", False) + # Since jobs.query can return results, ensure we use the lossless timestamp + # format. See: https://github.com/googleapis/python-bigquery/issues/395 + request_body.setdefault("formatOptions", {}) + request_body["formatOptions"]["useInt64Timestamp"] = True + return request_body @@ -215,7 +220,7 @@ def query_jobs_query( if timeout is not None: # Subtract a buffer for context switching, network latency, etc. - request_body["timeoutMs"] = max(0, int(1000 * (timeout - _TIMEOUT_BUFFER_SECS))) + request_body["timeoutMs"] = max(0, int(1000 * timeout) - _TIMEOUT_BUFFER_MILLIS) request_body["location"] = location request_body["query"] = query diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py index a3211433f..5fd06fe0e 100644 --- a/tests/unit/test__job_helpers.py +++ b/tests/unit/test__job_helpers.py @@ -12,14 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict, Optional from unittest import mock +from google.api_core import retry as retries import pytest from google.cloud.bigquery.client import Client from google.cloud.bigquery import _job_helpers from google.cloud.bigquery.job.query import QueryJob, QueryJobConfig +from google.cloud.bigquery.query import ScalarQueryParameter + + +def make_query_request(additional_properties: Optional[Dict[str, Any]] = None): + request = {"useLegacySql": False, "formatOptions": {"useInt64Timestamp": True}} + if additional_properties is not None: + request.update(additional_properties) + return request def make_query_response( @@ -45,16 +54,84 @@ def make_query_response( @pytest.mark.parametrize( ("job_config", "expected"), ( - (None, {"useLegacySql": False}), - (QueryJobConfig(), {"useLegacySql": False}), - (QueryJobConfig(dry_run=True), {"useLegacySql": False, "dryRun": True}), + (None, make_query_request()), + (QueryJobConfig(), make_query_request()), ( - QueryJobConfig(labels={"abc": "def"}), - {"useLegacySql": False, "labels": {"abc": "def"}}, + QueryJobConfig(default_dataset="my-project.my_dataset"), + make_query_request( + { + "defaultDataset": { + "projectId": "my-project", + "datasetId": "my_dataset", + } + } + ), ), + (QueryJobConfig(dry_run=True), make_query_request({"dryRun": True})), ( QueryJobConfig(use_query_cache=False), - {"useLegacySql": False, "useQueryCache": False}, + make_query_request({"useQueryCache": False}), + ), + ( + QueryJobConfig(use_legacy_sql=True), + make_query_request({"useLegacySql": True}), + ), + ( + QueryJobConfig( + query_parameters=[ + ScalarQueryParameter("named_param1", "STRING", "param-value"), + ScalarQueryParameter("named_param2", "INT64", 123), + ] + ), + make_query_request( + { + "parameterMode": "NAMED", + "queryParameters": [ + { + "name": "named_param1", + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "param-value"}, + }, + { + "name": "named_param2", + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "123"}, + }, + ], + } + ), + ), + ( + QueryJobConfig( + query_parameters=[ + ScalarQueryParameter(None, "STRING", "param-value"), + ScalarQueryParameter(None, "INT64", 123), + ] + ), + make_query_request( + { + "parameterMode": "POSITIONAL", + "queryParameters": [ + { + "parameterType": {"type": "STRING"}, + "parameterValue": {"value": "param-value"}, + }, + { + "parameterType": {"type": "INT64"}, + "parameterValue": {"value": "123"}, + }, + ], + } + ), + ), + # TODO: connection properties + ( + QueryJobConfig(labels={"abc": "def"}), + make_query_request({"labels": {"abc": "def"}}), + ), + ( + QueryJobConfig(maximum_bytes_billed=987654), + make_query_request({"maximumBytesBilled": "987654"}), ), ), ) @@ -112,3 +189,89 @@ def test__to_query_job_sets_errors(): ) job: QueryJob = _job_helpers._to_query_job(mock_client, "query-str", None, response) assert len(job.errors) == 2 + # If we got back a response instead of an HTTP error status code, most + # likely the job didn't completely fail. + assert job.error_result is None + + +def test_query_jobs_query_defaults(): + mock_client = mock.create_autospec(Client) + mock_retry = mock.create_autospec(retries.Retry) + mock_job_retry = mock.create_autospec(retries.Retry) + _job_helpers.query_jobs_query( + mock_client, + "SELECT * FROM test", + None, + "asia-northeast1", + "test-project", + mock_retry, + None, + mock_job_retry, + ) + + assert mock_client._call_api.call_count == 1 + call_args, call_kwargs = mock_client._call_api.call_args + assert call_args[0] is mock_retry + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query + assert call_kwargs["path"] == "/projects/test-project/queries" + assert call_kwargs["method"] == "POST" + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#QueryRequest + request = call_kwargs["data"] + assert request["requestId"] is not None + assert request["query"] == "SELECT * FROM test" + assert request["location"] == "asia-northeast1" + assert request["formatOptions"]["useInt64Timestamp"] is True + assert "timeoutMs" not in request + + +def test_query_jobs_query_sets_format_options(): + """Since jobs.query can return results, ensure we use the lossless + timestamp format. + + See: https://github.com/googleapis/python-bigquery/issues/395 + """ + mock_client = mock.create_autospec(Client) + mock_retry = mock.create_autospec(retries.Retry) + mock_job_retry = mock.create_autospec(retries.Retry) + _job_helpers.query_jobs_query( + mock_client, + "SELECT * FROM test", + None, + "US", + "test-project", + mock_retry, + None, + mock_job_retry, + ) + + assert mock_client._call_api.call_count == 1 + _, call_kwargs = mock_client._call_api.call_args + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#QueryRequest + request = call_kwargs["data"] + assert request["formatOptions"]["useInt64Timestamp"] is True + + +@pytest.mark.parametrize( + ("timeout", "expected_timeout"), + ((-1, 0), (0, 0), (1, 1000 - _job_helpers._TIMEOUT_BUFFER_MILLIS),), +) +def test_query_jobs_query_sets_timeout(timeout, expected_timeout): + mock_client = mock.create_autospec(Client) + mock_retry = mock.create_autospec(retries.Retry) + mock_job_retry = mock.create_autospec(retries.Retry) + _job_helpers.query_jobs_query( + mock_client, + "SELECT * FROM test", + None, + "US", + "test-project", + mock_retry, + timeout, + mock_job_retry, + ) + + assert mock_client._call_api.call_count == 1 + _, call_kwargs = mock_client._call_api.call_args + # See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#QueryRequest + request = call_kwargs["data"] + assert request["timeoutMs"] == expected_timeout From 0598aceec83c03e5218d03137c1fecfeec9c4bfa Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Nov 2021 09:51:08 -0600 Subject: [PATCH 18/24] adjust query job construction --- google/cloud/bigquery/_job_helpers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index e13becde7..b0b0cd32a 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -152,14 +152,16 @@ def _to_query_job( job_ref_resource = query_response["jobReference"] job_ref = job._JobReference._from_api_repr(job_ref_resource) query_job = job.QueryJob(job_ref, query, client=client) + query_job._properties.setdefault("configuration", {}) - # Not all relevant properties are in the jobs.query response, so + # Not all relevant properties are in the jobs.query response. Populate some + # expected properties based on the job configuration. if request_config is not None: query_job._properties["configuration"].update(request_config.to_api_repr()) - query_job._properties["configuration"]["query"]["query"] = query - query_job._properties["configuration"]["query"].setdefault( - "useLegacySql", False - ) + + query_job._properties["configuration"].setdefault("query", {}) + query_job._properties["configuration"]["query"]["query"] = query + query_job._properties["configuration"]["query"].setdefault("useLegacySql", False) query_job._properties.setdefault("statistics", {}) query_job._properties["statistics"].setdefault("query", {}) From 4f36baef4495e4651594f38514c53726e5ae9c38 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Nov 2021 10:03:59 -0600 Subject: [PATCH 19/24] avoid conflicting table IDs --- tests/system/conftest.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index e3c67b537..784a1dd5c 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. import pathlib +import random import re from typing import Tuple @@ -97,10 +98,11 @@ def load_scalars_table( data_path: str = "scalars.jsonl", ) -> str: schema = bigquery_client.schema_from_json(DATA_DIR / "scalars_schema.json") + table_id = data_path.replace(".", "_") + hex(random.randrange(1000000)) job_config = bigquery.LoadJobConfig() job_config.schema = schema job_config.source_format = enums.SourceFormat.NEWLINE_DELIMITED_JSON - full_table_id = f"{project_id}.{dataset_id}.scalars" + full_table_id = f"{project_id}.{dataset_id}.{table_id}" with open(DATA_DIR / data_path, "rb") as data_file: job = bigquery_client.load_table_from_file( data_file, full_table_id, job_config=job_config @@ -113,7 +115,7 @@ def load_scalars_table( def scalars_table(bigquery_client: bigquery.Client, project_id: str, dataset_id: str): full_table_id = load_scalars_table(bigquery_client, project_id, dataset_id) yield full_table_id - bigquery_client.delete_table(full_table_id) + bigquery_client.delete_table(full_table_id, not_found_ok=True) @pytest.fixture(scope="session") @@ -122,7 +124,7 @@ def scalars_table_tokyo( ): full_table_id = load_scalars_table(bigquery_client, project_id, dataset_id_tokyo) yield full_table_id - bigquery_client.delete_table(full_table_id) + bigquery_client.delete_table(full_table_id, not_found_ok=True) @pytest.fixture(scope="session") @@ -133,7 +135,7 @@ def scalars_extreme_table( bigquery_client, project_id, dataset_id, data_path="scalars_extreme.jsonl" ) yield full_table_id - bigquery_client.delete_table(full_table_id) + bigquery_client.delete_table(full_table_id, not_found_ok=True) @pytest.fixture(scope="session", params=["US", TOKYO_LOCATION]) From 3058498356f7188721cfdc685cbc82dcce6a600b Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Nov 2021 10:23:47 -0600 Subject: [PATCH 20/24] mock query response --- tests/unit/test__job_helpers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py index 6d4d6f73d..de70b1a96 100644 --- a/tests/unit/test__job_helpers.py +++ b/tests/unit/test__job_helpers.py @@ -198,6 +198,13 @@ def test_query_jobs_query_defaults(): mock_client = mock.create_autospec(Client) mock_retry = mock.create_autospec(retries.Retry) mock_job_retry = mock.create_autospec(retries.Retry) + mock_client._call_api.return_value = { + "jobReference": { + "projectId": "test-project", + "jobId": "abc", + "location": "asia-northeast1", + } + } _job_helpers.query_jobs_query( mock_client, "SELECT * FROM test", @@ -233,6 +240,9 @@ def test_query_jobs_query_sets_format_options(): mock_client = mock.create_autospec(Client) mock_retry = mock.create_autospec(retries.Retry) mock_job_retry = mock.create_autospec(retries.Retry) + mock_client._call_api.return_value = { + "jobReference": {"projectId": "test-project", "jobId": "abc", "location": "US"} + } _job_helpers.query_jobs_query( mock_client, "SELECT * FROM test", @@ -259,6 +269,9 @@ def test_query_jobs_query_sets_timeout(timeout, expected_timeout): mock_client = mock.create_autospec(Client) mock_retry = mock.create_autospec(retries.Retry) mock_job_retry = mock.create_autospec(retries.Retry) + mock_client._call_api.return_value = { + "jobReference": {"projectId": "test-project", "jobId": "abc", "location": "US"} + } _job_helpers.query_jobs_query( mock_client, "SELECT * FROM test", From ba785d94eb1e13d31f3b4da92e347f6822723209 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Nov 2021 14:12:38 -0600 Subject: [PATCH 21/24] fix unit test coverage --- tests/unit/test_client.py | 154 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 9be154288..8ebf5137e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -4016,6 +4016,160 @@ def test_query_defaults(self): self.assertEqual(sent_config["query"], QUERY) self.assertFalse(sent_config["useLegacySql"]) + def test_query_w_api_method_query(self): + query = "select count(*) from persons" + response = { + "jobReference": { + "projectId": self.PROJECT, + "location": "EU", + "jobId": "abcd", + }, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(response) + + job = client.query(query, location="EU", api_method="QUERY") + + self.assertEqual(job.query, query) + self.assertEqual(job.job_id, "abcd") + self.assertEqual(job.location, "EU") + + # Check that query actually starts the job. + expected_resource = { + "query": query, + "useLegacySql": False, + "location": "EU", + "formatOptions": {"useInt64Timestamp": True}, + "requestId": mock.ANY, + } + conn.api_request.assert_called_once_with( + method="POST", + path=f"/projects/{self.PROJECT}/queries", + data=expected_resource, + timeout=None, + ) + + def test_query_w_api_method_query_legacy_sql(self): + from google.cloud.bigquery import QueryJobConfig + + query = "select count(*) from persons" + response = { + "jobReference": { + "projectId": self.PROJECT, + "location": "EU", + "jobId": "abcd", + }, + } + job_config = QueryJobConfig() + job_config.use_legacy_sql = True + job_config.maximum_bytes_billed = 100 + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(response) + + job = client.query( + query, location="EU", job_config=job_config, api_method="QUERY" + ) + + self.assertEqual(job.query, query) + self.assertEqual(job.job_id, "abcd") + self.assertEqual(job.location, "EU") + + # Check that query actually starts the job. + expected_resource = { + "query": query, + "useLegacySql": True, + "location": "EU", + "formatOptions": {"useInt64Timestamp": True}, + "requestId": mock.ANY, + "maximumBytesBilled": "100", + } + conn.api_request.assert_called_once_with( + method="POST", + path=f"/projects/{self.PROJECT}/queries", + data=expected_resource, + timeout=None, + ) + + def test_query_w_api_method_query_parameters(self): + from google.cloud.bigquery import QueryJobConfig, ScalarQueryParameter + + query = "select count(*) from persons" + response = { + "jobReference": { + "projectId": self.PROJECT, + "location": "EU", + "jobId": "abcd", + }, + } + job_config = QueryJobConfig() + job_config.dry_run = True + job_config.query_parameters = [ScalarQueryParameter("param1", "INTEGER", 123)] + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(response) + + job = client.query( + query, location="EU", job_config=job_config, api_method="QUERY" + ) + + self.assertEqual(job.query, query) + self.assertEqual(job.job_id, "abcd") + self.assertEqual(job.location, "EU") + + # Check that query actually starts the job. + expected_resource = { + "query": query, + "dryRun": True, + "useLegacySql": False, + "location": "EU", + "formatOptions": {"useInt64Timestamp": True}, + "requestId": mock.ANY, + "parameterMode": "NAMED", + "queryParameters": [ + { + "name": "param1", + "parameterType": {"type": "INTEGER"}, + "parameterValue": {"value": "123"}, + }, + ], + } + conn.api_request.assert_called_once_with( + method="POST", + path=f"/projects/{self.PROJECT}/queries", + data=expected_resource, + timeout=None, + ) + + def test_query_w_api_method_query_and_job_id_fails(self): + query = "select count(*) from persons" + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client._connection = make_connection({}) + + with self.assertRaises(TypeError) as exc: + client.query(query, job_id="abcd", api_method="QUERY") + self.assertIn( + "`job_id` was provided, but the 'QUERY' `api_method` was requested", + exc.exception.args[0], + ) + + def test_query_w_api_method_unknown(self): + query = "select count(*) from persons" + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client._connection = make_connection({}) + + with self.assertRaises(ValueError) as exc: + client.query(query, api_method="UNKNOWN") + self.assertIn("Got unexpected value for api_method: ", exc.exception.args[0]) + def test_query_w_explicit_timeout(self): query = "select count(*) from persons" resource = { From b67ac5a9473b0fbb389286ac9c93e387d1f8ecfb Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Nov 2021 14:27:08 -0600 Subject: [PATCH 22/24] fix type errors --- google/cloud/bigquery/_job_helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index b0b0cd32a..82e7022a8 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -52,9 +52,9 @@ def make_job_id(job_id: Optional[str] = None, prefix: Optional[str] = None) -> s def query_jobs_insert( client: "Client", query: str, - job_config: job.QueryJobConfig, - job_id: str, - job_id_prefix: str, + job_config: Optional[job.QueryJobConfig], + job_id: Optional[str], + job_id_prefix: Optional[str], location: str, project: str, retry: retries.Retry, @@ -138,7 +138,7 @@ def _to_query_request(job_config: Optional[job.QueryJobConfig]) -> Dict[str, Any # Since jobs.query can return results, ensure we use the lossless timestamp # format. See: https://github.com/googleapis/python-bigquery/issues/395 request_body.setdefault("formatOptions", {}) - request_body["formatOptions"]["useInt64Timestamp"] = True + request_body["formatOptions"]["useInt64Timestamp"] = True # type: ignore return request_body From a3223b123a614aa8428ae9d09b9aa344829b6dd2 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Nov 2021 15:25:18 -0600 Subject: [PATCH 23/24] fix docs formatting --- google/cloud/bigquery/client.py | 4 ++-- google/cloud/bigquery/enums.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index b33de51fc..76ccafaf4 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -3167,7 +3167,7 @@ def query( retry: retries.Retry = DEFAULT_RETRY, timeout: TimeoutType = DEFAULT_TIMEOUT, job_retry: retries.Retry = DEFAULT_JOB_RETRY, - api_method: enums.QueryApiMethod = enums.QueryApiMethod.INSERT, + api_method: Union[str, enums.QueryApiMethod] = enums.QueryApiMethod.INSERT, ) -> job.QueryJob: """Run a SQL query. @@ -3219,7 +3219,7 @@ def query( called on the job returned. The ``job_retry`` specified here becomes the default ``job_retry`` for ``result()``, where it can also be specified. - api_method: + api_method (Union[str, enums.QueryApiMethod]): Method with which to start the query job. See :class:`google.cloud.bigquery.enums.QueryApiMethod` for diff --git a/google/cloud/bigquery/enums.py b/google/cloud/bigquery/enums.py index d399e4a26..c4a43126a 100644 --- a/google/cloud/bigquery/enums.py +++ b/google/cloud/bigquery/enums.py @@ -129,7 +129,7 @@ class QueryApiMethod(str, enum.Enum): INSERT = "INSERT" """Submit a query job by using the `jobs.insert REST API method - _`. + `_. This supports all job configuration options. """ From b2576090cc967217ba907703063f60408ea7c6a3 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 22 Nov 2021 10:52:34 -0600 Subject: [PATCH 24/24] comments and additional unit tests --- google/cloud/bigquery/_job_helpers.py | 19 ++++++++++++++----- tests/system/test_query.py | 11 ++++++++--- tests/unit/test__job_helpers.py | 19 +++++++++++++++++-- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/google/cloud/bigquery/_job_helpers.py b/google/cloud/bigquery/_job_helpers.py index 82e7022a8..33fc72261 100644 --- a/google/cloud/bigquery/_job_helpers.py +++ b/google/cloud/bigquery/_job_helpers.py @@ -28,15 +28,24 @@ from google.cloud.bigquery.client import Client -_TIMEOUT_BUFFER_MILLIS = 100 +# The purpose of _TIMEOUT_BUFFER_MILLIS is to allow the server-side timeout to +# happen before the client-side timeout. This is not strictly neccessary, as the +# client retries client-side timeouts, but the hope by making the server-side +# timeout slightly shorter is that it can save the server from some unncessary +# processing time. +# +# 250 milliseconds is chosen arbitrarily, though should be about the right +# order of magnitude for network latency and switching delays. It is about the +# amount of time for light to circumnavigate the world twice. +_TIMEOUT_BUFFER_MILLIS = 250 def make_job_id(job_id: Optional[str] = None, prefix: Optional[str] = None) -> str: """Construct an ID for a new job. Args: - job_id (Optional[str]): the user-provided job ID. - prefix (Optional[str]): the user-provided prefix for a job ID. + job_id: the user-provided job ID. + prefix: the user-provided prefix for a job ID. Returns: str: A job ID @@ -60,7 +69,7 @@ def query_jobs_insert( retry: retries.Retry, timeout: Optional[float], job_retry: retries.Retry, -): +) -> job.QueryJob: """Initiate a query using jobs.insert. See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert @@ -211,7 +220,7 @@ def query_jobs_query( retry: retries.Retry, timeout: Optional[float], job_retry: retries.Retry, -): +) -> job.QueryJob: """Initiate a query using jobs.query. See: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 7f3a0b676..f76b1e6ca 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -43,7 +43,7 @@ def table_with_9999_columns_10_rows(bigquery_client, project_id, dataset_id): """ table_id = "many_columns" row_count = 10 - col_projections = ",".join([f"r * {n} as col_{n}" for n in range(1, 10000)]) + col_projections = ",".join(f"r * {n} as col_{n}" for n in range(1, 10000)) sql = f""" CREATE TABLE `{project_id}.{dataset_id}.{table_id}` AS @@ -453,8 +453,13 @@ def test_query_error_w_api_method_query(bigquery_client: bigquery.Client): ) -def test_query_error_w_api_method_insert(bigquery_client: bigquery.Client): - """With jobs.insert, an exception is thrown when fetching the results..""" +def test_query_error_w_api_method_default(bigquery_client: bigquery.Client): + """Test that an exception is not thrown until fetching the results. + + For backwards compatibility, jobs.insert is the default API method. With + jobs.insert, a failed query job is "sucessfully" created. An exception is + thrown when fetching the results. + """ query_job = bigquery_client.query("SELECT * FROM not_a_real_dataset.doesnt_exist") diff --git a/tests/unit/test__job_helpers.py b/tests/unit/test__job_helpers.py index de70b1a96..63dde75e7 100644 --- a/tests/unit/test__job_helpers.py +++ b/tests/unit/test__job_helpers.py @@ -21,7 +21,7 @@ from google.cloud.bigquery.client import Client from google.cloud.bigquery import _job_helpers from google.cloud.bigquery.job.query import QueryJob, QueryJobConfig -from google.cloud.bigquery.query import ScalarQueryParameter +from google.cloud.bigquery.query import ConnectionProperty, ScalarQueryParameter def make_query_request(additional_properties: Optional[Dict[str, Any]] = None): @@ -124,7 +124,22 @@ def make_query_response( } ), ), - # TODO: connection properties + ( + QueryJobConfig( + connection_properties=[ + ConnectionProperty(key="time_zone", value="America/Chicago"), + ConnectionProperty(key="session_id", value="abcd-efgh-ijkl-mnop"), + ] + ), + make_query_request( + { + "connectionProperties": [ + {"key": "time_zone", "value": "America/Chicago"}, + {"key": "session_id", "value": "abcd-efgh-ijkl-mnop"}, + ] + } + ), + ), ( QueryJobConfig(labels={"abc": "def"}), make_query_request({"labels": {"abc": "def"}}),