Skip to content

Commit

Permalink
Integration With Qunatization Config
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 22, 2024
1 parent c310f72 commit baccf83
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -57,34 +54,14 @@ public static DefaultIndexBuildStrategy getInstance() {
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());

while (knnVectorValues.docId() != NO_MORE_DOCS) {
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true);
} else {
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
}
IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, true);
// append is true here so off heap memory buffer isn't overwritten
transferredDocIds.add(knnVectorValues.docId());
knnVectorValues.nextDoc();
Expand All @@ -100,7 +77,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndexFromTemplate(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
params,
Expand All @@ -113,7 +90,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
params,
indexInfo.getKnnEngine()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.experimental.UtilityClass;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;

@UtilityClass
class IndexBuildHelper {

/**
* Processes and transfers the vector based on whether quantization is applied or not.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @param vectorTransfer the off-heap vector transfer utility.
* @param append flag indicating whether to append or overwrite the transfer buffer.
* @return boolean indicating whether the transfer was successful.
* @throws IOException if an I/O error occurs during vector transfer.
*/
static boolean processAndTransferVector(
KNNVectorValues<?> knnVectorValues,
IndexBuildSetup indexBuildSetup,
OffHeapVectorTransfer vectorTransfer,
boolean append
) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) {
quantizationService.quantize(
indexBuildSetup.getQuantizationState(),
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return vectorTransfer.transfer(indexBuildSetup.getQuantizationOutput().getQuantizedVector(), append);
} else {
return vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), append);
}
}

/**
* Prepares the quantization setup including bytes per vector and dimensions.
*
* @param knnVectorValues the KNN vector values.
* @param indexInfo the index build parameters.
* @return an instance of QuantizationSetup containing relevant information.
*/
static IndexBuildSetup prepareIndexBuild(KNNVectorValues<?> knnVectorValues, BuildIndexParams indexInfo) {
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
QuantizationService quantizationService = QuantizationService.getInstance();

int bytesPerVector;
int dimensions;

if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

/**
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
* settings such as the output and state of quantization.
*/
@Getter
@AllArgsConstructor
final class IndexBuildSetup {
/**
* The number of bytes per vector.
*/
private final int bytesPerVector;

/**
* The number of dimensions in the vector.
*/
private final int dimensions;

/**
* The quantization output that will hold the quantized vector.
*/
private final QuantizationOutput quantizationOutput;

/**
* The state of quantization, which may include parameters and trained models.
*/
private final QuantizationState quantizationState;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -60,47 +57,26 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
iterateVectorValuesOnce(knnVectorValues);
KNNEngine engine = indexInfo.getKnnEngine();
Map<String, Object> indexParameters = indexInfo.getParameters();
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

// Initialize the index
long indexMemoryAddress = AccessController.doPrivileged(
(PrivilegedAction<Long>) () -> JNIService.initIndex(
knnVectorValues.totalLiveDocs(),
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexParameters,
engine
)
);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {

final List<Integer> transferredDocIds = new ArrayList<>(transferLimit);

while (knnVectorValues.docId() != NO_MORE_DOCS) {
// append is false to be able to reuse the memory location
boolean transferred;
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false);
} else {
transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false);
}
boolean transferred = IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, false);
// append is false to be able to reuse the memory location
transferredDocIds.add(knnVectorValues.docId());
if (transferred) {
Expand All @@ -110,7 +86,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand All @@ -130,7 +106,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ static KNNLibraryIndexingContext adjustPrefix(
// We need to update the prefix used to create the faiss index if we are using the quantization
// framework
if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) {
// TODO: Uncomment to use Quantization framework
// leaving commented now just so it wont fail creating faiss indices.
// prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class.
* It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID.
*/
class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {

private final KNNVectorValues<T> knnVectorValues;
private int lastIndex;
Expand All @@ -39,15 +39,13 @@ class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
@Override
public T getVectorByDocId(int docId) {
try {
int index = lastIndex;
while (index <= docId) {
while (lastIndex <= docId) {
knnVectorValues.nextDoc();
index++;
lastIndex++;
}
if (knnVectorValues.docId() == NO_MORE_DOCS) {
return null;
}
lastIndex = index;
// Return the vector and the updated index
return knnVectorValues.getVector();
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
Expand All @@ -20,6 +20,8 @@
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

/**
* A singleton class responsible for handling the quantization process, including training a quantizer
* and applying quantization to vectors. This class is designed to be thread-safe.
Expand All @@ -28,7 +30,7 @@
* @param <R> The type of the quantized output vectors.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class QuantizationService<T, R> {
public final class QuantizationService<T, R> {

/**
* The singleton instance of the {@link QuantizationService} class.
Expand Down Expand Up @@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin
* Retrieves quantization parameters from the FieldInfo.
*/
public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
if (fieldInfo.getAttribute("QuantizationConfig") != null) {
return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return new ScalarQuantizationParams(quantizationConfig.getQuantizationType());
}
return null;
}
Expand All @@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
* @return The {@link VectorDataType} to be used during the vector transfer process
*/
public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
return VectorDataType.BINARY;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return VectorDataType.BINARY;
}
return null;
}

/**
Expand Down
Loading

0 comments on commit baccf83

Please sign in to comment.