Skip to content

Commit

Permalink
Integration With Qunatization Config
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 22, 2024
1 parent c310f72 commit 048e1ff
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,17 @@ private <T, C> void trainAndIndex(
final C context
) throws IOException {
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
KNNVectorValues<T> knnVectorValuesForTraining = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context);
KNNVectorValues<T> knnVectorValuesForIndexing = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context);

KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context);
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;

if (quantizationParams != null) {
quantizationState = quantizationService.train(quantizationParams, knnVectorValuesForTraining);
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
}
NativeIndexWriter writer = (quantizationParams != null)
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);

indexOperation.buildAndWrite(writer, knnVectorValuesForIndexing);
knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context);
indexOperation.buildAndWrite(writer, knnVectorValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -57,35 +54,16 @@ public static DefaultIndexBuildStrategy getInstance() {
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());

while (knnVectorValues.docId() != NO_MORE_DOCS) {
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true);
} else {
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
}
Object vector = IndexBuildHelper.processAndReturnVector(knnVectorValues, indexBuildSetup);
// append is true here so off heap memory buffer isn't overwritten
vectorTransfer.transfer(vector, true);
transferredDocIds.add(knnVectorValues.docId());
knnVectorValues.nextDoc();
}
Expand All @@ -100,7 +78,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndexFromTemplate(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
params,
Expand All @@ -113,7 +91,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
params,
indexInfo.getKnnEngine()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.experimental.UtilityClass;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;

@UtilityClass
class IndexBuildHelper {

/**
* Processes and returns the vector based on whether quantization is applied or not.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @return the processed vector, either quantized or original.
* @throws IOException if an I/O error occurs during processing.
*/
static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) {
quantizationService.quantize(
indexBuildSetup.getQuantizationState(),
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return indexBuildSetup.getQuantizationOutput().getQuantizedVector();
} else {
return knnVectorValues.conditionalCloneVector();
}
}

/**
* Prepares the quantization setup including bytes per vector and dimensions.
*
* @param knnVectorValues the KNN vector values.
* @param indexInfo the index build parameters.
* @return an instance of QuantizationSetup containing relevant information.
*/
static IndexBuildSetup prepareIndexBuild(KNNVectorValues<?> knnVectorValues, BuildIndexParams indexInfo) {
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
QuantizationService quantizationService = QuantizationService.getInstance();

int bytesPerVector;
int dimensions;

if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

/**
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
* settings such as the output and state of quantization.
*/
@Getter
@AllArgsConstructor
final class IndexBuildSetup {
/**
* The number of bytes per vector.
*/
private final int bytesPerVector;

/**
* The number of dimensions in the vector.
*/
private final int dimensions;

/**
* The quantization output that will hold the quantized vector.
*/
private final QuantizationOutput quantizationOutput;

/**
* The state of quantization, which may include parameters and trained models.
*/
private final QuantizationState quantizationState;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -60,48 +57,27 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
iterateVectorValuesOnce(knnVectorValues);
KNNEngine engine = indexInfo.getKnnEngine();
Map<String, Object> indexParameters = indexInfo.getParameters();
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

// Initialize the index
long indexMemoryAddress = AccessController.doPrivileged(
(PrivilegedAction<Long>) () -> JNIService.initIndex(
knnVectorValues.totalLiveDocs(),
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexParameters,
engine
)
);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {

final List<Integer> transferredDocIds = new ArrayList<>(transferLimit);

while (knnVectorValues.docId() != NO_MORE_DOCS) {
Object vector = IndexBuildHelper.processAndReturnVector(knnVectorValues, indexBuildSetup);
// append is false to be able to reuse the memory location
boolean transferred;
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false);
} else {
transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false);
}
// append is false to be able to reuse the memory location
boolean transferred = vectorTransfer.transfer(vector, false);
transferredDocIds.add(knnVectorValues.docId());
if (transferred) {
// Insert vectors
Expand All @@ -110,7 +86,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand All @@ -130,7 +106,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ static KNNLibraryIndexingContext adjustPrefix(
// We need to update the prefix used to create the faiss index if we are using the quantization
// framework
if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) {
// TODO: Uncomment to use Quantization framework
// leaving commented now just so it wont fail creating faiss indices.
// prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class.
* It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID.
*/
class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {

private final KNNVectorValues<T> knnVectorValues;
private int lastIndex;
Expand All @@ -39,15 +39,13 @@ class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
@Override
public T getVectorByDocId(int docId) {
try {
int index = lastIndex;
while (index <= docId) {
while (lastIndex <= docId) {
knnVectorValues.nextDoc();
index++;
lastIndex++;
}
if (knnVectorValues.docId() == NO_MORE_DOCS) {
return null;
}
lastIndex = index;
// Return the vector and the updated index
return knnVectorValues.getVector();
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
Expand All @@ -20,6 +20,8 @@
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

/**
* A singleton class responsible for handling the quantization process, including training a quantizer
* and applying quantization to vectors. This class is designed to be thread-safe.
Expand All @@ -28,7 +30,7 @@
* @param <R> The type of the quantized output vectors.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class QuantizationService<T, R> {
public final class QuantizationService<T, R> {

/**
* The singleton instance of the {@link QuantizationService} class.
Expand Down Expand Up @@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin
* Retrieves quantization parameters from the FieldInfo.
*/
public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
if (fieldInfo.getAttribute("QuantizationConfig") != null) {
return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return new ScalarQuantizationParams(quantizationConfig.getQuantizationType());
}
return null;
}
Expand All @@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
* @return The {@link VectorDataType} to be used during the vector transfer process
*/
public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
return VectorDataType.BINARY;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return VectorDataType.BINARY;
}
return null;
}

/**
Expand Down
Loading

0 comments on commit 048e1ff

Please sign in to comment.