From 77353512c1f15e0dc996428a982941a7ee3036fb Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Fri, 3 Jun 2022 13:11:25 -0400 Subject: [PATCH] Change VectorReaderListener to expect number array (#416) Refactors VectorReaderListener onResponse to expect arrays of Number type from search result instead of Double type. Adds test case to confirm that it can handle Integer type. Cleans up tests in VectorReaderTest class. Signed-off-by: John Mazanec --- .../opensearch/knn/training/VectorReader.java | 2 +- .../knn/training/VectorReaderTests.java | 273 +++++++++++------- 2 files changed, 176 insertions(+), 99 deletions(-) diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index 3f4d53778..9b7db6d99 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -184,7 +184,7 @@ public void onResponse(SearchResponse searchResponse) { for (int i = 0; i < vectorsToAdd; i++) { trainingData.add( - ((List) hits[i].getSourceAsMap().get(fieldName)).stream().map(Double::floatValue).toArray(Float[]::new) + ((List) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new) ); } diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index f8a2feeb9..0b1e15289 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -11,9 +11,8 @@ package org.opensearch.knn.training; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.ValidationException; import org.opensearch.knn.KNNSingleNodeTestCase; @@ -32,32 +31,65 @@ public class VectorReaderTests extends KNNSingleNodeTestCase { - public static Logger logger = LogManager.getLogger(VectorReaderTests.class); + private final static int DEFAULT_LATCH_TIMEOUT = 100; + private final static String DEFAULT_INDEX_NAME = "test-index"; + private final static String DEFAULT_FIELD_NAME = "test-field"; + private final static int DEFAULT_DIMENSION = 16; + private final static int DEFAULT_NUM_VECTORS = 100; + private final static int DEFAULT_MAX_VECTOR_COUNT = 10000; + private final static int DEFAULT_SEARCH_SIZE = 10; public void testRead_valid_completeIndex() throws InterruptedException, ExecutionException, IOException { - // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectors = 100; - createIndex(indexName); - - // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createIndex(DEFAULT_INDEX_NAME); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); List vectors = new ArrayList<>(); - for (int i = 0; i < numVectors; i++) { - Float[] vector = new Float[dim]; + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + vectors.add(vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); + } - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } + // Configure VectorReader + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + VectorReader vectorReader = new VectorReader(client()); - vectors.add(vector); + // Read all vectors and confirm they match vectors + TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); + + List consumedVectors = testVectorConsumer.getVectorsConsumed(); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); + + List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); + } + + public void testRead_valid_trainVectorsIngestedAsIntegers() throws IOException, ExecutionException, InterruptedException { + createIndex(DEFAULT_INDEX_NAME); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); - addKnnDoc(indexName, Integer.toString(i), fieldName, vector); + // Create list of random vectors and ingest + Random random = new Random(); + List vectors = new ArrayList<>(); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Integer[] vector = random.ints(DEFAULT_DIMENSION).boxed().toArray(Integer[]::new); + vectors.add(vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Configure VectorReader @@ -66,23 +98,23 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); + final CountDownLatch inProgressLatch = new CountDownLatch(1); vectorReader.read( clusterService, - indexName, - fieldName, - 10000, - 10, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString())) + createOnSearchResponseCountDownListener(inProgressLatch) ); - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(numVectors, consumedVectors.size()); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); - List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatVectors = vectors.stream().flatMap(Arrays::stream).map(Integer::floatValue).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); } @@ -90,35 +122,25 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio public void testRead_valid_incompleteIndex() throws InterruptedException, ExecutionException, IOException { // Check if we get the right number of vectors if the index contains docs that are missing fields // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectors = 100; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); List vectors = new ArrayList<>(); - for (int i = 0; i < numVectors; i++) { - Float[] vector = new Float[dim]; - - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } - + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); vectors.add(vector); - - addKnnDoc(indexName, Integer.toString(i), fieldName, vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Create documents that do not have fieldName for training int docsWithoutKNN = 100; String fieldNameWithoutKnn = "test-field-2"; for (int i = 0; i < docsWithoutKNN; i++) { - addDoc(indexName, Integer.toString(i + numVectors), fieldNameWithoutKnn, "dummyValue"); + addDoc(DEFAULT_INDEX_NAME, Integer.toString(i + DEFAULT_NUM_VECTORS), fieldNameWithoutKnn, "dummyValue"); } // Configure VectorReader @@ -127,21 +149,21 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); + final CountDownLatch inProgressLatch = new CountDownLatch(1); vectorReader.read( clusterService, - indexName, - fieldName, - 10000, - 10, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString())) + createOnSearchResponseCountDownListener(inProgressLatch) ); - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(numVectors, consumedVectors.size()); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); @@ -151,26 +173,17 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, ExecutionException, IOException { // Check if we can limit the number of docs via max operation // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectorsIndex = 100; int maxNumVectorsRead = 20; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); - for (int i = 0; i < numVectorsIndex; i++) { - Float[] vector = new Float[dim]; - - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } - - addKnnDoc(indexName, Integer.toString(i), fieldName, vector); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Configure VectorReader @@ -179,18 +192,18 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec // Read maxNumVectorsRead vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); + final CountDownLatch inProgressLatch = new CountDownLatch(1); vectorReader.read( clusterService, - indexName, - fieldName, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, maxNumVectorsRead, - 10, + DEFAULT_SEARCH_SIZE, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString())) + createOnSearchResponseCountDownListener(inProgressLatch) ); - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); assertEquals(maxNumVectorsRead, consumedVectors.size()); @@ -198,82 +211,138 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec public void testRead_invalid_maxVectorCount() { // Create the index - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, -10, 10, null, null)); + int invalidMaxVectorCount = -10; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + invalidMaxVectorCount, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_searchSize() { // Create the index - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Search size is negative - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, -10, null, null)); + int invalidSearchSize1 = -10; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + invalidSearchSize1, + null, + null + ) + ); // Search size is greater than 10000 - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, 20000, null, null)); + int invalidSearchSize2 = 20000; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + invalidSearchSize2, + null, + null + ) + ); } public void testRead_invalid_indexDoesNotExist() { // Check that read throws a validation exception when the index does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because index does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_fieldDoesNotExist() { // Check that read throws a validation exception when the field does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field is not k-NN - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, ExecutionException, IOException { // Check that read throws a validation exception when the field does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - createIndex(indexName); - addDoc(indexName, "test-id", fieldName, "dummy"); + createIndex(DEFAULT_INDEX_NAME); + addDoc(DEFAULT_INDEX_NAME, "test-id", DEFAULT_FIELD_NAME, "dummy"); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } private static class TestVectorConsumer implements Consumer> { @@ -293,4 +362,12 @@ public List getVectorsConsumed() { return vectorsConsumed; } } + + private void assertLatchDecremented(CountDownLatch countDownLatch) throws InterruptedException { + assertTrue(countDownLatch.await(DEFAULT_LATCH_TIMEOUT, TimeUnit.SECONDS)); + } + + private ActionListener createOnSearchResponseCountDownListener(CountDownLatch countDownLatch) { + return ActionListener.wrap(response -> countDownLatch.countDown(), Throwable::printStackTrace); + } }