diff --git a/aiplatform/pom.xml b/aiplatform/pom.xml new file mode 100644 index 00000000000..0ba47527ea9 --- /dev/null +++ b/aiplatform/pom.xml @@ -0,0 +1,71 @@ + + + 4.0.0 + com.example.aiplatform + aiplatform-snippets + jar + Google Cloud Vertex AI Snippets + https://github.com/GoogleCloudPlatform/java-docs-samples/tree/main/aiplatform + + + + com.google.cloud.samples + shared-configuration + 1.2.0 + + + + 1.8 + 1.8 + UTF-8 + + + + + com.google.cloud + google-cloud-aiplatform + 3.4.1 + + + + com.google.cloud + google-cloud-storage + 2.13.0 + + + com.google.protobuf + protobuf-java-util + 4.0.0-rc-2 + + + com.google.code.gson + gson + 2.9.1 + + + junit + junit + 4.13.2 + test + + + com.google.truth + truth + 1.1.3 + test + + + com.google.api.grpc + proto-google-cloud-aiplatform-v1beta1 + 0.20.1 + + + com.google.cloud + google-cloud-bigquery + 2.18.0 + + + diff --git a/aiplatform/resources/daisy.jpg b/aiplatform/resources/daisy.jpg new file mode 100644 index 00000000000..ae01cae9183 Binary files /dev/null and b/aiplatform/resources/daisy.jpg differ diff --git a/aiplatform/resources/image_flower_daisy.jpg b/aiplatform/resources/image_flower_daisy.jpg new file mode 100644 index 00000000000..3ba1d67705a Binary files /dev/null and b/aiplatform/resources/image_flower_daisy.jpg differ diff --git a/aiplatform/resources/iod_caprese_salad.jpg b/aiplatform/resources/iod_caprese_salad.jpg new file mode 100644 index 00000000000..100ad677a91 Binary files /dev/null and b/aiplatform/resources/iod_caprese_salad.jpg differ diff --git a/aiplatform/src/main/java/aiplatform/BatchCreateFeaturesSample.java b/aiplatform/src/main/java/aiplatform/BatchCreateFeaturesSample.java new file mode 100644 index 00000000000..8b948092798 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/BatchCreateFeaturesSample.java @@ -0,0 +1,128 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Create features in bulk for an existing entity type. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup + * before running the code snippet + */ + +package aiplatform; + +// [START aiplatform_batch_create_features_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.BatchCreateFeaturesOperationMetadata; +import com.google.cloud.aiplatform.v1.BatchCreateFeaturesRequest; +import com.google.cloud.aiplatform.v1.BatchCreateFeaturesResponse; +import com.google.cloud.aiplatform.v1.CreateFeatureRequest; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.Feature.ValueType; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class BatchCreateFeaturesSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + batchCreateFeaturesSample(project, featurestoreId, entityTypeId, location, endpoint, timeout); + } + + static void batchCreateFeaturesSample( + String project, + String featurestoreId, + String entityTypeId, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + List createFeatureRequests = new ArrayList<>(); + + Feature titleFeature = + Feature.newBuilder() + .setDescription("The title of the movie") + .setValueType(ValueType.STRING) + .build(); + Feature genresFeature = + Feature.newBuilder() + .setDescription("The genres of the movie") + .setValueType(ValueType.STRING) + .build(); + Feature averageRatingFeature = + Feature.newBuilder() + .setDescription("The average rating for the movie, range is [1.0-5.0]") + .setValueType(ValueType.DOUBLE) + .build(); + + createFeatureRequests.add( + CreateFeatureRequest.newBuilder().setFeature(titleFeature).setFeatureId("title").build()); + + createFeatureRequests.add( + CreateFeatureRequest.newBuilder() + .setFeature(genresFeature) + .setFeatureId("genres") + .build()); + + createFeatureRequests.add( + CreateFeatureRequest.newBuilder() + .setFeature(averageRatingFeature) + .setFeatureId("average_rating") + .build()); + + BatchCreateFeaturesRequest batchCreateFeaturesRequest = + BatchCreateFeaturesRequest.newBuilder() + .setParent( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .addAllRequests(createFeatureRequests) + .build(); + + OperationFuture + batchCreateFeaturesFuture = + featurestoreServiceClient.batchCreateFeaturesAsync(batchCreateFeaturesRequest); + System.out.format( + "Operation name: %s%n", batchCreateFeaturesFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + BatchCreateFeaturesResponse batchCreateFeaturesResponse = + batchCreateFeaturesFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Batch Create Features Response"); + System.out.println(batchCreateFeaturesResponse); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_batch_create_features_sample] diff --git a/aiplatform/src/main/java/aiplatform/BatchReadFeatureValuesSample.java b/aiplatform/src/main/java/aiplatform/BatchReadFeatureValuesSample.java new file mode 100644 index 00000000000..a76c3388d1e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/BatchReadFeatureValuesSample.java @@ -0,0 +1,135 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Batch read feature values from a featurestore, as determined by your + * read instances list file, to export data. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_batch_read_feature_values_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.BatchReadFeatureValuesOperationMetadata; +import com.google.cloud.aiplatform.v1.BatchReadFeatureValuesRequest; +import com.google.cloud.aiplatform.v1.BatchReadFeatureValuesRequest.EntityTypeSpec; +import com.google.cloud.aiplatform.v1.BatchReadFeatureValuesResponse; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.CsvSource; +import com.google.cloud.aiplatform.v1.FeatureSelector; +import com.google.cloud.aiplatform.v1.FeatureValueDestination; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.IdMatcher; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class BatchReadFeatureValuesSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String inputCsvFile = "YOU_INPUT_CSV_FILE"; + String destinationTableUri = "YOUR_DESTINATION_TABLE_URI"; + List featureSelectorIds = Arrays.asList("title", "genres", "average_rating"); + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + batchReadFeatureValuesSample( + project, + featurestoreId, + entityTypeId, + inputCsvFile, + destinationTableUri, + featureSelectorIds, + location, + endpoint, + timeout); + } + + static void batchReadFeatureValuesSample( + String project, + String featurestoreId, + String entityTypeId, + String inputCsvFile, + String destinationTableUri, + List featureSelectorIds, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + List entityTypeSpecs = new ArrayList<>(); + + FeatureSelector featureSelector = + FeatureSelector.newBuilder() + .setIdMatcher(IdMatcher.newBuilder().addAllIds(featureSelectorIds).build()) + .build(); + EntityTypeSpec entityTypeSpec = + EntityTypeSpec.newBuilder() + .setEntityTypeId(entityTypeId) + .setFeatureSelector(featureSelector) + .build(); + + entityTypeSpecs.add(entityTypeSpec); + + BigQueryDestination bigQueryDestination = + BigQueryDestination.newBuilder().setOutputUri(destinationTableUri).build(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(inputCsvFile).build(); + BatchReadFeatureValuesRequest batchReadFeatureValuesRequest = + BatchReadFeatureValuesRequest.newBuilder() + .setFeaturestore(FeaturestoreName.of(project, location, featurestoreId).toString()) + .setCsvReadInstances(CsvSource.newBuilder().setGcsSource(gcsSource)) + .setDestination( + FeatureValueDestination.newBuilder().setBigqueryDestination(bigQueryDestination)) + .addAllEntityTypeSpecs(entityTypeSpecs) + .build(); + + OperationFuture + batchReadFeatureValuesFuture = + featurestoreServiceClient.batchReadFeatureValuesAsync(batchReadFeatureValuesRequest); + System.out.format( + "Operation name: %s%n", batchReadFeatureValuesFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + BatchReadFeatureValuesResponse batchReadFeatureValuesResponse = + batchReadFeatureValuesFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Batch Read Feature Values Response"); + System.out.println(batchReadFeatureValuesResponse); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_batch_read_feature_values_sample] diff --git a/aiplatform/src/main/java/aiplatform/CancelBatchPredictionJobSample.java b/aiplatform/src/main/java/aiplatform/CancelBatchPredictionJobSample.java new file mode 100644 index 00000000000..495f0f88598 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CancelBatchPredictionJobSample.java @@ -0,0 +1,56 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_cancel_batch_prediction_job_sample] + +import com.google.cloud.aiplatform.v1.BatchPredictionJobName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import java.io.IOException; + +public class CancelBatchPredictionJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID"; + cancelBatchPredictionJobSample(project, batchPredictionJobId); + } + + static void cancelBatchPredictionJobSample(String project, String batchPredictionJobId) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + BatchPredictionJobName batchPredictionJobName = + BatchPredictionJobName.of(project, location, batchPredictionJobId); + + jobServiceClient.cancelBatchPredictionJob(batchPredictionJobName); + + System.out.println("Cancelled the Batch Prediction Job"); + } + } +} +// [END aiplatform_cancel_batch_prediction_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/CancelDataLabelingJobSample.java b/aiplatform/src/main/java/aiplatform/CancelDataLabelingJobSample.java new file mode 100644 index 00000000000..eb540687edf --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CancelDataLabelingJobSample.java @@ -0,0 +1,53 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_cancel_data_labeling_job_sample] + +import com.google.cloud.aiplatform.v1.DataLabelingJobName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import java.io.IOException; + +public class CancelDataLabelingJobSample { + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String dataLabelingJobId = "YOUR_DATA_LABELING_JOB_ID"; + cancelDataLabelingJob(project, dataLabelingJobId); + } + + static void cancelDataLabelingJob(String project, String dataLabelingJobId) throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + + DataLabelingJobName dataLabelingJobName = + DataLabelingJobName.of(project, location, dataLabelingJobId); + jobServiceClient.cancelDataLabelingJob(dataLabelingJobName); + System.out.println("Cancelled Data labeling job"); + } + } +} +// [END aiplatform_cancel_data_labeling_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/CancelTrainingPipelineSample.java b/aiplatform/src/main/java/aiplatform/CancelTrainingPipelineSample.java new file mode 100644 index 00000000000..a689ae24625 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CancelTrainingPipelineSample.java @@ -0,0 +1,57 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_cancel_training_pipeline_sample] + +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipelineName; +import java.io.IOException; + +public class CancelTrainingPipelineSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID"; + String project = "YOUR_PROJECT_ID"; + cancelTrainingPipelineSample(project, trainingPipelineId); + } + + static void cancelTrainingPipelineSample(String project, String trainingPipelineId) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + TrainingPipelineName trainingPipelineName = + TrainingPipelineName.of(project, location, trainingPipelineId); + + pipelineServiceClient.cancelTrainingPipeline(trainingPipelineName); + + System.out.println("Cancelled the Training Pipeline"); + } + } +} +// [END aiplatform_cancel_training_pipeline_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java new file mode 100644 index 00000000000..105268f2e8b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java @@ -0,0 +1,107 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_bigquery_sample] +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateBatchPredictionJobBigquerySample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelName = "MODEL_NAME"; + String instancesFormat = "INSTANCES_FORMAT"; + String bigquerySourceInputUri = "BIGQUERY_SOURCE_INPUT_URI"; + String predictionsFormat = "PREDICTIONS_FORMAT"; + String bigqueryDestinationOutputUri = "BIGQUERY_DESTINATION_OUTPUT_URI"; + createBatchPredictionJobBigquerySample( + project, + displayName, + modelName, + instancesFormat, + bigquerySourceInputUri, + predictionsFormat, + bigqueryDestinationOutputUri); + } + + static void createBatchPredictionJobBigquerySample( + String project, + String displayName, + String model, + String instancesFormat, + String bigquerySourceInputUri, + String predictionsFormat, + String bigqueryDestinationOutputUri) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonObject jsonModelParameters = new JsonObject(); + Value.Builder modelParametersBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder); + Value modelParameters = modelParametersBuilder.build(); + BigQuerySource bigquerySource = + BigQuerySource.newBuilder().setInputUri(bigquerySourceInputUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat(instancesFormat) + .setBigquerySource(bigquerySource) + .build(); + BigQueryDestination bigqueryDestination = + BigQueryDestination.newBuilder().setOutputUri(bigqueryDestinationOutputUri).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat(predictionsFormat) + .setBigqueryDestination(bigqueryDestination) + .build(); + String modelName = ModelName.of(project, location, model).toString(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tName: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_batch_prediction_job_bigquery_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobSample.java new file mode 100644 index 00000000000..12bab04e13b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobSample.java @@ -0,0 +1,121 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_sample] +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.AcceleratorType; +import com.google.cloud.aiplatform.v1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.protobuf.Value; +import java.io.IOException; + +public class CreateBatchPredictionJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelName = "MODEL_NAME"; + String instancesFormat = "INSTANCES_FORMAT"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String predictionsFormat = "PREDICTIONS_FORMAT"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobSample( + project, + displayName, + modelName, + instancesFormat, + gcsSourceUri, + predictionsFormat, + gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobSample( + String project, + String displayName, + String model, + String instancesFormat, + String gcsSourceUri, + String predictionsFormat, + String gcsDestinationOutputUriPrefix) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + + // Passing in an empty Value object for model parameters + Value modelParameters = ValueConverter.EMPTY_VALUE; + + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat(instancesFormat) + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat(predictionsFormat) + .setGcsDestination(gcsDestination) + .build(); + MachineSpec machineSpec = + MachineSpec.newBuilder() + .setMachineType("n1-standard-2") + .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80) + .setAcceleratorCount(1) + .build(); + BatchDedicatedResources dedicatedResources = + BatchDedicatedResources.newBuilder() + .setMachineSpec(machineSpec) + .setStartingReplicaCount(1) + .setMaxReplicaCount(1) + .build(); + String modelName = ModelName.of(project, location, model).toString(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .setDedicatedResources(dedicatedResources) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tName: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_batch_prediction_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextClassificationSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextClassificationSample.java new file mode 100644 index 00000000000..ba79bf14b02 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextClassificationSample.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_text_classification_sample] +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; + +public class CreateBatchPredictionJobTextClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String location = "us-central1"; + String displayName = "DISPLAY_NAME"; + String modelId = "MODEL_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobTextClassificationSample( + project, location, displayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobTextClassificationSample( + String project, + String location, + String displayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + // The AI Platform services require regional API endpoints. + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + try { + String modelName = ModelName.of(project, location, modelId).toString(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + } catch (ApiException ex) { + System.out.format("Exception: %s\n", ex.getLocalizedMessage()); + } + } + } +} + +// [END aiplatform_create_batch_prediction_job_text_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSample.java new file mode 100644 index 00000000000..e753da2ed04 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSample.java @@ -0,0 +1,95 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_text_entity_extraction_sample] +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; + +public class CreateBatchPredictionJobTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String location = "us-central1"; + String displayName = "DISPLAY_NAME"; + String modelId = "MODEL_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobTextEntityExtractionSample( + project, location, displayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobTextEntityExtractionSample( + String project, + String location, + String displayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + // The AI Platform services require regional API endpoints. + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + try { + String modelName = ModelName.of(project, location, modelId).toString(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tname:%s\n", response.getName()); + } catch (ApiException ex) { + System.out.format("Exception: %s\n", ex.getLocalizedMessage()); + } + } + } +} + +// [END aiplatform_create_batch_prediction_job_text_entity_extraction_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..8191618c9fe --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSample.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample] +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; + +public class CreateBatchPredictionJobTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String location = "us-central1"; + String displayName = "DISPLAY_NAME"; + String modelId = "MODEL_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobTextSentimentAnalysisSample( + project, location, displayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobTextSentimentAnalysisSample( + String project, + String location, + String displayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + // The AI Platform services require regional API endpoints. + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + try { + String modelName = ModelName.of(project, location, modelId).toString(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + } catch (ApiException ex) { + System.out.format("Exception: %s\n", ex.getLocalizedMessage()); + } + } + } +} + +// [END aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java new file mode 100644 index 00000000000..0d0f68e5418 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_video_action_recognition_sample] +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.protobuf.Value; +import java.io.IOException; + +public class CreateBatchPredictionJobVideoActionRecognitionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String model = "MODEL"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobVideoActionRecognitionSample( + project, displayName, model, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobVideoActionRecognitionSample( + String project, + String displayName, + String model, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + Value modelParameters = ValueConverter.EMPTY_VALUE; + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + + String modelName = ModelName.of(project, location, model).toString(); + + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tName: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_batch_prediction_job_video_action_recognition_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java new file mode 100644 index 00000000000..905ab46b7c5 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java @@ -0,0 +1,204 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_video_classification_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.CompletionStats; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ManualBatchTuningParameters; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ResourcesConsumed; +import com.google.cloud.aiplatform.v1.schema.predict.params.VideoClassificationPredictionParams; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateBatchPredictionJobVideoClassificationSample { + + public static void main(String[] args) throws IOException { + String batchPredictionDisplayName = "YOUR_VIDEO_CLASSIFICATION_DISPLAY_NAME"; + String modelId = "YOUR_MODEL_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String gcsDestinationOutputUriPrefix = + "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/"; + String project = "YOUR_PROJECT_ID"; + createBatchPredictionJobVideoClassification( + batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project); + } + + static void createBatchPredictionJobVideoClassification( + String batchPredictionDisplayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix, + String project) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + VideoClassificationPredictionParams modelParamsObj = + VideoClassificationPredictionParams.newBuilder() + .setConfidenceThreshold(((float) 0.5)) + .setMaxPredictions(10000) + .setSegmentClassification(true) + .setShotClassification(true) + .setOneSecIntervalClassification(true) + .build(); + + Value modelParameters = ValueConverter.toValue(modelParamsObj); + + ModelName modelName = ModelName.of(project, location, modelId); + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + InputConfig inputConfig = + InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build(); + + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + OutputConfig outputConfig = + OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(batchPredictionDisplayName) + .setModel(modelName.toString()) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + BatchPredictionJob batchPredictionJobResponse = + jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob); + + System.out.println("Create Batch Prediction Job Video Classification Response"); + System.out.format("\tName: %s\n", batchPredictionJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName()); + System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel()); + System.out.format( + "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters()); + + System.out.format("\tState: %s\n", batchPredictionJobResponse.getState()); + System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap()); + + InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig(); + System.out.println("\tInput Config"); + System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat()); + + GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource(); + System.out.println("\t\tGcs Source"); + System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList()); + + BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource(); + System.out.println("\t\tBigquery Source"); + System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri()); + + OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig(); + System.out.println("\tOutput Config"); + System.out.format( + "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat()); + + GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination(); + System.out.println("\t\tGcs Destination"); + System.out.format( + "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix()); + + BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination(); + System.out.println("\t\tBig Query Destination"); + System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri()); + + BatchDedicatedResources batchDedicatedResources = + batchPredictionJobResponse.getDedicatedResources(); + System.out.println("\tBatch Dedicated Resources"); + System.out.format( + "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount()); + System.out.format( + "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount()); + + MachineSpec machineSpec = batchDedicatedResources.getMachineSpec(); + System.out.println("\t\tMachine Spec"); + System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType()); + System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType()); + System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount()); + + ManualBatchTuningParameters manualBatchTuningParameters = + batchPredictionJobResponse.getManualBatchTuningParameters(); + System.out.println("\tManual Batch Tuning Parameters"); + System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize()); + + OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo(); + System.out.println("\tOutput Info"); + System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory()); + System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset()); + + Status status = batchPredictionJobResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + List details = status.getDetailsList(); + + for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) { + System.out.println("\tPartial Failure"); + System.out.format("\t\tCode: %s\n", partialFailure.getCode()); + System.out.format("\t\tMessage: %s\n", partialFailure.getMessage()); + List partialFailureDetailsList = partialFailure.getDetailsList(); + } + + ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed(); + System.out.println("\tResources Consumed"); + System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours()); + + CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats(); + System.out.println("\tCompletion Stats"); + System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount()); + System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount()); + System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount()); + } + } +} +// [END aiplatform_create_batch_prediction_job_video_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java new file mode 100644 index 00000000000..860bc8da82a --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java @@ -0,0 +1,201 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_video_object_tracking_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.CompletionStats; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ManualBatchTuningParameters; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ResourcesConsumed; +import com.google.cloud.aiplatform.v1.schema.predict.params.VideoObjectTrackingPredictionParams; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateBatchPredictionJobVideoObjectTrackingSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String batchPredictionDisplayName = "YOUR_VIDEO_OBJECT_TRACKING_DISPLAY_NAME"; + String modelId = "YOUR_MODEL_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String gcsDestinationOutputUriPrefix = + "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/"; + String project = "YOUR_PROJECT_ID"; + batchPredictionJobVideoObjectTracking( + batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project); + } + + static void batchPredictionJobVideoObjectTracking( + String batchPredictionDisplayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix, + String project) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + ModelName modelName = ModelName.of(project, location, modelId); + + VideoObjectTrackingPredictionParams modelParamsObj = + VideoObjectTrackingPredictionParams.newBuilder() + .setConfidenceThreshold(((float) 0.5)) + .build(); + + Value modelParameters = ValueConverter.toValue(modelParamsObj); + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + InputConfig inputConfig = + InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build(); + + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + OutputConfig outputConfig = + OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(batchPredictionDisplayName) + .setModel(modelName.toString()) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + BatchPredictionJob batchPredictionJobResponse = + jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob); + + System.out.println("Create Batch Prediction Job Video Object Tracking Response"); + System.out.format("\tName: %s\n", batchPredictionJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName()); + System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel()); + System.out.format( + "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters()); + + System.out.format("\tState: %s\n", batchPredictionJobResponse.getState()); + System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap()); + + InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig(); + System.out.println("\tInput Config"); + System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat()); + + GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource(); + System.out.println("\t\tGcs Source"); + System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList()); + + BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource(); + System.out.println("\t\tBigquery Source"); + System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri()); + + OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig(); + System.out.println("\tOutput Config"); + System.out.format( + "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat()); + + GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination(); + System.out.println("\t\tGcs Destination"); + System.out.format( + "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix()); + + BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination(); + System.out.println("\t\tBig Query Destination"); + System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri()); + + BatchDedicatedResources batchDedicatedResources = + batchPredictionJobResponse.getDedicatedResources(); + System.out.println("\tBatch Dedicated Resources"); + System.out.format( + "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount()); + System.out.format( + "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount()); + + MachineSpec machineSpec = batchDedicatedResources.getMachineSpec(); + System.out.println("\t\tMachine Spec"); + System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType()); + System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType()); + System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount()); + + ManualBatchTuningParameters manualBatchTuningParameters = + batchPredictionJobResponse.getManualBatchTuningParameters(); + System.out.println("\tManual Batch Tuning Parameters"); + System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize()); + + OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo(); + System.out.println("\tOutput Info"); + System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory()); + System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset()); + + Status status = batchPredictionJobResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + List details = status.getDetailsList(); + + for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) { + System.out.println("\tPartial Failure"); + System.out.format("\t\tCode: %s\n", partialFailure.getCode()); + System.out.format("\t\tMessage: %s\n", partialFailure.getMessage()); + List partialFailureDetailsList = partialFailure.getDetailsList(); + } + + ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed(); + System.out.println("\tResources Consumed"); + System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours()); + + CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats(); + System.out.println("\tCompletion Stats"); + System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount()); + System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount()); + System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount()); + } + } +} +// [END aiplatform_create_batch_prediction_job_video_object_tracking_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java new file mode 100644 index 00000000000..1a0076fbc4b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java @@ -0,0 +1,97 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_active_learning_sample] +import com.google.cloud.aiplatform.v1.ActiveLearningConfig; +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateDataLabelingJobActiveLearningSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String dataset = "DATASET"; + String instructionUri = "INSTRUCTION_URI"; + String inputsSchemaUri = "INPUTS_SCHEMA_URI"; + String annotationSpec = "ANNOTATION_SPEC"; + createDataLabelingJobActiveLearningSample( + project, displayName, dataset, instructionUri, inputsSchemaUri, annotationSpec); + } + + static void createDataLabelingJobActiveLearningSample( + String project, + String displayName, + String dataset, + String instructionUri, + String inputsSchemaUri, + String annotationSpec) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonArray jsonAnnotationSpecs = new JsonArray(); + jsonAnnotationSpecs.add(annotationSpec); + JsonObject jsonInputs = new JsonObject(); + jsonInputs.add("annotation_specs", jsonAnnotationSpecs); + Value.Builder inputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonInputs.toString(), inputsBuilder); + Value inputs = inputsBuilder.build(); + ActiveLearningConfig activeLearningConfig = + ActiveLearningConfig.newBuilder().setMaxDataItemCount(1).build(); + + String datasetName = DatasetName.of(project, location, dataset).toString(); + + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .addDatasets(datasetName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri(inputsSchemaUri) + .setInputs(inputs) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", + "data_labeling_job_active_learning") + .setActiveLearningConfig(activeLearningConfig) + .build(); + LocationName parent = LocationName.of(project, location); + DataLabelingJob response = client.createDataLabelingJob(parent, dataLabelingJob); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_data_labeling_job_active_learning_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java new file mode 100644 index 00000000000..8d9dced5ec7 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java @@ -0,0 +1,115 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_image_sample] + +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.type.Money; +import java.io.IOException; +import java.util.Map; + +public class CreateDataLabelingJobImageSample { + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String displayName = "YOUR_DATA_LABELING_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String instructionUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_data_labeling_source/file.pdf"; + String annotationSpec = "YOUR_ANNOTATION_SPEC"; + createDataLabelingJobImage(project, displayName, datasetId, instructionUri, annotationSpec); + } + + static void createDataLabelingJobImage( + String project, + String displayName, + String datasetId, + String instructionUri, + String annotationSpec) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = "{\"annotation_specs\": [ " + annotationSpec + "]}"; + Value.Builder annotationSpecValue = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, annotationSpecValue); + + DatasetName datasetName = DatasetName.of(project, location, datasetId); + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri( + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/" + + "image_classification.yaml") + .addDatasets(datasetName.toString()) + .setInputs(annotationSpecValue) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", "my_test_saved_query") + .build(); + + DataLabelingJob dataLabelingJobResponse = + jobServiceClient.createDataLabelingJob(locationName, dataLabelingJob); + + System.out.println("Create Data Labeling Job Image Response"); + System.out.format("\tName: %s\n", dataLabelingJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", dataLabelingJobResponse.getDisplayName()); + System.out.format("\tDatasets: %s\n", dataLabelingJobResponse.getDatasetsList()); + System.out.format("\tLabeler Count: %s\n", dataLabelingJobResponse.getLabelerCount()); + System.out.format("\tInstruction Uri: %s\n", dataLabelingJobResponse.getInstructionUri()); + System.out.format("\tInputs Schema Uri: %s\n", dataLabelingJobResponse.getInputsSchemaUri()); + System.out.format("\tInputs: %s\n", dataLabelingJobResponse.getInputs()); + System.out.format("\tState: %s\n", dataLabelingJobResponse.getState()); + System.out.format("\tLabeling Progress: %s\n", dataLabelingJobResponse.getLabelingProgress()); + System.out.format("\tCreate Time: %s\n", dataLabelingJobResponse.getCreateTime()); + System.out.format("\tUpdate Time: %s\n", dataLabelingJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", dataLabelingJobResponse.getLabelsMap()); + System.out.format( + "\tSpecialist Pools: %s\n", dataLabelingJobResponse.getSpecialistPoolsList()); + for (Map.Entry annotationLabelMap : + dataLabelingJobResponse.getAnnotationLabelsMap().entrySet()) { + System.out.println("\tAnnotation Level"); + System.out.format("\t\tkey: %s\n", annotationLabelMap.getKey()); + System.out.format("\t\tvalue: %s\n", annotationLabelMap.getValue()); + } + Money money = dataLabelingJobResponse.getCurrentSpend(); + + System.out.println("\tCurrent Spend"); + System.out.format("\t\tCurrency Code: %s\n", money.getCurrencyCode()); + System.out.format("\t\tUnits: %s\n", money.getUnits()); + System.out.format("\t\tNanos: %s\n", money.getNanos()); + } + } +} +// [END aiplatform_create_data_labeling_job_image_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobSample.java b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobSample.java new file mode 100644 index 00000000000..a677169d7bc --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobSample.java @@ -0,0 +1,117 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_sample] + +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.type.Money; +import java.io.IOException; +import java.util.Map; + +public class CreateDataLabelingJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String displayName = "YOUR_DATA_LABELING_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String instructionUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_data_labeling_source/file.pdf"; + String inputsSchemaUri = "YOUR_INPUT_SCHEMA_URI"; + String annotationSpec = "YOUR_ANNOTATION_SPEC"; + createDataLabelingJob( + project, displayName, datasetId, instructionUri, inputsSchemaUri, annotationSpec); + } + + static void createDataLabelingJob( + String project, + String displayName, + String datasetId, + String instructionUri, + String inputsSchemaUri, + String annotationSpec) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = "{\"annotation_specs\": [ " + annotationSpec + "]}"; + Value.Builder annotationSpecValue = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, annotationSpecValue); + + DatasetName datasetName = DatasetName.of(project, location, datasetId); + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri(inputsSchemaUri) + .addDatasets(datasetName.toString()) + .setInputs(annotationSpecValue) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", "my_test_saved_query") + .build(); + + DataLabelingJob dataLabelingJobResponse = + jobServiceClient.createDataLabelingJob(locationName, dataLabelingJob); + + System.out.println("Create Data Labeling Job Response"); + System.out.format("\tName: %s\n", dataLabelingJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", dataLabelingJobResponse.getDisplayName()); + System.out.format("\tDatasets: %s\n", dataLabelingJobResponse.getDatasetsList()); + System.out.format("\tLabeler Count: %s\n", dataLabelingJobResponse.getLabelerCount()); + System.out.format("\tInstruction Uri: %s\n", dataLabelingJobResponse.getInstructionUri()); + System.out.format("\tInputs Schema Uri: %s\n", dataLabelingJobResponse.getInputsSchemaUri()); + System.out.format("\tInputs: %s\n", dataLabelingJobResponse.getInputs()); + System.out.format("\tState: %s\n", dataLabelingJobResponse.getState()); + System.out.format("\tLabeling Progress: %s\n", dataLabelingJobResponse.getLabelingProgress()); + System.out.format("\tCreate Time: %s\n", dataLabelingJobResponse.getCreateTime()); + System.out.format("\tUpdate Time: %s\n", dataLabelingJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", dataLabelingJobResponse.getLabelsMap()); + System.out.format( + "\tSpecialist Pools: %s\n", dataLabelingJobResponse.getSpecialistPoolsList()); + for (Map.Entry annotationLabelMap : + dataLabelingJobResponse.getAnnotationLabelsMap().entrySet()) { + System.out.println("\tAnnotation Level"); + System.out.format("\t\tkey: %s\n", annotationLabelMap.getKey()); + System.out.format("\t\tvalue: %s\n", annotationLabelMap.getValue()); + } + Money money = dataLabelingJobResponse.getCurrentSpend(); + + System.out.println("\tCurrent Spend"); + System.out.format("\t\tCurrency Code: %s\n", money.getCurrencyCode()); + System.out.format("\t\tUnits: %s\n", money.getUnits()); + System.out.format("\t\tNanos: %s\n", money.getNanos()); + } + } +} +// [END aiplatform_create_data_labeling_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java new file mode 100644 index 00000000000..528e4b2d0f5 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java @@ -0,0 +1,104 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_specialist_pool_sample] +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.SpecialistPoolName; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateDataLabelingJobSpecialistPoolSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String dataset = "DATASET"; + String specialistPool = "SPECIALIST_POOL"; + String instructionUri = "INSTRUCTION_URI"; + String inputsSchemaUri = "INPUTS_SCHEMA_URI"; + String annotationSpec = "ANNOTATION_SPEC"; + createDataLabelingJobSpecialistPoolSample( + project, + displayName, + dataset, + specialistPool, + instructionUri, + inputsSchemaUri, + annotationSpec); + } + + static void createDataLabelingJobSpecialistPoolSample( + String project, + String displayName, + String dataset, + String specialistPool, + String instructionUri, + String inputsSchemaUri, + String annotationSpec) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + JsonArray jsonAnnotationSpecs = new JsonArray(); + jsonAnnotationSpecs.add(annotationSpec); + JsonObject jsonInputs = new JsonObject(); + jsonInputs.add("annotation_specs", jsonAnnotationSpecs); + Value.Builder inputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonInputs.toString(), inputsBuilder); + Value inputs = inputsBuilder.build(); + + String datasetName = DatasetName.of(project, location, dataset).toString(); + String specialistPoolName = + SpecialistPoolName.of(project, location, specialistPool).toString(); + + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .addDatasets(datasetName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri(inputsSchemaUri) + .setInputs(inputs) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", + "data_labeling_job_specialist_pool") + .addSpecialistPools(specialistPoolName) + .build(); + LocationName parent = LocationName.of(project, location); + DataLabelingJob response = client.createDataLabelingJob(parent, dataLabelingJob); + System.out.format("response: %s\n", response); + } + } +} + +// [END aiplatform_create_data_labeling_job_specialist_pool_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java new file mode 100644 index 00000000000..cabf2399735 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java @@ -0,0 +1,115 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_data_labeling_job_video_sample] + +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.type.Money; +import java.io.IOException; +import java.util.Map; + +public class CreateDataLabelingJobVideoSample { + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String displayName = "YOUR_DATA_LABELING_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String instructionUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_data_labeling_source/file.pdf"; + String annotationSpec = "YOUR_ANNOTATION_SPEC"; + createDataLabelingJobVideo(project, displayName, datasetId, instructionUri, annotationSpec); + } + + static void createDataLabelingJobVideo( + String project, + String displayName, + String datasetId, + String instructionUri, + String annotationSpec) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = "{\"annotation_specs\": [ " + annotationSpec + "]}"; + Value.Builder annotationSpecValue = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, annotationSpecValue); + + DatasetName datasetName = DatasetName.of(project, location, datasetId); + DataLabelingJob dataLabelingJob = + DataLabelingJob.newBuilder() + .setDisplayName(displayName) + .setLabelerCount(1) + .setInstructionUri(instructionUri) + .setInputsSchemaUri( + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/" + + "video_classification.yaml") + .addDatasets(datasetName.toString()) + .setInputs(annotationSpecValue) + .putAnnotationLabels( + "aiplatform.googleapis.com/annotation_set_name", "my_test_saved_query") + .build(); + + DataLabelingJob dataLabelingJobResponse = + jobServiceClient.createDataLabelingJob(locationName, dataLabelingJob); + + System.out.println("Create Data Labeling Job Video Response"); + System.out.format("\tName: %s\n", dataLabelingJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", dataLabelingJobResponse.getDisplayName()); + System.out.format("\tDatasets: %s\n", dataLabelingJobResponse.getDatasetsList()); + System.out.format("\tLabeler Count: %s\n", dataLabelingJobResponse.getLabelerCount()); + System.out.format("\tInstruction Uri: %s\n", dataLabelingJobResponse.getInstructionUri()); + System.out.format("\tInputs Schema Uri: %s\n", dataLabelingJobResponse.getInputsSchemaUri()); + System.out.format("\tInputs: %s\n", dataLabelingJobResponse.getInputs()); + System.out.format("\tState: %s\n", dataLabelingJobResponse.getState()); + System.out.format("\tLabeling Progress: %s\n", dataLabelingJobResponse.getLabelingProgress()); + System.out.format("\tCreate Time: %s\n", dataLabelingJobResponse.getCreateTime()); + System.out.format("\tUpdate Time: %s\n", dataLabelingJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", dataLabelingJobResponse.getLabelsMap()); + System.out.format( + "\tSpecialist Pools: %s\n", dataLabelingJobResponse.getSpecialistPoolsList()); + for (Map.Entry annotationLabelMap : + dataLabelingJobResponse.getAnnotationLabelsMap().entrySet()) { + System.out.println("\tAnnotation Level"); + System.out.format("\t\tkey: %s\n", annotationLabelMap.getKey()); + System.out.format("\t\tvalue: %s\n", annotationLabelMap.getValue()); + } + + Money money = dataLabelingJobResponse.getCurrentSpend(); + System.out.println("\tCurrent Spend"); + System.out.format("\t\tCurrency Code: %s\n", money.getCurrencyCode()); + System.out.format("\t\tUnits: %s\n", money.getUnits()); + System.out.format("\t\tNanos: %s\n", money.getNanos()); + } + } +} +// [END aiplatform_create_data_labeling_job_video_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDatasetImageSample.java b/aiplatform/src/main/java/aiplatform/CreateDatasetImageSample.java new file mode 100644 index 00000000000..6fcb27157ef --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDatasetImageSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_image_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1.Dataset; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetImageSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + createDatasetImageSample(project, datasetDisplayName); + } + + static void createDatasetImageSample(String project, String datasetDisplayName) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + + System.out.println("Create Image Dataset Response"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + System.out.format("Create Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("Labels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_image_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDatasetSample.java b/aiplatform/src/main/java/aiplatform/CreateDatasetSample.java new file mode 100644 index 00000000000..0b0817f6904 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDatasetSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1.Dataset; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String metadataSchemaUri = "YOUR_METADATA_SCHEMA_URI"; + createDatasetSample(project, datasetDisplayName, metadataSchemaUri); + } + + static void createDatasetSample( + String project, String datasetDisplayName, String metadataSchemaUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Response"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + System.out.format("Create Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("Labels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java b/aiplatform/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java new file mode 100644 index 00000000000..fd7628be2fa --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_tabular_bigquery_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1.Dataset; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetTabularBigquerySample { + + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String bigqueryDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String bigqueryUri = + "bq://YOUR_GOOGLE_CLOUD_PROJECT_ID.BIGQUERY_DATASET_ID.BIGQUERY_TABLE_OR_VIEW_ID"; + createDatasetTableBigquery(project, bigqueryDisplayName, bigqueryUri); + } + + static void createDatasetTableBigquery( + String project, String bigqueryDisplayName, String bigqueryUri) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings settings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = DatasetServiceClient.create(settings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/tables_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"input_config\": {\"bigquery_source\": {\"uri\": \"" + bigqueryUri + "\"}}}"; + Value.Builder metaData = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, metaData); + + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(bigqueryDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .setMetadata(metaData) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Table Bigquery sample"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + } + } +} +// [END aiplatform_create_dataset_tabular_bigquery_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java b/aiplatform/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java new file mode 100644 index 00000000000..87bb139c9e2 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_tabular_gcs_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1.Dataset; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetTabularGcsSample { + + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String gcsSourceUri = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_gcs_table/file.csv"; + ; + createDatasetTableGcs(project, datasetDisplayName, gcsSourceUri); + } + + static void createDatasetTableGcs(String project, String datasetDisplayName, String gcsSourceUri) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings settings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = DatasetServiceClient.create(settings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/tables_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"input_config\": {\"gcs_source\": {\"uri\": [\"" + gcsSourceUri + "\"]}}}"; + Value.Builder metaData = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, metaData); + + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .setMetadata(metaData) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Table GCS sample"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + } + } +} +// [END aiplatform_create_dataset_tabular_gcs_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDatasetTextSample.java b/aiplatform/src/main/java/aiplatform/CreateDatasetTextSample.java new file mode 100644 index 00000000000..f919467e930 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDatasetTextSample.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_text_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1.Dataset; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetTextSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + + createDatasetTextSample(project, datasetDisplayName); + } + + static void createDatasetTextSample(String project, String datasetDisplayName) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml"; + + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(180, TimeUnit.SECONDS); + + System.out.println("Create Text Dataset Response"); + System.out.format("\tName: %s\n", datasetResponse.getName()); + System.out.format("\tDisplay Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("\tMetadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("\tMetadata: %s\n", datasetResponse.getMetadata()); + System.out.format("\tCreate Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("\tUpdate Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_text_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateDatasetVideoSample.java b/aiplatform/src/main/java/aiplatform/CreateDatasetVideoSample.java new file mode 100644 index 00000000000..65e96a7c8b7 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateDatasetVideoSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_video_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1.Dataset; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetVideoSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetVideoDisplayName = "YOUR_DATASET_VIDEO_DISPLAY_NAME"; + createDatasetSample(datasetVideoDisplayName, project); + } + + static void createDatasetSample(String datasetVideoDisplayName, String project) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetVideoDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Video Response"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + System.out.format("Create Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("Labels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_video_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateEndpointSample.java b/aiplatform/src/main/java/aiplatform/CreateEndpointSample.java new file mode 100644 index 00000000000..e0d9214342c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateEndpointSample.java @@ -0,0 +1,74 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_endpoint_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateEndpointOperationMetadata; +import com.google.cloud.aiplatform.v1.Endpoint; +import com.google.cloud.aiplatform.v1.EndpointServiceClient; +import com.google.cloud.aiplatform.v1.EndpointServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateEndpointSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String endpointDisplayName = "YOUR_ENDPOINT_DISPLAY_NAME"; + createEndpointSample(project, endpointDisplayName); + } + + static void createEndpointSample(String project, String endpointDisplayName) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + EndpointServiceSettings endpointServiceSettings = + EndpointServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (EndpointServiceClient endpointServiceClient = + EndpointServiceClient.create(endpointServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + Endpoint endpoint = Endpoint.newBuilder().setDisplayName(endpointDisplayName).build(); + + OperationFuture endpointFuture = + endpointServiceClient.createEndpointAsync(locationName, endpoint); + System.out.format("Operation name: %s\n", endpointFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Endpoint endpointResponse = endpointFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Endpoint Response"); + System.out.format("Name: %s\n", endpointResponse.getName()); + System.out.format("Display Name: %s\n", endpointResponse.getDisplayName()); + System.out.format("Description: %s\n", endpointResponse.getDescription()); + System.out.format("Labels: %s\n", endpointResponse.getLabelsMap()); + System.out.format("Create Time: %s\n", endpointResponse.getCreateTime()); + System.out.format("Update Time: %s\n", endpointResponse.getUpdateTime()); + } + } +} +// [END aiplatform_create_endpoint_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateEntityTypeMonitoringSample.java b/aiplatform/src/main/java/aiplatform/CreateEntityTypeMonitoringSample.java new file mode 100644 index 00000000000..b234d032497 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateEntityTypeMonitoringSample.java @@ -0,0 +1,114 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Create an entity type so that you can create its related features. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_create_entity_type_monitoring_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateEntityTypeOperationMetadata; +import com.google.cloud.aiplatform.v1.CreateEntityTypeRequest; +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.FeaturestoreMonitoringConfig; +import com.google.cloud.aiplatform.v1.FeaturestoreMonitoringConfig.SnapshotAnalysis; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateEntityTypeMonitoringSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String description = "YOUR_ENTITY_TYPE_DESCRIPTION"; + int monitoringIntervalDays = 1; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + createEntityTypeMonitoringSample( + project, + featurestoreId, + entityTypeId, + description, + monitoringIntervalDays, + location, + endpoint, + timeout); + } + + static void createEntityTypeMonitoringSample( + String project, + String featurestoreId, + String entityTypeId, + String description, + int monitoringIntervalDays, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + FeaturestoreMonitoringConfig featurestoreMonitoringConfig = + FeaturestoreMonitoringConfig.newBuilder() + .setSnapshotAnalysis( + SnapshotAnalysis.newBuilder().setMonitoringIntervalDays(monitoringIntervalDays)) + .build(); + + EntityType entityType = + EntityType.newBuilder() + .setDescription(description) + .setMonitoringConfig(featurestoreMonitoringConfig) + .build(); + + CreateEntityTypeRequest createEntityTypeRequest = + CreateEntityTypeRequest.newBuilder() + .setParent(FeaturestoreName.of(project, location, featurestoreId).toString()) + .setEntityType(entityType) + .setEntityTypeId(entityTypeId) + .build(); + + OperationFuture entityTypeFuture = + featurestoreServiceClient.createEntityTypeAsync(createEntityTypeRequest); + System.out.format( + "Operation name: %s%n", entityTypeFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + EntityType entityTypeResponse = entityTypeFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Create Entity Type Monitoring Response"); + System.out.format("Name: %s%n", entityTypeResponse.getName()); + } + } +} +// [END aiplatform_create_entity_type_monitoring_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateEntityTypeSample.java b/aiplatform/src/main/java/aiplatform/CreateEntityTypeSample.java new file mode 100644 index 00000000000..012ac19615e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateEntityTypeSample.java @@ -0,0 +1,93 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Create an entity type so that you can create its related features. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_create_entity_type_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateEntityTypeOperationMetadata; +import com.google.cloud.aiplatform.v1.CreateEntityTypeRequest; +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateEntityTypeSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String description = "YOUR_ENTITY_TYPE_DESCRIPTION"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + createEntityTypeSample( + project, featurestoreId, entityTypeId, description, location, endpoint, timeout); + } + + static void createEntityTypeSample( + String project, + String featurestoreId, + String entityTypeId, + String description, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + EntityType entityType = EntityType.newBuilder().setDescription(description).build(); + + CreateEntityTypeRequest createEntityTypeRequest = + CreateEntityTypeRequest.newBuilder() + .setParent(FeaturestoreName.of(project, location, featurestoreId).toString()) + .setEntityType(entityType) + .setEntityTypeId(entityTypeId) + .build(); + + OperationFuture entityTypeFuture = + featurestoreServiceClient.createEntityTypeAsync(createEntityTypeRequest); + System.out.format( + "Operation name: %s%n", entityTypeFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + EntityType entityTypeResponse = entityTypeFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Create Entity Type Response"); + System.out.format("Name: %s%n", entityTypeResponse.getName()); + } + } +} +// [END aiplatform_create_entity_type_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateFeatureSample.java b/aiplatform/src/main/java/aiplatform/CreateFeatureSample.java new file mode 100644 index 00000000000..10c18736f20 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateFeatureSample.java @@ -0,0 +1,108 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Create a single feature for an existing entity type. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_create_feature_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.CreateFeatureOperationMetadata; +import com.google.cloud.aiplatform.v1.CreateFeatureRequest; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.Feature.ValueType; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateFeatureSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String featureId = "YOUR_FEATURE_ID"; + String description = "YOUR_FEATURE_DESCRIPTION"; + ValueType valueType = ValueType.STRING; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 900; + createFeatureSample( + project, + featurestoreId, + entityTypeId, + featureId, + description, + valueType, + location, + endpoint, + timeout); + } + + static void createFeatureSample( + String project, + String featurestoreId, + String entityTypeId, + String featureId, + String description, + ValueType valueType, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + Feature feature = + Feature.newBuilder().setDescription(description).setValueType(valueType).build(); + + CreateFeatureRequest createFeatureRequest = + CreateFeatureRequest.newBuilder() + .setParent( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setFeature(feature) + .setFeatureId(featureId) + .build(); + + OperationFuture featureFuture = + featurestoreServiceClient.createFeatureAsync(createFeatureRequest); + System.out.format("Operation name: %s%n", featureFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Feature featureResponse = featureFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Create Feature Response"); + System.out.format("Name: %s%n", featureResponse.getName()); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_create_feature_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateFeaturestoreFixedNodesSample.java b/aiplatform/src/main/java/aiplatform/CreateFeaturestoreFixedNodesSample.java new file mode 100644 index 00000000000..69add3ff170 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateFeaturestoreFixedNodesSample.java @@ -0,0 +1,95 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Create a featurestore resource to contain entity types and features. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_create_featurestore_fixed_nodes_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateFeaturestoreOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.CreateFeaturestoreRequest; +import com.google.cloud.aiplatform.v1beta1.Featurestore; +import com.google.cloud.aiplatform.v1beta1.Featurestore.OnlineServingConfig; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateFeaturestoreFixedNodesSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + int fixedNodeCount = 1; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 900; + createFeaturestoreFixedNodesSample( + project, featurestoreId, fixedNodeCount, location, endpoint, timeout); + } + + static void createFeaturestoreFixedNodesSample( + String project, + String featurestoreId, + int fixedNodeCount, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + OnlineServingConfig.Builder builderValue = + OnlineServingConfig.newBuilder().setFixedNodeCount(fixedNodeCount); + Featurestore featurestore = + Featurestore.newBuilder().setOnlineServingConfig(builderValue).build(); + + CreateFeaturestoreRequest createFeaturestoreRequest = + CreateFeaturestoreRequest.newBuilder() + .setParent(LocationName.of(project, location).toString()) + .setFeaturestore(featurestore) + .setFeaturestoreId(featurestoreId) + .build(); + + OperationFuture featurestoreFuture = + featurestoreServiceClient.createFeaturestoreAsync(createFeaturestoreRequest); + System.out.format( + "Operation name: %s%n", featurestoreFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Featurestore featurestoreResponse = featurestoreFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Create Featurestore Response"); + System.out.format("Name: %s%n", featurestoreResponse.getName()); + } + } +} +// [END aiplatform_create_featurestore_fixed_nodes_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateFeaturestoreSample.java b/aiplatform/src/main/java/aiplatform/CreateFeaturestoreSample.java new file mode 100644 index 00000000000..50e558fbb14 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateFeaturestoreSample.java @@ -0,0 +1,101 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Create a featurestore resource to contain entity types and features. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_create_featurestore_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateFeaturestoreOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.CreateFeaturestoreRequest; +import com.google.cloud.aiplatform.v1beta1.Featurestore; +import com.google.cloud.aiplatform.v1beta1.Featurestore.OnlineServingConfig; +import com.google.cloud.aiplatform.v1beta1.Featurestore.OnlineServingConfig.Scaling; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateFeaturestoreSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + int minNodeCount = 1; + int maxNodeCount = 5; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 900; + createFeaturestoreSample( + project, featurestoreId, minNodeCount, maxNodeCount, location, endpoint, timeout); + } + + static void createFeaturestoreSample( + String project, + String featurestoreId, + int minNodeCount, + int maxNodeCount, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + OnlineServingConfig.Builder builderValue = + OnlineServingConfig.newBuilder() + .setScaling( + Scaling.newBuilder().setMinNodeCount(minNodeCount).setMaxNodeCount(maxNodeCount)); + Featurestore featurestore = + Featurestore.newBuilder().setOnlineServingConfig(builderValue).build(); + String parent = LocationName.of(project, location).toString(); + + CreateFeaturestoreRequest createFeaturestoreRequest = + CreateFeaturestoreRequest.newBuilder() + .setParent(parent) + .setFeaturestore(featurestore) + .setFeaturestoreId(featurestoreId) + .build(); + + OperationFuture featurestoreFuture = + featurestoreServiceClient.createFeaturestoreAsync(createFeaturestoreRequest); + System.out.format( + "Operation name: %s%n", featurestoreFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Featurestore featurestoreResponse = featurestoreFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Create Featurestore Response"); + System.out.format("Name: %s%n", featurestoreResponse.getName()); + } + } +} +// [END aiplatform_create_featurestore_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java b/aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java new file mode 100644 index 00000000000..0d86232e283 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java @@ -0,0 +1,174 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_hyperparameter_tuning_job_python_package_sample] +import com.google.cloud.aiplatform.v1.AcceleratorType; +import com.google.cloud.aiplatform.v1.CustomJobSpec; +import com.google.cloud.aiplatform.v1.HyperparameterTuningJob; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.PythonPackageSpec; +import com.google.cloud.aiplatform.v1.StudySpec; +import com.google.cloud.aiplatform.v1.StudySpec.MetricSpec; +import com.google.cloud.aiplatform.v1.StudySpec.MetricSpec.GoalType; +import com.google.cloud.aiplatform.v1.StudySpec.ParameterSpec; +import com.google.cloud.aiplatform.v1.StudySpec.ParameterSpec.ConditionalParameterSpec; +import com.google.cloud.aiplatform.v1.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition; +import com.google.cloud.aiplatform.v1.StudySpec.ParameterSpec.DiscreteValueSpec; +import com.google.cloud.aiplatform.v1.StudySpec.ParameterSpec.DoubleValueSpec; +import com.google.cloud.aiplatform.v1.StudySpec.ParameterSpec.ScaleType; +import com.google.cloud.aiplatform.v1.WorkerPoolSpec; +import java.io.IOException; +import java.util.Arrays; + +public class CreateHyperparameterTuningJobPythonPackageSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String executorImageUri = "EXECUTOR_IMAGE_URI"; + String packageUri = "PACKAGE_URI"; + String pythonModule = "PYTHON_MODULE"; + createHyperparameterTuningJobPythonPackageSample( + project, displayName, executorImageUri, packageUri, pythonModule); + } + + static void createHyperparameterTuningJobPythonPackageSample( + String project, + String displayName, + String executorImageUri, + String packageUri, + String pythonModule) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + // study spec + MetricSpec metric = + MetricSpec.newBuilder().setMetricId("val_rmse").setGoal(GoalType.MINIMIZE).build(); + + // decay + DoubleValueSpec doubleValueSpec = + DoubleValueSpec.newBuilder().setMinValue(1e-07).setMaxValue(1).build(); + ParameterSpec parameterDecaySpec = + ParameterSpec.newBuilder() + .setParameterId("decay") + .setDoubleValueSpec(doubleValueSpec) + .setScaleType(ScaleType.UNIT_LINEAR_SCALE) + .build(); + Double[] decayValues = {32.0, 64.0}; + DiscreteValueCondition discreteValueDecay = + DiscreteValueCondition.newBuilder().addAllValues(Arrays.asList(decayValues)).build(); + ConditionalParameterSpec conditionalParameterDecay = + ConditionalParameterSpec.newBuilder() + .setParameterSpec(parameterDecaySpec) + .setParentDiscreteValues(discreteValueDecay) + .build(); + + // learning rate + ParameterSpec parameterLearningSpec = + ParameterSpec.newBuilder() + .setParameterId("learning_rate") + .setDoubleValueSpec(doubleValueSpec) // Use the same min/max as for decay + .setScaleType(ScaleType.UNIT_LINEAR_SCALE) + .build(); + + Double[] learningRateValues = {4.0, 8.0, 16.0}; + DiscreteValueCondition discreteValueLearning = + DiscreteValueCondition.newBuilder() + .addAllValues(Arrays.asList(learningRateValues)) + .build(); + ConditionalParameterSpec conditionalParameterLearning = + ConditionalParameterSpec.newBuilder() + .setParameterSpec(parameterLearningSpec) + .setParentDiscreteValues(discreteValueLearning) + .build(); + + // batch size + Double[] batchSizeValues = {4.0, 8.0, 16.0, 32.0, 64.0, 128.0}; + + DiscreteValueSpec discreteValueSpec = + DiscreteValueSpec.newBuilder().addAllValues(Arrays.asList(batchSizeValues)).build(); + ParameterSpec parameter = + ParameterSpec.newBuilder() + .setParameterId("batch_size") + .setDiscreteValueSpec(discreteValueSpec) + .setScaleType(ScaleType.UNIT_LINEAR_SCALE) + .addConditionalParameterSpecs(conditionalParameterDecay) + .addConditionalParameterSpecs(conditionalParameterLearning) + .build(); + + // trial_job_spec + MachineSpec machineSpec = + MachineSpec.newBuilder() + .setMachineType("n1-standard-4") + .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80) + .setAcceleratorCount(1) + .build(); + + PythonPackageSpec pythonPackageSpec = + PythonPackageSpec.newBuilder() + .setExecutorImageUri(executorImageUri) + .addPackageUris(packageUri) + .setPythonModule(pythonModule) + .build(); + + WorkerPoolSpec workerPoolSpec = + WorkerPoolSpec.newBuilder() + .setMachineSpec(machineSpec) + .setReplicaCount(1) + .setPythonPackageSpec(pythonPackageSpec) + .build(); + + StudySpec studySpec = + StudySpec.newBuilder() + .addMetrics(metric) + .addParameters(parameter) + .setAlgorithm(StudySpec.Algorithm.RANDOM_SEARCH) + .build(); + CustomJobSpec trialJobSpec = + CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec).build(); + // hyperparameter_tuning_job + HyperparameterTuningJob hyperparameterTuningJob = + HyperparameterTuningJob.newBuilder() + .setDisplayName(displayName) + .setMaxTrialCount(4) + .setParallelTrialCount(2) + .setStudySpec(studySpec) + .setTrialJobSpec(trialJobSpec) + .build(); + LocationName parent = LocationName.of(project, location); + HyperparameterTuningJob response = + client.createHyperparameterTuningJob(parent, hyperparameterTuningJob); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_hyperparameter_tuning_job_python_package_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java b/aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java new file mode 100644 index 00000000000..b2295270a46 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobSample.java @@ -0,0 +1,106 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_hyperparameter_tuning_job_sample] +import com.google.cloud.aiplatform.v1.AcceleratorType; +import com.google.cloud.aiplatform.v1.ContainerSpec; +import com.google.cloud.aiplatform.v1.CustomJobSpec; +import com.google.cloud.aiplatform.v1.HyperparameterTuningJob; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.StudySpec; +import com.google.cloud.aiplatform.v1.WorkerPoolSpec; +import java.io.IOException; + +public class CreateHyperparameterTuningJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String containerImageUri = "CONTAINER_IMAGE_URI"; + createHyperparameterTuningJobSample(project, displayName, containerImageUri); + } + + static void createHyperparameterTuningJobSample( + String project, String displayName, String containerImageUri) throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + StudySpec.MetricSpec metric0 = + StudySpec.MetricSpec.newBuilder() + .setMetricId("accuracy") + .setGoal(StudySpec.MetricSpec.GoalType.MAXIMIZE) + .build(); + StudySpec.ParameterSpec.DoubleValueSpec doubleValueSpec = + StudySpec.ParameterSpec.DoubleValueSpec.newBuilder() + .setMinValue(0.001) + .setMaxValue(0.1) + .build(); + StudySpec.ParameterSpec parameter0 = + StudySpec.ParameterSpec.newBuilder() + // Learning rate. + .setParameterId("lr") + .setDoubleValueSpec(doubleValueSpec) + .build(); + StudySpec studySpec = + StudySpec.newBuilder().addMetrics(metric0).addParameters(parameter0).build(); + MachineSpec machineSpec = + MachineSpec.newBuilder() + .setMachineType("n1-standard-4") + .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_K80) + .setAcceleratorCount(1) + .build(); + ContainerSpec containerSpec = + ContainerSpec.newBuilder().setImageUri(containerImageUri).build(); + WorkerPoolSpec workerPoolSpec0 = + WorkerPoolSpec.newBuilder() + .setMachineSpec(machineSpec) + .setReplicaCount(1) + .setContainerSpec(containerSpec) + .build(); + CustomJobSpec trialJobSpec = + CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec0).build(); + HyperparameterTuningJob hyperparameterTuningJob = + HyperparameterTuningJob.newBuilder() + .setDisplayName(displayName) + .setMaxTrialCount(2) + .setParallelTrialCount(1) + .setMaxFailedTrialCount(1) + .setStudySpec(studySpec) + .setTrialJobSpec(trialJobSpec) + .build(); + LocationName parent = LocationName.of(project, location); + HyperparameterTuningJob response = + client.createHyperparameterTuningJob(parent, hyperparameterTuningJob); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_hyperparameter_tuning_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java new file mode 100644 index 00000000000..53e9867a6ff --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineCustomJobSample.java @@ -0,0 +1,119 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_custom_job_sample] +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateTrainingPipelineCustomJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelDisplayName = "MODEL_DISPLAY_NAME"; + String containerImageUri = "CONTAINER_IMAGE_URI"; + String baseOutputDirectoryPrefix = "BASE_OUTPUT_DIRECTORY_PREFIX"; + createTrainingPipelineCustomJobSample( + project, displayName, modelDisplayName, containerImageUri, baseOutputDirectoryPrefix); + } + + static void createTrainingPipelineCustomJobSample( + String project, + String displayName, + String modelDisplayName, + String containerImageUri, + String baseOutputDirectoryPrefix) + throws IOException { + PipelineServiceSettings settings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { + JsonObject jsonMachineSpec = new JsonObject(); + jsonMachineSpec.addProperty("machineType", "n1-standard-4"); + + JsonArray jsonArgs = new JsonArray(); + jsonArgs.add("--model_dir=$(AIP_MODEL_DIR)"); + + // A working docker image can be found at + // gs://cloud-samples-data/ai-platform/mnist_tfrecord/custom_job + JsonObject jsonContainerSpec = new JsonObject(); + jsonContainerSpec.addProperty("imageUri", containerImageUri); + jsonContainerSpec.add("args", jsonArgs); + + JsonObject jsonJsonWorkerPoolSpec0 = new JsonObject(); + jsonJsonWorkerPoolSpec0.addProperty("replicaCount", 1); + jsonJsonWorkerPoolSpec0.add("machineSpec", jsonMachineSpec); + jsonJsonWorkerPoolSpec0.add("containerSpec", jsonContainerSpec); + + JsonArray jsonWorkerPoolSpecs = new JsonArray(); + jsonWorkerPoolSpecs.add(jsonJsonWorkerPoolSpec0); + + JsonObject jsonBaseOutputDirectory = new JsonObject(); + // The GCS location for outputs must be accessible by the project's AI Platform + // service account. + jsonBaseOutputDirectory.addProperty("output_uri_prefix", baseOutputDirectoryPrefix); + + JsonObject jsonTrainingTaskInputs = new JsonObject(); + jsonTrainingTaskInputs.add("workerPoolSpecs", jsonWorkerPoolSpecs); + jsonTrainingTaskInputs.add("baseOutputDirectory", jsonBaseOutputDirectory); + + Value.Builder trainingTaskInputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder); + Value trainingTaskInputs = trainingTaskInputsBuilder.build(); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"; + String imageUri = "gcr.io/cloud-aiplatform/prediction/tf-cpu.1-15:latest"; + ModelContainerSpec containerSpec = + ModelContainerSpec.newBuilder().setImageUri(imageUri).build(); + Model modelToUpload = + Model.newBuilder() + .setDisplayName(modelDisplayName) + .setContainerSpec(containerSpec) + .build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(displayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setModelToUpload(modelToUpload) + .build(); + LocationName parent = LocationName.of(project, location); + TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_training_pipeline_custom_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java new file mode 100644 index 00000000000..8fad236877c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java @@ -0,0 +1,145 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; + +public class CreateTrainingPipelineCustomTrainingManagedDatasetSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String modelDisplayName = "MODEL_DISPLAY_NAME"; + String datasetId = "DATASET_ID"; + String annotationSchemaUri = "ANNOTATION_SCHEMA_URI"; + String trainingContainerSpecImageUri = "TRAINING_CONTAINER_SPEC_IMAGE_URI"; + String modelContainerSpecImageUri = "MODEL_CONTAINER_SPEC_IMAGE_URI"; + String baseOutputUriPrefix = "BASE_OUTPUT_URI_PREFIX"; + createTrainingPipelineCustomTrainingManagedDatasetSample( + project, + displayName, + modelDisplayName, + datasetId, + annotationSchemaUri, + trainingContainerSpecImageUri, + modelContainerSpecImageUri, + baseOutputUriPrefix); + } + + static void createTrainingPipelineCustomTrainingManagedDatasetSample( + String project, + String displayName, + String modelDisplayName, + String datasetId, + String annotationSchemaUri, + String trainingContainerSpecImageUri, + String modelContainerSpecImageUri, + String baseOutputUriPrefix) + throws IOException { + PipelineServiceSettings settings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { + JsonArray jsonArgs = new JsonArray(); + jsonArgs.add("--model-dir=$(AIP_MODEL_DIR)"); + // training_task_inputs + JsonObject jsonTrainingContainerSpec = new JsonObject(); + jsonTrainingContainerSpec.addProperty("imageUri", trainingContainerSpecImageUri); + // AIP_MODEL_DIR is set by the service according to baseOutputDirectory. + jsonTrainingContainerSpec.add("args", jsonArgs); + + JsonObject jsonMachineSpec = new JsonObject(); + jsonMachineSpec.addProperty("machineType", "n1-standard-8"); + + JsonObject jsonTrainingWorkerPoolSpec = new JsonObject(); + jsonTrainingWorkerPoolSpec.addProperty("replicaCount", 1); + jsonTrainingWorkerPoolSpec.add("machineSpec", jsonMachineSpec); + jsonTrainingWorkerPoolSpec.add("containerSpec", jsonTrainingContainerSpec); + + JsonArray jsonWorkerPoolSpecs = new JsonArray(); + jsonWorkerPoolSpecs.add(jsonTrainingWorkerPoolSpec); + + JsonObject jsonBaseOutputDirectory = new JsonObject(); + jsonBaseOutputDirectory.addProperty("outputUriPrefix", baseOutputUriPrefix); + + JsonObject jsonTrainingTaskInputs = new JsonObject(); + jsonTrainingTaskInputs.add("workerPoolSpecs", jsonWorkerPoolSpecs); + jsonTrainingTaskInputs.add("baseOutputDirectory", jsonBaseOutputDirectory); + + Value.Builder trainingTaskInputsBuilder = Value.newBuilder(); + JsonFormat.parser().merge(jsonTrainingTaskInputs.toString(), trainingTaskInputsBuilder); + Value trainingTaskInputs = trainingTaskInputsBuilder.build(); + // model_to_upload + ModelContainerSpec modelContainerSpec = + ModelContainerSpec.newBuilder().setImageUri(modelContainerSpecImageUri).build(); + Model model = + Model.newBuilder() + .setDisplayName(modelDisplayName) + .setContainerSpec(modelContainerSpec) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(baseOutputUriPrefix).build(); + + // input_data_config + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setAnnotationSchemaUri(annotationSchemaUri) + .setGcsDestination(gcsDestination) + .build(); + + // training_task_definition + String customTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"; + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(displayName) + .setInputDataConfig(inputDataConfig) + .setTrainingTaskDefinition(customTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setModelToUpload(model) + .build(); + LocationName parent = LocationName.of(project, location); + TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java new file mode 100644 index 00000000000..4f9c1e2c57a --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java @@ -0,0 +1,210 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_image_classification_sample] +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineImageClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + createTrainingPipelineImageClassificationSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineImageClassificationSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_image_classification_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + AutoMlImageClassificationInputs autoMlImageClassificationInputs = + AutoMlImageClassificationInputs.newBuilder() + .setModelType(ModelType.CLOUD) + .setMultiLabel(false) + .setBudgetMilliNodeHours(8000) + .setDisableEarlyStopping(false) + .build(); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs)) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Image Classification Response"); + System.out.format("Name: %s\n", trainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data Config"); + System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("Fraction Split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("Predict Schemata"); + System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("Supported Export Format"); + System.out.format("Id: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("Container Spec"); + System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("Command: %s\n", modelContainerSpec.getCommandList()); + System.out.format("Args: %s\n", modelContainerSpec.getArgsList()); + System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("Env"); + System.out.format("Name: %s\n", envVar.getName()); + System.out.format("Value: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("Port"); + System.out.format("Container Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("Deployed Model"); + System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_image_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java new file mode 100644 index 00000000000..65ade6ea4ad --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java @@ -0,0 +1,210 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_image_object_detection_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs.ModelType; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineImageObjectDetectionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + createTrainingPipelineImageObjectDetectionSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineImageObjectDetectionSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_image_object_detection_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + AutoMlImageObjectDetectionInputs autoMlImageObjectDetectionInputs = + AutoMlImageObjectDetectionInputs.newBuilder() + .setModelType(ModelType.CLOUD_HIGH_ACCURACY_1) + .setBudgetMilliNodeHours(20000) + .setDisableEarlyStopping(false) + .build(); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageObjectDetectionInputs)) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Image Object Detection Response"); + System.out.format("Name: %s\n", trainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data Config"); + System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("Fraction Split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("Predict Schemata"); + System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("Supported Export Format"); + System.out.format("Id: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("Container Spec"); + System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("Command: %s\n", modelContainerSpec.getCommandList()); + System.out.format("Args: %s\n", modelContainerSpec.getArgsList()); + System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("Env"); + System.out.format("Name: %s\n", envVar.getName()); + System.out.format("Value: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("Port"); + System.out.format("Container Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("Deployed Model"); + System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_image_object_detection_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineSample.java new file mode 100644 index 00000000000..33f94753e54 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineSample.java @@ -0,0 +1,210 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_sample] + +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String trainingTaskDefinition = "YOUR_TRAINING_TASK_DEFINITION"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + createTrainingPipelineSample( + project, trainingPipelineDisplayName, datasetId, trainingTaskDefinition, modelDisplayName); + } + + static void createTrainingPipelineSample( + String project, + String trainingPipelineDisplayName, + String datasetId, + String trainingTaskDefinition, + String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000," + + " \"disableEarlyStopping\": false}"; + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Response"); + System.out.format("Name: %s\n", trainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data Config"); + System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("Fraction Split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("Predict Schemata"); + System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("Supported Export Format"); + System.out.format("Id: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("Container Spec"); + System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("Command: %s\n", modelContainerSpec.getCommandList()); + System.out.format("Args: %s\n", modelContainerSpec.getArgsList()); + System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("Env"); + System.out.format("Name: %s\n", envVar.getName()); + System.out.format("Value: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("Port"); + System.out.format("Container Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("Deployed Model"); + System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java new file mode 100644 index 00000000000..107e8c01a4c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java @@ -0,0 +1,249 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_tabular_classification_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.ArrayList; + +public class CreateTrainingPipelineTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String targetColumn = "TARGET_COLUMN"; + createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn); + } + + static void createTrainingPipelineTableClassification( + String project, String modelDisplayName, String datasetId, String targetColumn) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; + + // Set the columns used for training and their data types + Transformation transformation1 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build()) + .build(); + Transformation transformation2 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build()) + .build(); + Transformation transformation3 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build()) + .build(); + Transformation transformation4 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build()) + .build(); + + ArrayList transformationArrayList = new ArrayList<>(); + transformationArrayList.add(transformation1); + transformationArrayList.add(transformation2); + transformationArrayList.add(transformation3); + transformationArrayList.add(transformation4); + + AutoMlTablesInputs autoMlTablesInputs = + AutoMlTablesInputs.newBuilder() + .setTargetColumn(targetColumn) + .setPredictionType("classification") + .addAllTransformations(transformationArrayList) + .setTrainBudgetMilliNodeHours(8000) + .build(); + + FractionSplit fractionSplit = + FractionSplit.newBuilder() + .setTrainingFraction(0.8) + .setValidationFraction(0.1) + .setTestFraction(0.1) + .build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setFractionSplit(fractionSplit) + .build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(modelDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs)) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Tabular Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format( + "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); + System.out.format( + "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_tabular_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java new file mode 100644 index 00000000000..427dae0c0cd --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java @@ -0,0 +1,321 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_tabular_regression_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.ArrayList; + +public class CreateTrainingPipelineTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String targetColumn = "TARGET_COLUMN"; + createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn); + } + + static void createTrainingPipelineTableRegression( + String project, String modelDisplayName, String datasetId, String targetColumn) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; + + // Set the columns used for training and their data types + ArrayList tranformations = new ArrayList<>(); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setTimestamp( + TimestampTransformation.newBuilder() + .setColumnName("TIMESTAMP_1unique_NULLABLE") + .setInvalidValuesAllowed(true)) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setTimestamp( + TimestampTransformation.newBuilder() + .setColumnName("DATETIME_1unique_NULLABLE") + .setInvalidValuesAllowed(true)) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE")) + .build()); + + AutoMlTablesInputs trainingTaskInputs = + AutoMlTablesInputs.newBuilder() + .addAllTransformations(tranformations) + .setTargetColumn(targetColumn) + .setPredictionType("regression") + .setTrainBudgetMilliNodeHours(8000) + .setDisableEarlyStopping(false) + // supported regression optimisation objectives: minimize-rmse, + // minimize-mae, minimize-rmsle + .setOptimizationObjective("minimize-rmse") + .build(); + + FractionSplit fractionSplit = + FractionSplit.newBuilder() + .setTrainingFraction(0.8) + .setValidationFraction(0.1) + .setTestFraction(0.1) + .build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setFractionSplit(fractionSplit) + .build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(modelDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Tabular Regression Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format( + "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); + System.out.format( + "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_tabular_regression_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java new file mode 100644 index 00000000000..ac338beb37c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java @@ -0,0 +1,209 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_text_classification_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTextClassificationInputs; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineTextClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + + createTrainingPipelineTextClassificationSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineTextClassificationSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_text_classification_1.0.0.yaml"; + + LocationName locationName = LocationName.of(project, location); + + AutoMlTextClassificationInputs trainingTaskInputs = + AutoMlTextClassificationInputs.newBuilder().setMultiLabel(false).build(); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Text Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\t\tPredict Schemata"); + System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\t\tSupported Export Format"); + System.out.format("\t\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("\t\tContainer Spec"); + System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList()); + System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList()); + System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("\t\t\tEnv"); + System.out.format("\t\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("\t\t\tPort"); + System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\t\tDeployed Model"); + System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_text_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java new file mode 100644 index 00000000000..63dc1348461 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java @@ -0,0 +1,205 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_text_entity_extraction_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + + createTrainingPipelineTextEntityExtractionSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineTextEntityExtractionSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_text_extraction_1.0.0.yaml"; + + LocationName locationName = LocationName.of(project, location); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.EMPTY_VALUE) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Text Entity Extraction Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\t\tPredict Schemata"); + System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\t\tSupported Export Format"); + System.out.format("\t\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("\t\tContainer Spec"); + System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList()); + System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList()); + System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("\t\t\tEnv"); + System.out.format("\t\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("\t\t\tPort"); + System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\t\tDeployed Model"); + System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_text_entity_extraction_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..ef87a9bfd2a --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java @@ -0,0 +1,213 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_text_sentiment_analysis_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTextSentimentInputs; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + + createTrainingPipelineTextSentimentAnalysisSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineTextSentimentAnalysisSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_text_sentiment_1.0.0.yaml"; + + LocationName locationName = LocationName.of(project, location); + + AutoMlTextSentimentInputs trainingTaskInputs = + AutoMlTextSentimentInputs.newBuilder() + // Sentiment max must be between 1 and 10 inclusive. + // Higher value means positive sentiment. + .setSentimentMax(4) + .build(); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Text Sentiment Analysis Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\t\tPredict Schemata"); + System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\t\tSupported Export Format"); + System.out.format("\t\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("\t\tContainer Spec"); + System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList()); + System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList()); + System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("\t\t\tEnv"); + System.out.format("\t\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("\t\t\tPort"); + System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\t\tDeployed Model"); + System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_text_sentiment_analysis_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java new file mode 100644 index 00000000000..02e15fb5dac --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_video_action_recognition_sample] +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs.ModelType; +import java.io.IOException; + +public class CreateTrainingPipelineVideoActionRecognitionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String displayName = "DISPLAY_NAME"; + String datasetId = "DATASET_ID"; + String modelDisplayName = "MODEL_DISPLAY_NAME"; + createTrainingPipelineVideoActionRecognitionSample( + project, displayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineVideoActionRecognitionSample( + String project, String displayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings settings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { + AutoMlVideoActionRecognitionInputs trainingTaskInputs = + AutoMlVideoActionRecognitionInputs.newBuilder().setModelType(ModelType.CLOUD).build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(displayName) + .setTrainingTaskDefinition( + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_video_action_recognition_1.0.0.yaml") + .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + LocationName parent = LocationName.of(project, location); + TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline); + System.out.format("response: %s\n", response); + System.out.format("Name: %s\n", response.getName()); + } + } +} + +// [END aiplatform_create_training_pipeline_video_action_recognition_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java new file mode 100644 index 00000000000..403476b24b9 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java @@ -0,0 +1,160 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_video_classification_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineVideoClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String videoClassificationDisplayName = + "YOUR_TRAINING_PIPELINE_VIDEO_CLASSIFICATION_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + createTrainingPipelineVideoClassification( + videoClassificationDisplayName, datasetId, modelDisplayName, project); + } + + static void createTrainingPipelineVideoClassification( + String videoClassificationDisplayName, + String datasetId, + String modelDisplayName, + String project) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_video_classification_1.0.0.yaml"; + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(videoClassificationDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.EMPTY_VALUE) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Video Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_video_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java new file mode 100644 index 00000000000..3bd30b4b9d5 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java @@ -0,0 +1,172 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_video_object_tracking_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs; +import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs.ModelType; +import com.google.rpc.Status; +import java.io.IOException; + +public class CreateTrainingPipelineVideoObjectTrackingSample { + + public static void main(String[] args) throws IOException { + String trainingPipelineVideoObjectTracking = + "YOUR_TRAINING_PIPELINE_VIDEO_OBJECT_TRACKING_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + createTrainingPipelineVideoObjectTracking( + trainingPipelineVideoObjectTracking, datasetId, modelDisplayName, project); + } + + static void createTrainingPipelineVideoObjectTracking( + String trainingPipelineVideoObjectTracking, + String datasetId, + String modelDisplayName, + String project) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_video_object_tracking_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + AutoMlVideoObjectTrackingInputs trainingTaskInputs = + AutoMlVideoObjectTrackingInputs.newBuilder().setModelType(ModelType.CLOUD).build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineVideoObjectTracking) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs)) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline createTrainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Video Object Tracking Response"); + System.out.format("Name: %s\n", createTrainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", createTrainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", + createTrainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", + createTrainingPipelineResponse.getTrainingTaskInputs().toString()); + System.out.format( + "Training Task Metadata: %s\n", + createTrainingPipelineResponse.getTrainingTaskMetadata().toString()); + + System.out.format("State: %s\n", createTrainingPipelineResponse.getState().toString()); + System.out.format( + "Create Time: %s\n", createTrainingPipelineResponse.getCreateTime().toString()); + System.out.format("StartTime %s\n", createTrainingPipelineResponse.getStartTime().toString()); + System.out.format("End Time: %s\n", createTrainingPipelineResponse.getEndTime().toString()); + System.out.format( + "Update Time: %s\n", createTrainingPipelineResponse.getUpdateTime().toString()); + System.out.format("Labels: %s\n", createTrainingPipelineResponse.getLabelsMap().toString()); + + InputDataConfig inputDataConfigResponse = createTrainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data config"); + System.out.format("Dataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit(); + System.out.println("Fraction split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = createTrainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %s\n", modelResponse.getLabelsMap()); + + Status status = createTrainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_video_object_tracking_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java b/aiplatform/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java new file mode 100644 index 00000000000..e0675190da6 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_batch_prediction_job_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.BatchPredictionJobName; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteBatchPredictionJobSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID"; + deleteBatchPredictionJobSample(project, batchPredictionJobId); + } + + static void deleteBatchPredictionJobSample(String project, String batchPredictionJobId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + + BatchPredictionJobName batchPredictionJobName = + BatchPredictionJobName.of(project, location, batchPredictionJobId); + + OperationFuture operationFuture = + jobServiceClient.deleteBatchPredictionJobAsync(batchPredictionJobName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Deleted Batch Prediction Job."); + } + } +} +// [END aiplatform_delete_batch_prediction_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteDataLabelingJobSample.java b/aiplatform/src/main/java/aiplatform/DeleteDataLabelingJobSample.java new file mode 100644 index 00000000000..b8c6b969b4a --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteDataLabelingJobSample.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_data_labeling_job_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DataLabelingJobName; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteDataLabelingJobSample { + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String dataLabelingJobId = "YOUR_DATA_LABELING_JOB_ID"; + deleteDataLabelingJob(project, dataLabelingJobId); + } + + static void deleteDataLabelingJob(String project, String dataLabelingJobId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + + DataLabelingJobName dataLabelingJobName = + DataLabelingJobName.of(project, location, dataLabelingJobId); + + OperationFuture operationFuture = + jobServiceClient.deleteDataLabelingJobAsync(dataLabelingJobName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Deleted Data Labeling Job."); + } + } +} +// [END aiplatform_delete_data_labeling_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteDatasetSample.java b/aiplatform/src/main/java/aiplatform/DeleteDatasetSample.java new file mode 100644 index 00000000000..30af542d339 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteDatasetSample.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_dataset_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteDatasetSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + deleteDatasetSample(project, datasetId); + } + + static void deleteDatasetSample(String project, String datasetId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Deleted Dataset."); + } + } +} +// [END aiplatform_delete_dataset_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteEndpointSample.java b/aiplatform/src/main/java/aiplatform/DeleteEndpointSample.java new file mode 100644 index 00000000000..5767b588809 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteEndpointSample.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_endpoint_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.EndpointServiceClient; +import com.google.cloud.aiplatform.v1.EndpointServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteEndpointSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String endpointId = "YOUR_ENDPOINT_ID"; + deleteEndpointSample(project, endpointId); + } + + static void deleteEndpointSample(String project, String endpointId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + EndpointServiceSettings endpointServiceSettings = + EndpointServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (EndpointServiceClient endpointServiceClient = + EndpointServiceClient.create(endpointServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + OperationFuture operationFuture = + endpointServiceClient.deleteEndpointAsync(endpointName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Empty deleteResponse = operationFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Delete Endpoint Response: %s\n", deleteResponse); + } + } +} +// [END aiplatform_delete_endpoint_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteEntityTypeSample.java b/aiplatform/src/main/java/aiplatform/DeleteEntityTypeSample.java new file mode 100644 index 00000000000..00e7c5e36af --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteEntityTypeSample.java @@ -0,0 +1,87 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Delete an entity type from featurestore resource. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_delete_entity_type_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DeleteEntityTypeRequest; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteEntityTypeSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + deleteEntityTypeSample(project, featurestoreId, entityTypeId, location, endpoint, timeout); + } + + static void deleteEntityTypeSample( + String project, + String featurestoreId, + String entityTypeId, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + DeleteEntityTypeRequest deleteEntityTypeRequest = + DeleteEntityTypeRequest.newBuilder() + .setName( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setForce(true) + .build(); + + OperationFuture operationFuture = + featurestoreServiceClient.deleteEntityTypeAsync(deleteEntityTypeRequest); + System.out.format("Operation name: %s%n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(timeout, TimeUnit.SECONDS); + + System.out.format("Deleted Entity Type."); + } + } +} +// [END aiplatform_delete_entity_type_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteExportModelSample.java b/aiplatform/src/main/java/aiplatform/DeleteExportModelSample.java new file mode 100644 index 00000000000..d6ed1995714 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteExportModelSample.java @@ -0,0 +1,45 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_export_model_sample] + +import com.google.cloud.storage.Blob; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageOptions; + +public class DeleteExportModelSample { + + public static void main(String[] args) { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String bucketName = "YOUR_BUCKET_NAME"; + String folderName = "YOUR_FOLDER_NAME"; + deleteExportModelSample(project, bucketName, folderName); + } + + static void deleteExportModelSample(String project, String bucketName, String folderName) { + Storage storage = StorageOptions.newBuilder().setProjectId(project).build().getService(); + Iterable blobs = + storage.list(bucketName, Storage.BlobListOption.prefix(folderName)).iterateAll(); + for (Blob blob : blobs) { + blob.delete(Blob.BlobSourceOption.generationMatch()); + } + System.out.println("Export Model Deleted"); + } +} +// [END aiplatform_delete_export_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteFeatureSample.java b/aiplatform/src/main/java/aiplatform/DeleteFeatureSample.java new file mode 100644 index 00000000000..bc77d5c804e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteFeatureSample.java @@ -0,0 +1,90 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Delete a single feature from an existing entity type. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_delete_feature_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DeleteFeatureRequest; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.FeatureName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteFeatureSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String featureId = "YOUR_FEATURE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + + deleteFeatureSample( + project, featurestoreId, entityTypeId, featureId, location, endpoint, timeout); + } + + static void deleteFeatureSample( + String project, + String featurestoreId, + String entityTypeId, + String featureId, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + DeleteFeatureRequest deleteFeatureRequest = + DeleteFeatureRequest.newBuilder() + .setName( + FeatureName.of(project, location, featurestoreId, entityTypeId, featureId) + .toString()) + .build(); + + OperationFuture operationFuture = + featurestoreServiceClient.deleteFeatureAsync(deleteFeatureRequest); + System.out.format("Operation name: %s%n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(timeout, TimeUnit.SECONDS); + System.out.format("Deleted Feature."); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_delete_feature_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteFeaturestoreSample.java b/aiplatform/src/main/java/aiplatform/DeleteFeaturestoreSample.java new file mode 100644 index 00000000000..eb69ad35020 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteFeaturestoreSample.java @@ -0,0 +1,86 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Delete a featurestore. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_delete_featurestore_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DeleteFeaturestoreRequest; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteFeaturestoreSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + boolean useForce = true; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 60; + deleteFeaturestoreSample(project, featurestoreId, useForce, location, endpoint, timeout); + } + + static void deleteFeaturestoreSample( + String project, + String featurestoreId, + boolean useForce, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + DeleteFeaturestoreRequest deleteFeaturestoreRequest = + DeleteFeaturestoreRequest.newBuilder() + .setName(FeaturestoreName.of(project, location, featurestoreId).toString()) + .setForce(useForce) + .build(); + + OperationFuture operationFuture = + featurestoreServiceClient.deleteFeaturestoreAsync(deleteFeaturestoreRequest); + System.out.format("Operation name: %s%n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(timeout, TimeUnit.SECONDS); + + System.out.format("Deleted Featurestore."); + } + } +} +// [END aiplatform_delete_featurestore_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteModelSample.java b/aiplatform/src/main/java/aiplatform/DeleteModelSample.java new file mode 100644 index 00000000000..f3ee72260c6 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteModelSample.java @@ -0,0 +1,63 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_model_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteModelSample { + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + deleteModel(project, modelId); + } + + static void deleteModel(String project, String modelId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelName modelName = ModelName.of(project, location, modelId); + OperationFuture operationFuture = + modelServiceClient.deleteModelAsync(modelName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + System.out.format("Deleted Model."); + } + } +} +// [END aiplatform_delete_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeleteTrainingPipelineSample.java b/aiplatform/src/main/java/aiplatform/DeleteTrainingPipelineSample.java new file mode 100644 index 00000000000..e6256c6b633 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeleteTrainingPipelineSample.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_training_pipeline_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipelineName; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteTrainingPipelineSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID"; + String project = "YOUR_PROJECT_ID"; + deleteTrainingPipelineSample(project, trainingPipelineId); + } + + static void deleteTrainingPipelineSample(String project, String trainingPipelineId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + TrainingPipelineName trainingPipelineName = + TrainingPipelineName.of(project, location, trainingPipelineId); + + OperationFuture operationFuture = + pipelineServiceClient.deleteTrainingPipelineAsync(trainingPipelineName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Deleted Training Pipeline."); + } + } +} +// [END aiplatform_delete_training_pipeline_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java b/aiplatform/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java new file mode 100644 index 00000000000..2548637635e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeployModelCustomTrainedModelSample.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_deploy_model_custom_trained_model_sample] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DedicatedResources; +import com.google.cloud.aiplatform.v1.DeployModelOperationMetadata; +import com.google.cloud.aiplatform.v1.DeployModelResponse; +import com.google.cloud.aiplatform.v1.DeployedModel; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.EndpointServiceClient; +import com.google.cloud.aiplatform.v1.EndpointServiceSettings; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +public class DeployModelCustomTrainedModelSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String endpointId = "ENDPOINT_ID"; + String modelName = "MODEL_NAME"; + String deployedModelDisplayName = "DEPLOYED_MODEL_DISPLAY_NAME"; + deployModelCustomTrainedModelSample(project, endpointId, modelName, deployedModelDisplayName); + } + + static void deployModelCustomTrainedModelSample( + String project, String endpointId, String model, String deployedModelDisplayName) + throws IOException, ExecutionException, InterruptedException { + EndpointServiceSettings settings = + EndpointServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (EndpointServiceClient client = EndpointServiceClient.create(settings)) { + MachineSpec machineSpec = MachineSpec.newBuilder().setMachineType("n1-standard-2").build(); + DedicatedResources dedicatedResources = + DedicatedResources.newBuilder().setMinReplicaCount(1).setMachineSpec(machineSpec).build(); + + String modelName = ModelName.of(project, location, model).toString(); + DeployedModel deployedModel = + DeployedModel.newBuilder() + .setModel(modelName) + .setDisplayName(deployedModelDisplayName) + // `dedicated_resources` must be used for non-AutoML models + .setDedicatedResources(dedicatedResources) + .build(); + // key '0' assigns traffic for the newly deployed model + // Traffic percentage values must add up to 100 + // Leave dictionary empty if endpoint should not accept any traffic + Map trafficSplit = new HashMap<>(); + trafficSplit.put("0", 100); + EndpointName endpoint = EndpointName.of(project, location, endpointId); + OperationFuture response = + client.deployModelAsync(endpoint, deployedModel, trafficSplit); + + // You can use OperationFuture.getInitialFuture to get a future representing the initial + // response to the request, which contains information while the operation is in progress. + System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName()); + + // OperationFuture.get() will block until the operation is finished. + DeployModelResponse deployModelResponse = response.get(); + System.out.format("deployModelResponse: %s\n", deployModelResponse); + } + } +} + +// [END aiplatform_deploy_model_custom_trained_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/DeployModelSample.java b/aiplatform/src/main/java/aiplatform/DeployModelSample.java new file mode 100644 index 00000000000..f950afd9656 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/DeployModelSample.java @@ -0,0 +1,113 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_deploy_model_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.AutomaticResources; +import com.google.cloud.aiplatform.v1.DedicatedResources; +import com.google.cloud.aiplatform.v1.DeployModelOperationMetadata; +import com.google.cloud.aiplatform.v1.DeployModelResponse; +import com.google.cloud.aiplatform.v1.DeployedModel; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.EndpointServiceClient; +import com.google.cloud.aiplatform.v1.EndpointServiceSettings; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeployModelSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String deployedModelDisplayName = "YOUR_DEPLOYED_MODEL_DISPLAY_NAME"; + String endpointId = "YOUR_ENDPOINT_NAME"; + String modelId = "YOUR_MODEL_ID"; + deployModelSample(project, deployedModelDisplayName, endpointId, modelId); + } + + static void deployModelSample( + String project, String deployedModelDisplayName, String endpointId, String modelId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + EndpointServiceSettings endpointServiceSettings = + EndpointServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (EndpointServiceClient endpointServiceClient = + EndpointServiceClient.create(endpointServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + // key '0' assigns traffic for the newly deployed model + // Traffic percentage values must add up to 100 + // Leave dictionary empty if endpoint should not accept any traffic + Map trafficSplit = new HashMap<>(); + trafficSplit.put("0", 100); + ModelName modelName = ModelName.of(project, location, modelId); + AutomaticResources automaticResourcesInput = + AutomaticResources.newBuilder().setMinReplicaCount(1).setMaxReplicaCount(1).build(); + DeployedModel deployedModelInput = + DeployedModel.newBuilder() + .setModel(modelName.toString()) + .setDisplayName(deployedModelDisplayName) + .setAutomaticResources(automaticResourcesInput) + .build(); + + OperationFuture deployModelResponseFuture = + endpointServiceClient.deployModelAsync(endpointName, deployedModelInput, trafficSplit); + System.out.format( + "Operation name: %s\n", deployModelResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + DeployModelResponse deployModelResponse = deployModelResponseFuture.get(20, TimeUnit.MINUTES); + + System.out.println("Deploy Model Response"); + DeployedModel deployedModel = deployModelResponse.getDeployedModel(); + System.out.println("\tDeployed Model"); + System.out.format("\t\tid: %s\n", deployedModel.getId()); + System.out.format("\t\tmodel: %s\n", deployedModel.getModel()); + System.out.format("\t\tDisplay Name: %s\n", deployedModel.getDisplayName()); + System.out.format("\t\tCreate Time: %s\n", deployedModel.getCreateTime()); + + DedicatedResources dedicatedResources = deployedModel.getDedicatedResources(); + System.out.println("\t\tDedicated Resources"); + System.out.format("\t\t\tMin Replica Count: %s\n", dedicatedResources.getMinReplicaCount()); + + MachineSpec machineSpec = dedicatedResources.getMachineSpec(); + System.out.println("\t\t\tMachine Spec"); + System.out.format("\t\t\t\tMachine Type: %s\n", machineSpec.getMachineType()); + System.out.format("\t\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType()); + System.out.format("\t\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount()); + + AutomaticResources automaticResources = deployedModel.getAutomaticResources(); + System.out.println("\t\tAutomatic Resources"); + System.out.format("\t\t\tMin Replica Count: %s\n", automaticResources.getMinReplicaCount()); + System.out.format("\t\t\tMax Replica Count: %s\n", automaticResources.getMaxReplicaCount()); + } + } +} +// [END aiplatform_deploy_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/ExportFeatureValuesSample.java b/aiplatform/src/main/java/aiplatform/ExportFeatureValuesSample.java new file mode 100644 index 00000000000..6bb7b00d66e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ExportFeatureValuesSample.java @@ -0,0 +1,119 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Bulk export feature values from a featurestore. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_export_feature_values_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesOperationMetadata; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesRequest; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesRequest.FullExport; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesResponse; +import com.google.cloud.aiplatform.v1.FeatureSelector; +import com.google.cloud.aiplatform.v1.FeatureValueDestination; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.IdMatcher; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ExportFeatureValuesSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String destinationTableUri = "YOUR_DESTINATION_TABLE_URI"; + List featureSelectorIds = Arrays.asList("title", "genres", "average_rating"); + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + exportFeatureValuesSample( + project, + featurestoreId, + entityTypeId, + destinationTableUri, + featureSelectorIds, + location, + endpoint, + timeout); + } + + static void exportFeatureValuesSample( + String project, + String featurestoreId, + String entityTypeId, + String destinationTableUri, + List featureSelectorIds, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + FeatureSelector featureSelector = + FeatureSelector.newBuilder() + .setIdMatcher(IdMatcher.newBuilder().addAllIds(featureSelectorIds).build()) + .build(); + + ExportFeatureValuesRequest exportFeatureValuesRequest = + ExportFeatureValuesRequest.newBuilder() + .setEntityType( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setDestination( + FeatureValueDestination.newBuilder() + .setBigqueryDestination( + BigQueryDestination.newBuilder().setOutputUri(destinationTableUri))) + .setFeatureSelector(featureSelector) + .setFullExport(FullExport.newBuilder()) + .build(); + + OperationFuture + exportFeatureValuesFuture = + featurestoreServiceClient.exportFeatureValuesAsync(exportFeatureValuesRequest); + System.out.format( + "Operation name: %s%n", exportFeatureValuesFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ExportFeatureValuesResponse exportFeatureValuesResponse = + exportFeatureValuesFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Export Feature Values Response"); + System.out.println(exportFeatureValuesResponse); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_export_feature_values_sample] diff --git a/aiplatform/src/main/java/aiplatform/ExportFeatureValuesSnapshotSample.java b/aiplatform/src/main/java/aiplatform/ExportFeatureValuesSnapshotSample.java new file mode 100644 index 00000000000..6d48d34d06c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ExportFeatureValuesSnapshotSample.java @@ -0,0 +1,119 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Bulk export feature values from a featurestore. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_export_feature_values_snapshot_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesOperationMetadata; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesRequest; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesRequest.SnapshotExport; +import com.google.cloud.aiplatform.v1.ExportFeatureValuesResponse; +import com.google.cloud.aiplatform.v1.FeatureSelector; +import com.google.cloud.aiplatform.v1.FeatureValueDestination; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.IdMatcher; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ExportFeatureValuesSnapshotSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String destinationTableUri = "YOUR_DESTINATION_TABLE_URI"; + List featureSelectorIds = Arrays.asList("title", "genres", "average_rating"); + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + exportFeatureValuesSnapshotSample( + project, + featurestoreId, + entityTypeId, + destinationTableUri, + featureSelectorIds, + location, + endpoint, + timeout); + } + + static void exportFeatureValuesSnapshotSample( + String project, + String featurestoreId, + String entityTypeId, + String destinationTableUri, + List featureSelectorIds, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + FeatureSelector featureSelector = + FeatureSelector.newBuilder() + .setIdMatcher(IdMatcher.newBuilder().addAllIds(featureSelectorIds).build()) + .build(); + + ExportFeatureValuesRequest exportFeatureValuesRequest = + ExportFeatureValuesRequest.newBuilder() + .setEntityType( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setDestination( + FeatureValueDestination.newBuilder() + .setBigqueryDestination( + BigQueryDestination.newBuilder().setOutputUri(destinationTableUri))) + .setFeatureSelector(featureSelector) + .setSnapshotExport(SnapshotExport.newBuilder()) + .build(); + + OperationFuture + exportFeatureValuesFuture = + featurestoreServiceClient.exportFeatureValuesAsync(exportFeatureValuesRequest); + System.out.format( + "Operation name: %s%n", exportFeatureValuesFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ExportFeatureValuesResponse exportFeatureValuesResponse = + exportFeatureValuesFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Snapshot Export Feature Values Response"); + System.out.println(exportFeatureValuesResponse); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_export_feature_values_snapshot_sample] diff --git a/aiplatform/src/main/java/aiplatform/ExportModelSample.java b/aiplatform/src/main/java/aiplatform/ExportModelSample.java new file mode 100644 index 00000000000..1979c7ce116 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ExportModelSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_export_model_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.ExportModelOperationMetadata; +import com.google.cloud.aiplatform.v1.ExportModelRequest; +import com.google.cloud.aiplatform.v1.ExportModelResponse; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ExportModelSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String gcsDestinationOutputUriPrefix = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_destination/"; + String exportFormat = "YOUR_EXPORT_FORMAT"; + exportModelSample(project, modelId, gcsDestinationOutputUriPrefix, exportFormat); + } + + static void exportModelSample( + String project, String modelId, String gcsDestinationOutputUriPrefix, String exportFormat) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + GcsDestination.Builder gcsDestination = GcsDestination.newBuilder(); + gcsDestination.setOutputUriPrefix(gcsDestinationOutputUriPrefix); + + ModelName modelName = ModelName.of(project, location, modelId); + ExportModelRequest.OutputConfig outputConfig = + ExportModelRequest.OutputConfig.newBuilder() + .setExportFormatId(exportFormat) + .setArtifactDestination(gcsDestination) + .build(); + + OperationFuture exportModelResponseFuture = + modelServiceClient.exportModelAsync(modelName, outputConfig); + System.out.format( + "Operation name: %s\n", exportModelResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ExportModelResponse exportModelResponse = + exportModelResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Export Model Response: %s\n", exportModelResponse); + } + } +} +// [END aiplatform_export_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/ExportModelTabularClassificationSample.java b/aiplatform/src/main/java/aiplatform/ExportModelTabularClassificationSample.java new file mode 100644 index 00000000000..9a722790eb6 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ExportModelTabularClassificationSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_export_model_tabular_classification_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.ExportModelOperationMetadata; +import com.google.cloud.aiplatform.v1.ExportModelRequest; +import com.google.cloud.aiplatform.v1.ExportModelResponse; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ExportModelTabularClassificationSample { + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String gcsDestinationOutputUriPrefix = "gs://your-gcs-bucket/destination_path"; + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + exportModelTableClassification(gcsDestinationOutputUriPrefix, project, modelId); + } + + static void exportModelTableClassification( + String gcsDestinationOutputUriPrefix, String project, String modelId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelName modelName = ModelName.of(project, location, modelId); + + GcsDestination.Builder gcsDestination = GcsDestination.newBuilder(); + gcsDestination.setOutputUriPrefix(gcsDestinationOutputUriPrefix); + ExportModelRequest.OutputConfig outputConfig = + ExportModelRequest.OutputConfig.newBuilder() + .setExportFormatId("tf-saved-model") + .setArtifactDestination(gcsDestination) + .build(); + + OperationFuture exportModelResponseFuture = + modelServiceClient.exportModelAsync(modelName, outputConfig); + System.out.format( + "Operation name: %s\n", exportModelResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ExportModelResponse exportModelResponse = + exportModelResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format( + "Export Model Tabular Classification Response: %s", exportModelResponse.toString()); + } + } +} +// [END aiplatform_export_model_tabular_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java b/aiplatform/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java new file mode 100644 index 00000000000..54e590085cb --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ExportModelVideoActionRecognitionSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_export_model_video_action_recognition_sample] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.ExportModelOperationMetadata; +import com.google.cloud.aiplatform.v1.ExportModelRequest; +import com.google.cloud.aiplatform.v1.ExportModelResponse; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class ExportModelVideoActionRecognitionSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String modelId = "MODEL_ID"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + String exportFormat = "EXPORT_FORMAT"; + exportModelVideoActionRecognitionSample( + project, modelId, gcsDestinationOutputUriPrefix, exportFormat); + } + + static void exportModelVideoActionRecognitionSample( + String project, String modelId, String gcsDestinationOutputUriPrefix, String exportFormat) + throws IOException, ExecutionException, InterruptedException { + ModelServiceSettings settings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient client = ModelServiceClient.create(settings)) { + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + ExportModelRequest.OutputConfig outputConfig = + ExportModelRequest.OutputConfig.newBuilder() + .setArtifactDestination(gcsDestination) + .setExportFormatId(exportFormat) + .build(); + ModelName name = ModelName.of(project, location, modelId); + OperationFuture response = + client.exportModelAsync(name, outputConfig); + + // You can use OperationFuture.getInitialFuture to get a future representing the initial + // response to the request, which contains information while the operation is in progress. + System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName()); + + // OperationFuture.get() will block until the operation is finished. + ExportModelResponse exportModelResponse = response.get(); + System.out.format("exportModelResponse: %s\n", exportModelResponse); + } + } +} + +// [END aiplatform_export_model_video_action_recognition_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetBatchPredictionJobSample.java b/aiplatform/src/main/java/aiplatform/GetBatchPredictionJobSample.java new file mode 100644 index 00000000000..4e4ba6b3ebe --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetBatchPredictionJobSample.java @@ -0,0 +1,135 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_batch_prediction_job_sample] + +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1.BatchPredictionJobName; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.CompletionStats; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.ResourcesConsumed; +import com.google.protobuf.Any; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class GetBatchPredictionJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID"; + getBatchPredictionJobSample(project, batchPredictionJobId); + } + + static void getBatchPredictionJobSample(String project, String batchPredictionJobId) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + BatchPredictionJobName batchPredictionJobName = + BatchPredictionJobName.of(project, location, batchPredictionJobId); + + BatchPredictionJob batchPredictionJob = + jobServiceClient.getBatchPredictionJob(batchPredictionJobName); + + System.out.println("Get Batch Prediction Job Response"); + System.out.format("\tName: %s\n", batchPredictionJob.getName()); + System.out.format("\tDisplay Name: %s\n", batchPredictionJob.getDisplayName()); + System.out.format("\tModel: %s\n", batchPredictionJob.getModel()); + + System.out.format("\tModel Parameters: %s\n", batchPredictionJob.getModelParameters()); + System.out.format("\tState: %s\n", batchPredictionJob.getState()); + + System.out.format("\tCreate Time: %s\n", batchPredictionJob.getCreateTime()); + System.out.format("\tStart Time: %s\n", batchPredictionJob.getStartTime()); + System.out.format("\tEnd Time: %s\n", batchPredictionJob.getEndTime()); + System.out.format("\tUpdate Time: %s\n", batchPredictionJob.getUpdateTime()); + System.out.format("\tLabels: %s\n", batchPredictionJob.getLabelsMap()); + + InputConfig inputConfig = batchPredictionJob.getInputConfig(); + System.out.println("\tInput Config"); + System.out.format("\t\tInstances Format: %s\n", inputConfig.getInstancesFormat()); + + GcsSource gcsSource = inputConfig.getGcsSource(); + System.out.println("\t\tGcs Source"); + System.out.format("\t\t\tUris: %s\n", gcsSource.getUrisList()); + + BigQuerySource bigquerySource = inputConfig.getBigquerySource(); + System.out.println("\t\tBigquery Source"); + System.out.format("\t\t\tInput Uri: %s\n", bigquerySource.getInputUri()); + + OutputConfig outputConfig = batchPredictionJob.getOutputConfig(); + System.out.println("\tOutput Config"); + System.out.format("\t\tPredictions Format: %s\n", outputConfig.getPredictionsFormat()); + + GcsDestination gcsDestination = outputConfig.getGcsDestination(); + System.out.println("\t\tGcs Destination"); + System.out.format("\t\t\tOutput Uri Prefix: %s\n", gcsDestination.getOutputUriPrefix()); + + BigQueryDestination bigqueryDestination = outputConfig.getBigqueryDestination(); + System.out.println("\t\tBigquery Destination"); + System.out.format("\t\t\tOutput Uri: %s\n", bigqueryDestination.getOutputUri()); + + OutputInfo outputInfo = batchPredictionJob.getOutputInfo(); + System.out.println("\tOutput Info"); + System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory()); + System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset()); + + Status status = batchPredictionJob.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + + List detailsList = status.getDetailsList(); + + for (Status partialFailure : batchPredictionJob.getPartialFailuresList()) { + System.out.println("\tPartial Failure"); + System.out.format("\t\tCode: %s\n", partialFailure.getCode()); + System.out.format("\t\tMessage: %s\n", partialFailure.getMessage()); + List details = partialFailure.getDetailsList(); + } + + ResourcesConsumed resourcesConsumed = batchPredictionJob.getResourcesConsumed(); + System.out.println("\tResources Consumed"); + System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours()); + + CompletionStats completionStats = batchPredictionJob.getCompletionStats(); + System.out.println("\tCompletion Stats"); + System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount()); + System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount()); + System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount()); + } + } +} +// [END aiplatform_get_batch_prediction_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetEntityTypeSample.java b/aiplatform/src/main/java/aiplatform/GetEntityTypeSample.java new file mode 100644 index 00000000000..f9e83f223ba --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetEntityTypeSample.java @@ -0,0 +1,70 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Get entity type details. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_get_entity_type_sample] + +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.GetEntityTypeRequest; +import java.io.IOException; + +public class GetEntityTypeSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + getEntityTypeSample(project, featurestoreId, entityTypeId, location, endpoint); + } + + static void getEntityTypeSample( + String project, String featurestoreId, String entityTypeId, String location, String endpoint) + throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + GetEntityTypeRequest getEntityTypeRequest = + GetEntityTypeRequest.newBuilder() + .setName( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .build(); + + EntityType entityType = featurestoreServiceClient.getEntityType(getEntityTypeRequest); + System.out.println("Get Entity Type Response"); + System.out.println(entityType); + } + } +} +// [END aiplatform_get_entity_type_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetFeatureSample.java b/aiplatform/src/main/java/aiplatform/GetFeatureSample.java new file mode 100644 index 00000000000..f7e38adf1a9 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetFeatureSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Get feature details. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_get_feature_sample] + +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.FeatureName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.GetFeatureRequest; +import java.io.IOException; + +public class GetFeatureSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String featureId = "YOUR_FEATURE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + + getFeatureSample(project, featurestoreId, entityTypeId, featureId, location, endpoint); + } + + static void getFeatureSample( + String project, + String featurestoreId, + String entityTypeId, + String featureId, + String location, + String endpoint) + throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + GetFeatureRequest getFeatureRequest = + GetFeatureRequest.newBuilder() + .setName( + FeatureName.of(project, location, featurestoreId, entityTypeId, featureId) + .toString()) + .build(); + + Feature feature = featurestoreServiceClient.getFeature(getFeatureRequest); + System.out.println("Get Feature Response"); + System.out.println(feature); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_get_feature_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetFeaturestoreSample.java b/aiplatform/src/main/java/aiplatform/GetFeaturestoreSample.java new file mode 100644 index 00000000000..1d8c4c77c98 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetFeaturestoreSample.java @@ -0,0 +1,67 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Gets details of a single featurestore. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_get_featurestore_sample] + +import com.google.cloud.aiplatform.v1beta1.Featurestore; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreName; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GetFeaturestoreRequest; +import java.io.IOException; + +public class GetFeaturestoreSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + getFeaturestoreSample(project, featurestoreId, location, endpoint); + } + + static void getFeaturestoreSample( + String project, String featurestoreId, String location, String endpoint) throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + GetFeaturestoreRequest getFeaturestoreRequest = + GetFeaturestoreRequest.newBuilder() + .setName(FeaturestoreName.of(project, location, featurestoreId).toString()) + .build(); + + Featurestore featurestore = featurestoreServiceClient.getFeaturestore(getFeaturestoreRequest); + System.out.println("Get Featurestore Response"); + System.out.println(featurestore); + } + } +} +// [END aiplatform_get_featurestore_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java b/aiplatform/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java new file mode 100644 index 00000000000..f886bc3325b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetHyperparameterTuningJobSample.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_hyperparameter_tuning_job_sample] +import com.google.cloud.aiplatform.v1.HyperparameterTuningJob; +import com.google.cloud.aiplatform.v1.HyperparameterTuningJobName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import java.io.IOException; + +public class GetHyperparameterTuningJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String hyperparameterTuningJobId = "HYPERPARAMETER_TUNING_JOB_ID"; + getHyperparameterTuningJobSample(project, hyperparameterTuningJobId); + } + + static void getHyperparameterTuningJobSample(String project, String hyperparameterTuningJobId) + throws IOException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + HyperparameterTuningJobName name = + HyperparameterTuningJobName.of(project, location, hyperparameterTuningJobId); + HyperparameterTuningJob response = client.getHyperparameterTuningJob(name); + System.out.format("response: %s\n", response); + } + } +} + +// [END aiplatform_get_hyperparameter_tuning_job_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java new file mode 100644 index 00000000000..abcc2ec9f58 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_image_classification_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationImageClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationImageClassificationSample(project, modelId, evaluationId); + } + + static void getModelEvaluationImageClassificationSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Image Classification Response"); + System.out.format("Model Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_image_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java new file mode 100644 index 00000000000..fc85324116f --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_image_object_detection_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationImageObjectDetectionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationImageObjectDetectionSample(project, modelId, evaluationId); + } + + static void getModelEvaluationImageObjectDetectionSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Image Object Detection Response"); + System.out.format("\tName: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_image_object_detection_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationSample.java new file mode 100644 index 00000000000..4944dda1c1d --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationSample.java @@ -0,0 +1,63 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationSample(project, modelId, evaluationId); + } + + static void getModelEvaluationSample(String project, String modelId, String evaluationId) + throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Response"); + System.out.format("Model Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationSliceSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationSliceSample.java new file mode 100644 index 00000000000..1de771c185f --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationSliceSample.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_slice_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluationSlice; +import com.google.cloud.aiplatform.v1.ModelEvaluationSlice.Slice; +import com.google.cloud.aiplatform.v1.ModelEvaluationSliceName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationSliceSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + String sliceId = "YOUR_SLICE_ID"; + getModelEvaluationSliceSample(project, modelId, evaluationId, sliceId); + } + + static void getModelEvaluationSliceSample( + String project, String modelId, String evaluationId, String sliceId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationSliceName modelEvaluationSliceName = + ModelEvaluationSliceName.of(project, location, modelId, evaluationId, sliceId); + + ModelEvaluationSlice modelEvaluationSlice = + modelServiceClient.getModelEvaluationSlice(modelEvaluationSliceName); + + System.out.println("Get Model Evaluation Slice Response"); + System.out.format("Model Evaluation Slice Name: %s\n", modelEvaluationSlice.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluationSlice.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluationSlice.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluationSlice.getCreateTime()); + + Slice slice = modelEvaluationSlice.getSlice(); + System.out.format("Slice Dimensions: %s\n", slice.getDimension()); + System.out.format("Slice Value: %s\n", slice.getValue()); + } + } +} +// [END aiplatform_get_model_evaluation_slice_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java new file mode 100644 index 00000000000..dc38eaede76 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_tabular_classification_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationTabularClassification(project, modelId, evaluationId); + } + + static void getModelEvaluationTabularClassification( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Tabular Classification Response"); + System.out.format("\tName: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_tabular_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java new file mode 100644 index 00000000000..908f9a47859 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_tabular_regression_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationTabularRegression(project, modelId, evaluationId); + } + + static void getModelEvaluationTabularRegression( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Tabular Regression Response"); + System.out.format("\tName: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_tabular_regression_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java new file mode 100644 index 00000000000..912f4c6766b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_text_classification_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTextClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + + getModelEvaluationTextClassificationSample(project, modelId, evaluationId); + } + + static void getModelEvaluationTextClassificationSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Text Classification Response"); + System.out.format("\tModel Name: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_text_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java new file mode 100644 index 00000000000..ac9164b9267 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_text_entity_extraction_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + + getModelEvaluationTextEntityExtractionSample(project, modelId, evaluationId); + } + + static void getModelEvaluationTextEntityExtractionSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Text Entity Extraction Response"); + System.out.format("\tModel Name: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_text_entity_extraction_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..81d686e2186 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_text_sentiment_analysis_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + + getModelEvaluationTextSentimentAnalysisSample(project, modelId, evaluationId); + } + + static void getModelEvaluationTextSentimentAnalysisSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Text Sentiment Analysis Response"); + System.out.format("\tModel Name: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_text_sentiment_analysis_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java new file mode 100644 index 00000000000..01748a85ea7 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoActionRecognitionSample.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_video_action_recognition_sample] +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationVideoActionRecognitionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "PROJECT"; + String modelId = "MODEL_ID"; + String evaluationId = "EVALUATION_ID"; + getModelEvaluationVideoActionRecognitionSample(project, modelId, evaluationId); + } + + static void getModelEvaluationVideoActionRecognitionSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings settings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient client = ModelServiceClient.create(settings)) { + ModelEvaluationName name = ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation response = client.getModelEvaluation(name); + System.out.format("response: %s\n", response); + } + } +} + +// [END aiplatform_get_model_evaluation_video_action_recognition_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java new file mode 100644 index 00000000000..4e4babc5e6f --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_video_classification_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationVideoClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationVideoClassification(project, modelId, evaluationId); + } + + static void getModelEvaluationVideoClassification( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Video Classification Response"); + System.out.format("Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_video_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java b/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java new file mode 100644 index 00000000000..a095c9a262e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_object_tracking_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluation; +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationVideoObjectTrackingSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationVideoObjectTracking(project, modelId, evaluationId); + } + + static void getModelEvaluationVideoObjectTracking( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Video Object Tracking Response"); + System.out.format("Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_object_tracking_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetModelSample.java b/aiplatform/src/main/java/aiplatform/GetModelSample.java new file mode 100644 index 00000000000..5222db4b86b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetModelSample.java @@ -0,0 +1,120 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_sample] + +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import java.io.IOException; + +public class GetModelSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + getModelSample(project, modelId); + } + + static void getModelSample(String project, String modelId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelName modelName = ModelName.of(project, location, modelId); + + Model modelResponse = modelServiceClient.getModel(modelName); + System.out.println("Get Model response"); + System.out.format("\tName: %s\n", modelResponse.getName()); + System.out.format("\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + } + } +} +// [END aiplatform_get_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/GetTrainingPipelineSample.java b/aiplatform/src/main/java/aiplatform/GetTrainingPipelineSample.java new file mode 100644 index 00000000000..11850291b5f --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/GetTrainingPipelineSample.java @@ -0,0 +1,177 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_training_pipeline_sample] + +import com.google.cloud.aiplatform.v1.DeployedModelRef; +import com.google.cloud.aiplatform.v1.EnvVar; +import com.google.cloud.aiplatform.v1.FilterSplit; +import com.google.cloud.aiplatform.v1.FractionSplit; +import com.google.cloud.aiplatform.v1.InputDataConfig; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.Port; +import com.google.cloud.aiplatform.v1.PredefinedSplit; +import com.google.cloud.aiplatform.v1.PredictSchemata; +import com.google.cloud.aiplatform.v1.TimestampSplit; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.TrainingPipelineName; +import com.google.rpc.Status; +import java.io.IOException; + +public class GetTrainingPipelineSample { + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID"; + getTrainingPipeline(project, trainingPipelineId); + } + + static void getTrainingPipeline(String project, String trainingPipelineId) throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + TrainingPipelineName trainingPipelineName = + TrainingPipelineName.of(project, location, trainingPipelineId); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.getTrainingPipeline(trainingPipelineName); + + System.out.println("Get Training Pipeline Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + + System.out.println("\t\tModel to upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_get_training_pipeline_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataImageClassificationSample.java b/aiplatform/src/main/java/aiplatform/ImportDataImageClassificationSample.java new file mode 100644 index 00000000000..f3c4e3ed03d --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataImageClassificationSample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_image_classification_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataImageClassificationSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_image_source/[file.csv/file.jsonl]"; + importDataImageClassificationSample(project, datasetId, gcsSourceUri); + } + + static void importDataImageClassificationSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "image_classification_single_label_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format( + "Import Data Image Classification Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_image_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java b/aiplatform/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java new file mode 100644 index 00000000000..78f7551945f --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_image_object_detection_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataImageObjectDetectionSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_image_source/[file.csv/file.jsonl]"; + importDataImageObjectDetectionSample(project, datasetId, gcsSourceUri); + } + + static void importDataImageObjectDetectionSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "image_bounding_box_io_format_1.0.0.yaml"; + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format( + "Import Data Image Object Detection Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_image_object_detection_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java b/aiplatform/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java new file mode 100644 index 00000000000..696fdeb5842 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_text_classification_single_label_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataTextClassificationSingleLabelSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.csv/file.jsonl]"; + + importDataTextClassificationSingleLabelSample(project, datasetId, gcsSourceUri); + } + + static void importDataTextClassificationSingleLabelSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "text_classification_single_label_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format( + "Import Data Text Classification Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_text_classification_single_label_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java b/aiplatform/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java new file mode 100644 index 00000000000..2a8ee01a886 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_text_entity_extraction_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataTextEntityExtractionSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.jsonl]"; + + importDataTextEntityExtractionSample(project, datasetId, gcsSourceUri); + } + + static void importDataTextEntityExtractionSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "text_extraction_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format( + "Import Data Text Entity Extraction Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_text_entity_extraction_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java b/aiplatform/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..064fb6eb207 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_text_sentiment_analysis_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataTextSentimentAnalysisSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.csv/file.jsonl]"; + + importDataTextSentimentAnalysisSample(project, datasetId, gcsSourceUri); + } + + static void importDataTextSentimentAnalysisSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "text_sentiment_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format( + "Import Data Text Sentiment Analysis Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_text_sentiment_analysis_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java b/aiplatform/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java new file mode 100644 index 00000000000..7bede6dfc4c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataVideoActionRecognitionSample.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_video_action_recognition_sample] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; + +public class ImportDataVideoActionRecognitionSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String datasetId = "DATASET_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + importDataVideoActionRecognitionSample(project, datasetId, gcsSourceUri); + } + + static void importDataVideoActionRecognitionSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, ExecutionException, InterruptedException { + DatasetServiceSettings settings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + String location = "us-central1"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient client = DatasetServiceClient.create(settings)) { + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + ImportDataConfig importConfig0 = + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri( + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "video_action_recognition_io_format_1.0.0.yaml") + .build(); + List importConfigs = new ArrayList<>(); + importConfigs.add(importConfig0); + DatasetName name = DatasetName.of(project, location, datasetId); + OperationFuture response = + client.importDataAsync(name, importConfigs); + + // You can use OperationFuture.getInitialFuture to get a future representing the initial + // response to the request, which contains information while the operation is in progress. + System.out.format("Operation name: %s\n", response.getInitialFuture().get().getName()); + + // OperationFuture.get() will block until the operation is finished. + ImportDataResponse importDataResponse = response.get(); + System.out.format("importDataResponse: %s\n", importDataResponse); + } + } +} + +// [END aiplatform_import_data_video_action_recognition_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataVideoClassificationSample.java b/aiplatform/src/main/java/aiplatform/ImportDataVideoClassificationSample.java new file mode 100644 index 00000000000..16cbc79e9a8 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataVideoClassificationSample.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_video_classification_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataVideoClassificationSample { + + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + importDataVideoClassification(gcsSourceUri, project, datasetId); + } + + static void importDataVideoClassification(String gcsSourceUri, String project, String datasetId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "video_classification_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + + DatasetName datasetName = DatasetName.of(project, location, datasetId); + List importDataConfigs = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigs); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(1800, TimeUnit.SECONDS); + + System.out.format( + "Import Data Video Classification Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_video_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java b/aiplatform/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java new file mode 100644 index 00000000000..ce099b95845 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java @@ -0,0 +1,86 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_video_object_tracking_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportDataConfig; +import com.google.cloud.aiplatform.v1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataVideoObjectTrackingSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + importDataVideObjectTracking(gcsSourceUri, project, datasetId); + } + + static void importDataVideObjectTracking(String gcsSourceUri, String project, String datasetId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "video_object_tracking_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + List importDataConfigs = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigs); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format( + "Import Data Video Object Tracking Response: %s\n", importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_video_object_tracking_sample] diff --git a/aiplatform/src/main/java/aiplatform/ImportFeatureValuesSample.java b/aiplatform/src/main/java/aiplatform/ImportFeatureValuesSample.java new file mode 100644 index 00000000000..405b05f54fb --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ImportFeatureValuesSample.java @@ -0,0 +1,122 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Bulk import values into a featurestore for existing features. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_import_feature_values_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.AvroSource; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.ImportFeatureValuesOperationMetadata; +import com.google.cloud.aiplatform.v1.ImportFeatureValuesRequest; +import com.google.cloud.aiplatform.v1.ImportFeatureValuesRequest.FeatureSpec; +import com.google.cloud.aiplatform.v1.ImportFeatureValuesResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportFeatureValuesSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String entityIdField = "YOUR_ENTITY_FIELD_ID"; + String featureTimeField = "YOUR_FEATURE_TIME_FIELD"; + String gcsSourceUri = "YOUR_GCS_SOURCE_URI"; + int workerCount = 2; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + importFeatureValuesSample( + project, + featurestoreId, + entityTypeId, + gcsSourceUri, + entityIdField, + featureTimeField, + workerCount, + location, + endpoint, + timeout); + } + + static void importFeatureValuesSample( + String project, + String featurestoreId, + String entityTypeId, + String gcsSourceUri, + String entityIdField, + String featureTimeField, + int workerCount, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + List featureSpecs = new ArrayList<>(); + + featureSpecs.add(FeatureSpec.newBuilder().setId("title").build()); + featureSpecs.add(FeatureSpec.newBuilder().setId("genres").build()); + featureSpecs.add(FeatureSpec.newBuilder().setId("average_rating").build()); + ImportFeatureValuesRequest importFeatureValuesRequest = + ImportFeatureValuesRequest.newBuilder() + .setEntityType( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setEntityIdField(entityIdField) + .setFeatureTimeField(featureTimeField) + .addAllFeatureSpecs(featureSpecs) + .setWorkerCount(workerCount) + .setAvroSource( + AvroSource.newBuilder() + .setGcsSource(GcsSource.newBuilder().addUris(gcsSourceUri))) + .build(); + OperationFuture + importFeatureValuesFuture = + featurestoreServiceClient.importFeatureValuesAsync(importFeatureValuesRequest); + System.out.format( + "Operation name: %s%n", importFeatureValuesFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportFeatureValuesResponse importFeatureValuesResponse = + importFeatureValuesFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Import Feature Values Response"); + System.out.println(importFeatureValuesResponse); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_import_feature_values_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListEntityTypesAsyncSample.java b/aiplatform/src/main/java/aiplatform/ListEntityTypesAsyncSample.java new file mode 100644 index 00000000000..b429a642c53 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListEntityTypesAsyncSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * List available entity type details of an existing featurestore resource. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_list_entity_types_async_sample] + +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.ListEntityTypesRequest; +import com.google.cloud.aiplatform.v1.ListEntityTypesResponse; +import com.google.common.base.Strings; +import java.io.IOException; + +public class ListEntityTypesAsyncSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + listEntityTypesAsyncSample(project, featurestoreId, location, endpoint); + } + + static void listEntityTypesAsyncSample( + String project, String featurestoreId, String location, String endpoint) throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + ListEntityTypesRequest listEntityTypeRequest = + ListEntityTypesRequest.newBuilder() + .setParent(FeaturestoreName.of(project, location, featurestoreId).toString()) + .build(); + System.out.println("List Entity Types Async Response"); + while (true) { + ListEntityTypesResponse listEntityTypesResponse = + featurestoreServiceClient.listEntityTypesCallable().call(listEntityTypeRequest); + for (EntityType element : listEntityTypesResponse.getEntityTypesList()) { + System.out.println(element); + } + String nextPageToken = listEntityTypesResponse.getNextPageToken(); + if (!Strings.isNullOrEmpty(nextPageToken)) { + listEntityTypeRequest = + listEntityTypeRequest.toBuilder().setPageToken(nextPageToken).build(); + } else { + break; + } + } + } + } +} +// [END aiplatform_list_entity_types_async_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListEntityTypesSample.java b/aiplatform/src/main/java/aiplatform/ListEntityTypesSample.java new file mode 100644 index 00000000000..1160216c4b8 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListEntityTypesSample.java @@ -0,0 +1,68 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * List available entity type details of an existing featurestore resource. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_list_entity_types_sample] + +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.ListEntityTypesRequest; +import java.io.IOException; + +public class ListEntityTypesSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + listEntityTypesSample(project, featurestoreId, location, endpoint); + } + + static void listEntityTypesSample( + String project, String featurestoreId, String location, String endpoint) throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + ListEntityTypesRequest listEntityTypeRequest = + ListEntityTypesRequest.newBuilder() + .setParent(FeaturestoreName.of(project, location, featurestoreId).toString()) + .build(); + System.out.println("List Entity Types Response"); + for (EntityType element : + featurestoreServiceClient.listEntityTypes(listEntityTypeRequest).iterateAll()) { + System.out.println(element); + } + } + } +} +// [END aiplatform_list_entity_types_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListFeaturesAsyncSample.java b/aiplatform/src/main/java/aiplatform/ListFeaturesAsyncSample.java new file mode 100644 index 00000000000..5cc41ec8cd7 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListFeaturesAsyncSample.java @@ -0,0 +1,83 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * List available feature details. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_list_features_async_sample] + +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.ListFeaturesRequest; +import com.google.cloud.aiplatform.v1.ListFeaturesResponse; +import com.google.common.base.Strings; +import java.io.IOException; + +public class ListFeaturesAsyncSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + + listFeaturesAsyncSample(project, featurestoreId, entityTypeId, location, endpoint); + } + + static void listFeaturesAsyncSample( + String project, String featurestoreId, String entityTypeId, String location, String endpoint) + throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + ListFeaturesRequest listFeaturesRequest = + ListFeaturesRequest.newBuilder() + .setParent( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .build(); + System.out.println("List Features Async Response"); + while (true) { + ListFeaturesResponse listFeaturesResponse = + featurestoreServiceClient.listFeaturesCallable().call(listFeaturesRequest); + for (Feature element : listFeaturesResponse.getFeaturesList()) { + System.out.println(element); + } + String nextPageToken = listFeaturesResponse.getNextPageToken(); + if (!Strings.isNullOrEmpty(nextPageToken)) { + listFeaturesRequest = listFeaturesRequest.toBuilder().setPageToken(nextPageToken).build(); + } else { + break; + } + } + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_list_features_async_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListFeaturesSample.java b/aiplatform/src/main/java/aiplatform/ListFeaturesSample.java new file mode 100644 index 00000000000..b17eeb35e48 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListFeaturesSample.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * List available feature details. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_list_features_sample] + +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.ListFeaturesRequest; +import java.io.IOException; + +public class ListFeaturesSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + + listFeaturesSample(project, featurestoreId, entityTypeId, location, endpoint); + } + + static void listFeaturesSample( + String project, String featurestoreId, String entityTypeId, String location, String endpoint) + throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + ListFeaturesRequest listFeaturesRequest = + ListFeaturesRequest.newBuilder() + .setParent( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .build(); + System.out.println("List Features Response"); + for (Feature element : + featurestoreServiceClient.listFeatures(listFeaturesRequest).iterateAll()) { + System.out.println(element); + } + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_list_features_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListFeaturestoresAsyncSample.java b/aiplatform/src/main/java/aiplatform/ListFeaturestoresAsyncSample.java new file mode 100644 index 00000000000..16ce54f407e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListFeaturestoresAsyncSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * List available featurestore details. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_list_featurestores_async_sample] + +import com.google.cloud.aiplatform.v1.Featurestore; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.ListFeaturestoresRequest; +import com.google.cloud.aiplatform.v1.ListFeaturestoresResponse; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.common.base.Strings; +import java.io.IOException; + +public class ListFeaturestoresAsyncSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + listFeaturestoresAsyncSample(project, location, endpoint); + } + + static void listFeaturestoresAsyncSample(String project, String location, String endpoint) + throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + ListFeaturestoresRequest listFeaturestoresRequest = + ListFeaturestoresRequest.newBuilder() + .setParent(LocationName.of(project, location).toString()) + .build(); + System.out.println("List Featurestores Async Response"); + while (true) { + ListFeaturestoresResponse listFeaturestoresResponse = + featurestoreServiceClient.listFeaturestoresCallable().call(listFeaturestoresRequest); + for (Featurestore element : listFeaturestoresResponse.getFeaturestoresList()) { + System.out.println(element); + } + String nextPageToken = listFeaturestoresResponse.getNextPageToken(); + if (!Strings.isNullOrEmpty(nextPageToken)) { + listFeaturestoresRequest = + listFeaturestoresRequest.toBuilder().setPageToken(nextPageToken).build(); + } else { + break; + } + } + } + } +} +// [END aiplatform_list_featurestores_async_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListFeaturestoresSample.java b/aiplatform/src/main/java/aiplatform/ListFeaturestoresSample.java new file mode 100644 index 00000000000..db4e5d7aab5 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListFeaturestoresSample.java @@ -0,0 +1,67 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * List available featurestore details. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_list_featurestores_sample] + +import com.google.cloud.aiplatform.v1.Featurestore; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.ListFeaturestoresRequest; +import com.google.cloud.aiplatform.v1.LocationName; +import java.io.IOException; + +public class ListFeaturestoresSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + listFeaturestoresSample(project, location, endpoint); + } + + static void listFeaturestoresSample(String project, String location, String endpoint) + throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + ListFeaturestoresRequest listFeaturestoresRequest = + ListFeaturestoresRequest.newBuilder() + .setParent(LocationName.of(project, location).toString()) + .build(); + + System.out.println("List Featurestores Response"); + for (Featurestore element : + featurestoreServiceClient.listFeaturestores(listFeaturestoresRequest).iterateAll()) { + System.out.println(element); + } + } + } +} +// [END aiplatform_list_featurestores_sample] diff --git a/aiplatform/src/main/java/aiplatform/ListModelEvaluationSliceSample.java b/aiplatform/src/main/java/aiplatform/ListModelEvaluationSliceSample.java new file mode 100644 index 00000000000..09cf36e0a60 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/ListModelEvaluationSliceSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_list_model_evaluation_slice_sample] + +import com.google.cloud.aiplatform.v1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1.ModelEvaluationSlice; +import com.google.cloud.aiplatform.v1.ModelEvaluationSlice.Slice; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import java.io.IOException; + +public class ListModelEvaluationSliceSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + // To obtain evaluationId run the code block below after setting modelServiceSettings. + // + // try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) + // { + // String location = "us-central1"; + // ModelName modelFullId = ModelName.of(project, location, modelId); + // ListModelEvaluationsRequest modelEvaluationsrequest = + // ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build(); + // for (ModelEvaluation modelEvaluation : + // modelServiceClient.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + // System.out.format("Model Evaluation Name: %s%n", modelEvaluation.getName()); + // } + // } + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + listModelEvaluationSliceSample(project, modelId, evaluationId); + } + + static void listModelEvaluationSliceSample(String project, String modelId, String evaluationId) + throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + for (ModelEvaluationSlice modelEvaluationSlice : + modelServiceClient.listModelEvaluationSlices(modelEvaluationName).iterateAll()) { + System.out.format("Model Evaluation Slice Name: %s\n", modelEvaluationSlice.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluationSlice.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluationSlice.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluationSlice.getCreateTime()); + + Slice slice = modelEvaluationSlice.getSlice(); + System.out.format("Slice Dimensions: %s\n", slice.getDimension()); + System.out.format("Slice Value: %s\n\n", slice.getValue()); + } + } + } +} +// [END aiplatform_list_model_evaluation_slice_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictCustomTrainedModelSample.java b/aiplatform/src/main/java/aiplatform/PredictCustomTrainedModelSample.java new file mode 100644 index 00000000000..40b2b2e8e5c --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictCustomTrainedModelSample.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_custom_trained_model_sample] + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.List; + +public class PredictCustomTrainedModelSample { + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]"; + String project = "YOUR_PROJECT_ID"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictCustomTrainedModel(project, endpointId, instance); + } + + static void predictCustomTrainedModel(String project, String endpointId, String instance) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + ListValue.Builder listValue = ListValue.newBuilder(); + JsonFormat.parser().merge(instance, listValue); + List instanceList = listValue.getValuesList(); + + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()) + .addAllInstances(instanceList) + .build(); + PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + + System.out.println("Predict Custom Trained model Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_custom_trained_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictImageClassificationSample.java b/aiplatform/src/main/java/aiplatform/PredictImageClassificationSample.java new file mode 100644 index 00000000000..c2d3ed60158 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictImageClassificationSample.java @@ -0,0 +1,104 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_image_classification_sample] + +import com.google.api.client.util.Base64; +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.cloud.aiplatform.v1.schema.predict.instance.ImageClassificationPredictionInstance; +import com.google.cloud.aiplatform.v1.schema.predict.params.ImageClassificationPredictionParams; +import com.google.cloud.aiplatform.v1.schema.predict.prediction.ClassificationPredictionResult; +import com.google.protobuf.Value; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class PredictImageClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String fileName = "YOUR_IMAGE_FILE_PATH"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictImageClassification(project, fileName, endpointId); + } + + static void predictImageClassification(String project, String fileName, String endpointId) + throws IOException { + PredictionServiceSettings settings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(settings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName))); + String content = new String(contents, StandardCharsets.UTF_8); + + ImageClassificationPredictionInstance predictionInstance = + ImageClassificationPredictionInstance.newBuilder().setContent(content).build(); + + List instances = new ArrayList<>(); + instances.add(ValueConverter.toValue(predictionInstance)); + + ImageClassificationPredictionParams predictionParams = + ImageClassificationPredictionParams.newBuilder() + .setConfidenceThreshold((float) 0.5) + .setMaxPredictions(5) + .build(); + + PredictResponse predictResponse = + predictionServiceClient.predict( + endpointName, instances, ValueConverter.toValue(predictionParams)); + System.out.println("Predict Image Classification Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + + ClassificationPredictionResult.Builder resultBuilder = + ClassificationPredictionResult.newBuilder(); + // Display names and confidences values correspond to + // IDs in the ID list. + ClassificationPredictionResult result = + (ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction); + int counter = 0; + for (Long id : result.getIdsList()) { + System.out.printf("Label ID: %d\n", id); + System.out.printf("Label: %s\n", result.getDisplayNames(counter)); + System.out.printf("Confidence: %.4f\n", result.getConfidences(counter)); + counter++; + } + } + } + } +} +// [END aiplatform_predict_image_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictImageObjectDetectionSample.java b/aiplatform/src/main/java/aiplatform/PredictImageObjectDetectionSample.java new file mode 100644 index 00000000000..16e2ac60585 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictImageObjectDetectionSample.java @@ -0,0 +1,103 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_image_object_detection_sample] + +import com.google.api.client.util.Base64; +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.cloud.aiplatform.v1.schema.predict.instance.ImageObjectDetectionPredictionInstance; +import com.google.cloud.aiplatform.v1.schema.predict.params.ImageObjectDetectionPredictionParams; +import com.google.cloud.aiplatform.v1.schema.predict.prediction.ImageObjectDetectionPredictionResult; +import com.google.protobuf.Value; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class PredictImageObjectDetectionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String fileName = "YOUR_IMAGE_FILE_PATH"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictImageObjectDetection(project, fileName, endpointId); + } + + static void predictImageObjectDetection(String project, String fileName, String endpointId) + throws IOException { + PredictionServiceSettings settings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(settings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName))); + String content = new String(contents, StandardCharsets.UTF_8); + + ImageObjectDetectionPredictionParams params = + ImageObjectDetectionPredictionParams.newBuilder() + .setConfidenceThreshold((float) (0.5)) + .setMaxPredictions(5) + .build(); + + ImageObjectDetectionPredictionInstance instance = + ImageObjectDetectionPredictionInstance.newBuilder().setContent(content).build(); + + List instances = new ArrayList<>(); + instances.add(ValueConverter.toValue(instance)); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, ValueConverter.toValue(params)); + System.out.println("Predict Image Object Detection Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + + ImageObjectDetectionPredictionResult.Builder resultBuilder = + ImageObjectDetectionPredictionResult.newBuilder(); + + ImageObjectDetectionPredictionResult result = + (ImageObjectDetectionPredictionResult) + ValueConverter.fromValue(resultBuilder, prediction); + + for (int i = 0; i < result.getIdsCount(); i++) { + System.out.printf("\tDisplay name: %s\n", result.getDisplayNames(i)); + System.out.printf("\tConfidences: %f\n", result.getConfidences(i)); + System.out.printf("\tIDs: %d\n", result.getIds(i)); + System.out.printf("\tBounding boxes: %s\n", result.getBboxes(i)); + } + } + } + } +} +// [END aiplatform_predict_image_object_detection_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictTabularClassificationSample.java b/aiplatform/src/main/java/aiplatform/PredictTabularClassificationSample.java new file mode 100644 index 00000000000..59adf1885d6 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictTabularClassificationSample.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_tabular_classification_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.cloud.aiplatform.v1.schema.predict.prediction.TabularClassificationPredictionResult; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.List; + +public class PredictTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictTabularClassification(instance, project, endpointId); + } + + static void predictTabularClassification(String instance, String project, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + ListValue.Builder listValue = ListValue.newBuilder(); + JsonFormat.parser().merge(instance, listValue); + List instanceList = listValue.getValuesList(); + + Value parameters = Value.newBuilder().setListValue(listValue).build(); + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instanceList, parameters); + System.out.println("Predict Tabular Classification Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + TabularClassificationPredictionResult.Builder resultBuilder = + TabularClassificationPredictionResult.newBuilder(); + TabularClassificationPredictionResult result = + (TabularClassificationPredictionResult) + ValueConverter.fromValue(resultBuilder, prediction); + + for (int i = 0; i < result.getClassesCount(); i++) { + System.out.printf("\tClass: %s", result.getClasses(i)); + System.out.printf("\tScore: %f", result.getScores(i)); + } + } + } + } +} +// [END aiplatform_predict_tabular_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictTabularRegressionSample.java b/aiplatform/src/main/java/aiplatform/PredictTabularRegressionSample.java new file mode 100644 index 00000000000..9520c958783 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictTabularRegressionSample.java @@ -0,0 +1,83 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_tabular_regression_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.cloud.aiplatform.v1.schema.predict.prediction.TabularRegressionPredictionResult; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.List; + +public class PredictTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictTabularRegression(instance, project, endpointId); + } + + static void predictTabularRegression(String instance, String project, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + ListValue.Builder listValue = ListValue.newBuilder(); + JsonFormat.parser().merge(instance, listValue); + List instanceList = listValue.getValuesList(); + + Value parameters = Value.newBuilder().setListValue(listValue).build(); + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instanceList, parameters); + System.out.println("Predict Tabular Regression Response"); + System.out.format("\tDisplay Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + TabularRegressionPredictionResult.Builder resultBuilder = + TabularRegressionPredictionResult.newBuilder(); + + TabularRegressionPredictionResult result = + (TabularRegressionPredictionResult) ValueConverter.fromValue(resultBuilder, prediction); + + System.out.printf("\tUpper bound: %f\n", result.getUpperBound()); + System.out.printf("\tLower bound: %f\n", result.getLowerBound()); + System.out.printf("\tValue: %f\n", result.getValue()); + } + } + } +} +// [END aiplatform_predict_tabular_regression_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java b/aiplatform/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java new file mode 100644 index 00000000000..3b66819d2bc --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_text_classification_sample] +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.cloud.aiplatform.v1.schema.predict.instance.TextClassificationPredictionInstance; +import com.google.cloud.aiplatform.v1.schema.predict.prediction.ClassificationPredictionResult; +import com.google.protobuf.Value; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class PredictTextClassificationSingleLabelSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String content = "YOUR_TEXT_CONTENT"; + String endpointId = "YOUR_ENDPOINT_ID"; + + predictTextClassificationSingleLabel(project, content, endpointId); + } + + static void predictTextClassificationSingleLabel( + String project, String content, String endpointId) throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + TextClassificationPredictionInstance predictionInstance = + TextClassificationPredictionInstance.newBuilder().setContent(content).build(); + + List instances = new ArrayList<>(); + instances.add(ValueConverter.toValue(predictionInstance)); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE); + System.out.println("Predict Text Classification Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions:\n\n"); + for (Value prediction : predictResponse.getPredictionsList()) { + + ClassificationPredictionResult.Builder resultBuilder = + ClassificationPredictionResult.newBuilder(); + + // Display names and confidences values correspond to + // IDs in the ID list. + ClassificationPredictionResult result = + (ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction); + int counter = 0; + for (Long id : result.getIdsList()) { + System.out.printf("Label ID: %d\n", id); + System.out.printf("Label: %s\n", result.getDisplayNames(counter)); + System.out.printf("Confidence: %.4f\n", result.getConfidences(counter)); + counter++; + } + } + } + } +} +// [END aiplatform_predict_text_classification_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictTextEntityExtractionSample.java b/aiplatform/src/main/java/aiplatform/PredictTextEntityExtractionSample.java new file mode 100644 index 00000000000..b7f10df4970 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictTextEntityExtractionSample.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_text_entity_extraction_sample] + +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.cloud.aiplatform.v1.schema.predict.instance.TextExtractionPredictionInstance; +import com.google.cloud.aiplatform.v1.schema.predict.prediction.TextExtractionPredictionResult; +import com.google.protobuf.Value; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class PredictTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String content = "YOUR_TEXT_CONTENT"; + String endpointId = "YOUR_ENDPOINT_ID"; + + predictTextEntityExtraction(project, content, endpointId); + } + + static void predictTextEntityExtraction(String project, String content, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + String jsonString = "{\"content\": \"" + content + "\"}"; + + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + TextExtractionPredictionInstance instance = + TextExtractionPredictionInstance.newBuilder().setContent(content).build(); + + List instances = new ArrayList<>(); + instances.add(ValueConverter.toValue(instance)); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE); + System.out.println("Predict Text Entity Extraction Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + TextExtractionPredictionResult.Builder resultBuilder = + TextExtractionPredictionResult.newBuilder(); + + TextExtractionPredictionResult result = + (TextExtractionPredictionResult) ValueConverter.fromValue(resultBuilder, prediction); + + for (int i = 0; i < result.getIdsCount(); i++) { + long textStartOffset = result.getTextSegmentStartOffsets(i); + long textEndOffset = result.getTextSegmentEndOffsets(i); + String entity = content.substring((int) textStartOffset, (int) textEndOffset); + + System.out.format("\tEntity: %s\n", entity); + System.out.format("\tEntity type: %s\n", result.getDisplayNames(i)); + System.out.format("\tConfidences: %f\n", result.getConfidences(i)); + System.out.format("\tIDs: %d\n", result.getIds(i)); + } + } + } + } +} +// [END aiplatform_predict_text_entity_extraction_sample] diff --git a/aiplatform/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java b/aiplatform/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..1d57a65dd7f --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_text_sentiment_analysis_sample] + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class PredictTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String content = "YOUR_TEXT_CONTENT"; + String endpointId = "YOUR_ENDPOINT_ID"; + + predictTextSentimentAnalysis(project, content, endpointId); + } + + static void predictTextSentimentAnalysis(String project, String content, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + String jsonString = "{\"content\": \"" + content + "\"}"; + + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); + Value.Builder instance = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, instance); + + List instances = new ArrayList<>(); + instances.add(instance.build()); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameter); + System.out.println("Predict Text Sentiment Analysis Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_text_sentiment_analysis_sample] diff --git a/aiplatform/src/main/java/aiplatform/SearchFeaturesAsyncSample.java b/aiplatform/src/main/java/aiplatform/SearchFeaturesAsyncSample.java new file mode 100644 index 00000000000..595fe18c533 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/SearchFeaturesAsyncSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Search for featurestore resources. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_search_features_async_sample] + +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.SearchFeaturesRequest; +import com.google.cloud.aiplatform.v1.SearchFeaturesResponse; +import com.google.common.base.Strings; +import java.io.IOException; + +public class SearchFeaturesAsyncSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String query = "YOUR_QUERY"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + searchFeaturesAsyncSample(project, query, location, endpoint); + } + + static void searchFeaturesAsyncSample( + String project, String query, String location, String endpoint) throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + SearchFeaturesRequest searchFeaturesRequest = + SearchFeaturesRequest.newBuilder() + .setLocation(LocationName.of(project, location).toString()) + .setQuery(query) + .build(); + System.out.println("Search Features Async Response"); + while (true) { + SearchFeaturesResponse response = + featurestoreServiceClient.searchFeaturesCallable().call(searchFeaturesRequest); + for (Feature element : response.getFeaturesList()) { + System.out.println(element); + } + String nextPageToken = response.getNextPageToken(); + if (!Strings.isNullOrEmpty(nextPageToken)) { + searchFeaturesRequest = + searchFeaturesRequest.toBuilder().setPageToken(nextPageToken).build(); + } else { + break; + } + } + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_search_features_async_sample] diff --git a/aiplatform/src/main/java/aiplatform/SearchFeaturesSample.java b/aiplatform/src/main/java/aiplatform/SearchFeaturesSample.java new file mode 100644 index 00000000000..62309a5a99e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/SearchFeaturesSample.java @@ -0,0 +1,69 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Search for featurestore resources. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_search_features_sample] + +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.SearchFeaturesRequest; +import java.io.IOException; + +public class SearchFeaturesSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String query = "YOUR_QUERY"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + searchFeaturesSample(project, query, location, endpoint); + } + + static void searchFeaturesSample(String project, String query, String location, String endpoint) + throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + SearchFeaturesRequest searchFeaturesRequest = + SearchFeaturesRequest.newBuilder() + .setLocation(LocationName.of(project, location).toString()) + .setQuery(query) + .build(); + System.out.println("Search Features Response"); + for (Feature element : + featurestoreServiceClient.searchFeatures(searchFeaturesRequest).iterateAll()) { + System.out.println(element); + } + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_search_features_sample] diff --git a/aiplatform/src/main/java/aiplatform/UndeployModelSample.java b/aiplatform/src/main/java/aiplatform/UndeployModelSample.java new file mode 100644 index 00000000000..db11f300166 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UndeployModelSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_undeploy_model_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.EndpointServiceClient; +import com.google.cloud.aiplatform.v1.EndpointServiceSettings; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.UndeployModelOperationMetadata; +import com.google.cloud.aiplatform.v1.UndeployModelResponse; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class UndeployModelSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String endpointId = "YOUR_ENDPOINT_ID"; + String modelId = "YOUR_MODEL_ID"; + undeployModelSample(project, endpointId, modelId); + } + + static void undeployModelSample(String project, String endpointId, String modelId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + EndpointServiceSettings endpointServiceSettings = + EndpointServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (EndpointServiceClient endpointServiceClient = + EndpointServiceClient.create(endpointServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + ModelName modelName = ModelName.of(project, location, modelId); + + // key '0' assigns traffic for the newly deployed model + // Traffic percentage values must add up to 100 + // Leave dictionary empty if endpoint should not accept any traffic + Map trafficSplit = new HashMap<>(); + trafficSplit.put("0", 100); + + OperationFuture operation = + endpointServiceClient.undeployModelAsync( + endpointName.toString(), modelName.toString(), trafficSplit); + System.out.format("Operation name: %s\n", operation.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + UndeployModelResponse undeployModelResponse = operation.get(180, TimeUnit.SECONDS); + + System.out.format("Undeploy Model Response: %s\n", undeployModelResponse); + } + } +} +// [END aiplatform_undeploy_model_sample] diff --git a/aiplatform/src/main/java/aiplatform/UpdateEntityTypeMonitoringSample.java b/aiplatform/src/main/java/aiplatform/UpdateEntityTypeMonitoringSample.java new file mode 100644 index 00000000000..3133b146f8b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UpdateEntityTypeMonitoringSample.java @@ -0,0 +1,87 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Update entity type. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_update_entity_type_monitoring_sample] + +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.FeaturestoreMonitoringConfig; +import com.google.cloud.aiplatform.v1.FeaturestoreMonitoringConfig.SnapshotAnalysis; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.UpdateEntityTypeRequest; +import java.io.IOException; + +public class UpdateEntityTypeMonitoringSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + int monitoringIntervalDays = 1; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + updateEntityTypeMonitoringSample( + project, featurestoreId, entityTypeId, monitoringIntervalDays, location, endpoint); + } + + static void updateEntityTypeMonitoringSample( + String project, + String featurestoreId, + String entityTypeId, + int monitoringIntervalDays, + String location, + String endpoint) + throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + FeaturestoreMonitoringConfig featurestoreMonitoringConfig = + FeaturestoreMonitoringConfig.newBuilder() + .setSnapshotAnalysis( + SnapshotAnalysis.newBuilder().setMonitoringIntervalDays(monitoringIntervalDays)) + .build(); + EntityType entityType = + EntityType.newBuilder() + .setName( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setMonitoringConfig(featurestoreMonitoringConfig) + .build(); + + UpdateEntityTypeRequest updateEntityTypeRequest = + UpdateEntityTypeRequest.newBuilder().setEntityType(entityType).build(); + EntityType entityTypeResponse = + featurestoreServiceClient.updateEntityType(updateEntityTypeRequest); + System.out.println("Update Entity Type Monitoring Response"); + System.out.println(entityTypeResponse); + } + } +} +// [END aiplatform_update_entity_type_monitoring_sample] diff --git a/aiplatform/src/main/java/aiplatform/UpdateEntityTypeSample.java b/aiplatform/src/main/java/aiplatform/UpdateEntityTypeSample.java new file mode 100644 index 00000000000..bd7af265020 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UpdateEntityTypeSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Update entity type. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_update_entity_type_sample] + +import com.google.cloud.aiplatform.v1.EntityType; +import com.google.cloud.aiplatform.v1.EntityTypeName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.UpdateEntityTypeRequest; +import java.io.IOException; + +public class UpdateEntityTypeSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String description = "Update Description"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + updateEntityTypeSample(project, featurestoreId, entityTypeId, description, location, endpoint); + } + + static void updateEntityTypeSample( + String project, + String featurestoreId, + String entityTypeId, + String description, + String location, + String endpoint) + throws IOException { + + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + EntityType entityType = + EntityType.newBuilder() + .setName( + EntityTypeName.of(project, location, featurestoreId, entityTypeId).toString()) + .setDescription(description) + .build(); + + UpdateEntityTypeRequest updateEntityTypeRequest = + UpdateEntityTypeRequest.newBuilder().setEntityType(entityType).build(); + EntityType entityTypeResponse = + featurestoreServiceClient.updateEntityType(updateEntityTypeRequest); + System.out.println("Update Entity Type Response"); + System.out.println(entityTypeResponse); + } + } +} +// [END aiplatform_update_entity_type_sample] diff --git a/aiplatform/src/main/java/aiplatform/UpdateFeatureSample.java b/aiplatform/src/main/java/aiplatform/UpdateFeatureSample.java new file mode 100644 index 00000000000..a68ada038ac --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UpdateFeatureSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Update feature. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_update_feature_sample] + +import com.google.cloud.aiplatform.v1.Feature; +import com.google.cloud.aiplatform.v1.FeatureName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.UpdateFeatureRequest; +import java.io.IOException; + +public class UpdateFeatureSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + String entityTypeId = "YOUR_ENTITY_TYPE_ID"; + String featureId = "YOUR_FEATURE_ID"; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + updateFeatureSample(project, featurestoreId, entityTypeId, featureId, location, endpoint); + } + + static void updateFeatureSample( + String project, + String featurestoreId, + String entityTypeId, + String featureId, + String location, + String endpoint) + throws IOException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + Feature feature = + Feature.newBuilder() + .setName( + FeatureName.of(project, location, featurestoreId, entityTypeId, featureId) + .toString()) + .setDescription("sample feature title updated") + .build(); + + UpdateFeatureRequest request = UpdateFeatureRequest.newBuilder().setFeature(feature).build(); + Feature featureResponse = featurestoreServiceClient.updateFeature(request); + System.out.println("Update Feature Response"); + System.out.format("Name: %s%n", featureResponse.getName()); + featurestoreServiceClient.close(); + } + } +} +// [END aiplatform_update_feature_sample] diff --git a/aiplatform/src/main/java/aiplatform/UpdateFeaturestoreFixedNodesSample.java b/aiplatform/src/main/java/aiplatform/UpdateFeaturestoreFixedNodesSample.java new file mode 100644 index 00000000000..71ef51edcc9 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UpdateFeaturestoreFixedNodesSample.java @@ -0,0 +1,93 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Update featurestore. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_update_featurestore_fixed_nodes_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.Featurestore; +import com.google.cloud.aiplatform.v1.Featurestore.OnlineServingConfig; +import com.google.cloud.aiplatform.v1.FeaturestoreName; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1.UpdateFeaturestoreOperationMetadata; +import com.google.cloud.aiplatform.v1.UpdateFeaturestoreRequest; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class UpdateFeaturestoreFixedNodesSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + int fixedNodeCount = 1; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + updateFeaturestoreFixedNodesSample( + project, featurestoreId, fixedNodeCount, location, endpoint, timeout); + } + + static void updateFeaturestoreFixedNodesSample( + String project, + String featurestoreId, + int fixedNodeCount, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + OnlineServingConfig.Builder builderValue = + OnlineServingConfig.newBuilder().setFixedNodeCount(fixedNodeCount); + Featurestore featurestore = + Featurestore.newBuilder() + .setName(FeaturestoreName.of(project, location, featurestoreId).toString()) + .setOnlineServingConfig(builderValue) + .build(); + + UpdateFeaturestoreRequest request = + UpdateFeaturestoreRequest.newBuilder().setFeaturestore(featurestore).build(); + + OperationFuture updateFeaturestoreFuture = + featurestoreServiceClient.updateFeaturestoreAsync(request); + System.out.format( + "Operation name: %s%n", updateFeaturestoreFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Featurestore featurestoreResponse = updateFeaturestoreFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Update Featurestore Fixed Nodes Response"); + System.out.format("Name: %s%n", featurestoreResponse.getName()); + } + } +} +// [END aiplatform_update_featurestore_fixed_nodes_sample] diff --git a/aiplatform/src/main/java/aiplatform/UpdateFeaturestoreSample.java b/aiplatform/src/main/java/aiplatform/UpdateFeaturestoreSample.java new file mode 100644 index 00000000000..7ccb0b0a18e --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UpdateFeaturestoreSample.java @@ -0,0 +1,98 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + * Updates the parameters of a single featurestore. See + * https://cloud.google.com/vertex-ai/docs/featurestore/setup before running + * the code snippet + */ + +package aiplatform; + +// [START aiplatform_update_featurestore_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.Featurestore; +import com.google.cloud.aiplatform.v1beta1.Featurestore.OnlineServingConfig; +import com.google.cloud.aiplatform.v1beta1.Featurestore.OnlineServingConfig.Scaling; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreName; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceClient; +import com.google.cloud.aiplatform.v1beta1.FeaturestoreServiceSettings; +import com.google.cloud.aiplatform.v1beta1.UpdateFeaturestoreOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.UpdateFeaturestoreRequest; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class UpdateFeaturestoreSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String featurestoreId = "YOUR_FEATURESTORE_ID"; + int minNodeCount = 2; + int maxNodeCount = 4; + String location = "us-central1"; + String endpoint = "us-central1-aiplatform.googleapis.com:443"; + int timeout = 300; + updateFeaturestoreSample( + project, featurestoreId, minNodeCount, maxNodeCount, location, endpoint, timeout); + } + + static void updateFeaturestoreSample( + String project, + String featurestoreId, + int minNodeCount, + int maxNodeCount, + String location, + String endpoint, + int timeout) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + FeaturestoreServiceSettings featurestoreServiceSettings = + FeaturestoreServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (FeaturestoreServiceClient featurestoreServiceClient = + FeaturestoreServiceClient.create(featurestoreServiceSettings)) { + + OnlineServingConfig.Builder builderValue = + OnlineServingConfig.newBuilder() + .setScaling( + Scaling.newBuilder().setMinNodeCount(minNodeCount).setMaxNodeCount(maxNodeCount)); + Featurestore featurestore = + Featurestore.newBuilder() + .setName(FeaturestoreName.of(project, location, featurestoreId).toString()) + .setOnlineServingConfig(builderValue) + .build(); + + UpdateFeaturestoreRequest request = + UpdateFeaturestoreRequest.newBuilder().setFeaturestore(featurestore).build(); + + OperationFuture updateFeaturestoreFuture = + featurestoreServiceClient.updateFeaturestoreAsync(request); + System.out.format( + "Operation name: %s%n", updateFeaturestoreFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Featurestore featurestoreResponse = updateFeaturestoreFuture.get(timeout, TimeUnit.SECONDS); + System.out.println("Update Featurestore Response"); + System.out.format("Name: %s%n", featurestoreResponse.getName()); + } + } +} +// [END aiplatform_update_featurestore_sample] diff --git a/aiplatform/src/main/java/aiplatform/UploadModelSample.java b/aiplatform/src/main/java/aiplatform/UploadModelSample.java new file mode 100644 index 00000000000..f6b2fecec8a --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/UploadModelSample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_upload_model_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.Model; +import com.google.cloud.aiplatform.v1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1.ModelServiceClient; +import com.google.cloud.aiplatform.v1.ModelServiceSettings; +import com.google.cloud.aiplatform.v1.UploadModelOperationMetadata; +import com.google.cloud.aiplatform.v1.UploadModelResponse; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class UploadModelSample { + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"; + String imageUri = "YOUR_IMAGE_URI"; + String artifactUri = "gs://your-gcs-bucket/artifact_path"; + uploadModel(project, modelDisplayName, metadataSchemaUri, imageUri, artifactUri); + } + + static void uploadModel( + String project, + String modelDisplayName, + String metadataSchemaUri, + String imageUri, + String artifactUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + ModelContainerSpec modelContainerSpec = + ModelContainerSpec.newBuilder().setImageUri(imageUri).build(); + + Model model = + Model.newBuilder() + .setDisplayName(modelDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .setArtifactUri(artifactUri) + .setContainerSpec(modelContainerSpec) + .build(); + + OperationFuture uploadModelResponseFuture = + modelServiceClient.uploadModelAsync(locationName, model); + System.out.format( + "Operation name: %s\n", uploadModelResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + UploadModelResponse uploadModelResponse = uploadModelResponseFuture.get(5, TimeUnit.MINUTES); + + System.out.println("Upload Model Response"); + System.out.format("Model: %s\n", uploadModelResponse.getModel()); + } + } +} +// [END aiplatform_upload_model_sample] diff --git a/aiplatform/src/test/java/aiplatform/CancelDataLabelingJobSampleTest.java b/aiplatform/src/test/java/aiplatform/CancelDataLabelingJobSampleTest.java new file mode 100644 index 00000000000..6ea3303aa68 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CancelDataLabelingJobSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CancelDataLabelingJobSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("DATA_LABELING_DATASET_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String INPUT_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification.yaml"; + private static final String ANNOTATION_SPEC = "daisy"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created data labeling + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore + public void testCancelDataLabelingJob() throws IOException, InterruptedException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobSample.createDataLabelingJob( + PROJECT, + dataLabelingDisplayName, + DATASET_ID, + INSTRUCTION_URI, + INPUT_SCHEMA_URI, + ANNOTATION_SPEC); + + String got = bout.toString(); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + } +} diff --git a/aiplatform/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java b/aiplatform/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java new file mode 100644 index 00000000000..a95073f9dc6 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CancelTrainingPipelineSampleTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CancelTrainingPipelineSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_DATASET_ID"); + private static final String TRAINING_TASK_DEFINITION = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_image_classification_1.0.0.yaml"; + private static String TRAINING_PIPELINE_ID = null; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, TRAINING_PIPELINE_ID); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void cancelTrainingPipeline() throws IOException, InterruptedException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineSample.createTrainingPipelineSample( + PROJECT, + trainingPipelineDisplayName, + DATASET_ID, + TRAINING_TASK_DEFINITION, + modelDisplayName); + + // Assert + String createTrainingPipelineResponse = bout.toString(); + assertThat(createTrainingPipelineResponse).contains(DATASET_ID); + assertThat(createTrainingPipelineResponse).contains("Create Training Pipeline Response"); + TRAINING_PIPELINE_ID = + createTrainingPipelineResponse + .split("Name: ")[1] + .split("trainingPipelines/")[1] + .split("\n")[0]; + + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, TRAINING_PIPELINE_ID); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(1); + + // Get TrainingPipeline + GetTrainingPipelineSample.getTrainingPipeline(PROJECT, TRAINING_PIPELINE_ID); + String trainingPipelineResponse = bout.toString(); + assertThat(trainingPipelineResponse).contains("Message: CANCELED"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java new file mode 100644 index 00000000000..25114e60731 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobBigquerySampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobBigquerySampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_TABULAR_BQ_MODEL_ID"); + private static final String BIGQUERY_SOURCE_URI = + "bq://ucaip-sample-tests.table_test.all_bq_types"; + private static final String BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX = "bq://ucaip-sample-tests"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_TABULAR_BQ_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobBigquerySample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_bigquery_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobBigquerySample.createBatchPredictionJobBigquerySample( + PROJECT, + batchPredictionDisplayName, + MODEL_ID, + "bigquery", + BIGQUERY_SOURCE_URI, + "bigquery", + BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java new file mode 100644 index 00000000000..1def01b3ddc --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_bigquery_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobSample.createBatchPredictionJobSample( + PROJECT, + batchPredictionDisplayName, + MODEL_ID, + "jsonl", + GCS_SOURCE_URI, + "jsonl", + GCS_OUTPUT_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextClassificationSampleTest.java new file mode 100644 index 00000000000..2d5d4d10baa --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextClassificationSampleTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateBatchPredictionJobTextClassificationSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String MODEL_ID = System.getenv("TEXT_CLASS_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String got; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_CLASS_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + + String batchPredictionJobId = + got.split("name:")[1].split("batchPredictionJobs/")[1].split("\"\n")[0]; + + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateBatchPredictionJobTextClassificationSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "temp_java_create_batch_prediction_TCN_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobTextClassificationSample + .createBatchPredictionJobTextClassificationSample( + PROJECT, + LOCATION, + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_OUTPUT_URI); + + // Assert + got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..22b7a85dd8b --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSampleTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateBatchPredictionJobTextEntityExtractionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String MODEL_ID = System.getenv("TEXT_ENTITY_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/batch_predict_TEN/ten_inputs.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String got; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_ENTITY_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + String batchPredictionJobId = + got.split("name:")[1].split("batchPredictionJobs/")[1].split("\"\n")[0]; + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateBatchPredictionJobTextEntityExtractionSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "temp_java_create_batch_prediction_TEN_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobTextEntityExtractionSample + .createBatchPredictionJobTextEntityExtractionSample( + PROJECT, + LOCATION, + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_OUTPUT_URI); + + // Assert + got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..73b65b8fdb4 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSampleTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateBatchPredictionJobTextSentimentAnalysisSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String MODEL_ID = System.getenv("TEXT_SENTI_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/batch_predict_TSN/tsn_inputs.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String got; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_SENTI_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + String batchPredictionJobId = + got.split("name:")[1].split("batchPredictionJobs/")[1].split("\"\n")[0]; + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateBatchPredictionJobTextSentimentAnalysisSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "temp_java_create_batch_prediction_TSN_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobTextSentimentAnalysisSample + .createBatchPredictionJobTextSentimentAnalysisSample( + PROJECT, + LOCATION, + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_OUTPUT_URI); + + // Assert + got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..90072c1c7ad --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSampleTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateBatchPredictionJobVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("BATCH_PREDICTION_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateBatchPredictionJobVideoActionRecognitionSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_video_action_recognition_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobVideoActionRecognitionSample + .createBatchPredictionJobVideoActionRecognitionSample( + PROJECT, batchPredictionDisplayName, MODEL_ID, GCS_SOURCE_URI, GCS_OUTPUT_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java new file mode 100644 index 00000000000..1f64dbca4e0 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateBatchPredictionJobVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("VIDEO_CLASS_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/vcn_40_batch_prediction_input.jsonl"; + private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("VIDEO_CLASS_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Batch Prediction Job + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateBatchPredictionJobVideoClassificationSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_video_classification_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobVideoClassificationSample.createBatchPredictionJobVideoClassification( + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_DESTINATION_OUTPUT_URI_PREFIX, + PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("Create Batch Prediction Job Video Classification Response"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..f4306e3d737 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateBatchPredictionJobVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("VIDEO_OBJECT_DETECT_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/vot_batch_prediction_input.jsonl"; + private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Batch Prediction Job + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateBatchPredictionJobVideoObjectTrackingSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_video_object_tracking_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobVideoObjectTrackingSample.batchPredictionJobVideoObjectTracking( + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_DESTINATION_OUTPUT_URI_PREFIX, + PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("Create Batch Prediction Job Video Object Tracking Response"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java new file mode 100644 index 00000000000..5280476f333 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobActiveLearningSampleTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobActiveLearningSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String INPUTS_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification_1.0.0.yaml"; + private static final String ANNOTATION_SPEC = "roses"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("Avoid creating actual data labeling job for humans") + public void testCreateDataLabelingJobActiveLearningSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_job_active_learning_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobActiveLearningSample.createDataLabelingJobActiveLearningSample( + PROJECT, + dataLabelingDisplayName, + DATASET_ID, + INSTRUCTION_URI, + INPUTS_SCHEMA_URI, + ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Image Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobImageSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobImageSampleTest.java new file mode 100644 index 00000000000..27dc9164002 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobImageSampleTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobImageSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("DATA_LABELING_IMAGE_DATASET_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String ANNOTATION_SPEC = "roses"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_IMAGE_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore + public void testCreateDataLabelingJobImageSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_job_image_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobImageSample.createDataLabelingJobImage( + PROJECT, dataLabelingDisplayName, DATASET_ID, INSTRUCTION_URI, ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Image Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobSampleTest.java new file mode 100644 index 00000000000..6f939353040 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobSampleTest.java @@ -0,0 +1,114 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("DATA_LABELING_DATASET_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String INPUT_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification.yaml"; + private static final String ANNOTATION_SPEC = "daisy"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore + public void testCreateDataLabelingJobSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobSample.createDataLabelingJob( + PROJECT, + dataLabelingDisplayName, + DATASET_ID, + INSTRUCTION_URI, + INPUT_SCHEMA_URI, + ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java new file mode 100644 index 00000000000..7c41c5d844a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobSpecialistPoolSampleTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobSpecialistPoolSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + private static final String SPECIALIST_POOL_ID = + System.getenv("DATA_LABELING_SPECIALIST_POOL_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String INPUTS_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/datalabelingjob/inputs/image_classification_1.0.0.yaml"; + private static final String ANNOTATION_SPEC = "roses"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_ACTIVE_LEARNING_DATASET_ID"); + requireEnvVar("DATA_LABELING_SPECIALIST_POOL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("Avoid creating actual data labeling job for humans") + public void testCreateDataLabelingJobSpecialistPoolSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_job_specialist_pool_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobSpecialistPoolSample.createDataLabelingJobSpecialistPoolSample( + PROJECT, + dataLabelingDisplayName, + DATASET_ID, + SPECIALIST_POOL_ID, + INSTRUCTION_URI, + INPUTS_SCHEMA_URI, + ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Image Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java new file mode 100644 index 00000000000..2c6ee822278 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDataLabelingJobVideoSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("DATA_LABELING_VIDEO_DATASET_ID"); + private static final String INSTRUCTION_URI = + "gs://ucaip-sample-resources/images/datalabeling_instructions.pdf"; + private static final String ANNOTATION_SPEC = "cars"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String dataLabelingJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("DATA_LABELING_VIDEO_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel data labeling job + CancelDataLabelingJobSample.cancelDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled Data labeling job"); + TimeUnit.MINUTES.sleep(1); + + // Delete the created dataset + DeleteDataLabelingJobSample.deleteDataLabelingJob(PROJECT, dataLabelingJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Data Labeling Job."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("Avoid creating actual data labeling job for humans") + public void testCreateDataLabelingJobVideoSample() throws IOException { + // Act + String dataLabelingDisplayName = + String.format( + "temp_data_labeling_job_video_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDataLabelingJobVideoSample.createDataLabelingJobVideo( + PROJECT, dataLabelingDisplayName, DATASET_ID, INSTRUCTION_URI, ANNOTATION_SPEC); + + // Assert + String got = bout.toString(); + assertThat(got).contains(dataLabelingDisplayName); + assertThat(got).contains("Create Data Labeling Job Video Response"); + dataLabelingJobId = got.split("Name: ")[1].split("dataLabelingJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDatasetImageSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDatasetImageSampleTest.java new file mode 100644 index 00000000000..d4667e6111c --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDatasetImageSampleTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetImageSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateDatasetSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetImageSample.createDatasetImageSample(PROJECT, datasetDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Image Dataset Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDatasetSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDatasetSampleTest.java new file mode 100644 index 00000000000..408bead923d --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDatasetSampleTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String METADATA_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT_ID, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String displayName = + String.format( + "temp_create_dataset_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetSample.createDatasetSample(PROJECT_ID, displayName, METADATA_SCHEMA_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(displayName); + assertThat(got).contains("Create Dataset Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java new file mode 100644 index 00000000000..42b002514a5 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java @@ -0,0 +1,93 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateDatasetTabularBigquerySampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String GCS_SOURCE_URI = "bq://ucaip-sample-tests.table_test.all_bq_types"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetTabularBigquerySample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_table_bigquery_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetTabularBigquerySample.createDatasetTableBigquery( + PROJECT, datasetDisplayName, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Dataset Table Bigquery sample"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java new file mode 100644 index 00000000000..3d9c5bba225 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateDatasetTabularGcsSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String GCS_SOURCE_URI = "gs://cloud-ml-tables-data/bank-marketing.csv"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateDatasetTabularGcsSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_table_gcs_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetTabularGcsSample.createDatasetTableGcs( + PROJECT, datasetDisplayName, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Dataset Table GCS sample"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDatasetTextSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDatasetTextSampleTest.java new file mode 100644 index 00000000000..ba3c98df9ab --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDatasetTextSampleTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetTextSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateDatasetSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetTextSample.createDatasetTextSample(PROJECT, datasetDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Text Dataset Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java new file mode 100644 index 00000000000..a983f079224 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetVideoSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateDatasetVideoSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String displayName = + String.format( + "temp_create_dataset_video_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetVideoSample.createDatasetSample(displayName, PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(displayName); + assertThat(got).contains("Create Dataset Video Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateEndpointSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateEndpointSampleTest.java new file mode 100644 index 00000000000..f301710da38 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateEndpointSampleTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateEndpointSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String endpointId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created endpoint + DeleteEndpointSample.deleteEndpointSample(PROJECT_ID, endpointId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Delete Endpoint Response: "); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateEndpointSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String displayName = + String.format( + "temp_create_endpoint_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateEndpointSample.createEndpointSample(PROJECT_ID, displayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains("us-central1"); + assertThat(got).contains("Create Endpoint Response"); + endpointId = got.split("Name: ")[1].split("endpoints/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java new file mode 100644 index 00000000000..93f04e9e065 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSampleTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateHyperparameterTuningJobPythonPackageSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String EXECUTOR_IMAGE_URI = + "us.gcr.io/cloud-aiplatform/training/tf-gpu.2-1:latest"; + private static final String PACKAGE_URI = + "gs://cloud-samples-data-us-central1/ai-platform-unified/training/python-packages/" + + "trainer.tar.bz2"; + private static final String PYTHON_MODULE = "trainer.hptuning_trainer"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String hyperparameterJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (JobServiceClient client = JobServiceClient.create(settings)) { + // Cancel hyper parameter job + String hyperparameterJobName = + String.format( + "projects/%s/locations/us-central1/hyperparameterTuningJobs/%s", + PROJECT, hyperparameterJobId); + client.cancelHyperparameterTuningJob(hyperparameterJobName); + + TimeUnit.MINUTES.sleep(1); + + // Delete the created job + client.deleteHyperparameterTuningJobAsync(hyperparameterJobName); + System.out.flush(); + System.setOut(originalPrintStream); + } + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateHyperparameterTuningJobPythonPackageSample() throws IOException { + String hyperparameterTuningJobDisplayName = + String.format( + "temp_hyperparameter_tuning_job_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + CreateHyperparameterTuningJobPythonPackageSample + .createHyperparameterTuningJobPythonPackageSample( + PROJECT, + hyperparameterTuningJobDisplayName, + EXECUTOR_IMAGE_URI, + PACKAGE_URI, + PYTHON_MODULE); + + // Assert + String got = bout.toString(); + assertThat(got).contains(hyperparameterTuningJobDisplayName); + assertThat(got).contains("response:"); + hyperparameterJobId = + got.split("Name: ")[1].split("hyperparameterTuningJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java new file mode 100644 index 00000000000..48343412a6f --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateHyperparameterTuningJobSampleTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateHyperparameterTuningJobSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String CONTAINER_IMAGE_URI = + "gcr.io/ucaip-sample-tests/ucaip-training-test:latest"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String hyperparameterJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (JobServiceClient client = JobServiceClient.create(settings)) { + // Cancel hyper parameter job + String hyperparameterJobName = + String.format( + "projects/%s/locations/us-central1/hyperparameterTuningJobs/%s", + PROJECT, hyperparameterJobId); + client.cancelHyperparameterTuningJob(hyperparameterJobName); + + TimeUnit.MINUTES.sleep(1); + + // Delete the created job + client.deleteHyperparameterTuningJobAsync(hyperparameterJobName); + System.out.flush(); + System.setOut(originalPrintStream); + } + } + + @Test + public void testCreateHyperparameterTuningJobSample() throws IOException { + String hyperparameterTuningJobDisplayName = + String.format( + "temp_hyperparameter_tuning_job_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateHyperparameterTuningJobSample.createHyperparameterTuningJobSample( + PROJECT, hyperparameterTuningJobDisplayName, CONTAINER_IMAGE_URI); + + String got = bout.toString(); + assertThat(got).contains(hyperparameterTuningJobDisplayName); + assertThat(got).contains("response:"); + hyperparameterJobId = + got.split("Name: ")[1].split("hyperparameterTuningJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java new file mode 100644 index 00000000000..3762eb6bb33 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineCustomJobSampleTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import io.grpc.StatusRuntimeException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineCustomJobSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String CONTAINER_IMAGE_URI = + "gcr.io/ucaip-sample-tests/mnist-custom-job:latest"; + private static final String GCS_OUTPUT_DIRECTORY = + "gs://ucaip-samples-us-central1/training_pipeline_output"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + int retryCount = 3; + while (retryCount > 0) { + retryCount--; + try { + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + // if delete operation is successful, break out of the loop and continue + break; + } catch (StatusRuntimeException | ExecutionException ex) { + // wait for another 1 minute, then retry + System.out.println("Retrying (due to unfinished cancellation operation)..."); + TimeUnit.MINUTES.sleep(1); + } catch (Exception otherExceptions) { + // other exception, let them throw + throw otherExceptions; + } + } + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineCustomJobSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineCustomJobSample.createTrainingPipelineCustomJobSample( + PROJECT, + trainingPipelineDisplayName, + modelDisplayName, + CONTAINER_IMAGE_URI, + GCS_OUTPUT_DIRECTORY); + + // Assert + String got = bout.toString(); + assertThat(got).contains(trainingPipelineDisplayName); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java new file mode 100644 index 00000000000..11cb9b8f1cd --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineCustomTrainingManagedDatasetSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("CUSTOM_MANAGED_DATASET"); + private static final String ANNOTATION_SCHEMA_URI = + "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml"; + private static final String TRAINING_CONTAINER_IMAGE_URI = + "gcr.io/ucaip-sample-tests/custom-container-managed-dataset:latest"; + private static final String MODEL_CONTAIN_SPEC_IMAGE_URI = + "gcr.io/cloud-aiplatform/prediction/tf-gpu.1-15:latest"; + private static final String GCS_OUTPUT_DIRECTORY = + "gs://ucaip-samples-us-central1/training_pipeline_output/custom_training_managed_dataset"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("CUSTOM_MANAGED_DATASET"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineCustomTrainingManagedDatasetSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineCustomTrainingManagedDatasetSample + .createTrainingPipelineCustomTrainingManagedDatasetSample( + PROJECT, + trainingPipelineDisplayName, + modelDisplayName, + DATASET_ID, + ANNOTATION_SCHEMA_URI, + TRAINING_CONTAINER_IMAGE_URI, + MODEL_CONTAIN_SPEC_IMAGE_URI, + GCS_OUTPUT_DIRECTORY); + + // Assert + String got = bout.toString(); + assertThat(got).contains(trainingPipelineDisplayName); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java new file mode 100644 index 00000000000..e77cf0e2873 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineImageClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateTrainingPipelineImageClassificationSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineImageClassificationSample.createTrainingPipelineImageClassificationSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Image Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..c4295cb9440 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_IMAGE_OBJECT_DETECT_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_IMAGE_OBJECT_DETECT_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineImageObjectDetectionSample() throws IOException { + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26); + // Act + String trainingPipelineDisplayName = + String.format("temp_create_training_pipeline_test_%s", tempUuid); + + String modelDisplayName = + String.format("temp_create_training_pipeline_model_test_%s", tempUuid); + + CreateTrainingPipelineImageObjectDetectionSample + .createTrainingPipelineImageObjectDetectionSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Image Object Detection Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineSampleTest.java new file mode 100644 index 00000000000..fe31cb03fa2 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineSampleTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = "1084241610289446912"; + private static final String TRAINING_TASK_DEFINITION = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_image_classification_1.0.0.yaml"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT_ID, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT_ID, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineSample() + throws IOException, InterruptedException, ExecutionException { + // Act + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26); + String trainingPipelineDisplayName = + String.format("temp_create_training_pipeline_test_%s", tempUuid); + + String modelDisplayName = + String.format("temp_create_training_pipeline_model_test_%s", tempUuid); + + CreateTrainingPipelineSample.createTrainingPipelineSample( + PROJECT_ID, + trainingPipelineDisplayName, + DATASET_ID, + TRAINING_TASK_DEFINITION, + modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Create Training Pipeline Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java new file mode 100644 index 00000000000..50ba9a264a3 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java @@ -0,0 +1,126 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipeline; +import com.google.cloud.aiplatform.v1.TrainingPipelineName; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_TABLES_CLASSIFICATION_DATASET_ID"); + private static final String TARGET_COLUMN = "species"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TABLES_CLASSIFICATION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + TrainingPipelineName trainingPipelineName = + TrainingPipelineName.of(PROJECT, location, trainingPipelineId); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.getTrainingPipeline(trainingPipelineName); + while (!trainingPipelineResponse.getState().name().contains("STATE_CANCELLED")) { + TimeUnit.SECONDS.sleep(30); + trainingPipelineResponse = pipelineServiceClient.getTrainingPipeline(trainingPipelineName); + } + } + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void createTrainingPipelineTabularClassification() throws IOException { + // Act + String modelDisplayName = + String.format( + "temp_create_training_pipelinetabularclassification_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTabularClassificationSample.createTrainingPipelineTableClassification( + PROJECT, modelDisplayName, DATASET_ID, TARGET_COLUMN); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Tabular Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java new file mode 100644 index 00000000000..c36ab9bb9c0 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateTrainingPipelineTabularRegressionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID"); + private static final String TARGET_COLUMN = "FLOAT_5000unique_REQUIRED"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(3); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void createTrainingPipelineTabularRegression() throws IOException { + // Act + String modelDisplayName = + String.format( + "temp_create_training_pipelinetabularregression_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTabularRegressionSample.createTrainingPipelineTableRegression( + PROJECT, modelDisplayName, DATASET_ID, TARGET_COLUMN); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Tabular Regression Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java new file mode 100644 index 00000000000..5b68dab26f6 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineTextClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_TEXT_CLASS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TEXT_CLASS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineTextClassificationSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTextClassificationSample.createTrainingPipelineTextClassificationSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Text Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..fc93ccb06da --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java @@ -0,0 +1,114 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineTextEntityExtractionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_TEXT_ENTITY_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TEXT_ENTITY_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateTrainingPipelineTextEntityExtractionSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTextEntityExtractionSample + .createTrainingPipelineTextEntityExtractionSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Text Entity Extraction Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..b84598e54d9 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineTextSentimentAnalysisSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_TEXT_SENTI_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TEXT_SENTI_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateTrainingPipelineTextSentimentAnalysisSample() throws IOException { + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26); + // Act + String trainingPipelineDisplayName = + String.format("temp_create_training_pipeline_test_%s", tempUuid); + + String modelDisplayName = + String.format("temp_create_training_pipeline_model_test_%s", tempUuid); + + CreateTrainingPipelineTextSentimentAnalysisSample + .createTrainingPipelineTextSentimentAnalysisSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Text Sentiment Analysis Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..11fb905febf --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSampleTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_VIDEO_ACTION_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_VIDEO_ACTION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineVideoActionRecognitionSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_video_action_recognition_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_video_action_recognition_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineVideoActionRecognitionSample + .createTrainingPipelineVideoActionRecognitionSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java new file mode 100644 index 00000000000..7a54fc65f8a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateTrainingPipelineVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_VIDEO_CLASS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_VIDEO_CLASS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateTrainingPipelineVideoClassificationSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_video_classification_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_video_classification_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineVideoClassificationSample.createTrainingPipelineVideoClassification( + trainingPipelineDisplayName, DATASET_ID, modelDisplayName, PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Video Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..359cde9976a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class CreateTrainingPipelineVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_VIDEO_OBJECT_DETECT_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_VIDEO_OBJECT_DETECT_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testCreateTrainingPipelineVideoObjectTrackingSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_video_object_tracking_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_video_object_tracking_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineVideoObjectTrackingSample.createTrainingPipelineVideoObjectTracking( + trainingPipelineDisplayName, DATASET_ID, modelDisplayName, PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Video Object Tracking Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java b/aiplatform/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java new file mode 100644 index 00000000000..71a0d53fa0a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/DeployModelCustomTrainedModelSampleTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import io.grpc.StatusRuntimeException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class DeployModelCustomTrainedModelSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "4992732768149438464"; + private static final String ENDPOINT_ID = "4366591682456584192"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + + // Undeploy the model + try { + UndeployModelSample.undeployModelSample(PROJECT_ID, ENDPOINT_ID, MODEL_ID); + } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) { + e.printStackTrace(); + } + } + + @Ignore("Issues with undeploy") + @Test + public void testDeployModelCustomTrainedModelSample() throws TimeoutException { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + String deployedModelDisplayName = + String.format( + "temp_deploy_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + try { + DeployModelCustomTrainedModelSample.deployModelCustomTrainedModelSample( + PROJECT_ID, ENDPOINT_ID, MODEL_ID, deployedModelDisplayName); + // Assert + String got = bout.toString(); + assertThat(got).contains("deployModelResponse"); + } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) { + assertThat(e.getMessage()).contains("is not found."); + } + } +} diff --git a/aiplatform/src/test/java/aiplatform/DeployModelSampleTest.java b/aiplatform/src/test/java/aiplatform/DeployModelSampleTest.java new file mode 100644 index 00000000000..e8878a9658f --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/DeployModelSampleTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import io.grpc.StatusRuntimeException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class DeployModelSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "00000000000000000"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testDeployModelSample() throws TimeoutException { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + String deployedModelDisplayName = + String.format( + "temp_deploy_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + try { + DeployModelSample.deployModelSample( + PROJECT_ID, deployedModelDisplayName, "4366591682456584192", MODEL_ID); + // Assert + String got = bout.toString(); + assertThat(got).contains("is not found."); + } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) { + assertThat(e.getMessage()).contains("is not found."); + } + } +} diff --git a/aiplatform/src/test/java/aiplatform/ExportModelSampleTest.java b/aiplatform/src/test/java/aiplatform/ExportModelSampleTest.java new file mode 100644 index 00000000000..20e9725f1e5 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ExportModelSampleTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExportModelSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3422489426196955136"; + private static final String GCS_DESTINATION_URI_PREFIX = + "gs://ucaip-samples-test-output/tmp/export_model_test"; + private static final String EXPORT_FORMAT = "tf-saved-model"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the export model + String bucketName; + String objectName; + bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2]; + objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID); + DeleteExportModelSample.deleteExportModelSample(PROJECT_ID, bucketName, objectName); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Export Model Deleted"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testExportModelSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ExportModelSample.exportModelSample( + PROJECT_ID, MODEL_ID, GCS_DESTINATION_URI_PREFIX, EXPORT_FORMAT); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Export Model Response: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java new file mode 100644 index 00000000000..967efab654f --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class ExportModelTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "6036688272397172736"; + private static final String GCS_DESTINATION_URI_PREFIX = + "gs://ucaip-samples-test-output/tmp/export_model_test"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + // Delete the export model + String bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2]; + String objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID); + DeleteExportModelSample.deleteExportModelSample(PROJECT, bucketName, objectName); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Export Model Deleted"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void exportModelTabularClassification() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // Act + ExportModelTabularClassificationSample.exportModelTableClassification( + GCS_DESTINATION_URI_PREFIX, PROJECT, MODEL_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Export Model Tabular Classification Response: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java b/aiplatform/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..c622eaf154d --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ExportModelVideoActionRecognitionSampleTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class ExportModelVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("EXPORT_MODEL_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + private static final String GCS_DESTINATION_URI_PREFIX = + "gs://ucaip-samples-test-output/tmp/export_model_video_action_recognition_sample"; + private static final String EXPORT_FORMAT = "tf-saved-model"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("EXPORT_MODEL_VIDEO_ACTION_RECOGNITION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + // Delete the export model + String bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2]; + String objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID); + DeleteExportModelSample.deleteExportModelSample(PROJECT, bucketName, objectName); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Export Model Deleted"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testExportModelVideoActionRecognitionSample() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // Act + ExportModelVideoActionRecognitionSample.exportModelVideoActionRecognitionSample( + PROJECT, MODEL_ID, GCS_DESTINATION_URI_PREFIX, EXPORT_FORMAT); + + // Assert + String got = bout.toString(); + assertThat(got).contains("exportModelResponse: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/FeatureValuesSamplesTest.java b/aiplatform/src/test/java/aiplatform/FeatureValuesSamplesTest.java new file mode 100644 index 00000000000..b4e6bba320d --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/FeatureValuesSamplesTest.java @@ -0,0 +1,345 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1.Feature.ValueType; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQuery.DatasetDeleteOption; +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Dataset; +import com.google.cloud.bigquery.DatasetId; +import com.google.cloud.bigquery.DatasetInfo; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class FeatureValuesSamplesTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final int MIN_NODE_COUNT = 1; + private static final int MAX_NODE_COUNT = 2; + private static final String DESCRIPTION = "Test Description"; + private static final boolean USE_FORCE = true; + private static final ValueType VALUE_TYPE = ValueType.STRING; + private static final String QUERY = "value_type=STRING"; + private static final String ENTITY_ID_FIELD = "movie_id"; + private static final String FEATURE_TIME_FIELD = "update_time"; + private static final String GCS_SOURCE_URI = + "gs://cloud-samples-data-us-central1/vertex-ai/feature-store/datasets/movies.avro"; + private static final int WORKER_COUNT = 2; + private static final String INPUT_CSV_FILE = + "gs://cloud-samples-data-us-central1/vertex-ai/feature-store/datasets/movie_prediction.csv"; + private static final List FEATURE_SELECTOR_IDS = + Arrays.asList("title", "genres", "average_rating"); + private static final String LOCATION = "us-central1"; + private static final String ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; + private static final int TIMEOUT = 900; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String featurestoreId; + private String destinationTableUri; + private Date date; + private SimpleDateFormat dateFormat; + private String datasetName; + private String destinationTableName; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + date = new Date(); + dateFormat = new SimpleDateFormat("yyyyMMddHHmmSSS"); + datasetName = "movie_predictions" + dateFormat.format(date); + destinationTableName = "training_data"; + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + static void createBigQueryDataset(String projectId, String datasetName, String location) { + try { + // Initialize client that will be used to send requests. This client only needs + // to be created + // once, and can be reused for multiple requests. + BigQuery bigquery = + BigQueryOptions.newBuilder() + .setLocation(location) + .setProjectId(projectId) + .build() + .getService(); + DatasetInfo datasetInfo = DatasetInfo.newBuilder(datasetName).build(); + + Dataset newDataset = bigquery.create(datasetInfo); + String newDatasetName = newDataset.getDatasetId().getDataset(); + System.out.println(newDatasetName + " created successfully"); + } catch (BigQueryException e) { + System.out.format("Dataset was not created. %n%s", e.toString()); + } + } + + static void deleteBigQueryDataset(String projectId, String datasetName, String location) { + try { + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + BigQuery bigquery = + BigQueryOptions.newBuilder() + .setLocation(location) + .setProjectId(projectId) + .build() + .getService(); + + DatasetId datasetId = DatasetId.of(projectId, datasetName); + boolean success = bigquery.delete(datasetId, DatasetDeleteOption.deleteContents()); + if (success) { + System.out.println("Dataset deleted successfully"); + } else { + System.out.println("Dataset was not found"); + } + } catch (BigQueryException e) { + System.out.format("Dataset was not deleted. %n%s", e.toString()); + } + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + + // Delete the featurestore + DeleteFeaturestoreSample.deleteFeaturestoreSample( + PROJECT_ID, featurestoreId, USE_FORCE, LOCATION, ENDPOINT, 300); + + // Assert + String deleteFeaturestoreResponse = bout.toString(); + assertThat(deleteFeaturestoreResponse).contains("Deleted Featurestore"); + + // Delete the big query dataset + deleteBigQueryDataset(PROJECT_ID, datasetName, LOCATION); + + // Assert + String deleteBigQueryResponse = bout.toString(); + assertThat(deleteBigQueryResponse).contains("Dataset deleted successfully"); + + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testFeatureValuesSamples() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Create the featurestore + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 23); + String id = String.format("temp_feature_values_samples_test_%s", tempUuid); + CreateFeaturestoreSample.createFeaturestoreSample( + PROJECT_ID, id, MIN_NODE_COUNT, MAX_NODE_COUNT, LOCATION, ENDPOINT, 900); + + // Assert + String createFeaturestoreResponse = bout.toString(); + assertThat(createFeaturestoreResponse).contains("Create Featurestore Response"); + featurestoreId = + createFeaturestoreResponse.split("Name: ")[1].split("featurestores/")[1].split("\n")[0] + .trim(); + + // Create the entity type + String entityTypeId = "movies"; + CreateEntityTypeSample.createEntityTypeSample( + PROJECT_ID, featurestoreId, entityTypeId, DESCRIPTION, LOCATION, ENDPOINT, 900); + + // Assert + String createEntityTypeResponse = bout.toString(); + assertThat(createEntityTypeResponse).contains("Create Entity Type Response"); + + // Create the feature + String featureTempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 25); + String featureId = String.format("temp_feature_feature_test_%s", featureTempUuid); + CreateFeatureSample.createFeatureSample( + PROJECT_ID, + featurestoreId, + entityTypeId, + featureId, + DESCRIPTION, + VALUE_TYPE, + LOCATION, + ENDPOINT, + 900); + + // Assert + String createFeatureResponse = bout.toString(); + assertThat(createFeatureResponse).contains("Create Feature Response"); + + // Get the feature + GetFeatureSample.getFeatureSample( + PROJECT_ID, featurestoreId, entityTypeId, featureId, LOCATION, ENDPOINT); + + // Assert + String getFeatureResponse = bout.toString(); + assertThat(getFeatureResponse).contains("Get Feature Response"); + + // List features + ListFeaturesSample.listFeaturesSample( + PROJECT_ID, featurestoreId, entityTypeId, LOCATION, ENDPOINT); + + // Assert + String listfeatureResponse = bout.toString(); + assertThat(listfeatureResponse).contains("List Features Response"); + + // List features + ListFeaturesAsyncSample.listFeaturesAsyncSample( + PROJECT_ID, featurestoreId, entityTypeId, LOCATION, ENDPOINT); + + // Assert + String listfeatureAsyncResponse = bout.toString(); + assertThat(listfeatureAsyncResponse).contains("List Features Async Response"); + + // Search features + SearchFeaturesSample.searchFeaturesSample(PROJECT_ID, QUERY, LOCATION, ENDPOINT); + + // Assert + String searchFeaturesResponse = bout.toString(); + assertThat(searchFeaturesResponse).contains("Search Features Response"); + + // Search features + SearchFeaturesAsyncSample.searchFeaturesAsyncSample(PROJECT_ID, QUERY, LOCATION, ENDPOINT); + + // Assert + String searchFeaturesAsyncResponse = bout.toString(); + assertThat(searchFeaturesAsyncResponse).contains("Search Features Async Response"); + + // Delete the feature + DeleteFeatureSample.deleteFeatureSample( + PROJECT_ID, featurestoreId, entityTypeId, featureId, LOCATION, ENDPOINT, 300); + + // Assert + String deleteFeatureResponse = bout.toString(); + assertThat(deleteFeatureResponse).contains("Deleted Feature"); + + // Batch create features + BatchCreateFeaturesSample.batchCreateFeaturesSample( + PROJECT_ID, featurestoreId, entityTypeId, LOCATION, ENDPOINT, TIMEOUT); + + // Assert + String batchCreateFeaturesResponse = bout.toString(); + assertThat(batchCreateFeaturesResponse).contains("Batch Create Features Response"); + + // Import feature values + ImportFeatureValuesSample.importFeatureValuesSample( + PROJECT_ID, + featurestoreId, + entityTypeId, + GCS_SOURCE_URI, + ENTITY_ID_FIELD, + FEATURE_TIME_FIELD, + WORKER_COUNT, + LOCATION, + ENDPOINT, + TIMEOUT); + + // Assert + String importFeatureValuesResponse = bout.toString(); + assertThat(importFeatureValuesResponse).contains("Import Feature Values Response"); + + // Create the big query dataset + createBigQueryDataset(PROJECT_ID, datasetName, LOCATION); + destinationTableUri = + String.format("bq://%s.%s.%s_full", PROJECT_ID, datasetName, destinationTableName); + + // Assert + String createBigQueryDatasetResponse = bout.toString(); + assertThat(createBigQueryDatasetResponse).contains(datasetName + " created successfully"); + + // Export feature values + ExportFeatureValuesSample.exportFeatureValuesSample( + PROJECT_ID, + featurestoreId, + entityTypeId, + destinationTableUri, + FEATURE_SELECTOR_IDS, + LOCATION, + ENDPOINT, + TIMEOUT); + + // Assert + String exportFeatureValuesResponse = bout.toString(); + assertThat(exportFeatureValuesResponse).contains("Export Feature Values Response"); + + destinationTableUri = + String.format("bq://%s.%s.%s_snapshot", PROJECT_ID, datasetName, destinationTableName); + + // Snapshot export feature values + ExportFeatureValuesSnapshotSample.exportFeatureValuesSnapshotSample( + PROJECT_ID, + featurestoreId, + entityTypeId, + destinationTableUri, + FEATURE_SELECTOR_IDS, + LOCATION, + ENDPOINT, + TIMEOUT); + + // Assert + String snapshotResponse = bout.toString(); + assertThat(snapshotResponse).contains("Snapshot Export Feature Values Response"); + + destinationTableUri = + String.format("bq://%s.%s.%s_batchRead", PROJECT_ID, datasetName, destinationTableName); + + // Batch read feature values + BatchReadFeatureValuesSample.batchReadFeatureValuesSample( + PROJECT_ID, + featurestoreId, + entityTypeId, + INPUT_CSV_FILE, + destinationTableUri, + FEATURE_SELECTOR_IDS, + LOCATION, + ENDPOINT, + TIMEOUT); + + // Assert + String batchReadFeatureValuesResponse = bout.toString(); + assertThat(batchReadFeatureValuesResponse).contains("Batch Read Feature Values Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/FeaturestoreSamplesTest.java b/aiplatform/src/test/java/aiplatform/FeaturestoreSamplesTest.java new file mode 100644 index 00000000000..74fcbde5b01 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/FeaturestoreSamplesTest.java @@ -0,0 +1,220 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class FeaturestoreSamplesTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final int MIN_NODE_COUNT = 1; + private static final int MAX_NODE_COUNT = 2; + private static final int FIXED_NODE_COUNT = 2; + private static final String DESCRIPTION = "Test Description"; + private static final int MONITORING_INTERVAL_DAYS = 1; + private static final boolean USE_FORCE = true; + private static final String LOCATION = "us-central1"; + private static final String ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; + private static final int TIMEOUT = 900; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String featurestoreId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + + // Delete the featurestore + DeleteFeaturestoreSample.deleteFeaturestoreSample( + PROJECT_ID, featurestoreId, USE_FORCE, LOCATION, ENDPOINT, 60); + + // Assert + String deleteFeaturestoreResponse = bout.toString(); + assertThat(deleteFeaturestoreResponse).contains("Deleted Featurestore"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateFeaturestoreSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Create the featurestore + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 25); + String id = String.format("temp_featurestore_samples_test_%s", tempUuid); + CreateFeaturestoreFixedNodesSample.createFeaturestoreFixedNodesSample( + PROJECT_ID, id, FIXED_NODE_COUNT, LOCATION, ENDPOINT, 900); + + // Assert + String createFeaturestoreResponse = bout.toString(); + assertThat(createFeaturestoreResponse).contains("Create Featurestore Response"); + featurestoreId = + createFeaturestoreResponse.split("Name: ")[1].split("featurestores/")[1].split("\n")[0] + .trim(); + + // Get the featurestore + GetFeaturestoreSample.getFeaturestoreSample(PROJECT_ID, featurestoreId, LOCATION, ENDPOINT); + + // Assert + String getFeaturestoreResponse = bout.toString(); + assertThat(getFeaturestoreResponse).contains("Get Featurestore Response"); + + // Update the featurestore with autoscaling + UpdateFeaturestoreSample.updateFeaturestoreSample( + PROJECT_ID, featurestoreId, MIN_NODE_COUNT, MAX_NODE_COUNT, LOCATION, ENDPOINT, TIMEOUT); + + // Assert + String updateFeaturestoreResponse = bout.toString(); + assertThat(updateFeaturestoreResponse).contains("Update Featurestore Response"); + + // List featurestores + ListFeaturestoresSample.listFeaturestoresSample(PROJECT_ID, LOCATION, ENDPOINT); + + // Assert + String listFeaturestoresResponse = bout.toString(); + assertThat(listFeaturestoresResponse).contains("List Featurestores Response"); + + // Update the featurestore with fixed nodes + UpdateFeaturestoreFixedNodesSample.updateFeaturestoreFixedNodesSample( + PROJECT_ID, featurestoreId, FIXED_NODE_COUNT, LOCATION, ENDPOINT, TIMEOUT); + + // Assert + String updateFeaturestoreFixedNodesResponse = bout.toString(); + assertThat(updateFeaturestoreFixedNodesResponse) + .contains("Update Featurestore Fixed Nodes Response"); + + // List featurestores + ListFeaturestoresAsyncSample.listFeaturestoresAsyncSample(PROJECT_ID, LOCATION, ENDPOINT); + + // Assert + String listFeaturestoresAsyncResponse = bout.toString(); + assertThat(listFeaturestoresAsyncResponse).contains("List Featurestores Async Response"); + + // Create the entity type + String entityTypeTempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 14); + String entityTypeId = String.format("temp_featurestore_samples_test_%s", entityTypeTempUuid); + CreateEntityTypeSample.createEntityTypeSample( + PROJECT_ID, featurestoreId, entityTypeId, DESCRIPTION, LOCATION, ENDPOINT, TIMEOUT); + + // Assert + String createEntityTypeResponse = bout.toString(); + assertThat(createEntityTypeResponse).contains("Create Entity Type Response"); + + // Get the entity type + GetEntityTypeSample.getEntityTypeSample( + PROJECT_ID, featurestoreId, entityTypeId, LOCATION, ENDPOINT); + + // Assert + String getEntityTypeResponse = bout.toString(); + assertThat(getEntityTypeResponse).contains("Get Entity Type Response"); + + // Create the entity type + String entityTypeMonitoringTempUuid = + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 14); + String entityTypeMonitoringId = + String.format("temp_featurestore_samples_test_%s", entityTypeMonitoringTempUuid); + CreateEntityTypeMonitoringSample.createEntityTypeMonitoringSample( + PROJECT_ID, + featurestoreId, + entityTypeMonitoringId, + DESCRIPTION, + MONITORING_INTERVAL_DAYS, + LOCATION, + ENDPOINT, + TIMEOUT); + + // Assert + String createEntityTypeMonitoringResponse = bout.toString(); + assertThat(createEntityTypeMonitoringResponse) + .contains("Create Entity Type Monitoring Response"); + + // List entity types + ListEntityTypesSample.listEntityTypesSample(PROJECT_ID, featurestoreId, LOCATION, ENDPOINT); + + // Assert + String listEntityTypeResponse = bout.toString(); + assertThat(listEntityTypeResponse).contains("List Entity Types Response"); + + // Update the entity type + UpdateEntityTypeSample.updateEntityTypeSample( + PROJECT_ID, featurestoreId, entityTypeId, DESCRIPTION, LOCATION, ENDPOINT); + + // Assert + String updateEntityTypeResponse = bout.toString(); + assertThat(updateEntityTypeResponse).contains("Update Entity Type Response"); + + // Update the entity type + UpdateEntityTypeMonitoringSample.updateEntityTypeMonitoringSample( + PROJECT_ID, featurestoreId, entityTypeId, MONITORING_INTERVAL_DAYS, LOCATION, ENDPOINT); + + // Assert + String updateEntityTypeMonitoringResponse = bout.toString(); + assertThat(updateEntityTypeMonitoringResponse) + .contains("Update Entity Type Monitoring Response"); + + // List entity types + ListEntityTypesAsyncSample.listEntityTypesAsyncSample( + PROJECT_ID, featurestoreId, LOCATION, ENDPOINT); + + // Assert + String listEntityTypeAsyncResponse = bout.toString(); + assertThat(listEntityTypeAsyncResponse).contains("List Entity Types Async Response"); + + // Delete the entity type + DeleteEntityTypeSample.deleteEntityTypeSample( + PROJECT_ID, featurestoreId, entityTypeId, LOCATION, ENDPOINT, TIMEOUT); + + // Assert + String deleteEntityTypeResponse = bout.toString(); + assertThat(deleteEntityTypeResponse).contains("Deleted Entity Type"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java b/aiplatform/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java new file mode 100644 index 00000000000..685768000d1 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetHyperparameterTuningJobSampleTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetHyperparameterTuningJobSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String HYPERPARAMETER_TUNING_JOB_ID = System.getenv("GET_HP_TUNING_JOB_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("GET_HP_TUNING_JOB_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetHyperparameterTuningJobSample() throws IOException { + GetHyperparameterTuningJobSample.getHyperparameterTuningJobSample( + PROJECT, HYPERPARAMETER_TUNING_JOB_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(HYPERPARAMETER_TUNING_JOB_ID); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java new file mode 100644 index 00000000000..27174f55f32 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationImageClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3512561418744365056"; + private static final String EVALUATION_ID = "9035588644970168320"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationImageClassificationSample() throws IOException { + // Act + GetModelEvaluationImageClassificationSample.getModelEvaluationImageClassificationSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Image Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..946482f6fc3 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3512561418744365056"; + private static final String EVALUATION_ID = "9035588644970168320"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationImageObjectDetectionSample() throws IOException { + // Act + GetModelEvaluationImageObjectDetectionSample.getModelEvaluationImageObjectDetectionSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Image Object Detection Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationSampleTest.java new file mode 100644 index 00000000000..d2f4e03d0e9 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationSampleTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3512561418744365056"; + private static final String EVALUATION_ID = "9035588644970168320"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationSample() throws IOException { + // Act + GetModelEvaluationSample.getModelEvaluationSample(PROJECT_ID, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationSliceSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationSliceSampleTest.java new file mode 100644 index 00000000000..b02103e89ed --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationSliceSampleTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationSliceSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3512561418744365056"; + private static final String EVALUATION_ID = "9035588644970168320"; + private static final String SLICE_ID = "6481571820677004173"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationSliceSample() throws IOException { + // Act + GetModelEvaluationSliceSample.getModelEvaluationSliceSample( + PROJECT_ID, MODEL_ID, EVALUATION_ID, SLICE_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(EVALUATION_ID); + assertThat(got).contains("Get Model Evaluation Slice Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java new file mode 100644 index 00000000000..23b0f28780e --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "6036688272397172736"; + private static final String EVALUATION_ID = "1866113044163962838"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void getModelEvaluationTabularClassification() throws IOException { + // Act + GetModelEvaluationTabularClassificationSample.getModelEvaluationTabularClassification( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Tabular Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java new file mode 100644 index 00000000000..bb5ec79b12a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationTabularRegressionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "8842430840248991744"; + private static final String EVALUATION_ID = "4944816689650806017"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void getModelEvaluationTabularRegression() throws IOException { + // Act + GetModelEvaluationTabularRegressionSample.getModelEvaluationTabularRegression( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Tabular Regression Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java new file mode 100644 index 00000000000..4e13470cd5a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationTextClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "7827432074230366208"; + private static final String EVALUATION_ID = "5064258198559522816"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationTextClassificationSample() throws IOException { + // Act + GetModelEvaluationTextClassificationSample.getModelEvaluationTextClassificationSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Text Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..5881a34296b --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationTextEntityExtractionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "6305215400179138560"; + private static final String EVALUATION_ID = "1754112472442208256"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationTextEntityExtractionSample() throws IOException { + // Act + GetModelEvaluationTextEntityExtractionSample.getModelEvaluationTextEntityExtractionSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Text Entity Extraction Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..cca27c67a86 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationTextSentimentAnalysisSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "4792568875336073216"; + private static final String EVALUATION_ID = "3347225656252432384"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationTextSentimentAnalysisSample() throws IOException { + // Act + GetModelEvaluationTextSentimentAnalysisSample.getModelEvaluationTextSentimentAnalysisSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Text Sentiment Analysis Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..549f7172c9d --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoActionRecognitionSampleTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("VIDEO_ACTION_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("VIDEO_ACTION_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("VIDEO_ACTION_MODEL_ID"); + requireEnvVar("VIDEO_ACTION_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationVideoActionRecognitionSample() throws IOException { + // Act + GetModelEvaluationVideoActionRecognitionSample.getModelEvaluationVideoActionRecognitionSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java new file mode 100644 index 00000000000..26a4628fb8e --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "8596984660557299712"; + private static final String EVALUATION_ID = "7092045712224944128"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationVideoClassificationSample() throws IOException { + // Act + GetModelEvaluationVideoClassificationSample.getModelEvaluationVideoClassification( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Video Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..7657b725537 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "8609932509485989888"; + private static final String EVALUATION_ID = "6016811301190238208"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationVideoObjectTrackingSample() throws IOException { + // Act + GetModelEvaluationVideoObjectTrackingSample.getModelEvaluationVideoObjectTracking( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Video Object Tracking Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetModelSampleTest.java b/aiplatform/src/test/java/aiplatform/GetModelSampleTest.java new file mode 100644 index 00000000000..59507901243 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetModelSampleTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3512561418744365056"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelSample() throws IOException { + // Act + GetModelSample.getModelSample(PROJECT_ID, MODEL_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/GetTrainingPipelineSampleTest.java b/aiplatform/src/test/java/aiplatform/GetTrainingPipelineSampleTest.java new file mode 100644 index 00000000000..41d5c09169c --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/GetTrainingPipelineSampleTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetTrainingPipelineSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TRAINING_PIPELINE_ID = System.getenv("GET_TRAINING_PIPELINE_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("GET_TRAINING_PIPELINE_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetTrainingPipelineSample() throws IOException { + // Act + GetTrainingPipelineSample.getTrainingPipeline(PROJECT, TRAINING_PIPELINE_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(TRAINING_PIPELINE_ID); + assertThat(got).contains("Get Training Pipeline Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java new file mode 100644 index 00000000000..a490fa4d8ee --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java @@ -0,0 +1,133 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ImportDataImageClassificationSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + + private static final String GCS_SOURCE_URI = "gs://ucaip-sample-resources/input.jsonl"; + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testImportDataSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataImageClassificationSample.importDataImageClassificationSample( + PROJECT, datasetId, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Image Classification Response: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java b/aiplatform/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..5c599cbc720 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java @@ -0,0 +1,133 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ImportDataImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = "gs://ucaip-sample-resources/input.jsonl"; + + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testImportDataSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataImageObjectDetectionSample.importDataImageObjectDetectionSample( + PROJECT, datasetId, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Image Object Detection Response: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ImportDataSampleTest.java b/aiplatform/src/test/java/aiplatform/ImportDataSampleTest.java new file mode 100644 index 00000000000..dcfaceb9f55 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ImportDataSampleTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import io.grpc.StatusRuntimeException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ImportDataSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = "000000000000000000000"; + + private static final String GCS_SOURCE_URI = + "gs://automl-cloud-dataset/SMSSpamCollection_train_dataset_2.csv"; + + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataSample() throws TimeoutException { + // As import data into dataset can take a long time, instead try to import data into a + // nonexistent dataset and confirm that the model was not found, but other + // elements of the request were valid. + try { + ImportDataTextClassificationSingleLabelSample.importDataTextClassificationSingleLabelSample( + PROJECT, DATASET_ID, GCS_SOURCE_URI); + // Assert + String got = bout.toString(); + assertThat(got).contains("The Dataset does not exist."); + } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) { + assertThat(e.getMessage()).contains("The Dataset does not exist."); + } + } +} diff --git a/aiplatform/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java b/aiplatform/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java new file mode 100644 index 00000000000..34102d77946 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ImportDataVideoActionRecognitionSampleTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class ImportDataVideoActionRecognitionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = + "gs://automl-video-demo-data/ucaip-var/swimrun.jsonl"; + + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testImportDataVideoActionRecognitionSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataVideoActionRecognitionSample.importDataVideoActionRecognitionSample( + PROJECT, datasetId, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("importDataResponse:"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java new file mode 100644 index 00000000000..53d639f1d64 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class ImportDataVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = + "gs://automl-video-demo-data/traffic_videos/traffic_videos_train.csv"; + + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + @Ignore("https://github.com/googleapis/java-aiplatform/issues/420") + public void testImportDataVideoClassificationSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataVideoClassificationSample.importDataVideoClassification( + GCS_SOURCE_URI, PROJECT, datasetId); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Video Classification Response: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java b/aiplatform/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..4052c91941c --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ImportDataVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = + "gs://automl-video-demo-data/traffic_videos/traffic_videos_train.csv"; + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + if (datasetId != null) { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataVideoObjectTrackingSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataVideoObjectTrackingSample.importDataVideObjectTracking( + GCS_SOURCE_URI, PROJECT, datasetId); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Video Object Tracking Response: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/ListModelEvaluationSliceSampleTest.java b/aiplatform/src/test/java/aiplatform/ListModelEvaluationSliceSampleTest.java new file mode 100644 index 00000000000..3ea5f26bc78 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/ListModelEvaluationSliceSampleTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ListModelEvaluationSliceSampleTest { + + private static final String PROJECT_ID = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = "3512561418744365056"; + private static final String EVALUATION_ID = "9035588644970168320"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testListModelEvaluationSliceSample() throws IOException { + // Act + ListModelEvaluationSliceSample.listModelEvaluationSliceSample( + PROJECT_ID, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(EVALUATION_ID); + assertThat(got).contains("Model Evaluation Slice Name: "); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictCustomTrainedModelSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictCustomTrainedModelSampleTest.java new file mode 100644 index 00000000000..eeeae9766d9 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictCustomTrainedModelSampleTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Base64; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictCustomTrainedModelSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String ENDPOINT_ID = + System.getenv("PREDICT_CUSTOM_TRAINED_MODEL_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("PREDICT_CUSTOM_TRAINED_MODEL_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictCustomTrainedModelSample() throws IOException { + // Act + ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get("resources/daisy.jpg"))); + String encoded = Base64.getEncoder().encodeToString(content.toByteArray()); + String instance = "[{'image_bytes': {'b64': '" + encoded + "'}, 'key':'0'}]"; + PredictCustomTrainedModelSample.predictCustomTrainedModel(PROJECT, ENDPOINT_ID, instance); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Custom Trained model Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictImageClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictImageClassificationSampleTest.java new file mode 100644 index 00000000000..8ca3fd95c5e --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictImageClassificationSampleTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictImageClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String FILE_NAME = "resources/image_flower_daisy.jpg"; + private static final String ENDPOINT_ID = System.getenv("IMAGE_CLASS_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("IMAGE_CLASS_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictImageClassification() throws IOException { + // Act + PredictImageClassificationSample.predictImageClassification(PROJECT, FILE_NAME, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Image Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..a7d3a16ee8a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class PredictImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String FILE_NAME = "resources/iod_caprese_salad.jpg"; + private static final String ENDPOINT_ID = System.getenv("IMAGE_OBJECT_DETECTION_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("IMAGE_OBJECT_DETECTION_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Ignore("See https://github.com/googleapis/java-aiplatform/issues/178") + @Test + public void testPredictImageObjectDetection() throws IOException { + // Act + PredictImageObjectDetectionSample.predictImageObjectDetection(PROJECT, FILE_NAME, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Image Object Detection Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java new file mode 100644 index 00000000000..25cbf12d5e7 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String INSTANCE = + "[{\"petal_length\": '1.4'," + + " \"petal_width\": '1.3'," + + " \"sepal_length\": '5.1'," + + " \"sepal_width\": '2.8'}]"; + + private static final String ENDPOINT_ID = + System.getenv("PREDICT_TABLES_CLASSIFCATION_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("PREDICT_TABLES_CLASSIFCATION_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTabularClassification() throws IOException { + // Act + PredictTabularClassificationSample.predictTabularClassification(INSTANCE, PROJECT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Tabular Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java new file mode 100644 index 00000000000..44f5bfdfa21 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTabularRegressionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String INSTANCE = + "[{\n" + + " \"BOOLEAN_2unique_NULLABLE\": False,\n" + + " \"DATETIME_1unique_NULLABLE\": '2019-01-01 00:00:00',\n" + + " \"DATE_1unique_NULLABLE\": '2019-01-01',\n" + + " \"FLOAT_5000unique_NULLABLE\": 1611,\n" + + " \"FLOAT_5000unique_REPEATED\": [2320,1192],\n" + + " \"INTEGER_5000unique_NULLABLE\": '8',\n" + + " \"NUMERIC_5000unique_NULLABLE\": 16,\n" + + " \"STRING_5000unique_NULLABLE\": 'str-2',\n" + + " \"STRUCT_NULLABLE\": {\n" + + " 'BOOLEAN_2unique_NULLABLE': False,\n" + + " 'DATE_1unique_NULLABLE': '2019-01-01',\n" + + " 'DATETIME_1unique_NULLABLE': '2019-01-01 00:00:00',\n" + + " 'FLOAT_5000unique_NULLABLE': 1308,\n" + + " 'FLOAT_5000unique_REPEATED': [2323, 1178],\n" + + " 'FLOAT_5000unique_REQUIRED': 3089,\n" + + " 'INTEGER_5000unique_NULLABLE': '1777',\n" + + " 'NUMERIC_5000unique_NULLABLE': 3323,\n" + + " 'TIME_1unique_NULLABLE': '23:59:59.999999',\n" + + " 'STRING_5000unique_NULLABLE': 'str-49',\n" + + " 'TIMESTAMP_1unique_NULLABLE': '1546387199999999'\n" + + " },\n" + + " \"TIMESTAMP_1unique_NULLABLE\": '1546387199999999',\n" + + " \"TIME_1unique_NULLABLE\": '23:59:59.999999'\n" + + "}]"; + private static final String ENDPOINT_ID = System.getenv("PREDICT_TABLES_REGRESSION_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("PREDICT_TABLES_REGRESSION_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTabularRegression() throws IOException { + // Act + PredictTabularRegressionSample.predictTabularRegression(INSTANCE, PROJECT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Tabular Regression Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java new file mode 100644 index 00000000000..a47674098a9 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTextClassificationSingleLabelSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TEXT_CONTENT = "This is the test String!"; + private static final String ENDPOINT_ID = System.getenv("TEXT_CLASS_SINGLE_LABEL_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_CLASS_SINGLE_LABEL_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTextClassification() throws IOException { + // Act + PredictTextClassificationSingleLabelSample.predictTextClassificationSingleLabel( + PROJECT, TEXT_CONTENT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Text Classification Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..71fd8e8ba26 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTextEntityExtractionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TEXT_CONTENT = + "1127526\\tAnalbuminemia in a neonate.\\tA small-for-gestational-age infant , found to have" + + " analbuminemia in the neonatal period , is reported and the twelve cases recorded in" + + " the world literature are reviewed . Patients lacking this serum protein are" + + " essentially asymptomatic , apart from minimal ankle edema and ease of fatigue ." + + " Apparent compensatory mechanisms which come into play when serum albumin is low" + + " include prolonged half-life of albumin and transferrin , an increase in serum" + + " globulins , beta lipoprotein , and glycoproteins , arterial hypotension with reduced" + + " capillary hydrostatic pressure , and the ability to respond with rapid sodium and" + + " chloride diuresis in response to small volume changes . Examination of plasma amino" + + " acids , an investigation not previously reported , revealed an extremely low plasma" + + " tryptophan level , a finding which may be important in view of the role of" + + " tryptophan in albumin synthesis."; + private static final String ENDPOINT_ID = System.getenv("TEXT_ENTITY_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_ENTITY_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTextEntityExtraction() throws IOException { + // Act + PredictTextEntityExtractionSample.predictTextEntityExtraction( + PROJECT, TEXT_CONTENT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Text Entity Extraction Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java b/aiplatform/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..d452dc94574 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTextSentimentAnalysisSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TEXT_CONTENT = + "I was excited at the concept of my favorite comic book hero being on television... and" + + " sorely disappointed at the end result.

