Skip to content

Commit

Permalink
Fix a flaky unit test:testMultiFieldsKnnIndex, which was failing due …
Browse files Browse the repository at this point in the history
…to inconsistent merge behaviors (opensearch-project#1924)

Signed-off-by: Navneet Verma <navneev@amazon.com>

Quantization Framework Implementation with 1bit, 2bit and 4bit Binary Quantizer
  • Loading branch information
navneet1v authored and Vikasht34 committed Aug 2, 2024
1 parent 2e31bcd commit 65b14e9
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private static void ensureRegistered() {
* @param <Q> the type of the quantized output
* @return an instance of {@link Quantizer} corresponding to the provided parameters
*/
public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(P params) {
public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(final P params) {
if (params == null) {
throw new IllegalArgumentException("Quantization parameters must not be null.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ private QuantizerRegistry() {}
* @param quantizerSupplier a supplier that provides instances of the quantizer
* @param <P> the type of quantization parameters
*/
public static <P extends QuantizationParams> void register(Class<P> paramClass,
QuantizationType quantizationType,
SQTypes sqType,
Supplier<? extends Quantizer<?, ?>> quantizerSupplier) {
public static <P extends QuantizationParams> void register( final Class<P> paramClass,
final QuantizationType quantizationType,
final SQTypes sqType,
final Supplier<? extends Quantizer<?, ?>> quantizerSupplier) {
String identifier = quantizationType.name() + "_" + sqType.name();
// Ensure that the quantizer for this identifier is registered only once
registry.computeIfAbsent(identifier, key -> quantizerSupplier);
Expand All @@ -54,7 +54,7 @@ public static <P extends QuantizationParams> void register(Class<P> paramClass,
* @return an instance of {@link Quantizer} corresponding to the provided parameters
* @throws IllegalArgumentException if no quantizer is registered for the given parameters
*/
public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(P params) {
public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(final P params) {
String identifier = params.getTypeIdentifier();
Supplier<? extends Quantizer<?, ?>> supplier = registry.get(identifier);
if (supplier == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class BinaryQuantizationOutput implements QuantizationOutput<byte[]> {
*
* @param quantizedVector the quantized vector represented as a byte array.
*/
public BinaryQuantizationOutput(byte[] quantizedVector) {
public BinaryQuantizationOutput(final byte[] quantizedVector) {
if (quantizedVector == null) {
throw new IllegalArgumentException("Quantized vector cannot be null");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class SQParams implements QuantizationParams {
*
* @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT).
*/
public SQParams(SQTypes sqType) {
public SQParams(final SQTypes sqType) {
this.sqType = sqType;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class DefaultQuantizationState implements QuantizationState {

private final QuantizationParams params;

public DefaultQuantizationState(QuantizationParams params) {
public DefaultQuantizationState(final QuantizationParams params) {
this.params = params;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public final class MultiBitScalarQuantizationState implements QuantizationState
* @param thresholds the threshold values for multi-bit quantization, organized as a 2D array
* where each row corresponds to a different bit level.
*/
public MultiBitScalarQuantizationState(SQParams quantizationParams, float[][] thresholds) {
public MultiBitScalarQuantizationState(final SQParams quantizationParams, final float[][] thresholds) {
this.quantizationParams = quantizationParams;
this.thresholds = thresholds;
}
Expand All @@ -49,7 +49,7 @@ public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, thresholds);
}

public static MultiBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (MultiBitScalarQuantizationState)
QuantizationStateSerializer.deserialize(bytes, (parentParams, specificData) ->
new MultiBitScalarQuantizationState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public final class OneBitScalarQuantizationState implements QuantizationState {
* @param quantizationParams the scalar quantization parameters.
* @param mean the mean values for each dimension.
*/
public OneBitScalarQuantizationState(SQParams quantizationParams, float[] mean) {
public OneBitScalarQuantizationState(final SQParams quantizationParams, final float[] mean) {
this.quantizationParams = quantizationParams;
this.mean = mean;
}
Expand All @@ -48,7 +48,7 @@ public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, mean);
}

public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
(parentParams, specificData) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class MultiBitScalarQuantizer implements Quantizer<float[], byte[]> {
*
* @param bitsPerCoordinate the number of bits used per coordinate for quantization.
*/
public MultiBitScalarQuantizer(int bitsPerCoordinate) {
public MultiBitScalarQuantizer(final int bitsPerCoordinate) {
if (bitsPerCoordinate < 2) {
throw new IllegalArgumentException("bitsPerCoordinate must be greater than 2 for multibit quantizer.");
}
Expand All @@ -51,7 +51,7 @@ public MultiBitScalarQuantizer(int bitsPerCoordinate) {
* @return a MultiBitScalarQuantizationState containing the computed thresholds.
*/
@Override
public QuantizationState train(TrainingRequest<float[]> trainingRequest) {
public QuantizationState train(final TrainingRequest<float[]> trainingRequest) {
if (!IS_TRAINING_REQUIRED) {
return new DefaultQuantizationState(trainingRequest.getParams());
}
Expand All @@ -77,7 +77,7 @@ public QuantizationState train(TrainingRequest<float[]> trainingRequest) {
* @return a BinaryQuantizationOutput containing the quantized data.
*/
@Override
public QuantizationOutput<byte[]> quantize(float[] vector, QuantizationState state) {
public QuantizationOutput<byte[]> quantize(final float[] vector, final QuantizationState state) {
if (state instanceof DefaultQuantizationState) {
return quantize(vector);
}
Expand Down Expand Up @@ -113,7 +113,11 @@ public QuantizationOutput<byte[]> quantize(float[] vector, QuantizationState sta
* @param dimension the number of dimensions in the vectors.
* @return the thresholds for quantization.
*/
private float[][] calculateThresholds(float[] mean, float[] stdDev, int dimension) {
private float[][] calculateThresholds(
final float[] mean,
final float[] stdDev,
final int dimension
) {
float[][] thresholds = new float[bitsPerCoordinate][dimension];
float coef = bitsPerCoordinate + 1;
for (int i = 0; i < bitsPerCoordinate; i++) {
Expand All @@ -135,7 +139,7 @@ private float[][] calculateThresholds(float[] mean, float[] stdDev, int dimensio
* @return a {@link QuantizationOutput} containing the byte array representation of the quantized vector
* @throws UnsupportedOperationException if the quantization state is not available
*/
private QuantizationOutput<byte[]> quantize(float[] vector) {
private QuantizationOutput<byte[]> quantize(final float[] vector) {
throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class OneBitScalarQuantizer implements Quantizer<float[], byte[]> {
* @return a OneBitScalarQuantizationState containing the calculated means.
*/
@Override
public QuantizationState train(TrainingRequest<float[]> trainingRequest) {
public QuantizationState train(final TrainingRequest<float[]> trainingRequest) {
if (!IS_TRAINING_REQUIRED) {
return new DefaultQuantizationState(trainingRequest.getParams());
}
Expand All @@ -56,7 +56,7 @@ public QuantizationState train(TrainingRequest<float[]> trainingRequest) {
* @return a BinaryQuantizationOutput containing the quantized data.
*/
@Override
public QuantizationOutput<byte[]> quantize(float[] vector, QuantizationState state) {
public QuantizationOutput<byte[]> quantize(final float[] vector, final QuantizationState state) {
if (state instanceof DefaultQuantizationState) {
return quantize(vector);
}
Expand Down Expand Up @@ -88,7 +88,7 @@ public QuantizationOutput<byte[]> quantize(float[] vector, QuantizationState sta
* @return a {@link QuantizationOutput} containing the byte array representation of the quantized vector
* @throws UnsupportedOperationException if the quantization state is not available
*/
private QuantizationOutput<byte[]> quantize(float[] vector) {
private QuantizationOutput<byte[]> quantize(final float[] vector) {
throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer.");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ final class ReservoirSampler implements Sampler {
* @return an array of sampled indices.
*/
@Override
public int[] sample(int totalNumberOfVectors, int sampleSize) {
public int[] sample(final int totalNumberOfVectors, final int sampleSize) {
if (totalNumberOfVectors <= sampleSize) {
return IntStream.range(0, totalNumberOfVectors).toArray();
}
Expand All @@ -45,7 +45,7 @@ public int[] sample(int totalNumberOfVectors, int sampleSize) {
* @param sampleSize the number of indices to sample.
* @return an array of sampled indices.
*/
private int[] reservoirSampleIndices(int numVectors, int sampleSize) {
private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) {
int[] indices = IntStream.range(0, sampleSize).toArray();
for (int i = sampleSize; i < numVectors; i++) {
int j = random.nextInt(i + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public enum SamplerType {
* @return a Sampler instance.
* @throws IllegalArgumentException if the sampler type is not supported.
*/
public static Sampler getSampler(SamplerType samplerType) {
public static Sampler getSampler(final SamplerType samplerType) {
switch (samplerType) {
case RESERVOIR:
return new ReservoirSampler();
Expand Down

0 comments on commit 65b14e9

Please sign in to comment.