Skip to content

Commit

Permalink
Integration With Qunatization Config
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 22, 2024
1 parent c310f72 commit 448ffb9
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -57,34 +54,14 @@ public static DefaultIndexBuildStrategy getInstance() {
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());

while (knnVectorValues.docId() != NO_MORE_DOCS) {
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true);
} else {
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
}
IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, true);
// append is true here so off heap memory buffer isn't overwritten
transferredDocIds.add(knnVectorValues.docId());
knnVectorValues.nextDoc();
Expand All @@ -100,7 +77,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndexFromTemplate(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
params,
Expand All @@ -113,7 +90,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
params,
indexInfo.getKnnEngine()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.experimental.UtilityClass;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;

@UtilityClass
class IndexBuildHelper {

// private static final QuantizationService quantizationService = QuantizationService.getInstance();

/**
* Processes and transfers the vector based on whether quantization is applied or not.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @param vectorTransfer the off-heap vector transfer utility.
* @param append flag indicating whether to append or overwrite the transfer buffer.
* @return boolean indicating whether the transfer was successful.
* @throws IOException if an I/O error occurs during vector transfer.
*/
static boolean processAndTransferVector(
KNNVectorValues<?> knnVectorValues,
IndexBuildSetup indexBuildSetup,
OffHeapVectorTransfer vectorTransfer,
boolean append
) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) {
quantizationService.quantize(
indexBuildSetup.getQuantizationState(),
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return vectorTransfer.transfer(indexBuildSetup.getQuantizationOutput().getQuantizedVector(), append);
} else {
return vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), append);
}
}

/**
* Prepares the quantization setup including bytes per vector and dimensions.
*
* @param knnVectorValues the KNN vector values.
* @param indexInfo the index build parameters.
* @return an instance of QuantizationSetup containing relevant information.
*/
static IndexBuildSetup prepareIndexBuild(KNNVectorValues<?> knnVectorValues, BuildIndexParams indexInfo) {
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
QuantizationService quantizationService = QuantizationService.getInstance();

int bytesPerVector;
int dimensions;

if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

/**
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
* settings such as the output and state of quantization.
*/
@Getter
@AllArgsConstructor
class IndexBuildSetup {
/**
* The number of bytes per vector.
*/
private final int bytesPerVector;

/**
* The number of dimensions in the vector.
*/
private final int dimensions;

/**
* The quantization output that will hold the quantized vector.
*/
private final QuantizationOutput quantizationOutput;

/**
* The state of quantization, which may include parameters and trained models.
*/
private final QuantizationState quantizationState;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -60,47 +57,26 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
iterateVectorValuesOnce(knnVectorValues);
KNNEngine engine = indexInfo.getKnnEngine();
Map<String, Object> indexParameters = indexInfo.getParameters();
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

// Initialize the index
long indexMemoryAddress = AccessController.doPrivileged(
(PrivilegedAction<Long>) () -> JNIService.initIndex(
knnVectorValues.totalLiveDocs(),
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexParameters,
engine
)
);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {

final List<Integer> transferredDocIds = new ArrayList<>(transferLimit);

while (knnVectorValues.docId() != NO_MORE_DOCS) {
// append is false to be able to reuse the memory location
boolean transferred;
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false);
} else {
transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false);
}
boolean transferred = IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, false);
// append is false to be able to reuse the memory location
transferredDocIds.add(knnVectorValues.docId());
if (transferred) {
Expand All @@ -110,7 +86,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand All @@ -130,7 +106,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@ class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
@Override
public T getVectorByDocId(int docId) {
try {
int index = lastIndex;
while (index <= docId) {
while (lastIndex <= docId) {
knnVectorValues.nextDoc();
index++;
lastIndex++;
}
if (knnVectorValues.docId() == NO_MORE_DOCS) {
return null;
}
lastIndex = index;
// Return the vector and the updated index
return knnVectorValues.getVector();
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
Expand All @@ -20,6 +20,8 @@
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

/**
* A singleton class responsible for handling the quantization process, including training a quantizer
* and applying quantization to vectors. This class is designed to be thread-safe.
Expand Down Expand Up @@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin
* Retrieves quantization parameters from the FieldInfo.
*/
public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
if (fieldInfo.getAttribute("QuantizationConfig") != null) {
return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return new ScalarQuantizationParams(quantizationConfig.getQuantizationType());
}
return null;
}
Expand All @@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
* @return The {@link VectorDataType} to be used during the vector transfer process
*/
public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
return VectorDataType.BINARY;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return VectorDataType.BINARY;
}
return null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -113,7 +116,8 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc
float[] floatVectorForBinaryQuantization_2 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f };
FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT);
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}");
fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }");
QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build();
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig));
fieldTypeForBinaryQuantization.freeze();

addFieldToIndex(
Expand Down Expand Up @@ -187,7 +191,8 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce
float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f };
FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT);
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}");
fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }");
QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build();
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig));

addFieldToIndex(
new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization),
Expand Down

0 comments on commit 448ffb9

Please sign in to comment.