diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java
new file mode 100644
index 000000000..254fb1c42
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+/**
+ * The QuantizationType enum represents the different types of quantization
+ * that can be applied in the KNN.
+ *
+ *
+ * - SPACE_QUANTIZATION: This type of quantization focuses on the space
+ * or the representation of the data vectors. It is commonly used for techniques
+ * that reduce the dimensionality or discretize the data space.
+ * - VALUE_QUANTIZATION: This type of quantization focuses on the values
+ * within the data vectors. It involves mapping continuous values into discrete
+ * values, which can be useful for compressing data or reducing the precision
+ * of the representation.
+ *
+ */
+public enum QuantizationType {
+ /**
+ * Represents space quantization, typically involving dimensionality reduction
+ * or space partitioning techniques.
+ */
+ SPACE_QUANTIZATION,
+
+ /**
+ * Represents value quantization, typically involving the conversion of continuous
+ * values into discrete ones.
+ */
+ VALUE_QUANTIZATION,
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java
new file mode 100644
index 000000000..f67a4e5c8
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+/**
+ * The SQTypes enum defines the various scalar quantization types that can be used
+ * in the KNN for vector quantization.
+ * Each type corresponds to a different bit-width representation of the quantized values.
+ */
+public enum SQTypes {
+ /**
+ * FP16 quantization uses 16-bit floating-point representation.
+ * This type offers a good balance between range and precision.
+ */
+ FP16,
+
+ /**
+ * INT8 quantization uses 8-bit integer representation.
+ * It is commonly used for efficient storage and processing.
+ */
+ INT8,
+
+ /**
+ * INT6 quantization uses 6-bit integer representation.
+ * It provides a lower precision than INT8 but with less storage space.
+ */
+ INT6,
+
+ /**
+ * INT4 quantization uses 4-bit integer representation.
+ * This type is suitable for highly compressed storage with significant loss of precision.
+ */
+ INT4,
+
+ /**
+ * ONE_BIT quantization uses a single bit per coordinate.
+ * This type is the most compact, representing only two possible values per dimension.
+ */
+ ONE_BIT,
+
+ /**
+ * TWO_BIT quantization uses two bits per coordinate.
+ * This type represents four possible values per dimension, offering a balance between compression and accuracy.
+ */
+ TWO_BIT,
+
+ /**
+ * FOUR_BIT quantization uses four bits per coordinate.
+ * It allows for sixteen possible values per dimension, providing more detail than lower bit-widths.
+ */
+ FOUR_BIT,
+
+ /**
+ * UNSUPPORTED_TYPE is used to denote quantization types that are not supported.
+ * This can be used as a placeholder or default value.
+ */
+ UNSUPPORTED_TYPE
+}
+
diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java
new file mode 100644
index 000000000..f4915c68a
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java
@@ -0,0 +1,21 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+/**
+ * The ValueQuantizationType enum defines the types of value quantization techniques
+ * that can be applied in the KNN.
+ * These types represent different methodologies for quantizing the values of vectors.
+ */
+public enum ValueQuantizationType {
+ /**
+ * SQ (Scalar Quantization) represents a method where each coordinate of the vector is quantized
+ * independently. This technique is widely used for its simplicity and efficiency.
+ */
+ SQ
+}
+
+
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
new file mode 100644
index 000000000..9b99f8c22
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+/**
+ * The QuantizerFactory class is responsible for creating instances of {@link Quantizer}
+ * based on the provided {@link QuantizationParams}. It uses a registry to look up the
+ * appropriate quantizer implementation for the given quantization parameters.
+ */
+public class QuantizerFactory {
+ private static volatile boolean isRegistered = false;
+
+ /**
+ * Retrieves a quantizer instance based on the provided quantization parameters.
+ *
+ * @param params the quantization parameters used to determine the appropriate quantizer
+ * @param the type of quantization parameters, extending {@link QuantizationParams}
+ * @param the type of the quantized output
+ * @return an instance of {@link Quantizer} corresponding to the provided parameters
+ */
+ public static Quantizer
getQuantizer(P params) {
+ if (params == null) {
+ throw new IllegalArgumentException("Quantization parameters must not be null.");
+ }
+ // Lazy Registration instead of static block as class level;
+ if (!isRegistered) {
+ registerDefaultQuantizers();
+ }
+ return QuantizerRegistry.getQuantizer(params);
+ }
+
+ /**
+ * Registers default quantizers if not already registered.
+ */
+ private static synchronized void registerDefaultQuantizers() {
+ if (!isRegistered) {
+ // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE_QUANTIZATION,
+ SQTypes.ONE_BIT,
+ OneBitScalarQuantizer::new
+ );
+ // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE_QUANTIZATION,
+ SQTypes.TWO_BIT,
+ () -> new MultiBitScalarQuantizer(2)
+ );
+ // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE_QUANTIZATION,
+ SQTypes.FOUR_BIT,
+ () -> new MultiBitScalarQuantizer(4)
+ );
+ isRegistered = true;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
new file mode 100644
index 000000000..5262a96ce
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Supplier;
+
+/**
+ * The QuantizerRegistry class is responsible for managing the registration and retrieval
+ * of quantizer instances. Quantizers are registered with specific quantization parameters
+ * and type identifiers, allowing for efficient lookup and instantiation.
+ */
+class QuantizerRegistry {
+
+ // Use ConcurrentHashMap for thread-safe access
+ private static final Map>> registry = new ConcurrentHashMap<>();
+
+ /**
+ * Registers a quantizer with the registry.
+ *
+ * @param paramClass the class of the quantization parameters
+ * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION)
+ * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT)
+ * @param quantizerSupplier a supplier that provides instances of the quantizer
+ * @param the type of quantization parameters
+ */
+ public static
void register(Class
paramClass,
+ QuantizationType quantizationType,
+ SQTypes sqType,
+ Supplier extends Quantizer, ?>> quantizerSupplier) {
+ String identifier = quantizationType.name() + "_" + sqType.name();
+ // Ensure that the quantizer for this identifier is registered only once
+ registry.computeIfAbsent(identifier, key -> {
+ return quantizerSupplier;
+ });
+ }
+
+ /**
+ * Retrieves a quantizer instance based on the provided quantization parameters.
+ *
+ * @param params the quantization parameters used to determine the appropriate quantizer
+ * @param
the type of quantization parameters
+ * @param the type of the quantized output
+ * @return an instance of {@link Quantizer} corresponding to the provided parameters
+ * @throws IllegalArgumentException if no quantizer is registered for the given parameters
+ */
+ public static Quantizer
getQuantizer(P params) {
+ String identifier = params.getTypeIdentifier();
+ Supplier extends Quantizer, ?>> supplier = registry.get(identifier);
+ if (supplier == null) {
+ throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier +
+ ". Available quantizers: " + registry.keySet());
+ }
+ @SuppressWarnings("unchecked")
+ Quantizer
quantizer = (Quantizer
) supplier.get();
+ return quantizer;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
new file mode 100644
index 000000000..cf8e68150
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationOutput;
+
+/**
+ * The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
+ * It implements the QuantizationOutput interface to handle byte arrays specifically.
+ */
+public class BinaryQuantizationOutput implements QuantizationOutput {
+ private final byte[] quantizedVector;
+
+ /**
+ * Constructs a BinaryQuantizationOutput instance with the specified quantized vector.
+ *
+ * @param quantizedVector the quantized vector represented as a byte array.
+ */
+ public BinaryQuantizationOutput(byte[] quantizedVector) {
+ if (quantizedVector == null) {
+ throw new IllegalArgumentException("Quantized vector cannot be null");
+ }
+ this.quantizedVector = quantizedVector;
+ }
+
+ @Override
+ public byte[] getQuantizedVector() {
+ return quantizedVector;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
new file mode 100644
index 000000000..d2fc7dce0
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
@@ -0,0 +1,22 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationOutput;
+
+/**
+ * The QuantizationOutput interface defines the contract for quantization output data.
+ *
+ * @param The type of the quantized data.
+ */
+public interface QuantizationOutput {
+ /**
+ * Returns the quantized vector.
+ *
+ * @return the quantized data.
+ */
+ T getQuantizedVector();
+}
+
+
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
new file mode 100644
index 000000000..db4677985
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+
+import java.io.Serializable;
+
+/**
+ * Interface for quantization parameters.
+ * This interface defines the basic contract for all quantization parameter types.
+ * It provides methods to retrieve the quantization type and a unique type identifier.
+ * Implementations of this interface are expected to provide specific configurations
+ * for various quantization strategies.
+ */
+public interface QuantizationParams extends Serializable{
+
+ /**
+ * Gets the quantization type associated with the parameters.
+ * The quantization type defines the overall strategy or method used
+ * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION.
+ *
+ * @return the {@link QuantizationType} indicating the quantization method.
+ */
+ QuantizationType getQuantizationType();
+
+ /**
+ * Provides a unique identifier for the quantization parameters.
+ * This identifier is typically a combination of the quantization type
+ * and additional specifics, and it serves to distinguish between different
+ * configurations or modes of quantization.
+ *
+ * @return a string representing the unique type identifier.
+ */
+ String getTypeIdentifier();
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
new file mode 100644
index 000000000..3983331ca
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.SQTypes;
+
+import java.util.Objects;
+
+/**
+ * The SQParams class represents the parameters specific to scalar quantization (SQ).
+ * This class implements the QuantizationParams interface and includes the type of scalar quantization.
+ */
+public class SQParams implements QuantizationParams {
+ private final SQTypes sqType;
+
+ /**
+ * Constructs an SQParams instance with the specified scalar quantization type.
+ *
+ * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT).
+ */
+ public SQParams(SQTypes sqType) {
+ this.sqType = sqType;
+ }
+
+ /**
+ * Returns the quantization type associated with these parameters.
+ *
+ * @return The quantization type, always VALUE_QUANTIZATION for SQParams.
+ */
+ @Override
+ public QuantizationType getQuantizationType() {
+ return QuantizationType.VALUE_QUANTIZATION;
+ }
+
+ /**
+ * Returns the scalar quantization type.
+ *
+ * @return The specific scalar quantization type.
+ */
+ public SQTypes getSqType() {
+ return sqType;
+ }
+
+ /**
+ * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type.
+ * This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
+ *
+ * @return A string representing the unique type identifier.
+ */
+ @Override
+ public String getTypeIdentifier() {
+ return getQuantizationType().name() + "_" + sqType.name();
+ }
+
+ /**
+ * Compares this object to the specified object. The result is true if and only if the argument is not null and is
+ * an SQParams object that represents the same scalar quantization type.
+ *
+ * @param o The object to compare this SQParams against.
+ * @return true if the given object represents an SQParams equivalent to this instance, false otherwise.
+ */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ SQParams sqParams = (SQParams) o;
+ return sqType == sqParams.sqType;
+ }
+
+ /**
+ * Returns a hash code value for this SQParams instance. This method is supported for the benefit of hash tables such
+ * as those provided by HashMap.
+ *
+ * @return A hash code value for this SQParams instance.
+ */
+ @Override
+ public int hashCode() {
+ return Objects.hash(sqType);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
new file mode 100644
index 000000000..98668c8f8
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+
+import java.io.IOException;
+
+/**
+ * DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
+ * It can be utilized by any quantizer to represent a default state.
+ */
+public class DefaultQuantizationState implements QuantizationState {
+
+ private final QuantizationParams params;
+
+ public DefaultQuantizationState(QuantizationParams params) {
+ this.params = params;
+ }
+
+ @Override
+ public QuantizationParams getQuantizationParams() {
+ return params;
+ }
+
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this, null);
+ }
+
+ public static DefaultQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
+ return (DefaultQuantizationState)
+ QuantizationStateSerializer.deserialize(bytes, (parentParams, specificData) ->
+ new DefaultQuantizationState(
+ (SQParams) parentParams)
+ );
+ }
+}
+
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
new file mode 100644
index 000000000..7a0261382
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+
+import java.io.IOException;
+
+/**
+ * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization,
+ * including the thresholds used for quantization.
+ */
+public final class MultiBitScalarQuantizationState implements QuantizationState {
+ private final SQParams quantizationParams;
+ private final float[][] thresholds;
+
+ /**
+ * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds.
+ *
+ * @param quantizationParams the scalar quantization parameters.
+ * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array
+ * where each row corresponds to a different bit level.
+ */
+ public MultiBitScalarQuantizationState(SQParams quantizationParams, float[][] thresholds) {
+ this.quantizationParams = quantizationParams;
+ this.thresholds = thresholds;
+ }
+
+ @Override
+ public SQParams getQuantizationParams() {
+ return quantizationParams;
+ }
+
+ /**
+ * Returns the thresholds used in the quantization process.
+ *
+ * @return a 2D array of threshold values.
+ */
+ public float[][] getThresholds() {
+ return thresholds;
+ }
+
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this, thresholds);
+ }
+
+ public static MultiBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
+ return (MultiBitScalarQuantizationState)
+ QuantizationStateSerializer.deserialize(bytes, (parentParams, specificData) ->
+ new MultiBitScalarQuantizationState(
+ (SQParams) parentParams,
+ (float[][]) specificData)
+ );
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
new file mode 100644
index 000000000..447272215
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+
+import java.io.IOException;
+
+/**
+ * OneBitScalarQuantizationState represents the state of one-bit scalar quantization,
+ * including the mean values used for quantization.
+ */
+public final class OneBitScalarQuantizationState implements QuantizationState {
+ private final SQParams quantizationParams;
+ private final float[] mean;
+
+ /**
+ * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values.
+ *
+ * @param quantizationParams the scalar quantization parameters.
+ * @param mean the mean values for each dimension.
+ */
+ public OneBitScalarQuantizationState(SQParams quantizationParams, float[] mean) {
+ this.quantizationParams = quantizationParams;
+ this.mean = mean;
+ }
+
+ @Override
+ public SQParams getQuantizationParams() {
+ return quantizationParams;
+ }
+
+ /**
+ * Returns the mean values used in the quantization process.
+ *
+ * @return an array of mean values.
+ */
+ public float[] getMean() {
+ return mean;
+ }
+
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this, mean);
+ }
+
+ public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
+ return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
+ bytes,
+ (parentParams, specificData) ->
+ new OneBitScalarQuantizationState(
+ (SQParams) parentParams,
+ (float[]) specificData)
+ );
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
new file mode 100644
index 000000000..6d19e385c
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ * QuantizationState interface represents the state of a quantization process, including the parameters used.
+ * This interface provides methods for serializing and deserializing the state.
+ */
+public interface QuantizationState extends Serializable {
+ /**
+ * Returns the quantization parameters associated with this state.
+ *
+ * @return the quantization parameters.
+ */
+ QuantizationParams getQuantizationParams();
+
+ /**
+ * Serializes the quantization state to a byte array.
+ *
+ * @return a byte array representing the serialized state.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ byte[] toByteArray() throws IOException;
+}
+
diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
new file mode 100644
index 000000000..527878aac
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.requests;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+
+/**
+ * TrainingRequest represents a request for training a quantizer.
+ *
+ * @param the type of vectors to be trained.
+ */
+public abstract class TrainingRequest {
+ private final QuantizationParams params;
+ private final int totalNumberOfVectors;
+ private int[] sampledIndices;
+
+ /**
+ * Constructs a TrainingRequest with the given parameters and total number of vectors.
+ *
+ * @param params the quantization parameters.
+ * @param totalNumberOfVectors the total number of vectors.
+ */
+ protected TrainingRequest(QuantizationParams params, int totalNumberOfVectors) {
+ this.params = params;
+ this.totalNumberOfVectors = totalNumberOfVectors;
+ }
+
+ /**
+ * Returns the quantization parameters.
+ *
+ * @return the quantization parameters.
+ */
+ public QuantizationParams getParams() {
+ return params;
+ }
+
+ /**
+ * Returns the total number of vectors.
+ *
+ * @return the total number of vectors.
+ */
+ public int getTotalNumberOfVectors() {
+ return totalNumberOfVectors;
+ }
+
+ /**
+ * Sets the sampled indices for this training request.
+ *
+ * @param sampledIndices the sampled indices.
+ */
+ public void setSampledIndices(int[] sampledIndices) {
+ this.sampledIndices = sampledIndices;
+ }
+
+ /**
+ * Returns the sampled indices for this training request.
+ *
+ * @return the sampled indices.
+ */
+ public int[] getSampledIndices() {
+ return sampledIndices;
+ }
+
+ /**
+ * Returns the vector corresponding to the specified document ID.
+ *
+ * @param docId the document ID.
+ * @return the vector corresponding to the specified document ID.
+ */
+ public abstract T getVectorByDocId(int docId);
+
+ /**
+ * Returns the next vector in the sequence.
+ *
+ * @return the next vector.
+ */
+ public abstract T getVector();
+}
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
new file mode 100644
index 000000000..f03d5eb5b
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
@@ -0,0 +1,193 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+import org.opensearch.knn.quantization.util.BitPackingUtils;
+import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension.
+ * It supports multiple bits per coordinate, allowing for finer granularity in quantization.
+ */
+public class MultiBitScalarQuantizer implements Quantizer {
+ private final int bitsPerCoordinate; // Number of bits used to quantize each dimension
+ private final int samplingSize = 25000; // Default sampling size for training
+ private static final boolean IS_TRAINING_REQUIRED = true;
+
+ /**
+ * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate.
+ *
+ * @param bitsPerCoordinate the number of bits used per coordinate for quantization.
+ */
+ public MultiBitScalarQuantizer(int bitsPerCoordinate) {
+ if (bitsPerCoordinate < 2) {
+ throw new IllegalArgumentException("bitsPerCoordinate must be greater than 2 for multibit quantizer.");
+ }
+ this.bitsPerCoordinate = bitsPerCoordinate;
+ }
+
+ /**
+ * Trains the quantizer based on the provided training request, which should be of type SamplingTrainingRequest.
+ * The training process calculates the mean and standard deviation for each dimension and then determines
+ * threshold values for quantization based on these statistics.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @return a MultiBitScalarQuantizationState containing the computed thresholds.
+ */
+ @Override
+ public QuantizationState train(TrainingRequest trainingRequest) {
+ if (!IS_TRAINING_REQUIRED) {
+ return new DefaultQuantizationState(trainingRequest.getParams());
+ }
+ SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest);
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+
+ int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length;
+ float[] sum = new float[dimension];
+ float[] sumSq = new float[dimension];
+ calculateSumAndSumSq(trainingRequest, sampledIndices, sum, sumSq);
+
+ float[] mean = calculateMean(sum, sampledIndices.length);
+ float[] stdDev = calculateStandardDeviation(sumSq, mean, sampledIndices.length);
+
+ float[][] thresholds = calculateThresholds(mean, stdDev, dimension);
+ return new MultiBitScalarQuantizationState(params, thresholds);
+ }
+
+ /**
+ * Quantizes the provided vector using the provided quantization state, producing a quantized output.
+ * The vector is quantized based on the thresholds in the quantization state.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing threshold information.
+ * @return a BinaryQuantizationOutput containing the quantized data.
+ */
+ @Override
+ public QuantizationOutput quantize(float[] vector, QuantizationState state) {
+ if (state instanceof DefaultQuantizationState) {
+ return quantize(vector);
+ }
+
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector to quantize must not be null.");
+ }
+ if (!(state instanceof MultiBitScalarQuantizationState)) {
+ throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState.");
+ }
+ MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state;
+ float[][] thresholds = multiBitState.getThresholds();
+ if (thresholds == null || thresholds[0].length != vector.length) {
+ throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
+ }
+
+ List bitArrays = new ArrayList<>();
+ for (int i = 0; i < bitsPerCoordinate; i++) {
+ byte[] bitArray = new byte[vector.length];
+ for (int j = 0; j < vector.length; j++) {
+ bitArray[j] = (byte) (vector[j] > thresholds[i][j] ? 1 : 0);
+ }
+ bitArrays.add(bitArray);
+ }
+
+ return new BinaryQuantizationOutput(BitPackingUtils.packBits(bitArrays));
+ }
+
+ /**
+ * Calculates the sum and sum of squares for each dimension based on sampled vectors.
+ *
+ * @param samplingRequest the sampling training request containing the vectors.
+ * @param sampledIndices the indices of the sampled vectors.
+ * @param sum the array to store the sum of each dimension.
+ * @param sumSq the array to store the sum of squares of each dimension.
+ */
+ private void calculateSumAndSumSq(
+ TrainingRequest samplingRequest,
+ int[] sampledIndices,
+ float[] sum,
+ float[] sumSq
+ ) {
+ for (int index : sampledIndices) {
+ float[] vector = samplingRequest.getVectorByDocId(index);
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector at sampled index " + index + " is null.");
+ }
+ for (int j = 0; j < vector.length; j++) {
+ sum[j] += vector[j];
+ sumSq[j] += vector[j] * vector[j];
+ }
+ }
+ }
+
+ /**
+ * Calculates the mean for each dimension.
+ *
+ * @param sum the array containing the sum of each dimension.
+ * @param totalSamples the total number of samples.
+ * @return the mean for each dimension.
+ */
+ private float[] calculateMean(float[] sum, int totalSamples) {
+ float[] mean = new float[sum.length];
+ for (int j = 0; j < sum.length; j++) {
+ mean[j] = sum[j] / totalSamples;
+ }
+ return mean;
+ }
+
+ /**
+ * Calculates the standard deviation for each dimension.
+ *
+ * @param sumSq the array containing the sum of squares of each dimension.
+ * @param mean the mean for each dimension.
+ * @param totalSamples the total number of samples.
+ * @return the standard deviation for each dimension.
+ */
+ private float[] calculateStandardDeviation(float[] sumSq, float[] mean, int totalSamples) {
+ float[] stdDev = new float[mean.length];
+ for (int j = 0; j < mean.length; j++) {
+ stdDev[j] = (float) Math.sqrt(sumSq[j] / totalSamples - mean[j] * mean[j]);
+ }
+ return stdDev;
+ }
+
+ /**
+ * Calculates the thresholds for quantization based on mean and standard deviation.
+ *
+ * @param mean the mean for each dimension.
+ * @param stdDev the standard deviation for each dimension.
+ * @param dimension the number of dimensions in the vectors.
+ * @return the thresholds for quantization.
+ */
+ private float[][] calculateThresholds(float[] mean, float[] stdDev, int dimension) {
+ float[][] thresholds = new float[bitsPerCoordinate][dimension];
+ float coef = bitsPerCoordinate + 1;
+ for (int i = 0; i < bitsPerCoordinate; i++) {
+ float iCoef = -1 + 2 * (i + 1) / coef;
+ for (int j = 0; j < dimension; j++) {
+ thresholds[i][j] = mean[j] + iCoef * stdDev[j];
+ }
+ }
+ return thresholds;
+ }
+
+ private QuantizationOutput quantize(float[] vector) {
+ throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer.");
+ }
+
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
new file mode 100644
index 000000000..54d1b9c45
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+import org.opensearch.knn.quantization.util.BitPackingUtils;
+import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+import java.util.Collections;
+
+/**
+ * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension.
+ * It computes the mean of each dimension during training and then uses these means as thresholds
+ * for quantizing the vectors.
+ */
+public class OneBitScalarQuantizer implements Quantizer {
+ private static final int SAMPLING_SIZE = 25000;
+ private static final boolean IS_TRAINING_REQUIRED = true;
+
+ /**
+ * Trains the quantizer by calculating the mean of each dimension from the sampled vectors.
+ * These means are used as thresholds in the quantization process.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @return a OneBitScalarQuantizationState containing the calculated means.
+ */
+ @Override
+ public QuantizationState train(TrainingRequest trainingRequest) {
+ if (!IS_TRAINING_REQUIRED) {
+ return new DefaultQuantizationState(trainingRequest.getParams());
+ }
+ SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest);
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), SAMPLING_SIZE);
+ float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices);
+ return new OneBitScalarQuantizationState(params, mean);
+ }
+
+ /**
+ * Quantizes the provided vector using the given quantization state.
+ * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing the means for each dimension.
+ * @return a BinaryQuantizationOutput containing the quantized data.
+ */
+ @Override
+ public QuantizationOutput quantize(float[] vector, QuantizationState state) {
+ if (state instanceof DefaultQuantizationState) {
+ return quantize(vector);
+ }
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector to quantize must not be null.");
+ }
+ if (!(state instanceof OneBitScalarQuantizationState)) {
+ throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState.");
+ }
+ OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state;
+ float[] thresholds = binaryState.getMean();
+ if (thresholds == null || thresholds.length != vector.length) {
+ throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
+ }
+ byte[] quantizedVector = new byte[vector.length];
+ for (int i = 0; i < vector.length; i++) {
+ quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0);
+ }
+ return new BinaryQuantizationOutput(BitPackingUtils.packBits(Collections.singletonList(quantizedVector)));
+ }
+
+ private QuantizationOutput quantize(float[] vector) {
+ throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer.");
+ }
+}
+
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
new file mode 100644
index 000000000..ef26612ed
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+/**
+ * The Quantizer interface defines the methods required for training and quantizing vectors
+ * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks.
+ * It supports training to determine quantization parameters and quantizing data vectors
+ * based on these parameters.
+ *
+ * @param The type of the vector or data to be quantized.
+ * @param The type of the quantized output, typically a compressed or encoded representation.
+ */
+public interface Quantizer {
+
+ /**
+ * Trains the quantizer based on the provided training request. The training process typically
+ * involves learning parameters that can be used to quantize vectors.
+ *
+ * @param trainingRequest the request containing data and parameters for training.
+ * @return a QuantizationState containing the learned parameters.
+ */
+ QuantizationState train(TrainingRequest trainingRequest);
+
+ /**
+ * Quantizes the provided vector using the specified quantization state.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing parameters for quantization.
+ * @return a QuantizationOutput containing the quantized representation of the vector.
+ */
+ QuantizationOutput quantize(T vector, QuantizationState state);
+}
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
new file mode 100644
index 000000000..f952d3dc7
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import java.util.Arrays;
+import java.util.Random;
+import java.util.stream.IntStream;
+
+/**
+ * ReservoirSampler implements the Sampler interface and provides a method for sampling
+ * a specified number of indices from a total number of vectors using the reservoir sampling algorithm.
+ * This algorithm is particularly useful for randomly sampling a subset of data from a larger set
+ * when the total size of the dataset is unknown or very large.
+ */
+public class ReservoirSampler implements Sampler {
+ private final Random random = new Random();
+
+ /**
+ * Samples indices from the range [0, totalNumberOfVectors).
+ * If the total number of vectors is less than or equal to the sample size, it returns all indices.
+ * Otherwise, it uses the reservoir sampling algorithm to select a random subset.
+ *
+ * @param totalNumberOfVectors the total number of vectors to sample from.
+ * @param sampleSize the number of indices to sample.
+ * @return an array of sampled indices.
+ */
+ @Override
+ public int[] sample(int totalNumberOfVectors, int sampleSize) {
+ if (totalNumberOfVectors <= sampleSize) {
+ return IntStream.range(0, totalNumberOfVectors).toArray();
+ }
+ return reservoirSampleIndices(totalNumberOfVectors, sampleSize);
+ }
+
+ /**
+ * Applies the reservoir sampling algorithm to select a random sample of indices.
+ * This method ensures that each index in the range [0, numVectors) has an equal probability
+ * of being included in the sample.
+ *
+ * @param numVectors the total number of vectors.
+ * @param sampleSize the number of indices to sample.
+ * @return an array of sampled indices.
+ */
+ private int[] reservoirSampleIndices(int numVectors, int sampleSize) {
+ int[] indices = IntStream.range(0, sampleSize).toArray();
+ for (int i = sampleSize; i < numVectors; i++) {
+ int j = random.nextInt(i + 1);
+ if (j < sampleSize) {
+ indices[j] = i;
+ }
+ }
+ Arrays.sort(indices);
+ return indices;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
new file mode 100644
index 000000000..9021073b4
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
@@ -0,0 +1,10 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+public interface Sampler {
+ int[] sample(int totalNumberOfVectors, int sampleSize);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
new file mode 100644
index 000000000..89470bb3f
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+/**
+ * SamplingFactory is a factory class for creating instances of Sampler.
+ * It uses the factory design pattern to encapsulate the creation logic for different types of samplers.
+ */
+public class SamplingFactory {
+
+ /**
+ * SamplerType is an enumeration of the different types of samplers that can be created by the factory.
+ */
+ public enum SamplerType {
+ RESERVOIR, // Represents a reservoir sampling strategy
+ // Add more enum values here for additional sampler types
+ }
+
+ /**
+ * Creates and returns a Sampler instance based on the specified SamplerType.
+ *
+ * @param samplerType the type of sampler to create.
+ * @return a Sampler instance.
+ * @throws IllegalArgumentException if the sampler type is not supported.
+ */
+ public static Sampler getSampler(SamplerType samplerType) {
+ switch (samplerType) {
+ case RESERVOIR:
+ return new ReservoirSampler();
+ // Add more cases for different samplers here
+ default:
+ throw new IllegalArgumentException("Unsupported sampler type: " + samplerType);
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java b/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java
new file mode 100644
index 000000000..4debf9335
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import java.util.List;
+
+/**
+ * Utility class for bit packing operations.
+ * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission.
+ */
+public class BitPackingUtils {
+
+ /**
+ * Packs the list of bit arrays into a single byte array.
+ * Each byte in the resulting array contains up to 8 bits from the bit arrays, packed from left to right.
+ *
+ * @param bitArrays the list of bit arrays to be packed. Each bit array should contain only 0s and 1s.
+ * @return a byte array containing the packed bits.
+ * @throws IllegalArgumentException if the bitArrays list is empty or if any bit array is null.
+ */
+ public static byte[] packBits(List bitArrays) {
+ if (bitArrays.isEmpty()) {
+ throw new IllegalArgumentException("The list of bit arrays cannot be empty.");
+ }
+
+ int bitLength = bitArrays.size() * bitArrays.get(0).length;
+ int byteLength = (bitLength + 7) / 8;
+ byte[] packedArray = new byte[byteLength];
+
+ for (int i = 0; i < bitArrays.size(); i++) {
+ byte[] bitArray = bitArrays.get(i);
+ if (bitArray == null) {
+ throw new IllegalArgumentException("Bit array cannot be null.");
+ }
+ for (int j = 0; j < bitArray.length; j++) {
+ int byteIndex = (i * bitArray.length + j) / 8;
+ int bitIndex = 7 - ((i * bitArray.length + j) % 8);
+ packedArray[byteIndex] |= (bitArray[j] << bitIndex);
+ }
+ }
+
+ return packedArray;
+ }
+}
+
diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java
new file mode 100644
index 000000000..366bf7559
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+
+import java.io.*;
+
+/**
+ * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing
+ * QuantizationState objects along with their specific data.
+ */
+public class QuantizationStateSerializer {
+
+ /**
+ * A functional interface for deserializing specific data associated with a QuantizationState.
+ */
+ @FunctionalInterface
+ public interface SerializableDeserializer {
+ QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData);
+ }
+
+ /**
+ * Serializes the QuantizationState and specific data into a byte array.
+ *
+ * @param state The QuantizationState to serialize.
+ * @param specificData The specific data related to the state, to be serialized.
+ * @return A byte array representing the serialized state and specific data.
+ * @throws IOException If an I/O error occurs during serialization.
+ */
+ public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException {
+ byte[] parentBytes = serializeParentParams(state.getQuantizationParams());
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ out.writeInt(parentBytes.length); // Write the length of the parent bytes
+ out.write(parentBytes); // Write the parent bytes
+ out.writeObject(specificData); // Write the specific data
+ out.flush();
+ return bos.toByteArray();
+ }
+ }
+
+ /**
+ * Deserializes a QuantizationState and its specific data from a byte array.
+ *
+ * @param bytes The byte array containing the serialized data.
+ * @param specificDataDeserializer The deserializer for the specific data associated with the state.
+ * @return The deserialized QuantizationState including its specific data.
+ * @throws IOException If an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException If the class of the serialized object cannot be found.
+ */
+ public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer)
+ throws IOException, ClassNotFoundException {
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
+ ObjectInputStream in = new ObjectInputStream(bis)) {
+ int parentLength = in.readInt(); // Read the length of the
+ // Read the length of the parent bytes
+ byte[] parentBytes = new byte[parentLength];
+ in.readFully(parentBytes); // Read the parent bytes
+ QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params
+ Serializable specificData = (Serializable) in.readObject(); // Read the specific data
+ return specificDataDeserializer.deserialize(parentParams, specificData);
+ }
+ }
+
+ /**
+ * Serializes the parent parameters of the QuantizationState into a byte array.
+ *
+ * @param params The QuantizationParams to serialize.
+ * @return A byte array representing the serialized parent parameters.
+ * @throws IOException If an I/O error occurs during serialization.
+ */
+ private static byte[] serializeParentParams(QuantizationParams params) throws IOException {
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ out.writeObject(params);
+ out.flush();
+ return bos.toByteArray();
+ }
+ }
+
+ /**
+ * Deserializes the parent parameters of the QuantizationState from a byte array.
+ *
+ * @param bytes The byte array containing the serialized parent parameters.
+ * @return The deserialized QuantizationParams.
+ * @throws IOException If an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException If the class of the serialized object cannot be found.
+ */
+ private static QuantizationParams deserializeParentParams(byte[] bytes)
+ throws IOException, ClassNotFoundException {
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
+ ObjectInputStream in = new ObjectInputStream(bis)) {
+ return (QuantizationParams) in.readObject();
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java
new file mode 100644
index 000000000..3ee61a2fc
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+/**
+ * Utility class providing common methods for quantizer operations, such as parameter validation and
+ * extraction. This class is designed to be used with various quantizer implementations that require
+ * consistent handling of training requests and sampled indices.
+ */
+public final class QuantizerHelper {
+
+ /**
+ * Private constructor to prevent instantiation of this utility class.
+ * The class is not meant to be instantiated, as it provides static utility methods only.
+ */
+ private QuantizerHelper() {
+ // Private constructor to prevent instantiation
+ }
+
+ /**
+ * Validates the provided training request to ensure it contains non-null quantization parameters.
+ * Extracts and returns the SQParams from the training request.
+ *
+ * @param trainingRequest the training request to validate and extract parameters from.
+ * @return the extracted SQParams.
+ * @throws IllegalArgumentException if the SQParams are null.
+ */
+ public static SQParams validateAndExtractParams(TrainingRequest> trainingRequest) {
+ QuantizationParams params = trainingRequest.getParams();
+ if (params == null || !(params instanceof SQParams)) {
+ throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams.");
+ }
+ return (SQParams) params;
+ }
+
+ /**
+ * Calculates the mean vector from a set of sampled vectors.
+ *
+ * This method takes a {@link TrainingRequest} object and an array of sampled indices,
+ * retrieves the vectors corresponding to these indices, and calculates the mean vector.
+ * Each element of the mean vector is computed as the average of the corresponding elements
+ * of the sampled vectors.
+ *
+ * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices.
+ * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation.
+ * @return A float array representing the mean vector of the sampled vectors.
+ * @throws IllegalArgumentException If any of the vectors at the sampled indices are null.
+ * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors.
+ */
+ public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) {
+ int totalSamples = sampledIndices.length;
+ float[] mean = null;
+ for (int index : sampledIndices) {
+ float[] vector = samplingRequest.getVectorByDocId(index);
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector at sampled index " + index + " is null.");
+ }
+ if (mean == null) {
+ mean = new float[vector.length];
+ }
+ for (int j = 0; j < vector.length; j++) {
+ mean[j] += vector[j];
+ }
+ }
+ if (mean == null) {
+ throw new IllegalStateException("Mean array should not be null after processing vectors.");
+ }
+ for (int j = 0; j < mean.length; j++) {
+ mean[j] /= totalSamples;
+ }
+ return mean;
+ }
+}
+
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
new file mode 100644
index 000000000..f14babd37
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class QuantizationTypeTests extends KNNTestCase {
+
+ public void testQuantizationTypeValues() {
+ QuantizationType[] expectedValues = {
+ QuantizationType.SPACE_QUANTIZATION,
+ QuantizationType.VALUE_QUANTIZATION
+ };
+ assertArrayEquals(expectedValues, QuantizationType.values());
+ }
+
+ public void testQuantizationTypeValueOf() {
+ assertEquals(QuantizationType.SPACE_QUANTIZATION, QuantizationType.valueOf("SPACE_QUANTIZATION"));
+ assertEquals(QuantizationType.VALUE_QUANTIZATION, QuantizationType.valueOf("VALUE_QUANTIZATION"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
new file mode 100644
index 000000000..b22289b88
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class SQTypesTests extends KNNTestCase {
+ public void testSQTypesValues() {
+ SQTypes[] expectedValues = {
+ SQTypes.FP16,
+ SQTypes.INT8,
+ SQTypes.INT6,
+ SQTypes.INT4,
+ SQTypes.ONE_BIT,
+ SQTypes.TWO_BIT,
+ SQTypes.FOUR_BIT,
+ SQTypes.UNSUPPORTED_TYPE
+ };
+ assertArrayEquals(expectedValues, SQTypes.values());
+ }
+
+ public void testSQTypesValueOf() {
+ assertEquals(SQTypes.FP16, SQTypes.valueOf("FP16"));
+ assertEquals(SQTypes.INT8, SQTypes.valueOf("INT8"));
+ assertEquals(SQTypes.INT6, SQTypes.valueOf("INT6"));
+ assertEquals(SQTypes.INT4, SQTypes.valueOf("INT4"));
+ assertEquals(SQTypes.ONE_BIT, SQTypes.valueOf("ONE_BIT"));
+ assertEquals(SQTypes.TWO_BIT, SQTypes.valueOf("TWO_BIT"));
+ assertEquals(SQTypes.FOUR_BIT, SQTypes.valueOf("FOUR_BIT"));
+ assertEquals(SQTypes.UNSUPPORTED_TYPE, SQTypes.valueOf("UNSUPPORTED_TYPE"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
new file mode 100644
index 000000000..47d7123f6
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
@@ -0,0 +1,21 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class ValueQuantizationTypeTests extends KNNTestCase {
+ public void testValueQuantizationTypeValues() {
+ ValueQuantizationType[] expectedValues = {
+ ValueQuantizationType.SQ
+ };
+ assertArrayEquals(expectedValues, ValueQuantizationType.values());
+ }
+
+ public void testValueQuantizationTypeValueOf() {
+ assertEquals(ValueQuantizationType.SQ, ValueQuantizationType.valueOf("SQ"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
new file mode 100644
index 000000000..2082688ab
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
@@ -0,0 +1,99 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.junit.Before;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.lang.reflect.Field;
+
+public class QuantizerFactoryTests extends KNNTestCase {
+
+ @Before
+ public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException {
+ Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered");
+ isRegisteredField.setAccessible(true);
+ isRegisteredField.setBoolean(null, false);
+ }
+
+ public void test_Lazy_Registration() {
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ assertFalse(isRegisteredFieldAccessible());
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof OneBitScalarQuantizer);
+ assertTrue(isRegisteredFieldAccessible());
+ }
+
+ public void testGetQuantizer_withOneBitSQParams() {
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof OneBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withTwoBitSQParams() {
+ SQParams params = new SQParams(SQTypes.TWO_BIT);
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof MultiBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withFourBitSQParams() {
+ SQParams params = new SQParams(SQTypes.FOUR_BIT);
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof MultiBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withUnsupportedType() {
+ SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE);
+ try {
+ QuantizerFactory.getQuantizer(params);
+ fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ assertTrue(e.getMessage().contains("No quantizer registered for type identifier"));
+ }
+ }
+
+ public void testGetQuantizer_withNullParams() {
+ try {
+ QuantizerFactory.getQuantizer(null);
+ fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ assertEquals("Quantization parameters must not be null.", e.getMessage());
+ }
+ }
+
+
+ public void test_Concurrent_Registration() throws InterruptedException {
+ Runnable task = () -> {
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ QuantizerFactory.getQuantizer(params);
+ };
+
+ Thread thread1 = new Thread(task);
+ Thread thread2 = new Thread(task);
+ thread1.start();
+ thread2.start();
+ thread1.join();
+ thread2.join();
+
+ assertTrue(isRegisteredFieldAccessible());
+ }
+
+ private boolean isRegisteredFieldAccessible() {
+ try {
+ Field field = QuantizerFactory.class.getDeclaredField("isRegistered");
+ field.setAccessible(true);
+ return field.getBoolean(null);
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ fail("Failed to access isRegistered field.");
+ return false;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
new file mode 100644
index 000000000..a848bc474
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.junit.BeforeClass;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+public class QuantizerRegistryTests extends KNNTestCase {
+
+ @BeforeClass
+ public static void setup() {
+ // Register the quantizers for testing with enums
+ QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.ONE_BIT, OneBitScalarQuantizer::new);
+ QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.TWO_BIT, () -> new MultiBitScalarQuantizer(2));
+ QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.FOUR_BIT, () -> new MultiBitScalarQuantizer(4));
+ }
+
+ public void testRegisterAndGetQuantizer() {
+ // Test for OneBitScalarQuantizer
+ SQParams oneBitParams = new SQParams(SQTypes.ONE_BIT);
+ Quantizer, ?> oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer);
+
+ // Test for MultiBitScalarQuantizer (2-bit)
+ SQParams twoBitParams = new SQParams(SQTypes.TWO_BIT);
+ Quantizer, ?> twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer);
+
+ // Test for MultiBitScalarQuantizer (4-bit)
+ SQParams fourBitParams = new SQParams(SQTypes.FOUR_BIT);
+ Quantizer, ?> fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withUnsupportedTypeIdentifier() {
+ // Create SQParams with an unsupported type identifier
+ SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered
+
+ // Expect IllegalArgumentException when requesting a quantizer with unsupported params
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
+ QuantizerRegistry.getQuantizer(params);
+ });
+
+ assertTrue(exception.getMessage().contains("No quantizer registered for type identifier"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
new file mode 100644
index 000000000..50a8eee60
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizationState;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+
+import java.io.IOException;
+
+public class QuantizationStateSerializerTests extends KNNTestCase {
+
+ public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ float[] mean = new float[]{0.1f, 0.2f, 0.3f};
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+
+ byte[] serialized = state.toByteArray();
+ OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized);
+
+ assertArrayEquals(mean, deserialized.getMean(), 0.0f);
+ assertEquals(params, deserialized.getQuantizationParams());
+ }
+
+ public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(SQTypes.TWO_BIT);
+ float[][] thresholds = new float[][]{
+ {0.1f, 0.2f, 0.3f},
+ {0.4f, 0.5f, 0.6f}
+ };
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ byte[] serialized = state.toByteArray();
+ MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized);
+
+ assertArrayEquals(thresholds, deserialized.getThresholds());
+ assertEquals(params, deserialized.getQuantizationParams());
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
new file mode 100644
index 000000000..e1ab98b07
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.quantizationState;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+
+import java.io.IOException;
+
+public class QuantizationStateTests extends KNNTestCase {
+
+ public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ float[] mean = {1.0f, 2.0f, 3.0f};
+
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+
+ byte[] serializedState = state.toByteArray();
+ OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState);
+ float delta = 0.0001f;
+
+ assertArrayEquals(mean, deserializedState.getMean(), delta);
+ assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ }
+
+
+ public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(SQTypes.TWO_BIT);
+ float[][] thresholds = {
+ {0.5f, 1.5f, 2.5f},
+ {1.0f, 2.0f, 3.0f}
+ };
+
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ byte[] serializedState = state.toByteArray();
+ MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState);
+ float delta = 0.0001f;
+
+ for (int i = 0; i < thresholds.length; i++) {
+ assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i],delta);
+ }
+ assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
new file mode 100644
index 000000000..458df08c5
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+public class MultiBitScalarQuantizerTests extends KNNTestCase {
+
+ public void testTrain_twoBit() {
+ float[][] vectors = {
+ {0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
+ {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f}
+ };
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ int[] sampledIndices = {0, 1, 2};
+ SQParams params = new SQParams(SQTypes.TWO_BIT);
+ TrainingRequest request = new MockTrainingRequest(params, vectors);
+ request.setSampledIndices(sampledIndices);
+ QuantizationState state = twoBitQuantizer.train(request);
+
+ assertTrue(state instanceof MultiBitScalarQuantizationState);
+ MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state;
+ assertNotNull(mbState.getThresholds());
+ assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds
+ }
+
+ public void testTrain_fourBit() {
+ MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
+ float[][] vectors = {
+ {0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
+ {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f}
+ };
+ int[] sampledIndices = {0, 1, 2};
+ SQParams params = new SQParams(SQTypes.FOUR_BIT);
+ TrainingRequest request = new MockTrainingRequest(params, vectors);
+ request.setSampledIndices(sampledIndices);
+ QuantizationState state = fourBitQuantizer.train(request);
+
+ assertTrue(state instanceof MultiBitScalarQuantizationState);
+ MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state;
+ assertNotNull(mbState.getThresholds());
+ assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds
+ }
+
+ public void testQuantize_twoBit() {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f};
+ float[][] thresholds = {
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
+ {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f}
+ };
+ SQParams params = new SQParams(SQTypes.TWO_BIT);
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ QuantizationOutput output = twoBitQuantizer.quantize(vector, state);
+ assertNotNull(output.getQuantizedVector());
+ assertEquals(2, output.getQuantizedVector().length);
+ }
+
+ public void testQuantize_fourBit() {
+ MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
+ float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f};
+ float[][] thresholds = {
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
+ {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f},
+ {1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f},
+ {1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f}
+ };
+ SQParams params = new SQParams(SQTypes.FOUR_BIT);
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ QuantizationOutput output = fourBitQuantizer.quantize(vector, state);
+ assertEquals(4, output.getQuantizedVector().length);
+ assertNotNull(output.getQuantizedVector());
+ }
+
+ public void testQuantize_withNullVector() {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ expectThrows(IllegalArgumentException.class,
+ () -> twoBitQuantizer.quantize(null, new MultiBitScalarQuantizationState(new SQParams(SQTypes.TWO_BIT),
+ new float[2][8])));
+ }
+
+ public void testQuantize_withInvalidState() {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f};
+ QuantizationState invalidState = new MockInvalidQuantizationState();
+ expectThrows(IllegalArgumentException.class,
+ () -> twoBitQuantizer.quantize(vector, invalidState));
+ }
+
+ public void testQuantize_withDefaultQuantizationState() {
+ MultiBitScalarQuantizer quantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f};
+ DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(SQTypes.ONE_BIT));
+
+ expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state));
+ }
+
+ // Mock classes for testing
+ private static class MockTrainingRequest extends TrainingRequest {
+ private final float[][] vectors;
+
+ public MockTrainingRequest(SQParams params, float[][] vectors) {
+ super(params, vectors.length);
+ this.vectors = vectors;
+ }
+
+ @Override
+ public float[] getVector() {
+ return new float[0];
+ }
+
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ }
+
+ private static class MockInvalidQuantizationState implements QuantizationState {
+ @Override
+ public SQParams getQuantizationParams() {
+ return new SQParams(SQTypes.UNSUPPORTED_TYPE);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return new byte[0];
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
new file mode 100644
index 000000000..9b384cd1c
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
@@ -0,0 +1,152 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.SQTypes;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+public class OneBitScalarQuantizerTests extends KNNTestCase {
+
+ public void testTrain_withTrainingRequired() {
+ float[][] vectors = {
+ {1.0f, 2.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f},
+ {7.0f, 8.0f, 9.0f}
+ };
+
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) {
+ @Override
+ public float[] getVector() {
+ return null; // Not used in this test
+ }
+
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ QuantizationState state = quantizer.train(originalRequest);
+
+ assertTrue(state instanceof OneBitScalarQuantizationState);
+ float[] mean = ((OneBitScalarQuantizationState) state).getMean();
+ assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f);
+ }
+
+ public void testQuantize_withState() {
+ float[] vector = {3.0f, 6.0f, 9.0f};
+ float[] thresholds = {4.0f, 5.0f, 6.0f};
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds);
+
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ QuantizationOutput output = quantizer.quantize(vector, state);
+
+ assertNotNull(output);
+ byte[] expectedPackedBits = new byte[]{0b01100000}; // or 96 in decimal
+ assertArrayEquals(expectedPackedBits, output.getQuantizedVector());
+ }
+
+ public void testQuantize_withDefaultQuantizationState() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = {3.0f, 6.0f, 9.0f};
+ DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(SQTypes.ONE_BIT));
+
+ expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state));
+ }
+
+ public void testQuantize_withNullVector() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), new float[]{0.0f});
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state));
+ }
+
+ public void testQuantize_withInvalidState() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = {1.0f, 2.0f, 3.0f};
+ QuantizationState invalidState = new QuantizationState() {
+ @Override
+ public SQParams getQuantizationParams() {
+ return new SQParams(SQTypes.ONE_BIT);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return new byte[0];
+ }
+ };
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState));
+ }
+
+ public void testQuantize_withMismatchedDimensions() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = {1.0f, 2.0f, 3.0f};
+ float[] thresholds = {4.0f, 5.0f};
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds);
+
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state));
+ }
+
+ public void testCalculateMean() {
+ float[][] vectors = {
+ {1.0f, 2.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f},
+ {7.0f, 8.0f, 9.0f}
+ };
+
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) {
+ @Override
+ public float[] getVector() {
+ return null; // Not used in this test
+ }
+
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(vectors.length, 3);
+ float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices);
+ assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f);
+ }
+
+ public void testCalculateMean_withNullVector() {
+ float[][] vectors = {
+ {1.0f, 2.0f, 3.0f},
+ null,
+ {7.0f, 8.0f, 9.0f}
+ };
+
+ SQParams params = new SQParams(SQTypes.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) {
+ @Override
+ public float[] getVector() {
+ return null; // Not used in this test
+ }
+
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(vectors.length, 3);
+ expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
new file mode 100644
index 000000000..3a42588a4
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import org.opensearch.knn.KNNTestCase;
+
+
+public class ReservoirSamplerTests extends KNNTestCase {
+
+ public void testSample() {
+ Sampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 100;
+ int sampleSize = 10;
+
+ int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(sampleSize, samples.length);
+ for (int index : samples) {
+ assertTrue(index >= 0 && index < totalNumberOfVectors);
+ }
+ }
+
+ public void testSample_withFullSampling() {
+ Sampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 10;
+ int sampleSize = 10;
+
+ int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(sampleSize, samples.length);
+ for (int index : samples) {
+ assertTrue(index >= 0 && index < totalNumberOfVectors);
+ }
+ }
+
+ public void testSample_withLessVectors() {
+ Sampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 5;
+ int sampleSize = 10;
+
+ int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(totalNumberOfVectors, samples.length);
+ for (int index : samples) {
+ assertTrue(index >= 0 && index < totalNumberOfVectors);
+ }
+ }
+
+ public void testSample_withZeroVectors() {
+ Sampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 0;
+ int sampleSize = 10;
+
+ int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(0, samples.length);
+ }
+
+ public void testSample_withOneVector() {
+ Sampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 1;
+ int sampleSize = 10;
+
+ int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(1, samples.length);
+ assertTrue(samples[0] == 0);
+ }
+}
+
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
new file mode 100644
index 000000000..56d496d2f
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class SamplingFactoryTests extends KNNTestCase {
+ public void testGetSampler_withReservoir() {
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ assertTrue(sampler instanceof ReservoirSampler);
+ }
+
+ public void testGetSampler_withUnsupportedType() {
+ expectThrows( NullPointerException.class, ()-> SamplingFactory.getSampler(null)); // This should throw an exception
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java
new file mode 100644
index 000000000..4d48b83ee
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import org.opensearch.knn.KNNTestCase;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class BitPackingUtilsTests extends KNNTestCase {
+
+ public void testPackBits_BitArray() {
+ List bitArrays = Arrays.asList(
+ new byte[]{1, 0, 1, 0, 1, 0, 1, 0},
+ new byte[]{1, 1, 0, 0, 1, 1, 0, 0}
+ );
+
+ byte[] packedBits = BitPackingUtils.packBits(bitArrays);
+ byte[] expected = new byte[]{(byte) 0b10101010, (byte) 0b11001100};
+
+ assertArrayEquals(expected, packedBits);
+ }
+
+ public void testPackBits_multipleBitArrays() {
+ List bitArrays = Arrays.asList(
+ new byte[]{1, 0, 1},
+ new byte[]{0, 1, 0},
+ new byte[]{1, 1, 1}
+ );
+
+ byte[] packedBits = BitPackingUtils.packBits(bitArrays);
+ byte[] expected = new byte[]{(byte) 0b10101011, (byte) 0b10000000};
+
+ assertArrayEquals(expected, packedBits);
+ }
+
+ public void testPackBits_emptyArray() {
+ List bitArrays = Arrays.asList();
+ expectThrows(IllegalArgumentException.class, () -> {
+ BitPackingUtils.packBits(bitArrays);
+ });;
+ }
+}