diff --git a/packages/google-cloud-automl/google/cloud/automl_v1beta1/tables/tables_client.py b/packages/google-cloud-automl/google/cloud/automl_v1beta1/tables/tables_client.py index 48d9fbef9a37..ab4c3d4821b8 100644 --- a/packages/google-cloud-automl/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/packages/google-cloud-automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -2596,7 +2596,7 @@ def predict( model=None, model_name=None, model_display_name=None, - params=None, + feature_importance=False, project=None, region=None, **kwargs @@ -2643,9 +2643,9 @@ def predict( The `model` instance you want to predict with . This must be supplied if `model_display_name` or `model_name` are not supplied. - params (dict[str, str]): - `feature_importance` can be set as True to enable local - explainability. The default is false. + feature_importance (bool): + True if enable feature importance explainability. The default is + False. Returns: A :class:`~google.cloud.automl_v1beta1.types.PredictResponse` @@ -2687,6 +2687,10 @@ def predict( request = {"row": {"values": values}} + params = None + if feature_importance: + params = {"feature_importance": "true"} + return self.prediction_client.predict(model.name, request, params, **kwargs) def batch_predict( diff --git a/packages/google-cloud-automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/packages/google-cloud-automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index f164a8875787..3f2b6d3de2bd 100644 --- a/packages/google-cloud-automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/packages/google-cloud-automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -1137,6 +1137,25 @@ def test_predict_from_dict(self): None, ) + def test_predict_from_dict_with_feature_importance(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_a = mock.Mock(display_name="a", data_type=data_type) + column_spec_b = mock.Mock(display_name="b", data_type=data_type) + model_metadata = mock.Mock( + input_feature_column_specs=[column_spec_a, column_spec_b] + ) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict( + {"a": "1", "b": "2"}, model_name="my_model", feature_importance=True + ) + client.prediction_client.predict.assert_called_with( + "my_model", + {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, + {"feature_importance": "true"}, + ) + def test_predict_from_dict_missing(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) column_spec_a = mock.Mock(display_name="a", data_type=data_type)