Skip to content

Commit

Permalink
Integration With Qunatization Config
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 23, 2024
1 parent c310f72 commit 62d9448
Show file tree
Hide file tree
Showing 20 changed files with 326 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ private interface VectorValuesRetriever<DataType, FieldInfo, MergeState, Result>
* field information, and additional context (e.g., merge state or field writer).
* @param indexOperation A functional interface that performs the indexing operation using the retrieved
* {@link KNNVectorValues}.
* @param context The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
* @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
* From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
* @param <T> The type of vectors being processed.
* @param <C> The type of the context needed for retrieving the vector values.
* @throws IOException If an I/O error occurs during the processing.
Expand All @@ -213,22 +214,20 @@ private <T, C> void trainAndIndex(
final FieldInfo fieldInfo,
final VectorValuesRetriever<VectorDataType, FieldInfo, C, KNNVectorValues<T>> vectorValuesRetriever,
final IndexOperation<T> indexOperation,
final C context
final C VectorProcessingContext
) throws IOException {
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
KNNVectorValues<T> knnVectorValuesForTraining = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context);
KNNVectorValues<T> knnVectorValuesForIndexing = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context);

KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;

if (quantizationParams != null) {
quantizationState = quantizationService.train(quantizationParams, knnVectorValuesForTraining);
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
}
NativeIndexWriter writer = (quantizationParams != null)
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);

indexOperation.buildAndWrite(writer, knnVectorValuesForIndexing);
knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
indexOperation.buildAndWrite(writer, knnVectorValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -57,35 +54,16 @@ public static DefaultIndexBuildStrategy getInstance() {
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo);

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());

while (knnVectorValues.docId() != NO_MORE_DOCS) {
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true);
} else {
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
}
Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup);
// append is true here so off heap memory buffer isn't overwritten
vectorTransfer.transfer(vector, true);
transferredDocIds.add(knnVectorValues.docId());
knnVectorValues.nextDoc();
}
Expand All @@ -100,7 +78,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndexFromTemplate(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
params,
Expand All @@ -113,7 +91,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
params,
indexInfo.getKnnEngine()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

/**
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
* settings such as the output and state of quantization.
*/
@Getter
@AllArgsConstructor
final class IndexBuildSetup {
/**
* The number of bytes per vector.
*/
private final int bytesPerVector;

/**
* Dimension of Vector for Indexing
*/
private final int dimensions;

/**
* The quantization output that will hold the quantized vector.
*/
private final QuantizationOutput quantizationOutput;

/**
* The state of quantization, which may include parameters and trained models.
*/
private final QuantizationState quantizationState;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -60,48 +57,27 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
iterateVectorValuesOnce(knnVectorValues);
KNNEngine engine = indexInfo.getKnnEngine();
Map<String, Object> indexParameters = indexInfo.getParameters();
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}
IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo);

// Initialize the index
long indexMemoryAddress = AccessController.doPrivileged(
(PrivilegedAction<Long>) () -> JNIService.initIndex(
knnVectorValues.totalLiveDocs(),
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexParameters,
engine
)
);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {

final List<Integer> transferredDocIds = new ArrayList<>(transferLimit);

while (knnVectorValues.docId() != NO_MORE_DOCS) {
Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup);
// append is false to be able to reuse the memory location
boolean transferred;
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false);
} else {
transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false);
}
// append is false to be able to reuse the memory location
boolean transferred = vectorTransfer.transfer(vector, false);
transferredDocIds.add(knnVectorValues.docId());
if (transferred) {
// Insert vectors
Expand All @@ -110,7 +86,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand All @@ -130,7 +106,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.experimental.UtilityClass;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;

@UtilityClass
class QuantizationIndexUtils {

/**
* Processes and returns the vector based on whether quantization is applied or not.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @return the processed vector, either quantized or original.
* @throws IOException if an I/O error occurs during processing.
*/
static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) {
quantizationService.quantize(
indexBuildSetup.getQuantizationState(),
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return indexBuildSetup.getQuantizationOutput().getQuantizedVector();
} else {
return knnVectorValues.conditionalCloneVector();
}
}

/**
* Prepares the quantization setup including bytes per vector and dimensions.
*
* @param knnVectorValues the KNN vector values.
* @param indexInfo the index build parameters.
* @return an instance of QuantizationSetup containing relevant information.
*/
static IndexBuildSetup prepareIndexBuild(KNNVectorValues<?> knnVectorValues, BuildIndexParams indexInfo) {
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
QuantizationService quantizationService = QuantizationService.getInstance();

int bytesPerVector;
int dimensions;

if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ static KNNLibraryIndexingContext adjustPrefix(
// We need to update the prefix used to create the faiss index if we are using the quantization
// framework
if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) {
// TODO: Uncomment to use Quantization framework
// leaving commented now just so it wont fail creating faiss indices.
// prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

package org.opensearch.knn.index.quantizationService;

import lombok.extern.log4j.Log4j2;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;

import java.io.IOException;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class.
* It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID.
*/
class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
@Log4j2
final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {

private final KNNVectorValues<T> knnVectorValues;
private int lastIndex;
Expand All @@ -31,27 +35,21 @@ class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
}

/**
* Retrieves the float vector associated with the specified document ID.
* Retrieves the vector associated with the specified document ID.
*
* @param docId the document ID.
* @param position the document ID.
* @return the float vector corresponding to the specified document ID, or null if the docId is invalid.
*/
@Override
public T getVectorByDocId(int docId) {
try {
int index = lastIndex;
while (index <= docId) {
knnVectorValues.nextDoc();
index++;
}
public T getVectorAtThePosition(int position) throws IOException {
while (lastIndex <= position) {
lastIndex++;
if (knnVectorValues.docId() == NO_MORE_DOCS) {
return null;
}
lastIndex = index;
// Return the vector and the updated index
return knnVectorValues.getVector();
} catch (Exception e) {
throw new RuntimeException("Failed to retrieve vector for docId: " + docId, e);
knnVectorValues.nextDoc();
}
// Return the vector and the updated index
return knnVectorValues.getVector();
}
}
Loading

0 comments on commit 62d9448

Please sign in to comment.