From 2a0bbdc152b3de07ec7d1909580193c6a22752ac Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Wed, 21 Apr 2021 09:06:15 -0700 Subject: [PATCH] samples: updates samples to v1 (3 of 8) (#215) * samples: updates the samples to v1 (3 of 10) * samples: more updates to v1 --- ...CreateTrainingPipelineCustomJobSample.java | 14 +- ...ineCustomTrainingManagedDatasetSample.java | 16 +- ...ningPipelineImageClassificationSample.java | 59 +- ...ingPipelineImageObjectDetectionSample.java | 55 +- .../CreateTrainingPipelineSample.java | 55 +- ...ngPipelineTabularClassificationSample.java | 521 +++++++------- ...ainingPipelineTabularRegressionSample.java | 665 +++++++++--------- ...iningPipelineTextClassificationSample.java | 55 +- ...ingPipelineTextEntityExtractionSample.java | 55 +- ...ngPipelineTextSentimentAnalysisSample.java | 57 +- ...gPipelineVideoActionRecognitionSample.java | 16 +- ...ningPipelineVideoClassificationSample.java | 20 +- ...ningPipelineVideoObjectTrackingSample.java | 24 +- ...eateHyperparameterTuningJobSampleTest.java | 4 +- 14 files changed, 716 insertions(+), 900 deletions(-) diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java index 7b40d0e8d05..53e9867a6ff 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,12 @@ package aiplatform; // [START aiplatform_create_training_pipeline_custom_job_sample] -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.protobuf.Value; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java index ea624de5b9d..8fad236877c 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java @@ -17,14 +17,14 @@ package aiplatform; // [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] -import com.google.cloud.aiplatform.v1beta1.GcsDestination; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.protobuf.Value; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java index 44570725367..4f9c1e2c57a 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java @@ -18,28 +18,24 @@ // [START aiplatform_create_training_pipeline_image_classification_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType; import com.google.rpc.Status; import java.io.IOException; @@ -204,25 +200,6 @@ static void createTrainingPipelineImageClassificationSample( System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); } - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("Explanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("Parameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("Sampled Shapley Attribution"); - System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("Metadata"); - System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap()); - System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "Feature Attributions Schema_uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - Status status = trainingPipelineResponse.getError(); System.out.println("Error"); System.out.format("Code: %s\n", status.getCode()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java index 78181e448ba..65ade6ea4ad 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java @@ -19,26 +19,22 @@ // [START aiplatform_create_training_pipeline_image_object_detection_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs.ModelType; import com.google.rpc.Status; @@ -204,25 +200,6 @@ static void createTrainingPipelineImageObjectDetectionSample( System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); } - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("Explanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("Parameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("Sampled Shapley Attribution"); - System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("Metadata"); - System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap()); - System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "Feature Attributions Schema_uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - Status status = trainingPipelineResponse.getError(); System.out.println("Error"); System.out.format("Code: %s\n", status.getCode()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java index 2dcb6e88cd2..33f94753e54 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineSample.java @@ -18,26 +18,22 @@ // [START aiplatform_create_training_pipeline_sample] -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; @@ -204,25 +200,6 @@ static void createTrainingPipelineSample( System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); } - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("Explanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("Parameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("Sampled Shapley Attribution"); - System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("Metadata"); - System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap()); - System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "Feature Attributions Schema_uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - Status status = trainingPipelineResponse.getError(); System.out.println("Error"); System.out.format("Code: %s\n", status.getCode()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java index 0ee0392dbea..107e8c01a4c 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java @@ -1,272 +1,249 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package aiplatform; - -// [START aiplatform_create_training_pipeline_tabular_classification_sample] - -import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; -import com.google.rpc.Status; -import java.io.IOException; -import java.util.ArrayList; - -public class CreateTrainingPipelineTabularClassificationSample { - - public static void main(String[] args) throws IOException { - // TODO(developer): Replace these variables before running the sample. - String project = "YOUR_PROJECT_ID"; - String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; - String datasetId = "YOUR_DATASET_ID"; - String targetColumn = "TARGET_COLUMN"; - createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn); - } - - static void createTrainingPipelineTableClassification( - String project, String modelDisplayName, String datasetId, String targetColumn) - throws IOException { - PipelineServiceSettings pipelineServiceSettings = - PipelineServiceSettings.newBuilder() - .setEndpoint("us-central1-aiplatform.googleapis.com:443") - .build(); - - // Initialize client that will be used to send requests. This client only needs to be created - // once, and can be reused for multiple requests. After completing all of your requests, call - // the "close" method on the client to safely clean up any remaining background resources. - try (PipelineServiceClient pipelineServiceClient = - PipelineServiceClient.create(pipelineServiceSettings)) { - String location = "us-central1"; - LocationName locationName = LocationName.of(project, location); - String trainingTaskDefinition = - "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; - - // Set the columns used for training and their data types - Transformation transformation1 = - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build()) - .build(); - Transformation transformation2 = - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build()) - .build(); - Transformation transformation3 = - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build()) - .build(); - Transformation transformation4 = - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build()) - .build(); - - ArrayList transformationArrayList = new ArrayList<>(); - transformationArrayList.add(transformation1); - transformationArrayList.add(transformation2); - transformationArrayList.add(transformation3); - transformationArrayList.add(transformation4); - - AutoMlTablesInputs autoMlTablesInputs = - AutoMlTablesInputs.newBuilder() - .setTargetColumn(targetColumn) - .setPredictionType("classification") - .addAllTransformations(transformationArrayList) - .setTrainBudgetMilliNodeHours(8000) - .build(); - - FractionSplit fractionSplit = - FractionSplit.newBuilder() - .setTrainingFraction(0.8) - .setValidationFraction(0.1) - .setTestFraction(0.1) - .build(); - - InputDataConfig inputDataConfig = - InputDataConfig.newBuilder() - .setDatasetId(datasetId) - .setFractionSplit(fractionSplit) - .build(); - Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); - - TrainingPipeline trainingPipeline = - TrainingPipeline.newBuilder() - .setDisplayName(modelDisplayName) - .setTrainingTaskDefinition(trainingTaskDefinition) - .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs)) - .setInputDataConfig(inputDataConfig) - .setModelToUpload(modelToUpload) - .build(); - - TrainingPipeline trainingPipelineResponse = - pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); - - System.out.println("Create Training Pipeline Tabular Classification Response"); - System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); - System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); - System.out.format( - "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); - System.out.format( - "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); - System.out.format( - "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); - - System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); - System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); - System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); - System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); - System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); - System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); - - InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); - System.out.println("\tInput Data Config"); - System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); - System.out.format( - "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); - - FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); - System.out.println("\t\tFraction Split"); - System.out.format( - "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); - System.out.format( - "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); - System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); - - FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); - System.out.println("\t\tFilter Split"); - System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); - System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); - System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); - - PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); - System.out.println("\t\tPredefined Split"); - System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); - - TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); - System.out.println("\t\tTimestamp Split"); - System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); - System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); - System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); - System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); - - Model modelResponse = trainingPipelineResponse.getModelToUpload(); - System.out.println("\tModel To Upload"); - System.out.format("\t\tName: %s\n", modelResponse.getName()); - System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); - System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); - System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); - System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); - System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); - System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); - - System.out.format( - "\t\tSupported Deployment Resources Types: %s\n", - modelResponse.getSupportedDeploymentResourcesTypesList().toString()); - System.out.format( - "\t\tSupported Input Storage Formats: %s\n", - modelResponse.getSupportedInputStorageFormatsList().toString()); - System.out.format( - "\t\tSupported Output Storage Formats: %s\n", - modelResponse.getSupportedOutputStorageFormatsList().toString()); - - System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); - System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); - System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); - PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); - - System.out.println("\tPredict Schemata"); - System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); - System.out.format( - "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); - System.out.format( - "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); - - for (Model.ExportFormat supportedExportFormat : - modelResponse.getSupportedExportFormatsList()) { - System.out.println("\tSupported Export Format"); - System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); - } - ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); - - System.out.println("\tContainer Spec"); - System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); - System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); - System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); - System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); - System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); - - for (EnvVar envVar : containerSpec.getEnvList()) { - System.out.println("\t\tEnv"); - System.out.format("\t\t\tName: %s\n", envVar.getName()); - System.out.format("\t\t\tValue: %s\n", envVar.getValue()); - } - - for (Port port : containerSpec.getPortsList()) { - System.out.println("\t\tPort"); - System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); - } - - for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { - System.out.println("\tDeployed Model"); - System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); - System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); - } - - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("\tExplanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("\t\tParameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("\t\tSampled Shapley Attribution"); - System.out.format("\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("\t\tMetadata"); - System.out.format("\t\t\tInput: %s\n", explanationMetadata.getInputsMap()); - System.out.format("\t\t\tOutput: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "\t\t\tFeature Attributions Schema Uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - - Status status = trainingPipelineResponse.getError(); - System.out.println("\tError"); - System.out.format("\t\tCode: %s\n", status.getCode()); - System.out.format("\t\tMessage: %s\n", status.getMessage()); - } - } -} -// [END aiplatform_create_training_pipeline_tabular_classification_sample] +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_tabular_classification_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.ArrayList; + +public class CreateTrainingPipelineTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String targetColumn = "TARGET_COLUMN"; + createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn); + } + + static void createTrainingPipelineTableClassification( + String project, String modelDisplayName, String datasetId, String targetColumn) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; + + // Set the columns used for training and their data types + Transformation transformation1 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build()) + .build(); + Transformation transformation2 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build()) + .build(); + Transformation transformation3 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build()) + .build(); + Transformation transformation4 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build()) + .build(); + + ArrayList transformationArrayList = new ArrayList<>(); + transformationArrayList.add(transformation1); + transformationArrayList.add(transformation2); + transformationArrayList.add(transformation3); + transformationArrayList.add(transformation4); + + AutoMlTablesInputs autoMlTablesInputs = + AutoMlTablesInputs.newBuilder() + .setTargetColumn(targetColumn) + .setPredictionType("classification") + .addAllTransformations(transformationArrayList) + .setTrainBudgetMilliNodeHours(8000) + .build(); + + FractionSplit fractionSplit = + FractionSplit.newBuilder() + .setTrainingFraction(0.8) + .setValidationFraction(0.1) + .setTestFraction(0.1) + .build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setFractionSplit(fractionSplit) + .build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(modelDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs)) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Tabular Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format( + "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); + System.out.format( + "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_tabular_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java index f9f6ade398d..427dae0c0cd 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java @@ -1,344 +1,321 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package aiplatform; - -// [START aiplatform_create_training_pipeline_tabular_regression_sample] - -import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation; -import com.google.rpc.Status; -import java.io.IOException; -import java.util.ArrayList; - -public class CreateTrainingPipelineTabularRegressionSample { - - public static void main(String[] args) throws IOException { - // TODO(developer): Replace these variables before running the sample. - String project = "YOUR_PROJECT_ID"; - String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; - String datasetId = "YOUR_DATASET_ID"; - String targetColumn = "TARGET_COLUMN"; - createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn); - } - - static void createTrainingPipelineTableRegression( - String project, String modelDisplayName, String datasetId, String targetColumn) - throws IOException { - PipelineServiceSettings pipelineServiceSettings = - PipelineServiceSettings.newBuilder() - .setEndpoint("us-central1-aiplatform.googleapis.com:443") - .build(); - - // Initialize client that will be used to send requests. This client only needs to be created - // once, and can be reused for multiple requests. After completing all of your requests, call - // the "close" method on the client to safely clean up any remaining background resources. - try (PipelineServiceClient pipelineServiceClient = - PipelineServiceClient.create(pipelineServiceSettings)) { - String location = "us-central1"; - LocationName locationName = LocationName.of(project, location); - String trainingTaskDefinition = - "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; - - // Set the columns used for training and their data types - ArrayList tranformations = new ArrayList<>(); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setTimestamp( - TimestampTransformation.newBuilder() - .setColumnName("TIMESTAMP_1unique_NULLABLE") - .setInvalidValuesAllowed(true)) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setTimestamp( - TimestampTransformation.newBuilder() - .setColumnName("DATETIME_1unique_NULLABLE") - .setInvalidValuesAllowed(true)) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE")) - .build()); - tranformations.add( - Transformation.newBuilder() - .setAuto( - AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE")) - .build()); - - AutoMlTablesInputs trainingTaskInputs = - AutoMlTablesInputs.newBuilder() - .addAllTransformations(tranformations) - .setTargetColumn(targetColumn) - .setPredictionType("regression") - .setTrainBudgetMilliNodeHours(8000) - .setDisableEarlyStopping(false) - // supported regression optimisation objectives: minimize-rmse, - // minimize-mae, minimize-rmsle - .setOptimizationObjective("minimize-rmse") - .build(); - - FractionSplit fractionSplit = - FractionSplit.newBuilder() - .setTrainingFraction(0.8) - .setValidationFraction(0.1) - .setTestFraction(0.1) - .build(); - - InputDataConfig inputDataConfig = - InputDataConfig.newBuilder() - .setDatasetId(datasetId) - .setFractionSplit(fractionSplit) - .build(); - Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); - - TrainingPipeline trainingPipeline = - TrainingPipeline.newBuilder() - .setDisplayName(modelDisplayName) - .setTrainingTaskDefinition(trainingTaskDefinition) - .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) - .setInputDataConfig(inputDataConfig) - .setModelToUpload(modelToUpload) - .build(); - - TrainingPipeline trainingPipelineResponse = - pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); - - System.out.println("Create Training Pipeline Tabular Regression Response"); - System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); - System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); - System.out.format( - "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); - System.out.format( - "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); - System.out.format( - "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); - - System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); - System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); - System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); - System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); - System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); - System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); - - InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); - System.out.println("\tInput Data Config"); - System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); - System.out.format( - "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); - - FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); - System.out.println("\t\tFraction Split"); - System.out.format( - "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); - System.out.format( - "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); - System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); - - FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); - System.out.println("\t\tFilter Split"); - System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); - System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); - System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); - - PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); - System.out.println("\t\tPredefined Split"); - System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); - - TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); - System.out.println("\t\tTimestamp Split"); - System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); - System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); - System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); - System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); - - Model modelResponse = trainingPipelineResponse.getModelToUpload(); - System.out.println("\tModel To Upload"); - System.out.format("\t\tName: %s\n", modelResponse.getName()); - System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); - System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); - System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); - System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); - System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); - System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); - - System.out.format( - "\t\tSupported Deployment Resources Types: %s\n", - modelResponse.getSupportedDeploymentResourcesTypesList().toString()); - System.out.format( - "\t\tSupported Input Storage Formats: %s\n", - modelResponse.getSupportedInputStorageFormatsList().toString()); - System.out.format( - "\t\tSupported Output Storage Formats: %s\n", - modelResponse.getSupportedOutputStorageFormatsList().toString()); - - System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); - System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); - System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); - PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); - - System.out.println("\tPredict Schemata"); - System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); - System.out.format( - "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); - System.out.format( - "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); - - for (Model.ExportFormat supportedExportFormat : - modelResponse.getSupportedExportFormatsList()) { - System.out.println("\tSupported Export Format"); - System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); - } - ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); - - System.out.println("\tContainer Spec"); - System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); - System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); - System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); - System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); - System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); - - for (EnvVar envVar : containerSpec.getEnvList()) { - System.out.println("\t\tEnv"); - System.out.format("\t\t\tName: %s\n", envVar.getName()); - System.out.format("\t\t\tValue: %s\n", envVar.getValue()); - } - - for (Port port : containerSpec.getPortsList()) { - System.out.println("\t\tPort"); - System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); - } - - for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { - System.out.println("\tDeployed Model"); - System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); - System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); - } - - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("\tExplanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("\t\tParameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("\t\tSampled Shapley Attribution"); - System.out.format("\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("\t\tMetadata"); - System.out.format("\t\t\tInput: %s\n", explanationMetadata.getInputsMap()); - System.out.format("\t\t\tOutput: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "\t\t\tFeature Attributions Schema Uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - - Status status = trainingPipelineResponse.getError(); - System.out.println("\tError"); - System.out.format("\t\tCode: %s\n", status.getCode()); - System.out.format("\t\tMessage: %s\n", status.getMessage()); - } - } -} -// [END aiplatform_create_training_pipeline_tabular_regression_sample] +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_tabular_regression_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.ArrayList; + +public class CreateTrainingPipelineTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String targetColumn = "TARGET_COLUMN"; + createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn); + } + + static void createTrainingPipelineTableRegression( + String project, String modelDisplayName, String datasetId, String targetColumn) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; + + // Set the columns used for training and their data types + ArrayList tranformations = new ArrayList<>(); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setTimestamp( + TimestampTransformation.newBuilder() + .setColumnName("TIMESTAMP_1unique_NULLABLE") + .setInvalidValuesAllowed(true)) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setTimestamp( + TimestampTransformation.newBuilder() + .setColumnName("DATETIME_1unique_NULLABLE") + .setInvalidValuesAllowed(true)) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE")) + .build()); + + AutoMlTablesInputs trainingTaskInputs = + AutoMlTablesInputs.newBuilder() + .addAllTransformations(tranformations) + .setTargetColumn(targetColumn) + .setPredictionType("regression") + .setTrainBudgetMilliNodeHours(8000) + .setDisableEarlyStopping(false) + // supported regression optimisation objectives: minimize-rmse, + // minimize-mae, minimize-rmsle + .setOptimizationObjective("minimize-rmse") + .build(); + + FractionSplit fractionSplit = + FractionSplit.newBuilder() + .setTrainingFraction(0.8) + .setValidationFraction(0.1) + .setTestFraction(0.1) + .build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setFractionSplit(fractionSplit) + .build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(modelDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Tabular Regression Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format( + "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); + System.out.format( + "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_tabular_regression_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java index dadd642c26a..ac338beb37c 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java @@ -19,26 +19,22 @@ // [START aiplatform_create_training_pipeline_text_classification_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTextClassificationInputs; import com.google.rpc.Status; import java.io.IOException; @@ -203,25 +199,6 @@ static void createTrainingPipelineTextClassificationSample( System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); } - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("\t\tExplanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("\t\t\tParameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("\t\t\t\tSampled Shapley Attribution"); - System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("\t\t\tMetadata"); - System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap()); - System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "\t\t\t\tFeature Attributions Schema_uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - Status status = trainingPipelineResponse.getError(); System.out.println("\tError"); System.out.format("\t\tCode: %s\n", status.getCode()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java index c62606c9886..63dc1348461 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java @@ -19,26 +19,22 @@ // [START aiplatform_create_training_pipeline_text_entity_extraction_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.rpc.Status; import java.io.IOException; @@ -199,25 +195,6 @@ static void createTrainingPipelineTextEntityExtractionSample( System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); } - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("\t\tExplanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("\t\t\tParameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("\t\t\t\tSampled Shapley Attribution"); - System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("\t\t\tMetadata"); - System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap()); - System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "\t\t\t\tFeature Attributions Schema_uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - Status status = trainingPipelineResponse.getError(); System.out.println("\tError"); System.out.format("\t\tCode: %s\n", status.getCode()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java index 0a5903ae219..ef87a9bfd2a 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java @@ -19,27 +19,23 @@ // [START aiplatform_create_training_pipeline_text_sentiment_analysis_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; -import com.google.cloud.aiplatform.v1beta1.EnvVar; -import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; -import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; -import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; -import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.Port; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.PredictSchemata; -import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTextSentimentInputs; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTextSentimentInputs; import com.google.rpc.Status; import java.io.IOException; @@ -207,25 +203,6 @@ static void createTrainingPipelineTextSentimentAnalysisSample( System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); } - ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); - System.out.println("\t\tExplanation Spec"); - - ExplanationParameters explanationParameters = explanationSpec.getParameters(); - System.out.println("\t\t\tParameters"); - - SampledShapleyAttribution sampledShapleyAttribution = - explanationParameters.getSampledShapleyAttribution(); - System.out.println("\t\t\t\tSampled Shapley Attribution"); - System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); - - ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); - System.out.println("\t\t\tMetadata"); - System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap()); - System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap()); - System.out.format( - "\t\t\t\tFeature Attributions Schema_uri: %s\n", - explanationMetadata.getFeatureAttributionsSchemaUri()); - Status status = trainingPipelineResponse.getError(); System.out.println("\tError"); System.out.format("\t\tCode: %s\n", status.getCode()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java index 9b3d83e7738..02e15fb5dac 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java @@ -18,14 +18,14 @@ // [START aiplatform_create_training_pipeline_video_action_recognition_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs.ModelType; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs.ModelType; import java.io.IOException; public class CreateTrainingPipelineVideoActionRecognitionSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java index 7bb27c5ae59..403476b24b9 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java @@ -19,16 +19,16 @@ // [START aiplatform_create_training_pipeline_video_classification_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; import com.google.rpc.Status; import java.io.IOException; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java index 03cf2a522c4..3bd30b4b9d5 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java @@ -19,18 +19,18 @@ // [START aiplatform_create_training_pipeline_video_object_tracking_sample] import com.google.cloud.aiplatform.util.ValueConverter; -import com.google.cloud.aiplatform.v1beta1.FilterSplit; -import com.google.cloud.aiplatform.v1beta1.FractionSplit; -import com.google.cloud.aiplatform.v1beta1.InputDataConfig; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.Model; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; -import com.google.cloud.aiplatform.v1beta1.TimestampSplit; -import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs.ModelType; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs.ModelType; import com.google.rpc.Status; import java.io.IOException; diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java index c5da6fc48df..48343412a6f 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java @@ -35,8 +35,8 @@ public class CreateHyperparameterTuningJobSampleTest { private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); - private static final String CONTAINER_IMAGE_URI = "gcr.io/ucaip-sample-tests/ucaip-training-test:" - + "latest"; + private static final String CONTAINER_IMAGE_URI = + "gcr.io/ucaip-sample-tests/ucaip-training-test:latest"; private ByteArrayOutputStream bout; private PrintStream out; private PrintStream originalPrintStream;