diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index aa3fdf6ae..684185a4d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -11,11 +11,8 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; -import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.security.AccessController; @@ -57,34 +54,14 @@ public static DefaultIndexBuildStrategy getInstance() { public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { // Needed to make sure we don't get 0 dimensions while initializing index iterateVectorValuesOnce(knnVectorValues); - QuantizationService quantizationHandler = QuantizationService.getInstance(); - QuantizationState quantizationState = indexInfo.getQuantizationState(); - QuantizationOutput quantizationOutput = null; + IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo); - int bytesPerVector; - int dimensions; - - // Handle quantization state if present - if (quantizationState != null) { - bytesPerVector = quantizationState.getBytesPerVector(); - dimensions = quantizationState.getDimensions(); - quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams()); - } else { - bytesPerVector = knnVectorValues.bytesPerVector(); - dimensions = knnVectorValues.dimension(); - } - - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { final List transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs()); while (knnVectorValues.docId() != NO_MORE_DOCS) { - if (quantizationState != null && quantizationOutput != null) { - quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput); - vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true); - } else { - vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true); - } + IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, true); // append is true here so off heap memory buffer isn't overwritten transferredDocIds.add(knnVectorValues.docId()); knnVectorValues.nextDoc(); @@ -100,7 +77,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.createIndexFromTemplate( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexInfo.getIndexPath(), (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), params, @@ -113,7 +90,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.createIndex( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexInfo.getIndexPath(), params, indexInfo.getKnnEngine() diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildHelper.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildHelper.java new file mode 100644 index 000000000..593247de6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildHelper.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.experimental.UtilityClass; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; + +@UtilityClass +class IndexBuildHelper { + + // private static final QuantizationService quantizationService = QuantizationService.getInstance(); + + /** + * Processes and transfers the vector based on whether quantization is applied or not. + * + * @param knnVectorValues the KNN vector values to be processed. + * @param indexBuildSetup the setup containing quantization state and output, along with other parameters. + * @param vectorTransfer the off-heap vector transfer utility. + * @param append flag indicating whether to append or overwrite the transfer buffer. + * @return boolean indicating whether the transfer was successful. + * @throws IOException if an I/O error occurs during vector transfer. + */ + static boolean processAndTransferVector( + KNNVectorValues knnVectorValues, + IndexBuildSetup indexBuildSetup, + OffHeapVectorTransfer vectorTransfer, + boolean append + ) throws IOException { + QuantizationService quantizationService = QuantizationService.getInstance(); + if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) { + quantizationService.quantize( + indexBuildSetup.getQuantizationState(), + knnVectorValues.getVector(), + indexBuildSetup.getQuantizationOutput() + ); + return vectorTransfer.transfer(indexBuildSetup.getQuantizationOutput().getQuantizedVector(), append); + } else { + return vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), append); + } + } + + /** + * Prepares the quantization setup including bytes per vector and dimensions. + * + * @param knnVectorValues the KNN vector values. + * @param indexInfo the index build parameters. + * @return an instance of QuantizationSetup containing relevant information. + */ + static IndexBuildSetup prepareIndexBuild(KNNVectorValues knnVectorValues, BuildIndexParams indexInfo) { + QuantizationState quantizationState = indexInfo.getQuantizationState(); + QuantizationOutput quantizationOutput = null; + QuantizationService quantizationService = QuantizationService.getInstance(); + + int bytesPerVector; + int dimensions; + + if (quantizationState != null) { + bytesPerVector = quantizationState.getBytesPerVector(); + dimensions = quantizationState.getDimensions(); + quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams()); + } else { + bytesPerVector = knnVectorValues.bytesPerVector(); + dimensions = knnVectorValues.dimension(); + } + + return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java new file mode 100644 index 000000000..856140f85 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +/** + * IndexBuildSetup encapsulates the configuration and parameters required for building an index. + * This includes the size of each vector, the dimensions of the vectors, and any quantization-related + * settings such as the output and state of quantization. + */ +@Getter +@AllArgsConstructor +class IndexBuildSetup { + /** + * The number of bytes per vector. + */ + private final int bytesPerVector; + + /** + * The number of dimensions in the vector. + */ + private final int dimensions; + + /** + * The quantization output that will hold the quantized vector. + */ + private final QuantizationOutput quantizationOutput; + + /** + * The state of quantization, which may include parameters and trained models. + */ + private final QuantizationState quantizationState; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 1660a1996..3355ffebf 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -11,11 +11,8 @@ import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.security.AccessController; @@ -60,47 +57,26 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector iterateVectorValuesOnce(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); - QuantizationService quantizationHandler = QuantizationService.getInstance(); - QuantizationState quantizationState = indexInfo.getQuantizationState(); - QuantizationOutput quantizationOutput = null; - - int bytesPerVector; - int dimensions; - - // Handle quantization state if present - if (quantizationState != null) { - bytesPerVector = quantizationState.getBytesPerVector(); - dimensions = quantizationState.getDimensions(); - quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams()); - } else { - bytesPerVector = knnVectorValues.bytesPerVector(); - dimensions = knnVectorValues.dimension(); - } + IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo); // Initialize the index long indexMemoryAddress = AccessController.doPrivileged( (PrivilegedAction) () -> JNIService.initIndex( knnVectorValues.totalLiveDocs(), - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexParameters, engine ) ); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { final List transferredDocIds = new ArrayList<>(transferLimit); while (knnVectorValues.docId() != NO_MORE_DOCS) { // append is false to be able to reuse the memory location - boolean transferred; - if (quantizationState != null && quantizationOutput != null) { - quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput); - transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false); - } else { - transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false); - } + boolean transferred = IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, false); // append is false to be able to reuse the memory location transferredDocIds.add(knnVectorValues.docId()); if (transferred) { @@ -110,7 +86,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.insertToIndex( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexParameters, indexMemoryAddress, engine @@ -130,7 +106,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.insertToIndex( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexParameters, indexMemoryAddress, engine diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java index a7f3b84f5..aec2eefef 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java +++ b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java @@ -39,15 +39,13 @@ class KNNVectorQuantizationTrainingRequest extends TrainingRequest { @Override public T getVectorByDocId(int docId) { try { - int index = lastIndex; - while (index <= docId) { + while (lastIndex <= docId) { knnVectorValues.nextDoc(); - index++; + lastIndex++; } if (knnVectorValues.docId() == NO_MORE_DOCS) { return null; } - lastIndex = index; // Return the vector and the updated index return knnVectorValues.getVector(); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java index f971f70f9..f399c37ab 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java @@ -9,8 +9,8 @@ import lombok.NoArgsConstructor; import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.factory.QuantizerFactory; import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; @@ -20,6 +20,8 @@ import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; +import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig; + /** * A singleton class responsible for handling the quantization process, including training a quantizer * and applying quantization to vectors. This class is designed to be thread-safe. @@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin * Retrieves quantization parameters from the FieldInfo. */ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { - // TODO: Replace this with actual logic to extract quantization parameters from FieldInfo - if (fieldInfo.getAttribute("QuantizationConfig") != null) { - return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { + return new ScalarQuantizationParams(quantizationConfig.getQuantizationType()); } return null; } @@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { * @return The {@link VectorDataType} to be used during the vector transfer process */ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { - // TODO: Replace this with actual logic to extract quantization parameters from FieldInfo - return VectorDataType.BINARY; + QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { + return VectorDataType.BINARY; + } + return null; } /** diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index d80d398a2..85b5d07e6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -47,8 +47,11 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; import java.util.Arrays; @@ -113,7 +116,8 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc float[] floatVectorForBinaryQuantization_2 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); - fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); fieldTypeForBinaryQuantization.freeze(); addFieldToIndex( @@ -187,7 +191,8 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); - fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); addFieldToIndex( new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization),