Skip to content

Commit

Permalink
Quantization Framework Implementation with 1bit and MultiBit Binary Q…
Browse files Browse the repository at this point in the history
…uantizer

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 8, 2024
1 parent 4fda2c5 commit 0ebbef0
Show file tree
Hide file tree
Showing 41 changed files with 967 additions and 894 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
* Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931)
* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
* Quantization Framework For Disk Optimized Vector Search and Implementation of Binary 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,42 @@

package org.opensearch.knn.quantization.enums;

import lombok.Getter;

/**
* The SQTypes enum defines the various scalar quantization types that can be used
* in the KNN for vector quantization.
* Each type corresponds to a different bit-width representation of the quantized values.
* The ScalarQuantizationType enum defines the various scalar quantization types that can be used
* for vector quantization. Each type corresponds to a different bit-width representation of the quantized values.
*
* <p>
* Future Developers: If you change the name of any enum constant, do not change its associated value.
* Serialization and deserialization depend on these values to maintain compatibility.
* </p>
*/
@Getter
public enum ScalarQuantizationType {
/**
* ONE_BIT quantization uses a single bit per coordinate.
*/
ONE_BIT,
ONE_BIT(1),

/**
* TWO_BIT quantization uses two bits per coordinate.
*/
TWO_BIT,
TWO_BIT(2),

/**
* FOUR_BIT quantization uses four bits per coordinate.
*/
FOUR_BIT,
FOUR_BIT(4);

private final int id;

/**
* UNSUPPORTED_TYPE is used to denote quantization types that are not supported.
* This can be used as a placeholder or default value.
* Constructs a ScalarQuantizationType with the specified ID.
*
* @param id the ID representing the quantization type.
*/
UNSUPPORTED_TYPE
ScalarQuantizationType(int id) {
this.id = id;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.quantization.factory;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

Expand All @@ -15,12 +17,10 @@
* based on the provided {@link QuantizationParams}. It uses a registry to look up the
* appropriate quantizer implementation for the given quantization parameters.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class QuantizerFactory {
private static final AtomicBoolean isRegistered = new AtomicBoolean(false);

// Private constructor to prevent instantiation
private QuantizerFactory() {}

/**
* Ensures that default quantizers are registered.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,42 @@

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.enums.QuantizationType;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;

/**
* The QuantizerRegistrar class is responsible for registering default quantizers.
* This class ensures that the registration happens only once in a thread-safe manner.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class QuantizerRegistrar {

// Private constructor to prevent instantiation
private QuantizerRegistrar() {}

/**
* Registers default quantizers if not already registered.
* <p>
* This method is synchronized to ensure that registration occurs only once,
* even in a multi-threaded environment.
* </p>
*/
public static synchronized void registerDefaultQuantizers() {
static synchronized void registerDefaultQuantizers() {
// Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT
QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new);
QuantizerRegistry.register(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT),
new OneBitScalarQuantizer()
);
// Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2
QuantizerRegistry.register(
SQParams.class,
QuantizationType.VALUE,
ScalarQuantizationType.TWO_BIT,
() -> new MultiBitScalarQuantizer(2)
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT),
new MultiBitScalarQuantizer(2)
);
// Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4
QuantizerRegistry.register(
SQParams.class,
QuantizationType.VALUE,
ScalarQuantizationType.FOUR_BIT,
() -> new MultiBitScalarQuantizer(4)
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT),
new MultiBitScalarQuantizer(4)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,33 @@

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.enums.QuantizationType;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

/**
* The QuantizerRegistry class is responsible for managing the registration and retrieval
* of quantizer instances. Quantizers are registered with specific quantization parameters
* and type identifiers, allowing for efficient lookup and instantiation.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class QuantizerRegistry {

// Private constructor to prevent instantiation
private QuantizerRegistry() {}

// ConcurrentHashMap for thread-safe access
private static final Map<String, Supplier<? extends Quantizer<?, ?>>> registry = new ConcurrentHashMap<>();
private static final Map<String, Quantizer<?, ?>> registry = new ConcurrentHashMap<>();

/**
* Registers a quantizer with the registry.
*
* @param paramClass the class of the quantization parameters
* @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION)
* @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT)
* @param quantizerSupplier a supplier that provides instances of the quantizer
* @param <P> the type of quantization parameters
* @param paramIdentifier the unique identifier for the quantization parameters
* @param quantizer an instance of the quantizer
*/
public static <P extends QuantizationParams> void register(
final Class<P> paramClass,
final QuantizationType quantizationType,
final ScalarQuantizationType sqType,
final Supplier<? extends Quantizer<?, ?>> quantizerSupplier
) {
String identifier = createIdentifier(quantizationType, sqType);
static void register(final String paramIdentifier, final Quantizer<?, ?> quantizer) {
// Ensure that the quantizer for this identifier is registered only once
registry.computeIfAbsent(identifier, key -> quantizerSupplier);
registry.putIfAbsent(paramIdentifier, quantizer);
}

/**
Expand All @@ -56,27 +43,14 @@ public static <P extends QuantizationParams> void register(
* @return an instance of {@link Quantizer} corresponding to the provided parameters
* @throws IllegalArgumentException if no quantizer is registered for the given parameters
*/
public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(final P params) {
static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(final P params) {
String identifier = params.getTypeIdentifier();
Supplier<? extends Quantizer<?, ?>> supplier = registry.get(identifier);
if (supplier == null) {
throw new IllegalArgumentException(
"No quantizer registered for type identifier: " + identifier + ". Available quantizers: " + registry.keySet()
);
Quantizer<?, ?> quantizer = registry.get(identifier);
if (quantizer == null) {
throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier);
}
@SuppressWarnings("unchecked")
Quantizer<P, Q> quantizer = (Quantizer<P, Q>) supplier.get();
return quantizer;
}

/**
* Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype.
*
* @param quantizationType the quantization type
* @param sqType the specific quantization subtype
* @return a string identifier
*/
private static String createIdentifier(final QuantizationType quantizationType, final ScalarQuantizationType sqType) {
return quantizationType.name() + "_" + sqType.name();
Quantizer<P, Q> typedQuantizer = (Quantizer<P, Q>) quantizer;
return typedQuantizer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,36 @@

package org.opensearch.knn.quantization.models.quantizationOutput;

import lombok.NoArgsConstructor;
import lombok.Getter;

import java.io.ByteArrayOutputStream;
import java.io.IOException;

/**
* The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
* It implements the QuantizationOutput interface to handle byte arrays specifically.
*/
@NoArgsConstructor
public class BinaryQuantizationOutput implements QuantizationOutput<byte[]> {
private final byte[] quantizedVector;
@Getter
private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();

/**
* Constructs a BinaryQuantizationOutput instance with the specified quantized vector.
* Updates the quantized vector with a new byte array.
*
* @param quantizedVector the quantized vector represented as a byte array.
* @param newQuantizedVector the new quantized vector represented as a byte array.
*/
public BinaryQuantizationOutput(final byte[] quantizedVector) {
if (quantizedVector == null) {
throw new IllegalArgumentException("Quantized vector cannot be null");
public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException {
if (newQuantizedVector == null || newQuantizedVector.length == 0) {
throw new IllegalArgumentException("Quantized vector cannot be null or empty");
}
this.quantizedVector = quantizedVector;
byteArrayOutputStream.reset();
byteArrayOutputStream.write(newQuantizedVector);
}

@Override
public byte[] getQuantizedVector() {
return quantizedVector;
return byteArrayOutputStream.toByteArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.quantization.models.quantizationOutput;

import java.io.IOException;

/**
* The QuantizationOutput interface defines the contract for quantization output data.
*
Expand All @@ -17,4 +19,12 @@ public interface QuantizationOutput<T> {
* @return the quantized data.
*/
T getQuantizedVector();

/**
* Updates the quantized vector with new data.
*
* @param newQuantizedVector the new quantized vector data.
* @throws IOException if an I/O error occurs during the update.
*/
void updateQuantizedVector(T newQuantizedVector) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

package org.opensearch.knn.quantization.models.quantizationParams;

import org.opensearch.knn.quantization.enums.QuantizationType;

import java.io.Serializable;
import java.io.Externalizable;

/**
* Interface for quantization parameters.
Expand All @@ -16,17 +14,7 @@
* Implementations of this interface are expected to provide specific configurations
* for various quantization strategies.
*/
public interface QuantizationParams extends Serializable {

/**
* Gets the quantization type associated with the parameters.
* The quantization type defines the overall strategy or method used
* for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION.
*
* @return the {@link QuantizationType} indicating the quantization method.
*/
QuantizationType getQuantizationType();

public interface QuantizationParams extends Externalizable {
/**
* Provides a unique identifier for the quantization parameters.
* This identifier is typically a combination of the quantization type
Expand Down
Loading

0 comments on commit 0ebbef0

Please sign in to comment.