diff --git a/automl/snippets/src/test/java/com/example/automl/ImportDatasetTest.java b/automl/snippets/src/test/java/com/example/automl/ImportDatasetTest.java index 9933a5d7d1a..e516523117b 100644 --- a/automl/snippets/src/test/java/com/example/automl/ImportDatasetTest.java +++ b/automl/snippets/src/test/java/com/example/automl/ImportDatasetTest.java @@ -19,11 +19,20 @@ import static com.google.common.truth.Truth.assertThat; import static junit.framework.TestCase.assertNotNull; +import com.google.api.core.ApiFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.CreateDatasetRequest; +import com.google.cloud.automl.v1beta1.Dataset; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.TextExtractionDatasetMetadata; +import com.google.longrunning.Operation; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; import java.util.UUID; +import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.junit.After; import org.junit.Before; @@ -56,20 +65,34 @@ public static void checkRequirements() { } @Before - public void setUp() throws InterruptedException, ExecutionException, IOException { - bout = new ByteArrayOutputStream(); - out = new PrintStream(bout); - System.setOut(out); - - // Create a dataset that can be used for the import test + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Create a fake dataset to be deleted // Create a random dataset name with a length of 32 characters (max allowed by AutoML) // To prevent name collisions when running tests in multiple java versions at once. // AutoML doesn't allow "-", but accepts "_" String datasetName = String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); - LanguageEntityExtractionCreateDataset.createDataset(PROJECT_ID, datasetName); - String got = bout.toString(); - datasetId = got.split("Dataset id: ")[1].split("\n")[0]; + try (AutoMlClient client = AutoMlClient.create()) { + + LocationName projectLocation = LocationName.of(PROJECT_ID, "us-central1"); + TextExtractionDatasetMetadata metadata = TextExtractionDatasetMetadata.newBuilder().build(); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetName) + .setTextExtractionDatasetMetadata(metadata) + .build(); + + CreateDatasetRequest request = + CreateDatasetRequest.newBuilder() + .setParent(projectLocation.toString()) + .setDataset(dataset) + .build(); + ApiFuture future = client.createDatasetCallable().futureCall(request); + Dataset createdDataset = future.get(5, TimeUnit.MINUTES); + String[] names = createdDataset.getName().split("/"); + datasetId = names[names.length - 1]; + } bout = new ByteArrayOutputStream(); out = new PrintStream(bout); @@ -85,9 +108,23 @@ public void tearDown() throws InterruptedException, ExecutionException, IOExcept @Test public void testImportDataset() - throws IOException, ExecutionException, InterruptedException, TimeoutException { - ImportDataset.importDataset(PROJECT_ID, datasetId, BUCKET + "/entity-extraction/dataset.csv"); + throws InterruptedException, ExecutionException, TimeoutException, IOException { + + try { + ImportDataset.importDataset(PROJECT_ID, datasetId, BUCKET + "/entity-extraction/dataset.csv"); + } catch (CancellationException ex) { + // capture operation ID from output and wait for that operation to be finished. + String fullOperationId = ex.getMessage().split("Operation name: ")[1].trim(); + AutoMlClient client = AutoMlClient.create(); + Operation importDatasetLro = client.getOperationsClient().getOperation(fullOperationId); + while (!importDatasetLro.getDone()) { + Thread.sleep(3000); + } + // retry the import. + ImportDataset.importDataset(PROJECT_ID, datasetId, BUCKET + "/entity-extraction/dataset.csv"); + } String got = bout.toString(); + assertThat(got).contains("Dataset imported."); } }