forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add quantization state reader and writer (opensearch-project#1997)
Add quantization state reader and writer Signed-off-by: Ryan Bogan <rbogan@amazon.com>
- Loading branch information
Showing
18 changed files
with
1,341 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
155 changes: 155 additions & 0 deletions
155
src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.KNN990Codec; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
import lombok.extern.log4j.Log4j2; | ||
import org.apache.lucene.codecs.CodecUtil; | ||
import org.apache.lucene.index.IndexFileNames; | ||
import org.apache.lucene.index.SegmentReadState; | ||
import org.apache.lucene.store.IOContext; | ||
import org.apache.lucene.store.IndexInput; | ||
import org.opensearch.knn.common.KNNConstants; | ||
import org.opensearch.knn.quantization.enums.ScalarQuantizationType; | ||
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; | ||
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; | ||
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
/** | ||
* Reads quantization states | ||
*/ | ||
@Log4j2 | ||
public final class KNN990QuantizationStateReader { | ||
|
||
/** | ||
* Read quantization states and return list of fieldNames and bytes | ||
* File format: | ||
* Header | ||
* QS1 state bytes | ||
* QS2 state bytes | ||
* Number of quantization states | ||
* QS1 field number | ||
* QS1 state bytes length | ||
* QS1 position of state bytes | ||
* QS2 field number | ||
* QS2 state bytes length | ||
* QS2 position of state bytes | ||
* Position of index section (where QS1 field name is located) | ||
* -1 (marker) | ||
* Footer | ||
* | ||
* @param state the read state to read from | ||
*/ | ||
public static Map<String, byte[]> read(SegmentReadState state) throws IOException { | ||
String quantizationStateFileName = getQuantizationStateFileName(state); | ||
Map<String, byte[]> readQuantizationStateInfos = null; | ||
|
||
try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { | ||
CodecUtil.retrieveChecksum(input); | ||
|
||
int numFields = getNumFields(input); | ||
|
||
readQuantizationStateInfos = new HashMap<>(); | ||
|
||
// Read each field's metadata from the index section and then read bytes | ||
for (int i = 0; i < numFields; i++) { | ||
int fieldNumber = input.readInt(); | ||
int length = input.readInt(); | ||
long position = input.readVLong(); | ||
byte[] stateBytes = readStateBytes(input, position, length); | ||
String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); | ||
readQuantizationStateInfos.put(fieldName, stateBytes); | ||
} | ||
} catch (Exception e) { | ||
log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e); | ||
return Collections.emptyMap(); | ||
} | ||
return readQuantizationStateInfos; | ||
} | ||
|
||
/** | ||
* Reads an individual quantization state for a given field | ||
* @param readConfig a config class that contains necessary information for reading the state | ||
* @return quantization state | ||
*/ | ||
public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { | ||
SegmentReadState segmentReadState = readConfig.getSegmentReadState(); | ||
String field = readConfig.getField(); | ||
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); | ||
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); | ||
|
||
try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { | ||
CodecUtil.retrieveChecksum(input); | ||
int numFields = getNumFields(input); | ||
|
||
long position = -1; | ||
int length = 0; | ||
|
||
// Read each field's metadata from the index section, break when correct field is found | ||
for (int i = 0; i < numFields; i++) { | ||
int tempFieldNumber = input.readInt(); | ||
int tempLength = input.readInt(); | ||
long tempPosition = input.readVLong(); | ||
if (tempFieldNumber == fieldNumber) { | ||
position = tempPosition; | ||
length = tempLength; | ||
break; | ||
} | ||
} | ||
|
||
if (position == -1 || length == 0) { | ||
throw new IllegalArgumentException(String.format("Field %s not found", field)); | ||
} | ||
|
||
byte[] stateBytes = readStateBytes(input, position, length); | ||
|
||
// Deserialize the byte array to a quantization state object | ||
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); | ||
switch (scalarQuantizationType) { | ||
case ONE_BIT: | ||
return OneBitScalarQuantizationState.fromByteArray(stateBytes); | ||
case TWO_BIT: | ||
case FOUR_BIT: | ||
return MultiBitScalarQuantizationState.fromByteArray(stateBytes); | ||
default: | ||
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); | ||
} | ||
} catch (Exception e) { | ||
log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e); | ||
return null; | ||
} | ||
} | ||
|
||
@VisibleForTesting | ||
static int getNumFields(IndexInput input) throws IOException { | ||
long footerStart = input.length() - CodecUtil.footerLength(); | ||
long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; | ||
input.seek(markerAndIndexPosition); | ||
long indexStartPosition = input.readLong(); | ||
input.seek(indexStartPosition); | ||
return input.readInt(); | ||
} | ||
|
||
@VisibleForTesting | ||
static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException { | ||
input.seek(position); | ||
byte[] stateBytes = new byte[length]; | ||
input.readBytes(stateBytes, 0, length); | ||
return stateBytes; | ||
} | ||
|
||
@VisibleForTesting | ||
static String getQuantizationStateFileName(SegmentReadState state) { | ||
return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); | ||
} | ||
} |
116 changes: 116 additions & 0 deletions
116
src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.KNN990Codec; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Setter; | ||
import org.apache.lucene.codecs.CodecUtil; | ||
import org.apache.lucene.index.IndexFileNames; | ||
import org.apache.lucene.index.SegmentWriteState; | ||
import org.apache.lucene.store.IndexOutput; | ||
import org.opensearch.knn.common.KNNConstants; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* Writes quantization states to off heap memory | ||
*/ | ||
public final class KNN990QuantizationStateWriter { | ||
|
||
private final IndexOutput output; | ||
private List<FieldQuantizationState> fieldQuantizationStates = new ArrayList<>(); | ||
static final String NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA = "NativeEngines990KnnVectorsFormatQSData"; | ||
|
||
/** | ||
* Constructor | ||
* Overall file format for writer: | ||
* Header | ||
* QS1 state bytes | ||
* QS2 state bytes | ||
* Number of quantization states | ||
* QS1 field number | ||
* QS1 state bytes length | ||
* QS1 position of state bytes | ||
* QS2 field number | ||
* QS2 state bytes length | ||
* QS2 position of state bytes | ||
* Position of index section (where QS1 field name is located) | ||
* -1 (marker) | ||
* Footer | ||
* @param segmentWriteState segment write state containing segment information | ||
* @throws IOException exception could be thrown while creating the output | ||
*/ | ||
public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { | ||
String quantizationStateFileName = IndexFileNames.segmentFileName( | ||
segmentWriteState.segmentInfo.name, | ||
segmentWriteState.segmentSuffix, | ||
KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX | ||
); | ||
|
||
output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context); | ||
} | ||
|
||
/** | ||
* Writes an index header | ||
* @param segmentWriteState state containing segment information | ||
* @throws IOException exception could be thrown while writing header | ||
*/ | ||
public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { | ||
CodecUtil.writeIndexHeader( | ||
output, | ||
NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA, | ||
0, | ||
segmentWriteState.segmentInfo.getId(), | ||
segmentWriteState.segmentSuffix | ||
); | ||
} | ||
|
||
/** | ||
* Writes a quantization state as bytes | ||
* | ||
* @param fieldNumber field number | ||
* @param quantizationState quantization state | ||
* @throws IOException could be thrown while writing | ||
*/ | ||
public void writeState(int fieldNumber, QuantizationState quantizationState) throws IOException { | ||
byte[] stateBytes = quantizationState.toByteArray(); | ||
long position = output.getFilePointer(); | ||
output.writeBytes(stateBytes, stateBytes.length); | ||
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); | ||
} | ||
|
||
/** | ||
* Writes index footer and other index information for parsing later | ||
* @throws IOException could be thrown while writing | ||
*/ | ||
public void writeFooter() throws IOException { | ||
long indexStartPosition = output.getFilePointer(); | ||
output.writeInt(fieldQuantizationStates.size()); | ||
for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { | ||
output.writeInt(fieldQuantizationState.fieldNumber); | ||
output.writeInt(fieldQuantizationState.stateBytes.length); | ||
output.writeVLong(fieldQuantizationState.position); | ||
} | ||
output.writeLong(indexStartPosition); | ||
output.writeInt(-1); | ||
CodecUtil.writeFooter(output); | ||
} | ||
|
||
@AllArgsConstructor | ||
private static class FieldQuantizationState { | ||
final int fieldNumber; | ||
final byte[] stateBytes; | ||
@Setter | ||
Long position; | ||
} | ||
|
||
public void closeOutput() throws IOException { | ||
output.close(); | ||
} | ||
} |
Oops, something went wrong.