Skip to content

Commit

Permalink
Add quantization state reader and writer (opensearch-project#1997)
Browse files Browse the repository at this point in the history
Add quantization state reader and writer

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
ryanbogan authored Sep 5, 2024
1 parent 435e417 commit a58a9dc
Show file tree
Hide file tree
Showing 18 changed files with 1,341 additions and 37 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
### Documentation
### Maintenance
### Refactoring
### Refactoring
3 changes: 2 additions & 1 deletion release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ Compatible with OpenSearch 2.17.0
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997)
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class KNNConstants {
public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature";

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
9 changes: 5 additions & 4 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.monitor.jvm.JvmInfo;
import org.opensearch.monitor.os.OsProbe;

Expand Down Expand Up @@ -60,6 +60,7 @@ public class KNNSettings {
private static final OsProbe osProbe = OsProbe.getInstance();

private static final int INDEX_THREAD_QTY_MAX = 32;
private static final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance();

/**
* Settings name
Expand Down Expand Up @@ -379,11 +380,11 @@ private void setSettingsUpdateConsumers() {
NativeMemoryCacheManager.getInstance().rebuildCache(builder.build());
}, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> {
QuantizationStateCache.getInstance().setMaxCacheSizeInKB(it.getKb());
QuantizationStateCache.getInstance().rebuildCache();
quantizationStateCacheManager.setMaxCacheSizeInKB(it.getKb());
quantizationStateCacheManager.rebuildCache();
});
clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> {
QuantizationStateCache.getInstance().rebuildCache();
quantizationStateCacheManager.rebuildCache();
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Reads quantization states
*/
@Log4j2
public final class KNN990QuantizationStateReader {

/**
* Read quantization states and return list of fieldNames and bytes
* File format:
* Header
* QS1 state bytes
* QS2 state bytes
* Number of quantization states
* QS1 field number
* QS1 state bytes length
* QS1 position of state bytes
* QS2 field number
* QS2 state bytes length
* QS2 position of state bytes
* Position of index section (where QS1 field name is located)
* -1 (marker)
* Footer
*
* @param state the read state to read from
*/
public static Map<String, byte[]> read(SegmentReadState state) throws IOException {
String quantizationStateFileName = getQuantizationStateFileName(state);
Map<String, byte[]> readQuantizationStateInfos = null;

try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);

int numFields = getNumFields(input);

readQuantizationStateInfos = new HashMap<>();

// Read each field's metadata from the index section and then read bytes
for (int i = 0; i < numFields; i++) {
int fieldNumber = input.readInt();
int length = input.readInt();
long position = input.readVLong();
byte[] stateBytes = readStateBytes(input, position, length);
String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName();
readQuantizationStateInfos.put(fieldName, stateBytes);
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e);
return Collections.emptyMap();
}
return readQuantizationStateInfos;
}

/**
* Reads an individual quantization state for a given field
* @param readConfig a config class that contains necessary information for reading the state
* @return quantization state
*/
public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException {
SegmentReadState segmentReadState = readConfig.getSegmentReadState();
String field = readConfig.getField();
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();

try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

long position = -1;
int length = 0;

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}

byte[] stateBytes = readStateBytes(input, position, length);

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e);
return null;
}
}

@VisibleForTesting
static int getNumFields(IndexInput input) throws IOException {
long footerStart = input.length() - CodecUtil.footerLength();
long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES;
input.seek(markerAndIndexPosition);
long indexStartPosition = input.readLong();
input.seek(indexStartPosition);
return input.readInt();
}

@VisibleForTesting
static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException {
input.seek(position);
byte[] stateBytes = new byte[length];
input.readBytes(stateBytes, 0, length);
return stateBytes;
}

@VisibleForTesting
static String getQuantizationStateFileName(SegmentReadState state) {
return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import lombok.Setter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* Writes quantization states to off heap memory
*/
public final class KNN990QuantizationStateWriter {

private final IndexOutput output;
private List<FieldQuantizationState> fieldQuantizationStates = new ArrayList<>();
static final String NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA = "NativeEngines990KnnVectorsFormatQSData";

/**
* Constructor
* Overall file format for writer:
* Header
* QS1 state bytes
* QS2 state bytes
* Number of quantization states
* QS1 field number
* QS1 state bytes length
* QS1 position of state bytes
* QS2 field number
* QS2 state bytes length
* QS2 position of state bytes
* Position of index section (where QS1 field name is located)
* -1 (marker)
* Footer
* @param segmentWriteState segment write state containing segment information
* @throws IOException exception could be thrown while creating the output
*/
public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException {
String quantizationStateFileName = IndexFileNames.segmentFileName(
segmentWriteState.segmentInfo.name,
segmentWriteState.segmentSuffix,
KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX
);

output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context);
}

/**
* Writes an index header
* @param segmentWriteState state containing segment information
* @throws IOException exception could be thrown while writing header
*/
public void writeHeader(SegmentWriteState segmentWriteState) throws IOException {
CodecUtil.writeIndexHeader(
output,
NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA,
0,
segmentWriteState.segmentInfo.getId(),
segmentWriteState.segmentSuffix
);
}

/**
* Writes a quantization state as bytes
*
* @param fieldNumber field number
* @param quantizationState quantization state
* @throws IOException could be thrown while writing
*/
public void writeState(int fieldNumber, QuantizationState quantizationState) throws IOException {
byte[] stateBytes = quantizationState.toByteArray();
long position = output.getFilePointer();
output.writeBytes(stateBytes, stateBytes.length);
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position));
}

/**
* Writes index footer and other index information for parsing later
* @throws IOException could be thrown while writing
*/
public void writeFooter() throws IOException {
long indexStartPosition = output.getFilePointer();
output.writeInt(fieldQuantizationStates.size());
for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) {
output.writeInt(fieldQuantizationState.fieldNumber);
output.writeInt(fieldQuantizationState.stateBytes.length);
output.writeVLong(fieldQuantizationState.position);
}
output.writeLong(indexStartPosition);
output.writeInt(-1);
CodecUtil.writeFooter(output);
}

@AllArgsConstructor
private static class FieldQuantizationState {
final int fieldNumber;
final byte[] stateBytes;
@Setter
Long position;
}

public void closeOutput() throws IOException {
output.close();
}
}
Loading

0 comments on commit a58a9dc

Please sign in to comment.