diff --git a/google/cloud/bigquery/model.py b/google/cloud/bigquery/model.py index 4d2bc346c..45a88ab22 100644 --- a/google/cloud/bigquery/model.py +++ b/google/cloud/bigquery/model.py @@ -16,6 +16,8 @@ """Define resources for the BigQuery ML Models API.""" +from __future__ import annotations # type: ignore + import copy import datetime import typing @@ -184,6 +186,21 @@ def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]: standard_sql.StandardSqlField.from_api_repr(column) for column in resource ] + @property + def transform_columns(self) -> Sequence[TransformColumn]: + """The input feature columns that were used to train this model. + The output transform columns used to train this model. + + See REST API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn + + Read-only. + """ + resources: Sequence[Dict[str, Any]] = typing.cast( + Sequence[Dict[str, Any]], self._properties.get("transformColumns", []) + ) + return [TransformColumn(resource) for resource in resources] + @property def label_columns(self) -> Sequence[standard_sql.StandardSqlField]: """Label columns that were used to train this model. @@ -434,6 +451,60 @@ def __repr__(self): ) +class TransformColumn: + """TransformColumn represents a transform column feature. + + See + https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn + + Args: + resource: + A dictionary representing a transform column feature. + """ + + def __init__(self, resource: Dict[str, Any]): + self._properties = resource + + @property + def name(self) -> Optional[str]: + """Name of the column.""" + return self._properties.get("name") + + @property + def type_(self) -> Optional[standard_sql.StandardSqlDataType]: + """Data type of the column after the transform. + + Returns: + Optional[google.cloud.bigquery.standard_sql.StandardSqlDataType]: + Data type of the column. + """ + type_json = self._properties.get("type") + if type_json is None: + return None + return standard_sql.StandardSqlDataType.from_api_repr(type_json) + + @property + def transform_sql(self) -> Optional[str]: + """The SQL expression used in the column transform.""" + return self._properties.get("transformSql") + + @classmethod + def from_api_repr(cls, resource: Dict[str, Any]) -> "TransformColumn": + """Constructs a transform column feature given its API representation + + Args: + resource: + Transform column feature representation from the API + + Returns: + Transform column feature parsed from ``resource``. + """ + this = cls({}) + resource = copy.deepcopy(resource) + this._properties = resource + return this + + def _model_arg_to_model_ref(value, default_project=None): """Helper to convert a string or Model to ModelReference. diff --git a/mypy.ini b/mypy.ini index 4505b4854..beaa679a8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.6 +python_version = 3.8 namespace_packages = True diff --git a/tests/unit/model/test_model.py b/tests/unit/model/test_model.py index 1ae988414..279a954c7 100644 --- a/tests/unit/model/test_model.py +++ b/tests/unit/model/test_model.py @@ -18,7 +18,9 @@ import pytest + import google.cloud._helpers +import google.cloud.bigquery.model KMS_KEY_NAME = "projects/1/locations/us/keyRings/1/cryptoKeys/1" @@ -136,6 +138,7 @@ def test_from_api_repr(target_class): google.cloud._helpers._rfc3339_to_datetime(got.training_runs[2]["startTime"]) == expiration_time ) + assert got.transform_columns == [] def test_from_api_repr_w_minimal_resource(target_class): @@ -293,6 +296,71 @@ def test_feature_columns(object_under_test): assert object_under_test.feature_columns == expected +def test_from_api_repr_w_transform_columns(target_class): + resource = { + "modelReference": { + "projectId": "my-project", + "datasetId": "my_dataset", + "modelId": "my_model", + }, + "transformColumns": [ + { + "name": "transform_name", + "type": {"typeKind": "INT64"}, + "transformSql": "transform_sql", + } + ], + } + got = target_class.from_api_repr(resource) + assert len(got.transform_columns) == 1 + transform_column = got.transform_columns[0] + assert isinstance(transform_column, google.cloud.bigquery.model.TransformColumn) + assert transform_column.name == "transform_name" + + +def test_transform_column_name(): + transform_columns = google.cloud.bigquery.model.TransformColumn( + {"name": "is_female"} + ) + assert transform_columns.name == "is_female" + + +def test_transform_column_transform_sql(): + transform_columns = google.cloud.bigquery.model.TransformColumn( + {"transformSql": "is_female"} + ) + assert transform_columns.transform_sql == "is_female" + + +def test_transform_column_type(): + transform_columns = google.cloud.bigquery.model.TransformColumn( + {"type": {"typeKind": "BOOL"}} + ) + assert transform_columns.type_.type_kind == "BOOL" + + +def test_transform_column_type_none(): + transform_columns = google.cloud.bigquery.model.TransformColumn({}) + assert transform_columns.type_ is None + + +def test_transform_column_from_api_repr_with_unknown_properties(): + transform_column = google.cloud.bigquery.model.TransformColumn.from_api_repr( + { + "name": "is_female", + "type": {"typeKind": "BOOL"}, + "transformSql": "is_female", + "test": "one", + } + ) + assert transform_column._properties == { + "name": "is_female", + "type": {"typeKind": "BOOL"}, + "transformSql": "is_female", + "test": "one", + } + + def test_label_columns(object_under_test): from google.cloud.bigquery import standard_sql