diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 92f9bb831..43f4d7ad6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -204,7 +204,8 @@ private interface VectorValuesRetriever * field information, and additional context (e.g., merge state or field writer). * @param indexOperation A functional interface that performs the indexing operation using the retrieved * {@link KNNVectorValues}. - * @param context The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). + * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). + * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information * @param The type of vectors being processed. * @param The type of the context needed for retrieving the vector values. * @throws IOException If an I/O error occurs during the processing. @@ -213,22 +214,20 @@ private void trainAndIndex( final FieldInfo fieldInfo, final VectorValuesRetriever> vectorValuesRetriever, final IndexOperation indexOperation, - final C context + final C VectorProcessingContext ) throws IOException { final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - KNNVectorValues knnVectorValuesForTraining = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context); - KNNVectorValues knnVectorValuesForIndexing = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context); - + KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; - if (quantizationParams != null) { - quantizationState = quantizationService.train(quantizationParams, knnVectorValuesForTraining); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues); } NativeIndexWriter writer = (quantizationParams != null) ? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState) : NativeIndexWriter.getWriter(fieldInfo, segmentWriteState); - indexOperation.buildAndWrite(writer, knnVectorValuesForIndexing); + knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); + indexOperation.buildAndWrite(writer, knnVectorValues); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index aa3fdf6ae..d2a6027db 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -11,11 +11,8 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; -import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.security.AccessController; @@ -57,35 +54,16 @@ public static DefaultIndexBuildStrategy getInstance() { public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { // Needed to make sure we don't get 0 dimensions while initializing index iterateVectorValuesOnce(knnVectorValues); - QuantizationService quantizationHandler = QuantizationService.getInstance(); - QuantizationState quantizationState = indexInfo.getQuantizationState(); - QuantizationOutput quantizationOutput = null; + IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); - int bytesPerVector; - int dimensions; - - // Handle quantization state if present - if (quantizationState != null) { - bytesPerVector = quantizationState.getBytesPerVector(); - dimensions = quantizationState.getDimensions(); - quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams()); - } else { - bytesPerVector = knnVectorValues.bytesPerVector(); - dimensions = knnVectorValues.dimension(); - } - - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { final List transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs()); while (knnVectorValues.docId() != NO_MORE_DOCS) { - if (quantizationState != null && quantizationOutput != null) { - quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput); - vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true); - } else { - vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true); - } + Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); // append is true here so off heap memory buffer isn't overwritten + vectorTransfer.transfer(vector, true); transferredDocIds.add(knnVectorValues.docId()); knnVectorValues.nextDoc(); } @@ -100,7 +78,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.createIndexFromTemplate( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexInfo.getIndexPath(), (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), params, @@ -113,7 +91,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.createIndex( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexInfo.getIndexPath(), params, indexInfo.getKnnEngine() diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java new file mode 100644 index 000000000..c6c999c07 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +/** + * IndexBuildSetup encapsulates the configuration and parameters required for building an index. + * This includes the size of each vector, the dimensions of the vectors, and any quantization-related + * settings such as the output and state of quantization. + */ +@Getter +@AllArgsConstructor +final class IndexBuildSetup { + /** + * The number of bytes per vector. + */ + private final int bytesPerVector; + + /** + * Dimension of Vector for Indexing + */ + private final int dimensions; + + /** + * The quantization output that will hold the quantized vector. + */ + private final QuantizationOutput quantizationOutput; + + /** + * The state of quantization, which may include parameters and trained models. + */ + private final QuantizationState quantizationState; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 1660a1996..1115bfe05 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -11,11 +11,8 @@ import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.security.AccessController; @@ -60,48 +57,27 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector iterateVectorValuesOnce(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); - QuantizationService quantizationHandler = QuantizationService.getInstance(); - QuantizationState quantizationState = indexInfo.getQuantizationState(); - QuantizationOutput quantizationOutput = null; - - int bytesPerVector; - int dimensions; - - // Handle quantization state if present - if (quantizationState != null) { - bytesPerVector = quantizationState.getBytesPerVector(); - dimensions = quantizationState.getDimensions(); - quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams()); - } else { - bytesPerVector = knnVectorValues.bytesPerVector(); - dimensions = knnVectorValues.dimension(); - } + IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); // Initialize the index long indexMemoryAddress = AccessController.doPrivileged( (PrivilegedAction) () -> JNIService.initIndex( knnVectorValues.totalLiveDocs(), - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexParameters, engine ) ); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { final List transferredDocIds = new ArrayList<>(transferLimit); while (knnVectorValues.docId() != NO_MORE_DOCS) { + Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); // append is false to be able to reuse the memory location - boolean transferred; - if (quantizationState != null && quantizationOutput != null) { - quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput); - transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false); - } else { - transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false); - } - // append is false to be able to reuse the memory location + boolean transferred = vectorTransfer.transfer(vector, false); transferredDocIds.add(knnVectorValues.docId()); if (transferred) { // Insert vectors @@ -110,7 +86,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.insertToIndex( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexParameters, indexMemoryAddress, engine @@ -130,7 +106,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector JNIService.insertToIndex( intListToArray(transferredDocIds), vectorAddress, - dimensions, + indexBuildSetup.getDimensions(), indexParameters, indexMemoryAddress, engine diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java new file mode 100644 index 000000000..8fec1af6d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.experimental.UtilityClass; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; + +@UtilityClass +class QuantizationIndexUtils { + + /** + * Processes and returns the vector based on whether quantization is applied or not. + * + * @param knnVectorValues the KNN vector values to be processed. + * @param indexBuildSetup the setup containing quantization state and output, along with other parameters. + * @return the processed vector, either quantized or original. + * @throws IOException if an I/O error occurs during processing. + */ + static Object processAndReturnVector(KNNVectorValues knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException { + QuantizationService quantizationService = QuantizationService.getInstance(); + if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) { + quantizationService.quantize( + indexBuildSetup.getQuantizationState(), + knnVectorValues.getVector(), + indexBuildSetup.getQuantizationOutput() + ); + return indexBuildSetup.getQuantizationOutput().getQuantizedVector(); + } else { + return knnVectorValues.conditionalCloneVector(); + } + } + + /** + * Prepares the quantization setup including bytes per vector and dimensions. + * + * @param knnVectorValues the KNN vector values. + * @param indexInfo the index build parameters. + * @return an instance of QuantizationSetup containing relevant information. + */ + static IndexBuildSetup prepareIndexBuild(KNNVectorValues knnVectorValues, BuildIndexParams indexInfo) { + QuantizationState quantizationState = indexInfo.getQuantizationState(); + QuantizationOutput quantizationOutput = null; + QuantizationService quantizationService = QuantizationService.getInstance(); + + int bytesPerVector; + int dimensions; + + if (quantizationState != null) { + bytesPerVector = quantizationState.getBytesPerVector(); + dimensions = quantizationState.getDimensions(); + quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams()); + } else { + bytesPerVector = knnVectorValues.bytesPerVector(); + dimensions = knnVectorValues.dimension(); + } + + return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState); + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 908671a21..a2d51da3e 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -97,9 +97,7 @@ static KNNLibraryIndexingContext adjustPrefix( // We need to update the prefix used to create the faiss index if we are using the quantization // framework if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) { - // TODO: Uncomment to use Quantization framework - // leaving commented now just so it wont fail creating faiss indices. - // prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; } if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java index a7f3b84f5..4cf68d16c 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java +++ b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java @@ -5,16 +5,20 @@ package org.opensearch.knn.index.quantizationService; +import lombok.extern.log4j.Log4j2; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class. * It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID. */ -class KNNVectorQuantizationTrainingRequest extends TrainingRequest { +@Log4j2 +final class KNNVectorQuantizationTrainingRequest extends TrainingRequest { private final KNNVectorValues knnVectorValues; private int lastIndex; @@ -31,27 +35,21 @@ class KNNVectorQuantizationTrainingRequest extends TrainingRequest { } /** - * Retrieves the float vector associated with the specified document ID. + * Retrieves the vector associated with the specified document ID. * - * @param docId the document ID. + * @param position the document ID. * @return the float vector corresponding to the specified document ID, or null if the docId is invalid. */ @Override - public T getVectorByDocId(int docId) { - try { - int index = lastIndex; - while (index <= docId) { - knnVectorValues.nextDoc(); - index++; - } + public T getVectorAtThePosition(int position) throws IOException { + while (lastIndex <= position) { + lastIndex++; if (knnVectorValues.docId() == NO_MORE_DOCS) { return null; } - lastIndex = index; - // Return the vector and the updated index - return knnVectorValues.getVector(); - } catch (Exception e) { - throw new RuntimeException("Failed to retrieve vector for docId: " + docId, e); + knnVectorValues.nextDoc(); } + // Return the vector and the updated index + return knnVectorValues.getVector(); } } diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java index f971f70f9..a9e3cc715 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java @@ -9,8 +9,8 @@ import lombok.NoArgsConstructor; import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.factory.QuantizerFactory; import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; @@ -20,6 +20,8 @@ import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; +import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig; + /** * A singleton class responsible for handling the quantization process, including training a quantizer * and applying quantization to vectors. This class is designed to be thread-safe. @@ -28,7 +30,7 @@ * @param The type of the quantized output vectors. */ @NoArgsConstructor(access = AccessLevel.PRIVATE) -public class QuantizationService { +public final class QuantizationService { /** * The singleton instance of the {@link QuantizationService} class. @@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin * Retrieves quantization parameters from the FieldInfo. */ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { - // TODO: Replace this with actual logic to extract quantization parameters from FieldInfo - if (fieldInfo.getAttribute("QuantizationConfig") != null) { - return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { + return new ScalarQuantizationParams(quantizationConfig.getQuantizationType()); } return null; } @@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { * @return The {@link VectorDataType} to be used during the vector transfer process */ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { - // TODO: Replace this with actual logic to extract quantization parameters from FieldInfo - return VectorDataType.BINARY; + QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { + return VectorDataType.BINARY; + } + return null; } /** diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index ba54b60ab..79ce7b955 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -146,6 +146,10 @@ public int getBytesPerVector() { public int getDimensions() { // For multi-bit quantization, the dimension for indexing is the number of rows * columns in the thresholds array. // Where number of column reprensents Dimesion of Original vector and number of rows equals to number of bits + // Check if thresholds are null or have invalid structure + if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) { + throw new IllegalStateException("Error in getting Dimension: The thresholds array is not initialized."); + } return thresholds.length * thresholds[0].length; } diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java index 54ebe311c..d8b0eab10 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -8,6 +8,8 @@ import lombok.AllArgsConstructor; import lombok.Getter; +import java.io.IOException; + /** * TrainingRequest represents a request for training a quantizer. * @@ -24,8 +26,8 @@ public abstract class TrainingRequest { /** * Returns the vector corresponding to the specified document ID. * - * @param docId the document ID. + * @param position the document position. * @return the vector corresponding to the specified document ID. */ - public abstract T getVectorByDocId(int docId); + public abstract T getVectorAtThePosition(int position) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 981e7f00c..a0e6ec402 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -17,6 +17,8 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import oshi.util.tuples.Pair; +import java.io.IOException; + /** * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. * Unlike the OneBitScalarQuantizer, which uses a single bit per dimension to represent whether a value is above @@ -106,7 +108,7 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi * @return a MultiBitScalarQuantizationState containing the computed thresholds. */ @Override - public QuantizationState train(final TrainingRequest trainingRequest) { + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); // Calculate sum, mean, and standard deviation in one pass Pair meanAndStdDev = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices); diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index a0f6a26b4..ac48a9523 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -15,6 +15,8 @@ import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; +import java.io.IOException; + /** * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. * It computes the mean of each dimension during training and then uses these means as thresholds @@ -56,7 +58,7 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { * @return a OneBitScalarQuantizationState containing the calculated means. */ @Override - public QuantizationState train(final TrainingRequest trainingRequest) { + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index c0b297f5d..521863205 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -9,6 +9,8 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + /** * The Quantizer interface defines the methods required for training and quantizing vectors * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. @@ -27,7 +29,7 @@ public interface Quantizer { * @param trainingRequest the request containing data and parameters for training. * @return a QuantizationState containing the learned parameters. */ - QuantizationState train(TrainingRequest trainingRequest); + QuantizationState train(TrainingRequest trainingRequest) throws IOException; /** * Quantizes the provided vector using the specified quantization state. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index f8f2cffed..bac2067c0 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -9,6 +9,8 @@ import lombok.experimental.UtilityClass; import oshi.util.tuples.Pair; +import java.io.IOException; + /** * Utility class providing common methods for quantizer operations, such as parameter validation and * extraction. This class is designed to be used with various quantizer implementations that require @@ -25,12 +27,12 @@ class QuantizerHelper { * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. */ - static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) { + static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) throws IOException { int totalSamples = sampledIndices.length; float[] mean = null; int lastIndex = 0; for (int docId : sampledIndices) { - float[] vector = samplingRequest.getVectorByDocId(docId); + float[] vector = samplingRequest.getVectorAtThePosition(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } @@ -64,13 +66,14 @@ static float[] calculateMeanThresholds(TrainingRequest samplingRequest, * @throws IllegalArgumentException if any of the vectors at the sampled indices are null. * @throws IllegalStateException if the mean or standard deviation arrays are not initialized after processing. */ - static Pair calculateMeanAndStdDev(TrainingRequest trainingRequest, int[] sampledIndices) { + static Pair calculateMeanAndStdDev(TrainingRequest trainingRequest, int[] sampledIndices) + throws IOException { float[] meanArray = null; float[] stdDevArray = null; int totalSamples = sampledIndices.length; int lastIndex = 0; for (int docId : sampledIndices) { - float[] vector = trainingRequest.getVectorByDocId(docId); + float[] vector = trainingRequest.getVectorAtThePosition(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index d80d398a2..85b5d07e6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -47,8 +47,11 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; import java.util.Arrays; @@ -113,7 +116,8 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc float[] floatVectorForBinaryQuantization_2 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); - fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); fieldTypeForBinaryQuantization.freeze(); addFieldToIndex( @@ -187,7 +191,8 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); - fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); addFieldToIndex( new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization), diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java new file mode 100644 index 000000000..30a2098dd --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.junit.Before; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class QuantizationIndexUtilsTests extends KNNTestCase { + + private KNNVectorValues knnVectorValues; + private BuildIndexParams buildIndexParams; + private QuantizationService quantizationService; + + @Before + public void setUp() throws Exception { + super.setUp(); + quantizationService = mock(QuantizationService.class); + + // Predefined float vectors for testing + List floatVectors = List.of( + new float[] { 1.0f, 2.0f, 3.0f }, + new float[] { 4.0f, 5.0f, 6.0f }, + new float[] { 7.0f, 8.0f, 9.0f } + ); + + // Use the predefined vectors to create KNNVectorValues + knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) + ); + + // Mocking BuildIndexParams + buildIndexParams = mock(BuildIndexParams.class); + } + + public void testPrepareIndexBuild_withQuantization_success() { + QuantizationState quantizationState = mock(OneBitScalarQuantizationState.class); + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 0x01 }); + when(quantizationState.getDimensions()).thenReturn(2); + when(quantizationState.getBytesPerVector()).thenReturn(8); + when(quantizationState.getQuantizationParams()).thenReturn(params); + + when(buildIndexParams.getQuantizationState()).thenReturn(quantizationState); + + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + + assertNotNull(setup.getQuantizationState()); + assertEquals(8, setup.getBytesPerVector()); + assertEquals(2, setup.getDimensions()); + } + + public void testPrepareIndexBuild_withoutQuantization_success() throws IOException { + when(buildIndexParams.getQuantizationState()).thenReturn(null); + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + assertNull(setup.getQuantizationState()); + assertEquals(knnVectorValues.bytesPerVector(), setup.getBytesPerVector()); + assertEquals(knnVectorValues.dimension(), setup.getDimensions()); + } + + public void testProcessAndReturnVector_withoutQuantization_success() throws IOException { + // Set up the BuildIndexParams to return no quantization + when(buildIndexParams.getQuantizationState()).thenReturn(null); + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + // Process and return the vector + assertNotNull(QuantizationIndexUtils.processAndReturnVector(knnVectorValues, setup)); + } + + public void testProcessAndReturnVector_withQuantization_success() throws IOException { + // Set up quantization state and output + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + knnVectorValues.nextDoc(); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + when(buildIndexParams.getQuantizationState()).thenReturn(state); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + // Process and return the vector + Object result = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, setup); + assertTrue(result instanceof byte[]); + assertArrayEquals(new byte[] { 0x00 }, (byte[]) result); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java index c9ce50f22..75da6811e 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java @@ -246,7 +246,7 @@ public void testGetKNNLibraryIndexingContext_whenMethodIsHNSWWithQFrame_thenCrea .vectorDataType(VectorDataType.FLOAT) .build(); int m = 88; - String expectedIndexDescription = "HNSW" + m + ",Flat"; + String expectedIndexDescription = "BHNSW" + m + ",Flat"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .field(NAME, METHOD_HNSW) @@ -285,7 +285,7 @@ public void testGetKNNLibraryIndexingContext_whenMethodIsIVFWithQFrame_thenCreat .vectorDataType(VectorDataType.FLOAT) .build(); int nlist = 88; - String expectedIndexDescription = "IVF" + nlist + ",Flat"; + String expectedIndexDescription = "BIVF" + nlist + ",Flat"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .field(NAME, METHOD_IVF) diff --git a/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java b/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java index e3f8b607a..38371d8c3 100644 --- a/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java +++ b/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java @@ -5,7 +5,6 @@ package org.opensearch.knn.integ; -import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; @@ -29,23 +28,24 @@ public class QFrameworkIT extends KNNRestTestCase { public void testBaseCase() throws IOException { createTestIndex(4); - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, TEST_VECTOR); - Response response = searchKNNIndex( - INDEX_NAME, - XContentFactory.jsonBuilder() - .startObject() - .startObject("query") - .startObject("knn") - .startObject(FIELD_NAME) - .field("vector", TEST_VECTOR) - .field("k", K) - .endObject() - .endObject() - .endObject() - .endObject(), - 1 - ); - assertOK(response); + // TODO :- UnComment this once Search is Integrated and KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING is enabled + // addKnnDoc(INDEX_NAME, "1", FIELD_NAME, TEST_VECTOR); + // Response response = searchKNNIndex( + // INDEX_NAME, + // XContentFactory.jsonBuilder() + // .startObject() + // .startObject("query") + // .startObject("knn") + // .startObject(FIELD_NAME) + // .field("vector", TEST_VECTOR) + // .field("k", K) + // .endObject() + // .endObject() + // .endObject() + // .endObject(), + // 1 + // ); + // assertOK(response); } private void createTestIndex(int bitCount) throws IOException { diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java index 45acaf357..de815d8ad 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -17,7 +17,7 @@ public class MultiBitScalarQuantizerTests extends KNNTestCase { - public void testTrain_twoBit() { + public void testTrain_twoBit() throws IOException { float[][] vectors = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, @@ -33,7 +33,7 @@ public void testTrain_twoBit() { assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds } - public void testTrain_fourBit() { + public void testTrain_fourBit() throws IOException { MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); float[][] vectors = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, @@ -220,8 +220,8 @@ public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) { } @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index 7d98daf95..a6b907ccb 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -22,14 +22,14 @@ public class OneBitScalarQuantizerTests extends KNNTestCase { - public void testTrain_withTrainingRequired() { + public void testTrain_withTrainingRequired() throws IOException { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); TrainingRequest originalRequest = new TrainingRequest(vectors.length) { @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } }; OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); @@ -118,14 +118,14 @@ public void testQuantize_withMismatchedDimensions() throws IOException { expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output)); } - public void testCalculateMean() { + public void testCalculateMean() throws IOException { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } }; @@ -141,8 +141,8 @@ public void testCalculateMean_withNullVector() { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } };