From 0df06131fa51aa5f5c714472d9bc7153e4d1818f Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Fri, 13 Sep 2024 14:19:31 -0700 Subject: [PATCH] Adds Unit tests for NativeEngines990KnnVectorsWriter (#2097) Had to separate out the common code to make it easy to write tests Mocking was difficult to do with the functional interfaces and it was throwing NPE in the test especially with the mock of NativeIndexWriter. Signed-off-by: Tejas Shah --- .../NativeEngines990KnnVectorsWriter.java | 152 ++++----- ...eEngines990KnnVectorsWriterFlushTests.java | 288 ++++++++++++++++++ ...eEngines990KnnVectorsWriterMergeTests.java | 240 +++++++++++++++ .../index/vectorvalues/TestVectorValues.java | 6 +- 4 files changed, 585 insertions(+), 101 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 0d016a60b..3f32003ac 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -25,11 +25,10 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; -import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -39,6 +38,7 @@ import java.util.List; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. @@ -47,15 +47,11 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); - private static final String FLUSH_OPERATION = "flush"; - private static final String MERGE_OPERATION = "merge"; - private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; - private final QuantizationService quantizationService = QuantizationService.getInstance(); public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { this.segmentWriteState = segmentWriteState; @@ -84,14 +80,27 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); for (final NativeEngineFieldVectorsWriter field : fields) { - trainAndIndex( - field.getFieldInfo(), - (vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter), - NativeIndexWriter::flushIndex, - field, - KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS, - FLUSH_OPERATION - ); + final FieldInfo fieldInfo = field.getFieldInfo(); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + int totalLiveDocs = getLiveDocs(getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors())); + if (totalLiveDocs > 0) { + KNNVectorValues knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValues, totalLiveDocs); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + + knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + + StopWatch stopWatch = new StopWatch().start(); + + writer.flushIndex(knnVectorValues, totalLiveDocs); + + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } else { + log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); + } } } @@ -100,15 +109,26 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs - trainAndIndex( - fieldInfo, - this::getKNNVectorValuesForMerge, - NativeIndexWriter::mergeIndex, - mergeState, - KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS, - MERGE_OPERATION - ); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + int totalLiveDocs = getLiveDocs(getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState)); + if (totalLiveDocs == 0) { + log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); + return; + } + + KNNVectorValues knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); + final QuantizationState quantizationState = train(fieldInfo, knnVectorValues, totalLiveDocs); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + + knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); + + StopWatch stopWatch = new StopWatch().start(); + + writer.mergeIndex(knnVectorValues, totalLiveDocs); + + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); } /** @@ -157,18 +177,6 @@ public long ramBytesUsed() { .sum(); } - /** - * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer. - * - * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. - * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors. - * @param The type of vectors being processed. - * @return The {@link KNNVectorValues} associated with the field. - */ - private KNNVectorValues getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter field) { - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); - } - /** * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. * @@ -187,84 +195,28 @@ private KNNVectorValues getKNNVectorValuesForMerge( switch (fieldInfo.getVectorEncoding()) { case FLOAT32: FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + return getVectorValues(vectorDataType, mergedFloats); case BYTE: ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + return getVectorValues(vectorDataType, mergedBytes); default: throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); } } - /** - * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}. - * - * @param The type of vectors being processed. - */ - @FunctionalInterface - private interface IndexOperation { - void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; - } - - /** - * Functional interface representing a method that retrieves {@link KNNVectorValues} based on - * the vector data type, field information, and the merge state. - * - * @param The type of the data representing the vector (e.g., {@link VectorDataType}). - * @param The metadata about the field. - * @param The state of the merge operation. - * @param The result of the retrieval, typically {@link KNNVectorValues}. - */ - @FunctionalInterface - private interface VectorValuesRetriever { - Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException; - } + private QuantizationState train(final FieldInfo fieldInfo, final KNNVectorValues knnVectorValues, final int totalLiveDocs) + throws IOException { - /** - * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values - * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed. - * - * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. - * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type, - * field information, and additional context (e.g., merge state or field writer). - * @param indexOperation A functional interface that performs the indexing operation using the retrieved - * {@link KNNVectorValues}. - * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). - * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information - * @param The type of vectors being processed. - * @param The type of the context needed for retrieving the vector values. - * @throws IOException If an I/O error occurs during the processing. - */ - private void trainAndIndex( - final FieldInfo fieldInfo, - final VectorValuesRetriever> vectorValuesRetriever, - final IndexOperation indexOperation, - final C VectorProcessingContext, - final KNNGraphValue graphBuildTime, - final String operationName - ) throws IOException { - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - // Count the docIds - int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext)); - if (totalLiveDocs == 0) { - log.debug("No live docs for field " + fieldInfo.name); - return; - } + final QuantizationService quantizationService = QuantizationService.getInstance(); + final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; - QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); - if (quantizationParams != null) { + if (quantizationParams != null && totalLiveDocs > 0) { initQuantizationStateWriterIfNecessary(); - KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } - NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - KNNVectorValues knnVectors = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); - StopWatch stopWatch = new StopWatch().start(); - indexOperation.buildAndWrite(writer, knnVectors, totalLiveDocs); - long time_in_millis = stopWatch.stop().totalTime().millis(); - graphBuildTime.incrementBy(time_in_millis); - log.warn("Graph build took " + time_in_millis + " ms for " + operationName); + + return quantizationState; } /** diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java new file mode 100644 index 000000000..ad72f5b24 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -0,0 +1,288 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + + private NativeEngines990KnnVectorsWriter objectUnderTest; + + private final String description; + private final List> vectorsPerField; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Single field", List.of(Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }))), + $("Single field, no total live docs", List.of()), + $( + "Multi Field", + List.of( + Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), + Map.of( + 0, + new float[] { 1, 2, 3, 4 }, + 1, + new float[] { 2, 3, 4, 5 }, + 2, + new float[] { 3, 4, 5, 6 }, + 3, + new float[] { 4, 5, 6, 7 } + ) + ) + ) + ) + ); + } + + @SneakyThrows + public void testFlush() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + @SneakyThrows + public void testFlush_WithQuantization() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + assertTrue(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java new file mode 100644 index 000000000..440e8bbc5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -0,0 +1,240 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + @Mock + private FloatVectorValues floatVectorValues; + @Mock + private MergeState mergeState; + + private NativeEngines990KnnVectorsWriter objectUnderTest; + + private final String description; + private final Map mergedVectors; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Merge one field", Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 })), + $("Merge, no live docs", Map.of()) + ) + ); + } + + @SneakyThrows + public void testMerge() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + + @SneakyThrows + public void testMerge_WithQuantization() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + if (!mergedVectors.isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java index 3bf79b004..0f15d5240 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -184,7 +184,11 @@ public static class PreDefinedFloatVectorValues extends FloatVectorValues { public PreDefinedFloatVectorValues(final List vectors) { super(); this.count = vectors.size(); - this.dimension = vectors.get(0).length; + if (!vectors.isEmpty()) { + this.dimension = vectors.get(0).length; + } else { + this.dimension = 0; + } this.vectors = vectors; this.current = -1; vector = new float[dimension];