Skip to content

Commit

Permalink
feat: Update proto definitions for bigquery/v2 to support new proto f…
Browse files Browse the repository at this point in the history
…ields for BQML. (#817)

PiperOrigin-RevId: 387137741

Source-Link: googleapis/googleapis@8962c92

Source-Link: googleapis/googleapis-gen@102f1b4
  • Loading branch information
gcf-owl-bot[bot] authored Jul 27, 2021
1 parent 3c1be14 commit fe7a902
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 9 deletions.
104 changes: 95 additions & 9 deletions google/cloud/bigquery_v2/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class Model(proto.Message):
Output only. Label columns that were used to train this
model. The output of the model will have a `predicted_`
prefix to these columns.
best_trial_id (int):
The best trial_id across all training runs.
"""

class ModelType(proto.Enum):
Expand All @@ -113,6 +115,7 @@ class ModelType(proto.Enum):
ARIMA = 11
AUTOML_REGRESSOR = 12
AUTOML_CLASSIFIER = 13
ARIMA_PLUS = 19

class LossType(proto.Enum):
r"""Loss metric to evaluate model training performance."""
Expand Down Expand Up @@ -151,6 +154,7 @@ class DataFrequency(proto.Enum):
WEEKLY = 5
DAILY = 6
HOURLY = 7
PER_MINUTE = 8

class HolidayRegion(proto.Enum):
r"""Type of supported holiday regions for time series forecasting
Expand Down Expand Up @@ -285,7 +289,7 @@ class RegressionMetrics(proto.Message):
median_absolute_error (google.protobuf.wrappers_pb2.DoubleValue):
Median absolute error.
r_squared (google.protobuf.wrappers_pb2.DoubleValue):
R^2 score.
R^2 score. This corresponds to r2_score in ML.EVALUATE.
"""

mean_absolute_error = proto.Field(
Expand Down Expand Up @@ -528,7 +532,7 @@ class ClusteringMetrics(proto.Message):
Mean of squared distances between each sample
to its cluster centroid.
clusters (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster]):
[Beta] Information for all clusters.
Information for all clusters.
"""

class Cluster(proto.Message):
Expand Down Expand Up @@ -697,10 +701,29 @@ class ArimaSingleModelForecastingMetrics(proto.Message):
Is arima model fitted with drift or not. It
is always false when d is not 1.
time_series_id (str):
The id to indicate different time series.
The time_series_id value for this time series. It will be
one of the unique values from the time_series_id_column
specified during ARIMA model training. Only present when
time_series_id_column training option was used.
time_series_ids (Sequence[str]):
The tuple of time_series_ids identifying this time series.
It will be one of the unique tuples of values present in the
time_series_id_columns specified during ARIMA model
training. Only present when time_series_id_columns training
option was used and the order of values here are same as the
order of time_series_id_columns.
seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]):
Seasonal periods. Repeated because multiple
periods are supported for one time series.
has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue):
If true, holiday_effect is a part of time series
decomposition result.
has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
If true, spikes_and_dips is a part of time series
decomposition result.
has_step_changes (google.protobuf.wrappers_pb2.BoolValue):
If true, step_changes is a part of time series decomposition
result.
"""

non_seasonal_order = proto.Field(
Expand All @@ -711,9 +734,19 @@ class ArimaSingleModelForecastingMetrics(proto.Message):
)
has_drift = proto.Field(proto.BOOL, number=3,)
time_series_id = proto.Field(proto.STRING, number=4,)
time_series_ids = proto.RepeatedField(proto.STRING, number=9,)
seasonal_periods = proto.RepeatedField(
proto.ENUM, number=5, enum="Model.SeasonalPeriod.SeasonalPeriodType",
)
has_holiday_effect = proto.Field(
proto.MESSAGE, number=6, message=wrappers_pb2.BoolValue,
)
has_spikes_and_dips = proto.Field(
proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue,
)
has_step_changes = proto.Field(
proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue,
)

non_seasonal_order = proto.RepeatedField(
proto.MESSAGE, number=1, message="Model.ArimaOrder",
Expand Down Expand Up @@ -901,7 +934,7 @@ class TrainingRun(proto.Message):
"""