The only amazing thing was the" + + " wall crawling (despite the visibility of the cable). I didn't think Nick Hammond was" + + " Peter Parker... and he was visibly of a different build than the guy who did the" + + " stunts in the spider suit. You could tell they were two different actors.
Granted, I can also spot in the modern Spider-Man movies when I am looking at" + + " Tobey Macguire and when I am looking at CGI. But that is from a trained eye and" + + " experience working with CGI. Still, the 70's version could have been better despite" + + " lack of Special FX.

The webs were hokey and looked like ropes that seemed" + + " to wrap around things rather than stick to them. And what was up with giving him a" + + " spider mobile to ride around in. Hello? He's the web slinger people.
Sorry... didn't mean to get so worked up, but our beloved wall crawler deserved" + + " better."; + private static final String ENDPOINT_ID = System.getenv("TEXT_SENTI_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_SENTI_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTextSentimentAnalysis() throws IOException { + // Act + PredictTextSentimentAnalysisSample.predictTextSentimentAnalysis( + PROJECT, TEXT_CONTENT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Text Sentiment Analysis Response"); + } +} diff --git a/aiplatform/src/test/java/aiplatform/UploadModelSampleTest.java b/aiplatform/src/test/java/aiplatform/UploadModelSampleTest.java new file mode 100644 index 00000000000..c085f8c6776 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/UploadModelSampleTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class UploadModelSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String METADATASCHEMA_URI = ""; + private static final String IMAGE_URI = + "gcr.io/cloud-ml-service-public/" + + "cloud-ml-online-prediction-model-server-cpu:" + + "v1_15py3cmle_op_images_20200229_0210_RC00"; + private static final String ARTIFACT_URI = "gs://ucaip-samples-us-central1/model/explain/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String uploadedModelId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // Cancel the Training Pipeline + DeleteModelSample.deleteModel(PROJECT, uploadedModelId); + + // Assert + String deleteModelResponse = bout.toString(); + assertThat(deleteModelResponse).contains("Deleted Model."); + TimeUnit.MINUTES.sleep(1); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void uploadModelSampleTest() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // Act + String modelDisplayName = + String.format( + "temp_upload_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + UploadModelSample.uploadModel( + PROJECT, modelDisplayName, METADATASCHEMA_URI, IMAGE_URI, ARTIFACT_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Upload Model Response"); + uploadedModelId = got.split("Model:")[1].split("models/")[1].split("\n")[0]; + } +}