diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java new file mode 100644 index 000000000..254fb1c42 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +/** + * The QuantizationType enum represents the different types of quantization + * that can be applied in the KNN. + * + * + */ +public enum QuantizationType { + /** + * Represents space quantization, typically involving dimensionality reduction + * or space partitioning techniques. + */ + SPACE_QUANTIZATION, + + /** + * Represents value quantization, typically involving the conversion of continuous + * values into discrete ones. + */ + VALUE_QUANTIZATION, +} diff --git a/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java new file mode 100644 index 000000000..f67a4e5c8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +/** + * The SQTypes enum defines the various scalar quantization types that can be used + * in the KNN for vector quantization. + * Each type corresponds to a different bit-width representation of the quantized values. + */ +public enum SQTypes { + /** + * FP16 quantization uses 16-bit floating-point representation. + * This type offers a good balance between range and precision. + */ + FP16, + + /** + * INT8 quantization uses 8-bit integer representation. + * It is commonly used for efficient storage and processing. + */ + INT8, + + /** + * INT6 quantization uses 6-bit integer representation. + * It provides a lower precision than INT8 but with less storage space. + */ + INT6, + + /** + * INT4 quantization uses 4-bit integer representation. + * This type is suitable for highly compressed storage with significant loss of precision. + */ + INT4, + + /** + * ONE_BIT quantization uses a single bit per coordinate. + * This type is the most compact, representing only two possible values per dimension. + */ + ONE_BIT, + + /** + * TWO_BIT quantization uses two bits per coordinate. + * This type represents four possible values per dimension, offering a balance between compression and accuracy. + */ + TWO_BIT, + + /** + * FOUR_BIT quantization uses four bits per coordinate. + * It allows for sixteen possible values per dimension, providing more detail than lower bit-widths. + */ + FOUR_BIT, + + /** + * UNSUPPORTED_TYPE is used to denote quantization types that are not supported. + * This can be used as a placeholder or default value. + */ + UNSUPPORTED_TYPE +} + diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java new file mode 100644 index 000000000..f4915c68a --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +/** + * The ValueQuantizationType enum defines the types of value quantization techniques + * that can be applied in the KNN. + * These types represent different methodologies for quantizing the values of vectors. + */ +public enum ValueQuantizationType { + /** + * SQ (Scalar Quantization) represents a method where each coordinate of the vector is quantized + * independently. This technique is widely used for its simplicity and efficiency. + */ + SQ +} + + diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java new file mode 100644 index 000000000..9b99f8c22 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +/** + * The QuantizerFactory class is responsible for creating instances of {@link Quantizer} + * based on the provided {@link QuantizationParams}. It uses a registry to look up the + * appropriate quantizer implementation for the given quantization parameters. + */ +public class QuantizerFactory { + private static volatile boolean isRegistered = false; + + /** + * Retrieves a quantizer instance based on the provided quantization parameters. + * + * @param params the quantization parameters used to determine the appropriate quantizer + * @param

the type of quantization parameters, extending {@link QuantizationParams} + * @param the type of the quantized output + * @return an instance of {@link Quantizer} corresponding to the provided parameters + */ + public static

Quantizer getQuantizer(P params) { + if (params == null) { + throw new IllegalArgumentException("Quantization parameters must not be null."); + } + // Lazy Registration instead of static block as class level; + if (!isRegistered) { + registerDefaultQuantizers(); + } + return QuantizerRegistry.getQuantizer(params); + } + + /** + * Registers default quantizers if not already registered. + */ + private static synchronized void registerDefaultQuantizers() { + if (!isRegistered) { + // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT + QuantizerRegistry.register( + SQParams.class, + QuantizationType.VALUE_QUANTIZATION, + SQTypes.ONE_BIT, + OneBitScalarQuantizer::new + ); + // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2 + QuantizerRegistry.register( + SQParams.class, + QuantizationType.VALUE_QUANTIZATION, + SQTypes.TWO_BIT, + () -> new MultiBitScalarQuantizer(2) + ); + // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4 + QuantizerRegistry.register( + SQParams.class, + QuantizationType.VALUE_QUANTIZATION, + SQTypes.FOUR_BIT, + () -> new MultiBitScalarQuantizer(4) + ); + isRegistered = true; + } + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java new file mode 100644 index 000000000..5262a96ce --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; + +/** + * The QuantizerRegistry class is responsible for managing the registration and retrieval + * of quantizer instances. Quantizers are registered with specific quantization parameters + * and type identifiers, allowing for efficient lookup and instantiation. + */ +class QuantizerRegistry { + + // Use ConcurrentHashMap for thread-safe access + private static final Map>> registry = new ConcurrentHashMap<>(); + + /** + * Registers a quantizer with the registry. + * + * @param paramClass the class of the quantization parameters + * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION) + * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT) + * @param quantizerSupplier a supplier that provides instances of the quantizer + * @param

the type of quantization parameters + */ + public static

void register(Class

paramClass, + QuantizationType quantizationType, + SQTypes sqType, + Supplier> quantizerSupplier) { + String identifier = quantizationType.name() + "_" + sqType.name(); + // Ensure that the quantizer for this identifier is registered only once + registry.computeIfAbsent(identifier, key -> { + return quantizerSupplier; + }); + } + + /** + * Retrieves a quantizer instance based on the provided quantization parameters. + * + * @param params the quantization parameters used to determine the appropriate quantizer + * @param

the type of quantization parameters + * @param the type of the quantized output + * @return an instance of {@link Quantizer} corresponding to the provided parameters + * @throws IllegalArgumentException if no quantizer is registered for the given parameters + */ + public static

Quantizer getQuantizer(P params) { + String identifier = params.getTypeIdentifier(); + Supplier> supplier = registry.get(identifier); + if (supplier == null) { + throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier + + ". Available quantizers: " + registry.keySet()); + } + @SuppressWarnings("unchecked") + Quantizer quantizer = (Quantizer) supplier.get(); + return quantizer; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java new file mode 100644 index 000000000..cf8e68150 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +/** + * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. + * It implements the QuantizationOutput interface to handle byte arrays specifically. + */ +public class BinaryQuantizationOutput implements QuantizationOutput { + private final byte[] quantizedVector; + + /** + * Constructs a BinaryQuantizationOutput instance with the specified quantized vector. + * + * @param quantizedVector the quantized vector represented as a byte array. + */ + public BinaryQuantizationOutput(byte[] quantizedVector) { + if (quantizedVector == null) { + throw new IllegalArgumentException("Quantized vector cannot be null"); + } + this.quantizedVector = quantizedVector; + } + + @Override + public byte[] getQuantizedVector() { + return quantizedVector; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java new file mode 100644 index 000000000..d2fc7dce0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +/** + * The QuantizationOutput interface defines the contract for quantization output data. + * + * @param The type of the quantized data. + */ +public interface QuantizationOutput { + /** + * Returns the quantized vector. + * + * @return the quantized data. + */ + T getQuantizedVector(); +} + + diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java new file mode 100644 index 000000000..db4677985 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import org.opensearch.knn.quantization.enums.QuantizationType; + +import java.io.Serializable; + +/** + * Interface for quantization parameters. + * This interface defines the basic contract for all quantization parameter types. + * It provides methods to retrieve the quantization type and a unique type identifier. + * Implementations of this interface are expected to provide specific configurations + * for various quantization strategies. + */ +public interface QuantizationParams extends Serializable{ + + /** + * Gets the quantization type associated with the parameters. + * The quantization type defines the overall strategy or method used + * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION. + * + * @return the {@link QuantizationType} indicating the quantization method. + */ + QuantizationType getQuantizationType(); + + /** + * Provides a unique identifier for the quantization parameters. + * This identifier is typically a combination of the quantization type + * and additional specifics, and it serves to distinguish between different + * configurations or modes of quantization. + * + * @return a string representing the unique type identifier. + */ + String getTypeIdentifier(); +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java new file mode 100644 index 000000000..3983331ca --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; + +import java.util.Objects; + +/** + * The SQParams class represents the parameters specific to scalar quantization (SQ). + * This class implements the QuantizationParams interface and includes the type of scalar quantization. + */ +public class SQParams implements QuantizationParams { + private final SQTypes sqType; + + /** + * Constructs an SQParams instance with the specified scalar quantization type. + * + * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT). + */ + public SQParams(SQTypes sqType) { + this.sqType = sqType; + } + + /** + * Returns the quantization type associated with these parameters. + * + * @return The quantization type, always VALUE_QUANTIZATION for SQParams. + */ + @Override + public QuantizationType getQuantizationType() { + return QuantizationType.VALUE_QUANTIZATION; + } + + /** + * Returns the scalar quantization type. + * + * @return The specific scalar quantization type. + */ + public SQTypes getSqType() { + return sqType; + } + + /** + * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type. + * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. + * + * @return A string representing the unique type identifier. + */ + @Override + public String getTypeIdentifier() { + return getQuantizationType().name() + "_" + sqType.name(); + } + + /** + * Compares this object to the specified object. The result is true if and only if the argument is not null and is + * an SQParams object that represents the same scalar quantization type. + * + * @param o The object to compare this SQParams against. + * @return true if the given object represents an SQParams equivalent to this instance, false otherwise. + */ + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SQParams sqParams = (SQParams) o; + return sqType == sqParams.sqType; + } + + /** + * Returns a hash code value for this SQParams instance. This method is supported for the benefit of hash tables such + * as those provided by HashMap. + * + * @return A hash code value for this SQParams instance. + */ + @Override + public int hashCode() { + return Objects.hash(sqType); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java new file mode 100644 index 000000000..98668c8f8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.util.QuantizationStateSerializer; + +import java.io.IOException; + +/** + * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. + * It can be utilized by any quantizer to represent a default state. + */ +public class DefaultQuantizationState implements QuantizationState { + + private final QuantizationParams params; + + public DefaultQuantizationState(QuantizationParams params) { + this.params = params; + } + + @Override + public QuantizationParams getQuantizationParams() { + return params; + } + + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this, null); + } + + public static DefaultQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { + return (DefaultQuantizationState) + QuantizationStateSerializer.deserialize(bytes, (parentParams, specificData) -> + new DefaultQuantizationState( + (SQParams) parentParams) + ); + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java new file mode 100644 index 000000000..7a0261382 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.util.QuantizationStateSerializer; + +import java.io.IOException; + +/** + * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, + * including the thresholds used for quantization. + */ +public final class MultiBitScalarQuantizationState implements QuantizationState { + private final SQParams quantizationParams; + private final float[][] thresholds; + + /** + * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds. + * + * @param quantizationParams the scalar quantization parameters. + * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array + * where each row corresponds to a different bit level. + */ + public MultiBitScalarQuantizationState(SQParams quantizationParams, float[][] thresholds) { + this.quantizationParams = quantizationParams; + this.thresholds = thresholds; + } + + @Override + public SQParams getQuantizationParams() { + return quantizationParams; + } + + /** + * Returns the thresholds used in the quantization process. + * + * @return a 2D array of threshold values. + */ + public float[][] getThresholds() { + return thresholds; + } + + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this, thresholds); + } + + public static MultiBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { + return (MultiBitScalarQuantizationState) + QuantizationStateSerializer.deserialize(bytes, (parentParams, specificData) -> + new MultiBitScalarQuantizationState( + (SQParams) parentParams, + (float[][]) specificData) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java new file mode 100644 index 000000000..447272215 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.util.QuantizationStateSerializer; + +import java.io.IOException; + +/** + * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, + * including the mean values used for quantization. + */ +public final class OneBitScalarQuantizationState implements QuantizationState { + private final SQParams quantizationParams; + private final float[] mean; + + /** + * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values. + * + * @param quantizationParams the scalar quantization parameters. + * @param mean the mean values for each dimension. + */ + public OneBitScalarQuantizationState(SQParams quantizationParams, float[] mean) { + this.quantizationParams = quantizationParams; + this.mean = mean; + } + + @Override + public SQParams getQuantizationParams() { + return quantizationParams; + } + + /** + * Returns the mean values used in the quantization process. + * + * @return an array of mean values. + */ + public float[] getMean() { + return mean; + } + + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this, mean); + } + + public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { + return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize( + bytes, + (parentParams, specificData) -> + new OneBitScalarQuantizationState( + (SQParams) parentParams, + (float[]) specificData) + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java new file mode 100644 index 000000000..6d19e385c --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +import java.io.IOException; +import java.io.Serializable; + +/** + * QuantizationState interface represents the state of a quantization process, including the parameters used. + * This interface provides methods for serializing and deserializing the state. + */ +public interface QuantizationState extends Serializable { + /** + * Returns the quantization parameters associated with this state. + * + * @return the quantization parameters. + */ + QuantizationParams getQuantizationParams(); + + /** + * Serializes the quantization state to a byte array. + * + * @return a byte array representing the serialized state. + * @throws IOException if an I/O error occurs during serialization. + */ + byte[] toByteArray() throws IOException; +} + diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java new file mode 100644 index 000000000..527878aac --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.requests; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +/** + * TrainingRequest represents a request for training a quantizer. + * + * @param the type of vectors to be trained. + */ +public abstract class TrainingRequest { + private final QuantizationParams params; + private final int totalNumberOfVectors; + private int[] sampledIndices; + + /** + * Constructs a TrainingRequest with the given parameters and total number of vectors. + * + * @param params the quantization parameters. + * @param totalNumberOfVectors the total number of vectors. + */ + protected TrainingRequest(QuantizationParams params, int totalNumberOfVectors) { + this.params = params; + this.totalNumberOfVectors = totalNumberOfVectors; + } + + /** + * Returns the quantization parameters. + * + * @return the quantization parameters. + */ + public QuantizationParams getParams() { + return params; + } + + /** + * Returns the total number of vectors. + * + * @return the total number of vectors. + */ + public int getTotalNumberOfVectors() { + return totalNumberOfVectors; + } + + /** + * Sets the sampled indices for this training request. + * + * @param sampledIndices the sampled indices. + */ + public void setSampledIndices(int[] sampledIndices) { + this.sampledIndices = sampledIndices; + } + + /** + * Returns the sampled indices for this training request. + * + * @return the sampled indices. + */ + public int[] getSampledIndices() { + return sampledIndices; + } + + /** + * Returns the vector corresponding to the specified document ID. + * + * @param docId the document ID. + * @return the vector corresponding to the specified document ID. + */ + public abstract T getVectorByDocId(int docId); + + /** + * Returns the next vector in the sequence. + * + * @return the next vector. + */ + public abstract T getVector(); +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java new file mode 100644 index 000000000..f03d5eb5b --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; +import org.opensearch.knn.quantization.util.BitPackingUtils; +import org.opensearch.knn.quantization.util.QuantizerHelper; + +import java.util.ArrayList; +import java.util.List; + +/** + * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. + * It supports multiple bits per coordinate, allowing for finer granularity in quantization. + */ +public class MultiBitScalarQuantizer implements Quantizer { + private final int bitsPerCoordinate; // Number of bits used to quantize each dimension + private final int samplingSize = 25000; // Default sampling size for training + private static final boolean IS_TRAINING_REQUIRED = true; + + /** + * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate. + * + * @param bitsPerCoordinate the number of bits used per coordinate for quantization. + */ + public MultiBitScalarQuantizer(int bitsPerCoordinate) { + if (bitsPerCoordinate < 2) { + throw new IllegalArgumentException("bitsPerCoordinate must be greater than 2 for multibit quantizer."); + } + this.bitsPerCoordinate = bitsPerCoordinate; + } + + /** + * Trains the quantizer based on the provided training request, which should be of type SamplingTrainingRequest. + * The training process calculates the mean and standard deviation for each dimension and then determines + * threshold values for quantization based on these statistics. + * + * @param trainingRequest the request containing the data and parameters for training. + * @return a MultiBitScalarQuantizationState containing the computed thresholds. + */ + @Override + public QuantizationState train(TrainingRequest trainingRequest) { + if (!IS_TRAINING_REQUIRED) { + return new DefaultQuantizationState(trainingRequest.getParams()); + } + SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + + int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; + float[] sum = new float[dimension]; + float[] sumSq = new float[dimension]; + calculateSumAndSumSq(trainingRequest, sampledIndices, sum, sumSq); + + float[] mean = calculateMean(sum, sampledIndices.length); + float[] stdDev = calculateStandardDeviation(sumSq, mean, sampledIndices.length); + + float[][] thresholds = calculateThresholds(mean, stdDev, dimension); + return new MultiBitScalarQuantizationState(params, thresholds); + } + + /** + * Quantizes the provided vector using the provided quantization state, producing a quantized output. + * The vector is quantized based on the thresholds in the quantization state. + * + * @param vector the vector to quantize. + * @param state the quantization state containing threshold information. + * @return a BinaryQuantizationOutput containing the quantized data. + */ + @Override + public QuantizationOutput quantize(float[] vector, QuantizationState state) { + if (state instanceof DefaultQuantizationState) { + return quantize(vector); + } + + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + if (!(state instanceof MultiBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState."); + } + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state; + float[][] thresholds = multiBitState.getThresholds(); + if (thresholds == null || thresholds[0].length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + + List bitArrays = new ArrayList<>(); + for (int i = 0; i < bitsPerCoordinate; i++) { + byte[] bitArray = new byte[vector.length]; + for (int j = 0; j < vector.length; j++) { + bitArray[j] = (byte) (vector[j] > thresholds[i][j] ? 1 : 0); + } + bitArrays.add(bitArray); + } + + return new BinaryQuantizationOutput(BitPackingUtils.packBits(bitArrays)); + } + + /** + * Calculates the sum and sum of squares for each dimension based on sampled vectors. + * + * @param samplingRequest the sampling training request containing the vectors. + * @param sampledIndices the indices of the sampled vectors. + * @param sum the array to store the sum of each dimension. + * @param sumSq the array to store the sum of squares of each dimension. + */ + private void calculateSumAndSumSq( + TrainingRequest samplingRequest, + int[] sampledIndices, + float[] sum, + float[] sumSq + ) { + for (int index : sampledIndices) { + float[] vector = samplingRequest.getVectorByDocId(index); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + } + for (int j = 0; j < vector.length; j++) { + sum[j] += vector[j]; + sumSq[j] += vector[j] * vector[j]; + } + } + } + + /** + * Calculates the mean for each dimension. + * + * @param sum the array containing the sum of each dimension. + * @param totalSamples the total number of samples. + * @return the mean for each dimension. + */ + private float[] calculateMean(float[] sum, int totalSamples) { + float[] mean = new float[sum.length]; + for (int j = 0; j < sum.length; j++) { + mean[j] = sum[j] / totalSamples; + } + return mean; + } + + /** + * Calculates the standard deviation for each dimension. + * + * @param sumSq the array containing the sum of squares of each dimension. + * @param mean the mean for each dimension. + * @param totalSamples the total number of samples. + * @return the standard deviation for each dimension. + */ + private float[] calculateStandardDeviation(float[] sumSq, float[] mean, int totalSamples) { + float[] stdDev = new float[mean.length]; + for (int j = 0; j < mean.length; j++) { + stdDev[j] = (float) Math.sqrt(sumSq[j] / totalSamples - mean[j] * mean[j]); + } + return stdDev; + } + + /** + * Calculates the thresholds for quantization based on mean and standard deviation. + * + * @param mean the mean for each dimension. + * @param stdDev the standard deviation for each dimension. + * @param dimension the number of dimensions in the vectors. + * @return the thresholds for quantization. + */ + private float[][] calculateThresholds(float[] mean, float[] stdDev, int dimension) { + float[][] thresholds = new float[bitsPerCoordinate][dimension]; + float coef = bitsPerCoordinate + 1; + for (int i = 0; i < bitsPerCoordinate; i++) { + float iCoef = -1 + 2 * (i + 1) / coef; + for (int j = 0; j < dimension; j++) { + thresholds[i][j] = mean[j] + iCoef * stdDev[j]; + } + } + return thresholds; + } + + private QuantizationOutput quantize(float[] vector) { + throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer."); + } + +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java new file mode 100644 index 000000000..54d1b9c45 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; +import org.opensearch.knn.quantization.util.BitPackingUtils; +import org.opensearch.knn.quantization.util.QuantizerHelper; + +import java.util.Collections; + +/** + * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. + * It computes the mean of each dimension during training and then uses these means as thresholds + * for quantizing the vectors. + */ +public class OneBitScalarQuantizer implements Quantizer { + private static final int SAMPLING_SIZE = 25000; + private static final boolean IS_TRAINING_REQUIRED = true; + + /** + * Trains the quantizer by calculating the mean of each dimension from the sampled vectors. + * These means are used as thresholds in the quantization process. + * + * @param trainingRequest the request containing the data and parameters for training. + * @return a OneBitScalarQuantizationState containing the calculated means. + */ + @Override + public QuantizationState train(TrainingRequest trainingRequest) { + if (!IS_TRAINING_REQUIRED) { + return new DefaultQuantizationState(trainingRequest.getParams()); + } + SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), SAMPLING_SIZE); + float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices); + return new OneBitScalarQuantizationState(params, mean); + } + + /** + * Quantizes the provided vector using the given quantization state. + * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. + * + * @param vector the vector to quantize. + * @param state the quantization state containing the means for each dimension. + * @return a BinaryQuantizationOutput containing the quantized data. + */ + @Override + public QuantizationOutput quantize(float[] vector, QuantizationState state) { + if (state instanceof DefaultQuantizationState) { + return quantize(vector); + } + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + if (!(state instanceof OneBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState."); + } + OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state; + float[] thresholds = binaryState.getMean(); + if (thresholds == null || thresholds.length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + byte[] quantizedVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0); + } + return new BinaryQuantizationOutput(BitPackingUtils.packBits(Collections.singletonList(quantizedVector))); + } + + private QuantizationOutput quantize(float[] vector) { + throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer."); + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java new file mode 100644 index 000000000..ef26612ed --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +/** + * The Quantizer interface defines the methods required for training and quantizing vectors + * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. + * It supports training to determine quantization parameters and quantizing data vectors + * based on these parameters. + * + * @param The type of the vector or data to be quantized. + * @param The type of the quantized output, typically a compressed or encoded representation. + */ +public interface Quantizer { + + /** + * Trains the quantizer based on the provided training request. The training process typically + * involves learning parameters that can be used to quantize vectors. + * + * @param trainingRequest the request containing data and parameters for training. + * @return a QuantizationState containing the learned parameters. + */ + QuantizationState train(TrainingRequest trainingRequest); + + /** + * Quantizes the provided vector using the specified quantization state. + * + * @param vector the vector to quantize. + * @param state the quantization state containing parameters for quantization. + * @return a QuantizationOutput containing the quantized representation of the vector. + */ + QuantizationOutput quantize(T vector, QuantizationState state); +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java new file mode 100644 index 000000000..f952d3dc7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import java.util.Arrays; +import java.util.Random; +import java.util.stream.IntStream; + +/** + * ReservoirSampler implements the Sampler interface and provides a method for sampling + * a specified number of indices from a total number of vectors using the reservoir sampling algorithm. + * This algorithm is particularly useful for randomly sampling a subset of data from a larger set + * when the total size of the dataset is unknown or very large. + */ +public class ReservoirSampler implements Sampler { + private final Random random = new Random(); + + /** + * Samples indices from the range [0, totalNumberOfVectors). + * If the total number of vectors is less than or equal to the sample size, it returns all indices. + * Otherwise, it uses the reservoir sampling algorithm to select a random subset. + * + * @param totalNumberOfVectors the total number of vectors to sample from. + * @param sampleSize the number of indices to sample. + * @return an array of sampled indices. + */ + @Override + public int[] sample(int totalNumberOfVectors, int sampleSize) { + if (totalNumberOfVectors <= sampleSize) { + return IntStream.range(0, totalNumberOfVectors).toArray(); + } + return reservoirSampleIndices(totalNumberOfVectors, sampleSize); + } + + /** + * Applies the reservoir sampling algorithm to select a random sample of indices. + * This method ensures that each index in the range [0, numVectors) has an equal probability + * of being included in the sample. + * + * @param numVectors the total number of vectors. + * @param sampleSize the number of indices to sample. + * @return an array of sampled indices. + */ + private int[] reservoirSampleIndices(int numVectors, int sampleSize) { + int[] indices = IntStream.range(0, sampleSize).toArray(); + for (int i = sampleSize; i < numVectors; i++) { + int j = random.nextInt(i + 1); + if (j < sampleSize) { + indices[j] = i; + } + } + Arrays.sort(indices); + return indices; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java new file mode 100644 index 000000000..9021073b4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +public interface Sampler { + int[] sample(int totalNumberOfVectors, int sampleSize); +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java new file mode 100644 index 000000000..89470bb3f --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +/** + * SamplingFactory is a factory class for creating instances of Sampler. + * It uses the factory design pattern to encapsulate the creation logic for different types of samplers. + */ +public class SamplingFactory { + + /** + * SamplerType is an enumeration of the different types of samplers that can be created by the factory. + */ + public enum SamplerType { + RESERVOIR, // Represents a reservoir sampling strategy + // Add more enum values here for additional sampler types + } + + /** + * Creates and returns a Sampler instance based on the specified SamplerType. + * + * @param samplerType the type of sampler to create. + * @return a Sampler instance. + * @throws IllegalArgumentException if the sampler type is not supported. + */ + public static Sampler getSampler(SamplerType samplerType) { + switch (samplerType) { + case RESERVOIR: + return new ReservoirSampler(); + // Add more cases for different samplers here + default: + throw new IllegalArgumentException("Unsupported sampler type: " + samplerType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java b/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java new file mode 100644 index 000000000..4debf9335 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.util; + +import java.util.List; + +/** + * Utility class for bit packing operations. + * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission. + */ +public class BitPackingUtils { + + /** + * Packs the list of bit arrays into a single byte array. + * Each byte in the resulting array contains up to 8 bits from the bit arrays, packed from left to right. + * + * @param bitArrays the list of bit arrays to be packed. Each bit array should contain only 0s and 1s. + * @return a byte array containing the packed bits. + * @throws IllegalArgumentException if the bitArrays list is empty or if any bit array is null. + */ + public static byte[] packBits(List bitArrays) { + if (bitArrays.isEmpty()) { + throw new IllegalArgumentException("The list of bit arrays cannot be empty."); + } + + int bitLength = bitArrays.size() * bitArrays.get(0).length; + int byteLength = (bitLength + 7) / 8; + byte[] packedArray = new byte[byteLength]; + + for (int i = 0; i < bitArrays.size(); i++) { + byte[] bitArray = bitArrays.get(i); + if (bitArray == null) { + throw new IllegalArgumentException("Bit array cannot be null."); + } + for (int j = 0; j < bitArray.length; j++) { + int byteIndex = (i * bitArray.length + j) / 8; + int bitIndex = 7 - ((i * bitArray.length + j) % 8); + packedArray[byteIndex] |= (bitArray[j] << bitIndex); + } + } + + return packedArray; + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java new file mode 100644 index 000000000..366bf7559 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.util; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.*; + +/** + * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing + * QuantizationState objects along with their specific data. + */ +public class QuantizationStateSerializer { + + /** + * A functional interface for deserializing specific data associated with a QuantizationState. + */ + @FunctionalInterface + public interface SerializableDeserializer { + QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); + } + + /** + * Serializes the QuantizationState and specific data into a byte array. + * + * @param state The QuantizationState to serialize. + * @param specificData The specific data related to the state, to be serialized. + * @return A byte array representing the serialized state and specific data. + * @throws IOException If an I/O error occurs during serialization. + */ + public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { + byte[] parentBytes = serializeParentParams(state.getQuantizationParams()); + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeInt(parentBytes.length); // Write the length of the parent bytes + out.write(parentBytes); // Write the parent bytes + out.writeObject(specificData); // Write the specific data + out.flush(); + return bos.toByteArray(); + } + } + + /** + * Deserializes a QuantizationState and its specific data from a byte array. + * + * @param bytes The byte array containing the serialized data. + * @param specificDataDeserializer The deserializer for the specific data associated with the state. + * @return The deserialized QuantizationState including its specific data. + * @throws IOException If an I/O error occurs during deserialization. + * @throws ClassNotFoundException If the class of the serialized object cannot be found. + */ + public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer) + throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis)) { + int parentLength = in.readInt(); // Read the length of the + // Read the length of the parent bytes + byte[] parentBytes = new byte[parentLength]; + in.readFully(parentBytes); // Read the parent bytes + QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params + Serializable specificData = (Serializable) in.readObject(); // Read the specific data + return specificDataDeserializer.deserialize(parentParams, specificData); + } + } + + /** + * Serializes the parent parameters of the QuantizationState into a byte array. + * + * @param params The QuantizationParams to serialize. + * @return A byte array representing the serialized parent parameters. + * @throws IOException If an I/O error occurs during serialization. + */ + private static byte[] serializeParentParams(QuantizationParams params) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(params); + out.flush(); + return bos.toByteArray(); + } + } + + /** + * Deserializes the parent parameters of the QuantizationState from a byte array. + * + * @param bytes The byte array containing the serialized parent parameters. + * @return The deserialized QuantizationParams. + * @throws IOException If an I/O error occurs during deserialization. + * @throws ClassNotFoundException If the class of the serialized object cannot be found. + */ + private static QuantizationParams deserializeParentParams(byte[] bytes) + throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis)) { + return (QuantizationParams) in.readObject(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java new file mode 100644 index 000000000..3ee61a2fc --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.util; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +/** + * Utility class providing common methods for quantizer operations, such as parameter validation and + * extraction. This class is designed to be used with various quantizer implementations that require + * consistent handling of training requests and sampled indices. + */ +public final class QuantizerHelper { + + /** + * Private constructor to prevent instantiation of this utility class. + * The class is not meant to be instantiated, as it provides static utility methods only. + */ + private QuantizerHelper() { + // Private constructor to prevent instantiation + } + + /** + * Validates the provided training request to ensure it contains non-null quantization parameters. + * Extracts and returns the SQParams from the training request. + * + * @param trainingRequest the training request to validate and extract parameters from. + * @return the extracted SQParams. + * @throws IllegalArgumentException if the SQParams are null. + */ + public static SQParams validateAndExtractParams(TrainingRequest trainingRequest) { + QuantizationParams params = trainingRequest.getParams(); + if (params == null || !(params instanceof SQParams)) { + throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams."); + } + return (SQParams) params; + } + + /** + * Calculates the mean vector from a set of sampled vectors. + * + *

This method takes a {@link TrainingRequest} object and an array of sampled indices, + * retrieves the vectors corresponding to these indices, and calculates the mean vector. + * Each element of the mean vector is computed as the average of the corresponding elements + * of the sampled vectors.

+ * + * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. + * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. + * @return A float array representing the mean vector of the sampled vectors. + * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. + * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. + */ + public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) { + int totalSamples = sampledIndices.length; + float[] mean = null; + for (int index : sampledIndices) { + float[] vector = samplingRequest.getVectorByDocId(index); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + } + if (mean == null) { + mean = new float[vector.length]; + } + for (int j = 0; j < vector.length; j++) { + mean[j] += vector[j]; + } + } + if (mean == null) { + throw new IllegalStateException("Mean array should not be null after processing vectors."); + } + for (int j = 0; j < mean.length; j++) { + mean[j] /= totalSamples; + } + return mean; + } +} + diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java new file mode 100644 index 000000000..f14babd37 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class QuantizationTypeTests extends KNNTestCase { + + public void testQuantizationTypeValues() { + QuantizationType[] expectedValues = { + QuantizationType.SPACE_QUANTIZATION, + QuantizationType.VALUE_QUANTIZATION + }; + assertArrayEquals(expectedValues, QuantizationType.values()); + } + + public void testQuantizationTypeValueOf() { + assertEquals(QuantizationType.SPACE_QUANTIZATION, QuantizationType.valueOf("SPACE_QUANTIZATION")); + assertEquals(QuantizationType.VALUE_QUANTIZATION, QuantizationType.valueOf("VALUE_QUANTIZATION")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java new file mode 100644 index 000000000..b22289b88 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class SQTypesTests extends KNNTestCase { + public void testSQTypesValues() { + SQTypes[] expectedValues = { + SQTypes.FP16, + SQTypes.INT8, + SQTypes.INT6, + SQTypes.INT4, + SQTypes.ONE_BIT, + SQTypes.TWO_BIT, + SQTypes.FOUR_BIT, + SQTypes.UNSUPPORTED_TYPE + }; + assertArrayEquals(expectedValues, SQTypes.values()); + } + + public void testSQTypesValueOf() { + assertEquals(SQTypes.FP16, SQTypes.valueOf("FP16")); + assertEquals(SQTypes.INT8, SQTypes.valueOf("INT8")); + assertEquals(SQTypes.INT6, SQTypes.valueOf("INT6")); + assertEquals(SQTypes.INT4, SQTypes.valueOf("INT4")); + assertEquals(SQTypes.ONE_BIT, SQTypes.valueOf("ONE_BIT")); + assertEquals(SQTypes.TWO_BIT, SQTypes.valueOf("TWO_BIT")); + assertEquals(SQTypes.FOUR_BIT, SQTypes.valueOf("FOUR_BIT")); + assertEquals(SQTypes.UNSUPPORTED_TYPE, SQTypes.valueOf("UNSUPPORTED_TYPE")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java new file mode 100644 index 000000000..47d7123f6 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class ValueQuantizationTypeTests extends KNNTestCase { + public void testValueQuantizationTypeValues() { + ValueQuantizationType[] expectedValues = { + ValueQuantizationType.SQ + }; + assertArrayEquals(expectedValues, ValueQuantizationType.values()); + } + + public void testValueQuantizationTypeValueOf() { + assertEquals(ValueQuantizationType.SQ, ValueQuantizationType.valueOf("SQ")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java new file mode 100644 index 000000000..2082688ab --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.junit.Before; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.lang.reflect.Field; + +public class QuantizerFactoryTests extends KNNTestCase { + + @Before + public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException { + Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); + isRegisteredField.setAccessible(true); + isRegisteredField.setBoolean(null, false); + } + + public void test_Lazy_Registration() { + SQParams params = new SQParams(SQTypes.ONE_BIT); + assertFalse(isRegisteredFieldAccessible()); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + assertTrue(isRegisteredFieldAccessible()); + } + + public void testGetQuantizer_withOneBitSQParams() { + SQParams params = new SQParams(SQTypes.ONE_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + } + + public void testGetQuantizer_withTwoBitSQParams() { + SQParams params = new SQParams(SQTypes.TWO_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof MultiBitScalarQuantizer); + } + + public void testGetQuantizer_withFourBitSQParams() { + SQParams params = new SQParams(SQTypes.FOUR_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof MultiBitScalarQuantizer); + } + + public void testGetQuantizer_withUnsupportedType() { + SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE); + try { + QuantizerFactory.getQuantizer(params); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("No quantizer registered for type identifier")); + } + } + + public void testGetQuantizer_withNullParams() { + try { + QuantizerFactory.getQuantizer(null); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertEquals("Quantization parameters must not be null.", e.getMessage()); + } + } + + + public void test_Concurrent_Registration() throws InterruptedException { + Runnable task = () -> { + SQParams params = new SQParams(SQTypes.ONE_BIT); + QuantizerFactory.getQuantizer(params); + }; + + Thread thread1 = new Thread(task); + Thread thread2 = new Thread(task); + thread1.start(); + thread2.start(); + thread1.join(); + thread2.join(); + + assertTrue(isRegisteredFieldAccessible()); + } + + private boolean isRegisteredFieldAccessible() { + try { + Field field = QuantizerFactory.class.getDeclaredField("isRegistered"); + field.setAccessible(true); + return field.getBoolean(null); + } catch (NoSuchFieldException | IllegalAccessException e) { + fail("Failed to access isRegistered field."); + return false; + } + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java new file mode 100644 index 000000000..a848bc474 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.junit.BeforeClass; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +public class QuantizerRegistryTests extends KNNTestCase { + + @BeforeClass + public static void setup() { + // Register the quantizers for testing with enums + QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.ONE_BIT, OneBitScalarQuantizer::new); + QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.TWO_BIT, () -> new MultiBitScalarQuantizer(2)); + QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.FOUR_BIT, () -> new MultiBitScalarQuantizer(4)); + } + + public void testRegisterAndGetQuantizer() { + // Test for OneBitScalarQuantizer + SQParams oneBitParams = new SQParams(SQTypes.ONE_BIT); + Quantizer oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer); + + // Test for MultiBitScalarQuantizer (2-bit) + SQParams twoBitParams = new SQParams(SQTypes.TWO_BIT); + Quantizer twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer); + + // Test for MultiBitScalarQuantizer (4-bit) + SQParams fourBitParams = new SQParams(SQTypes.FOUR_BIT); + Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer); + } + + public void testGetQuantizer_withUnsupportedTypeIdentifier() { + // Create SQParams with an unsupported type identifier + SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered + + // Expect IllegalArgumentException when requesting a quantizer with unsupported params + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + QuantizerRegistry.getQuantizer(params); + }); + + assertTrue(exception.getMessage().contains("No quantizer registered for type identifier")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java new file mode 100644 index 000000000..50a8eee60 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizationState; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; + +import java.io.IOException; + +public class QuantizationStateSerializerTests extends KNNTestCase { + + public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.ONE_BIT); + float[] mean = new float[]{0.1f, 0.2f, 0.3f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + byte[] serialized = state.toByteArray(); + OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized); + + assertArrayEquals(mean, deserialized.getMean(), 0.0f); + assertEquals(params, deserialized.getQuantizationParams()); + } + + public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.TWO_BIT); + float[][] thresholds = new float[][]{ + {0.1f, 0.2f, 0.3f}, + {0.4f, 0.5f, 0.6f} + }; + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + byte[] serialized = state.toByteArray(); + MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); + + assertArrayEquals(thresholds, deserialized.getThresholds()); + assertEquals(params, deserialized.getQuantizationParams()); + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java new file mode 100644 index 000000000..e1ab98b07 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.quantizationState; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; + +import java.io.IOException; + +public class QuantizationStateTests extends KNNTestCase { + + public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.ONE_BIT); + float[] mean = {1.0f, 2.0f, 3.0f}; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + byte[] serializedState = state.toByteArray(); + OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + float delta = 0.0001f; + + assertArrayEquals(mean, deserializedState.getMean(), delta); + assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + } + + + public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.TWO_BIT); + float[][] thresholds = { + {0.5f, 1.5f, 2.5f}, + {1.0f, 2.0f, 3.0f} + }; + + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + byte[] serializedState = state.toByteArray(); + MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState); + float delta = 0.0001f; + + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i],delta); + } + assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java new file mode 100644 index 000000000..458df08c5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +public class MultiBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain_twoBit() { + float[][] vectors = { + {0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f} + }; + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + int[] sampledIndices = {0, 1, 2}; + SQParams params = new SQParams(SQTypes.TWO_BIT); + TrainingRequest request = new MockTrainingRequest(params, vectors); + request.setSampledIndices(sampledIndices); + QuantizationState state = twoBitQuantizer.train(request); + + assertTrue(state instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state; + assertNotNull(mbState.getThresholds()); + assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds + } + + public void testTrain_fourBit() { + MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); + float[][] vectors = { + {0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f} + }; + int[] sampledIndices = {0, 1, 2}; + SQParams params = new SQParams(SQTypes.FOUR_BIT); + TrainingRequest request = new MockTrainingRequest(params, vectors); + request.setSampledIndices(sampledIndices); + QuantizationState state = fourBitQuantizer.train(request); + + assertTrue(state instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state; + assertNotNull(mbState.getThresholds()); + assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds + } + + public void testQuantize_twoBit() { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + float[][] thresholds = { + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f} + }; + SQParams params = new SQParams(SQTypes.TWO_BIT); + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + QuantizationOutput output = twoBitQuantizer.quantize(vector, state); + assertNotNull(output.getQuantizedVector()); + assertEquals(2, output.getQuantizedVector().length); + } + + public void testQuantize_fourBit() { + MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + float[][] thresholds = { + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f}, + {1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f}, + {1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f} + }; + SQParams params = new SQParams(SQTypes.FOUR_BIT); + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + QuantizationOutput output = fourBitQuantizer.quantize(vector, state); + assertEquals(4, output.getQuantizedVector().length); + assertNotNull(output.getQuantizedVector()); + } + + public void testQuantize_withNullVector() { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + expectThrows(IllegalArgumentException.class, + () -> twoBitQuantizer.quantize(null, new MultiBitScalarQuantizationState(new SQParams(SQTypes.TWO_BIT), + new float[2][8]))); + } + + public void testQuantize_withInvalidState() { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + QuantizationState invalidState = new MockInvalidQuantizationState(); + expectThrows(IllegalArgumentException.class, + () -> twoBitQuantizer.quantize(vector, invalidState)); + } + + public void testQuantize_withDefaultQuantizationState() { + MultiBitScalarQuantizer quantizer = new MultiBitScalarQuantizer(2); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(SQTypes.ONE_BIT)); + + expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state)); + } + + // Mock classes for testing + private static class MockTrainingRequest extends TrainingRequest { + private final float[][] vectors; + + public MockTrainingRequest(SQParams params, float[][] vectors) { + super(params, vectors.length); + this.vectors = vectors; + } + + @Override + public float[] getVector() { + return new float[0]; + } + + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + } + + private static class MockInvalidQuantizationState implements QuantizationState { + @Override + public SQParams getQuantizationParams() { + return new SQParams(SQTypes.UNSUPPORTED_TYPE); + } + + @Override + public byte[] toByteArray() { + return new byte[0]; + } + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java new file mode 100644 index 000000000..9b384cd1c --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; +import org.opensearch.knn.quantization.util.QuantizerHelper; + +public class OneBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain_withTrainingRequired() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVector() { + return null; // Not used in this test + } + + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationState state = quantizer.train(originalRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + float[] mean = ((OneBitScalarQuantizationState) state).getMean(); + assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); + } + + public void testQuantize_withState() { + float[] vector = {3.0f, 6.0f, 9.0f}; + float[] thresholds = {4.0f, 5.0f, 6.0f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationOutput output = quantizer.quantize(vector, state); + + assertNotNull(output); + byte[] expectedPackedBits = new byte[]{0b01100000}; // or 96 in decimal + assertArrayEquals(expectedPackedBits, output.getQuantizedVector()); + } + + public void testQuantize_withDefaultQuantizationState() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = {3.0f, 6.0f, 9.0f}; + DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(SQTypes.ONE_BIT)); + + expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state)); + } + + public void testQuantize_withNullVector() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), new float[]{0.0f}); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state)); + } + + public void testQuantize_withInvalidState() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = {1.0f, 2.0f, 3.0f}; + QuantizationState invalidState = new QuantizationState() { + @Override + public SQParams getQuantizationParams() { + return new SQParams(SQTypes.ONE_BIT); + } + + @Override + public byte[] toByteArray() { + return new byte[0]; + } + }; + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState)); + } + + public void testQuantize_withMismatchedDimensions() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = {1.0f, 2.0f, 3.0f}; + float[] thresholds = {4.0f, 5.0f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds); + + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state)); + } + + public void testCalculateMean() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVector() { + return null; // Not used in this test + } + + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(vectors.length, 3); + float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices); + assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); + } + + public void testCalculateMean_withNullVector() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + null, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVector() { + return null; // Not used in this test + } + + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(vectors.length, 3); + expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices)); + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java new file mode 100644 index 000000000..3a42588a4 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + + +public class ReservoirSamplerTests extends KNNTestCase { + + public void testSample() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 100; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(sampleSize, samples.length); + for (int index : samples) { + assertTrue(index >= 0 && index < totalNumberOfVectors); + } + } + + public void testSample_withFullSampling() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 10; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(sampleSize, samples.length); + for (int index : samples) { + assertTrue(index >= 0 && index < totalNumberOfVectors); + } + } + + public void testSample_withLessVectors() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 5; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(totalNumberOfVectors, samples.length); + for (int index : samples) { + assertTrue(index >= 0 && index < totalNumberOfVectors); + } + } + + public void testSample_withZeroVectors() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 0; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, samples.length); + } + + public void testSample_withOneVector() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 1; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(1, samples.length); + assertTrue(samples[0] == 0); + } +} + diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java new file mode 100644 index 000000000..56d496d2f --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + +public class SamplingFactoryTests extends KNNTestCase { + public void testGetSampler_withReservoir() { + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + assertTrue(sampler instanceof ReservoirSampler); + } + + public void testGetSampler_withUnsupportedType() { + expectThrows( NullPointerException.class, ()-> SamplingFactory.getSampler(null)); // This should throw an exception + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java new file mode 100644 index 000000000..4d48b83ee --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.util; + +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.List; + +public class BitPackingUtilsTests extends KNNTestCase { + + public void testPackBits_BitArray() { + List bitArrays = Arrays.asList( + new byte[]{1, 0, 1, 0, 1, 0, 1, 0}, + new byte[]{1, 1, 0, 0, 1, 1, 0, 0} + ); + + byte[] packedBits = BitPackingUtils.packBits(bitArrays); + byte[] expected = new byte[]{(byte) 0b10101010, (byte) 0b11001100}; + + assertArrayEquals(expected, packedBits); + } + + public void testPackBits_multipleBitArrays() { + List bitArrays = Arrays.asList( + new byte[]{1, 0, 1}, + new byte[]{0, 1, 0}, + new byte[]{1, 1, 1} + ); + + byte[] packedBits = BitPackingUtils.packBits(bitArrays); + byte[] expected = new byte[]{(byte) 0b10101011, (byte) 0b10000000}; + + assertArrayEquals(expected, packedBits); + } + + public void testPackBits_emptyArray() { + List bitArrays = Arrays.asList(); + expectThrows(IllegalArgumentException.class, () -> { + BitPackingUtils.packBits(bitArrays); + });; + } +}