class TrainingOptions(proto.Message):
r"""
r"""Options used in model training.
Attributes:
max_iterations (int):
The maximum number of iterations in training.
Expand Down Expand Up @@ -972,8 +1005,9 @@ class TrainingOptions(proto.Message):
num_clusters (int):
Number of clusters for clustering models.
model_uri (str):
[Beta] Google Cloud Storage URI from which the model was
imported. Only applicable for imported models.
Google Cloud Storage URI from which the model
was imported. Only applicable for imported
models.
optimization_strategy (google.cloud.bigquery_v2.types.Model.OptimizationStrategy):
Optimization strategy for training linear
regression models.
Expand Down Expand Up @@ -1030,8 +1064,11 @@ class TrainingOptions(proto.Message):
If a valid value is specified, then holiday
effects modeling is enabled.
time_series_id_column (str):
The id column that will be used to indicate
different time series to forecast in parallel.
The time series id column that was used
during ARIMA model training.
time_series_id_columns (Sequence[str]):
The time series id columns that were used
during ARIMA model training.
horizon (int):
The number of periods ahead that need to be
forecasted.
Expand All @@ -1042,6 +1079,15 @@ class TrainingOptions(proto.Message):
output feature name is A.b.
auto_arima_max_order (int):
The max value of non-seasonal p and q.
decompose_time_series (google.protobuf.wrappers_pb2.BoolValue):
If true, perform decompose time series and
save the results.
clean_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
If true, clean spikes and dips in the input
time series.
adjust_step_changes (google.protobuf.wrappers_pb2.BoolValue):
If true, detect step changes and make data
adjustment in the input time series.
"""

max_iterations = proto.Field(proto.INT64, number=1,)
Expand Down Expand Up @@ -1120,9 +1166,19 @@ class TrainingOptions(proto.Message):
proto.ENUM, number=42, enum="Model.HolidayRegion",
)
time_series_id_column = proto.Field(proto.STRING, number=43,)
time_series_id_columns = proto.RepeatedField(proto.STRING, number=51,)
horizon = proto.Field(proto.INT64, number=44,)
preserve_input_structs = proto.Field(proto.BOOL, number=45,)
auto_arima_max_order = proto.Field(proto.INT64, number=46,)
decompose_time_series = proto.Field(
proto.MESSAGE, number=50, message=wrappers_pb2.BoolValue,
)
clean_spikes_and_dips = proto.Field(
proto.MESSAGE, number=52, message=wrappers_pb2.BoolValue,
)
adjust_step_changes = proto.Field(
proto.MESSAGE, number=53, message=wrappers_pb2.BoolValue,
)

class IterationResult(proto.Message):
r"""Information about a single iteration of the training run.
Expand Down Expand Up @@ -1218,10 +1274,29 @@ class ArimaModelInfo(proto.Message):
Whether Arima model fitted with drift or not.
It is always false when d is not 1.
time_series_id (str):
The id to indicate different time series.
The time_series_id value for this time series. It will be
one of the unique values from the time_series_id_column
specified during ARIMA model training. Only present when
time_series_id_column training option was used.
time_series_ids (Sequence[str]):
The tuple of time_series_ids identifying this time series.
It will be one of the unique tuples of values present in the
time_series_id_columns specified during ARIMA model
training. Only present when time_series_id_columns training
option was used and the order of values here are same as the
order of time_series_id_columns.
seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]):
Seasonal periods. Repeated because multiple
periods are supported for one time series.
has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue):
If true, holiday_effect is a part of time series
decomposition result.
has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
If true, spikes_and_dips is a part of time series
decomposition result.
has_step_changes (google.protobuf.wrappers_pb2.BoolValue):
If true, step_changes is a part of time series decomposition
result.
"""

non_seasonal_order = proto.Field(
Expand All @@ -1237,11 +1312,21 @@ class ArimaModelInfo(proto.Message):
)
has_drift = proto.Field(proto.BOOL, number=4,)
time_series_id = proto.Field(proto.STRING, number=5,)
time_series_ids = proto.RepeatedField(proto.STRING, number=10,)
seasonal_periods = proto.RepeatedField(
proto.ENUM,
number=6,
enum="Model.SeasonalPeriod.SeasonalPeriodType",
)
has_holiday_effect = proto.Field(
proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue,
)
has_spikes_and_dips = proto.Field(
proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue,
)
has_step_changes = proto.Field(
proto.MESSAGE, number=9, message=wrappers_pb2.BoolValue,
)

arima_model_info = proto.RepeatedField(
proto.MESSAGE,
Expand Down Expand Up @@ -1319,6 +1404,7 @@ class ArimaModelInfo(proto.Message):
label_columns = proto.RepeatedField(
proto.MESSAGE, number=11, message=standard_sql.StandardSqlField,
)
best_trial_id = proto.Field(proto.INT64, number=19,)


class GetModelRequest(proto.Message):
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/bigquery_v2/types/table_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,23 @@ class TableReference(proto.Message):
maximum length is 1,024 characters. Certain operations allow
suffixing of the table ID with a partition decorator, such
as ``sample_table$20190123``.
project_id_alternative (Sequence[str]):
The alternative field that will be used when ESF is not able
to translate the received data to the project_id field.
dataset_id_alternative (Sequence[str]):
The alternative field that will be used when ESF is not able
to translate the received data to the project_id field.
table_id_alternative (Sequence[str]):
The alternative field that will be used when ESF is not able
to translate the received data to the project_id field.
"""

project_id = proto.Field(proto.STRING, number=1,)
dataset_id = proto.Field(proto.STRING, number=2,)
table_id = proto.Field(proto.STRING, number=3,)
project_id_alternative = proto.RepeatedField(proto.STRING, number=4,)
dataset_id_alternative = proto.RepeatedField(proto.STRING, number=5,)
table_id_alternative = proto.RepeatedField(proto.STRING, number=6,)


__all__ = tuple(sorted(__protobuf__.manifest))

0 comments on commit fe7a902

Please sign in to comment.