diff --git a/CHANGELOG.md b/CHANGELOG.md index 92af64ccb..f47f8f20d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,4 +30,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) -* Quantization Framework For Disk Optimized Vector Search and Implementation of Binary 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) +* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java deleted file mode 100644 index 4a2a17a57..000000000 --- a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -/** - * The QuantizationType enum represents the different types of quantization - * that can be applied in the KNN. - * - * - */ -public enum QuantizationType { - /** - * Represents space quantization, typically involving dimensionality reduction - * or space partitioning techniques. - */ - SPACE, - - /** - * Represents value quantization, typically involving the conversion of continuous - * values into discrete ones. - */ - VALUE, -} diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java index 88290c6a8..0b5458065 100644 --- a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java +++ b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java @@ -5,30 +5,42 @@ package org.opensearch.knn.quantization.enums; +import lombok.Getter; + /** - * The SQTypes enum defines the various scalar quantization types that can be used - * in the KNN for vector quantization. - * Each type corresponds to a different bit-width representation of the quantized values. + * The ScalarQuantizationType enum defines the various scalar quantization types that can be used + * for vector quantization. Each type corresponds to a different bit-width representation of the quantized values. + * + *

+ * Future Developers: If you change the name of any enum constant, do not change its associated value. + * Serialization and deserialization depend on these values to maintain compatibility. + *

*/ +@Getter public enum ScalarQuantizationType { /** * ONE_BIT quantization uses a single bit per coordinate. */ - ONE_BIT, + ONE_BIT(1), /** * TWO_BIT quantization uses two bits per coordinate. */ - TWO_BIT, + TWO_BIT(2), /** * FOUR_BIT quantization uses four bits per coordinate. */ - FOUR_BIT, + FOUR_BIT(4); + + private final int id; /** - * UNSUPPORTED_TYPE is used to denote quantization types that are not supported. - * This can be used as a placeholder or default value. + * Constructs a ScalarQuantizationType with the specified ID. + * + * @param id the ID representing the quantization type. */ - UNSUPPORTED_TYPE + ScalarQuantizationType(int id) { + this.id = id; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java deleted file mode 100644 index 43db46cf6..000000000 --- a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -/** - * The ValueQuantizationType enum defines the types of value quantization techniques - * that can be applied in the KNN. - */ -public enum ValueQuantizationType { - /** - * SQ (Scalar Quantization) represents a method where each coordinate of the vector is quantized - * independently. - */ - SCALAR -} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java index 985efd4cd..b99f6ebdc 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -5,6 +5,8 @@ package org.opensearch.knn.quantization.factory; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -15,12 +17,10 @@ * based on the provided {@link QuantizationParams}. It uses a registry to look up the * appropriate quantizer implementation for the given quantization parameters. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) public final class QuantizerFactory { private static final AtomicBoolean isRegistered = new AtomicBoolean(false); - // Private constructor to prevent instantiation - private QuantizerFactory() {} - /** * Ensures that default quantizers are registered. */ diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java index c8a2eb2bf..889130c9a 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java @@ -5,9 +5,10 @@ package org.opensearch.knn.quantization.factory; -import org.opensearch.knn.quantization.enums.QuantizationType; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; @@ -15,11 +16,9 @@ * The QuantizerRegistrar class is responsible for registering default quantizers. * This class ensures that the registration happens only once in a thread-safe manner. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) final class QuantizerRegistrar { - // Private constructor to prevent instantiation - private QuantizerRegistrar() {} - /** * Registers default quantizers if not already registered. *

@@ -27,22 +26,21 @@ private QuantizerRegistrar() {} * even in a multi-threaded environment. *

*/ - public static synchronized void registerDefaultQuantizers() { + static synchronized void registerDefaultQuantizers() { // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT - QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2 QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.TWO_BIT, - () -> new MultiBitScalarQuantizer(2) + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), + new MultiBitScalarQuantizer(2) ); // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4 QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.FOUR_BIT, - () -> new MultiBitScalarQuantizer(4) + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), + new MultiBitScalarQuantizer(4) ); } } diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java index 1243d79ef..50e1158c2 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -5,46 +5,33 @@ package org.opensearch.knn.quantization.factory; -import org.opensearch.knn.quantization.enums.QuantizationType; -import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.quantizer.Quantizer; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Supplier; /** * The QuantizerRegistry class is responsible for managing the registration and retrieval * of quantizer instances. Quantizers are registered with specific quantization parameters * and type identifiers, allowing for efficient lookup and instantiation. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) final class QuantizerRegistry { - - // Private constructor to prevent instantiation - private QuantizerRegistry() {} - // ConcurrentHashMap for thread-safe access - private static final Map>> registry = new ConcurrentHashMap<>(); + private static final Map> registry = new ConcurrentHashMap<>(); /** * Registers a quantizer with the registry. * - * @param paramClass the class of the quantization parameters - * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION) - * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT) - * @param quantizerSupplier a supplier that provides instances of the quantizer - * @param

the type of quantization parameters + * @param paramIdentifier the unique identifier for the quantization parameters + * @param quantizer an instance of the quantizer */ - public static

void register( - final Class

paramClass, - final QuantizationType quantizationType, - final ScalarQuantizationType sqType, - final Supplier> quantizerSupplier - ) { - String identifier = createIdentifier(quantizationType, sqType); + static void register(final String paramIdentifier, final Quantizer quantizer) { // Ensure that the quantizer for this identifier is registered only once - registry.computeIfAbsent(identifier, key -> quantizerSupplier); + registry.putIfAbsent(paramIdentifier, quantizer); } /** @@ -56,27 +43,14 @@ public static

void register( * @return an instance of {@link Quantizer} corresponding to the provided parameters * @throws IllegalArgumentException if no quantizer is registered for the given parameters */ - public static

Quantizer getQuantizer(final P params) { + static

Quantizer getQuantizer(final P params) { String identifier = params.getTypeIdentifier(); - Supplier> supplier = registry.get(identifier); - if (supplier == null) { - throw new IllegalArgumentException( - "No quantizer registered for type identifier: " + identifier + ". Available quantizers: " + registry.keySet() - ); + Quantizer quantizer = registry.get(identifier); + if (quantizer == null) { + throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier); } @SuppressWarnings("unchecked") - Quantizer quantizer = (Quantizer) supplier.get(); - return quantizer; - } - - /** - * Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype. - * - * @param quantizationType the quantization type - * @param sqType the specific quantization subtype - * @return a string identifier - */ - private static String createIdentifier(final QuantizationType quantizationType, final ScalarQuantizationType sqType) { - return quantizationType.name() + "_" + sqType.name(); + Quantizer typedQuantizer = (Quantizer) quantizer; + return typedQuantizer; } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java index 18077182f..dbf8e5bf9 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -5,27 +5,36 @@ package org.opensearch.knn.quantization.models.quantizationOutput; +import lombok.NoArgsConstructor; +import lombok.Getter; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + /** * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. * It implements the QuantizationOutput interface to handle byte arrays specifically. */ +@NoArgsConstructor public class BinaryQuantizationOutput implements QuantizationOutput { - private final byte[] quantizedVector; + @Getter + private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); /** - * Constructs a BinaryQuantizationOutput instance with the specified quantized vector. + * Updates the quantized vector with a new byte array. * - * @param quantizedVector the quantized vector represented as a byte array. + * @param newQuantizedVector the new quantized vector represented as a byte array. */ - public BinaryQuantizationOutput(final byte[] quantizedVector) { - if (quantizedVector == null) { - throw new IllegalArgumentException("Quantized vector cannot be null"); + public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException { + if (newQuantizedVector == null || newQuantizedVector.length == 0) { + throw new IllegalArgumentException("Quantized vector cannot be null or empty"); } - this.quantizedVector = quantizedVector; + byteArrayOutputStream.reset(); + byteArrayOutputStream.write(newQuantizedVector); } @Override public byte[] getQuantizedVector() { - return quantizedVector; + return byteArrayOutputStream.toByteArray(); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java index c5c5fd21f..8f01a0594 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -5,6 +5,8 @@ package org.opensearch.knn.quantization.models.quantizationOutput; +import java.io.IOException; + /** * The QuantizationOutput interface defines the contract for quantization output data. * @@ -17,4 +19,12 @@ public interface QuantizationOutput { * @return the quantized data. */ T getQuantizedVector(); + + /** + * Updates the quantized vector with new data. + * + * @param newQuantizedVector the new quantized vector data. + * @throws IOException if an I/O error occurs during the update. + */ + void updateQuantizedVector(T newQuantizedVector) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java index 2c982a306..88b22c4fd 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -5,9 +5,7 @@ package org.opensearch.knn.quantization.models.quantizationParams; -import org.opensearch.knn.quantization.enums.QuantizationType; - -import java.io.Serializable; +import java.io.Externalizable; /** * Interface for quantization parameters. @@ -16,17 +14,7 @@ * Implementations of this interface are expected to provide specific configurations * for various quantization strategies. */ -public interface QuantizationParams extends Serializable { - - /** - * Gets the quantization type associated with the parameters. - * The quantization type defines the overall strategy or method used - * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION. - * - * @return the {@link QuantizationType} indicating the quantization method. - */ - QuantizationType getQuantizationType(); - +public interface QuantizationParams extends Externalizable { /** * Provides a unique identifier for the quantization parameters. * This identifier is typically a combination of the quantization type diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java deleted file mode 100644 index 0b6bbc988..000000000 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.models.quantizationParams; - -import org.opensearch.knn.quantization.enums.QuantizationType; -import org.opensearch.knn.quantization.enums.ScalarQuantizationType; - -import java.util.Objects; - -/** - * The SQParams class represents the parameters specific to scalar quantization (SQ). - * This class implements the QuantizationParams interface and includes the type of scalar quantization. - */ -public class SQParams implements QuantizationParams { - private final ScalarQuantizationType sqType; - - /** - * Constructs an SQParams instance with the specified scalar quantization type. - * - * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT). - */ - public SQParams(final ScalarQuantizationType sqType) { - this.sqType = sqType; - } - - /** - * Returns the quantization type associated with these parameters. - * - * @return The quantization type, always VALUE_QUANTIZATION for SQParams. - */ - @Override - public QuantizationType getQuantizationType() { - return QuantizationType.VALUE; - } - - /** - * Returns the scalar quantization type. - * - * @return The specific scalar quantization type. - */ - public ScalarQuantizationType getSqType() { - return sqType; - } - - /** - * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type. - * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. - * - * @return A string representing the unique type identifier. - */ - @Override - public String getTypeIdentifier() { - return getQuantizationType().name() + "_" + sqType.name(); - } - - /** - * Compares this object to the specified object. The result is true if and only if the argument is not null and is - * an SQParams object that represents the same scalar quantization type. - * - * @param o The object to compare this SQParams against. - * @return true if the given object represents an SQParams equivalent to this instance, false otherwise. - */ - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SQParams sqParams = (SQParams) o; - return sqType == sqParams.sqType; - } - - /** - * Returns a hash code value for this SQParams instance. - * - * @return A hash code value for this SQParams instance. - */ - @Override - public int hashCode() { - return Objects.hash(sqType); - } -} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java new file mode 100644 index 000000000..c7c24062d --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Locale; + +/** + * The SQParams class represents the parameters specific to scalar quantization (SQ). + * This class implements the QuantizationParams interface and includes the type of scalar quantization. + */ +@Getter +@AllArgsConstructor +@NoArgsConstructor // No-argument constructor for deserialization +@EqualsAndHashCode +public class ScalarQuantizationParams implements QuantizationParams { + private ScalarQuantizationType sqType; + private static final long serialVersionUID = 1L; // Version ID for serialization + + /** + * Static method to generate type identifier based on ScalarQuantizationType. + * + * @param sqType the scalar quantization type. + * @return A string representing the unique type identifier. + */ + public static String generateTypeIdentifier(ScalarQuantizationType sqType) { + return generateIdentifier(sqType.getId()); + } + + /** + * Serializes the SQParams object to an external output. + * This method writes the scalar quantization type to the output stream. + * + * @param out the ObjectOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public void writeExternal(ObjectOutput out) throws IOException { + // The version is already written by the parent state class, no need to write it here again + // Retrieve the current version from VersionContext + // This context will be used by other classes involved in the serialization process. + // Example: + // int version = VersionContext.getVersion(); // Get the current version from VersionContext + // Any Version Specific logic can be wriiten based on Version + out.writeObject(sqType); + } + + /** + * Deserializes the SQParams object from an external input with versioning. + * This method reads the scalar quantization type and new field from the input stream based on the version. + * + * @param in the ObjectInput to read the object from. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // The version is already read by the parent state class and set in VersionContext + // Retrieve the current version from VersionContext to handle version-specific deserialization logic + // int versionId = VersionContext.getVersion(); + // Version version = Version.fromId(versionId); + + sqType = (ScalarQuantizationType) in.readObject(); + + // Add version-specific deserialization logic + // For example, if new fields are added in a future version, handle them here + // This section contains conditional logic to handle different versions appropriately. + // Example: + // if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) { + // // Handle logic for versions between 1.0.0 and 2.0.0 + // // Example: Read additional fields introduced in version 1.0.0 + // // newField = in.readInt(); + // } else if (version.onOrAfter(Version.V_2_0_0)) { + // // Handle logic for versions 2.0.0 and above + // // Example: Read additional fields introduced in version 2.0.0 + // // anotherNewField = in.readFloat(); + // } + } + + /** + * Provides a unique type identifier for the SQParams, combining the SQ type. + * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. + * + * @return A string representing the unique type identifier. + */ + @Override + public String getTypeIdentifier() { + return generateIdentifier(sqType.getId()); + } + + private static String generateIdentifier(int id) { + return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java index acc8c2f00..3e3249c6f 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -5,28 +5,27 @@ package org.opensearch.knn.quantization.models.quantizationState; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.Version; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.util.QuantizationStateSerializer; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. * It can be utilized by any quantizer to represent a default state. */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor public class DefaultQuantizationState implements QuantizationState { - - private final QuantizationParams params; - - /** - * Constructs a DefaultQuantizationState with the given quantization parameters. - * - * @param params the quantization parameters. - */ - public DefaultQuantizationState(final QuantizationParams params) { - this.params = params; - } + private QuantizationParams params; + private static final long serialVersionUID = 1L; // Version ID for serialization /** * Returns the quantization parameters associated with this state. @@ -60,7 +59,34 @@ public byte[] toByteArray() throws IOException { public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (DefaultQuantizationState) QuantizationStateSerializer.deserialize( bytes, - (parentParams, specificData) -> new DefaultQuantizationState((SQParams) parentParams) + new DefaultQuantizationState(), + (parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams) ); } + + /** + * Writes the object to the output stream. + * This method is part of the Externalizable interface and is used to serialize the object. + * + * @param out the output stream to write the object to. + * @throws IOException if an I/O error occurs. + */ + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + out.writeObject(params); + } + + /** + * Reads the object from the input stream. + * This method is part of the Externalizable interface and is used to deserialize the object. + * + * @param in the input stream to read the object from. + * @throws IOException if an I/O error occurs. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.params = (QuantizationParams) in.readObject(); + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 58834dd2c..095d245f2 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -5,54 +5,186 @@ package org.opensearch.knn.quantization.models.quantizationState; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.util.QuantizationStateSerializer; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.Version; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, * including the thresholds used for quantization. */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor public final class MultiBitScalarQuantizationState implements QuantizationState { - private final SQParams quantizationParams; - private final float[][] thresholds; - + private ScalarQuantizationParams quantizationParams; /** - * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds. + * The threshold values for multi-bit quantization, organized as a 2D array + * where each row corresponds to a different bit level. + * + * For example: + * - For 2-bit quantization: + * thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level + * thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level + * - For 4-bit quantization: + * thresholds[0] -> {0.1f, 0.2f, 0.3f} + * thresholds[1] -> {0.4f, 0.5f, 0.6f} + * thresholds[2] -> {0.7f, 0.8f, 0.9f} + * thresholds[3] -> {1.0f, 1.1f, 1.2f} * - * @param quantizationParams the scalar quantization parameters. - * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array - * where each row corresponds to a different bit level. + * Each column represents the threshold for a specific dimension in the vector space. */ - public MultiBitScalarQuantizationState(final SQParams quantizationParams, final float[][] thresholds) { - this.quantizationParams = quantizationParams; - this.thresholds = thresholds; - } + private float[][] thresholds; + private static final long serialVersionUID = 1L; // Version ID for serialization @Override - public SQParams getQuantizationParams() { + public ScalarQuantizationParams getQuantizationParams() { return quantizationParams; } /** - * Returns the thresholds used in the quantization process. + * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * + *

Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized + * as part of the state to access the version information and implement version-specific logic if needed.

+ * + *

The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable, + * ensuring that the version is available to all classes involved in the serialization process within the current thread context.

+ * + *
+     * {@code
+     * // Example usage in the writeExternal method:
+     * VersionContext.setVersion(version);
+     * out.writeInt(version); // Write the version
+     * quantizationParams.writeExternal(out);
+     * out.writeInt(meanThresholds.length);
+     * for (float mean : meanThresholds) {
+     *     out.writeFloat(mean);
+     * }
+     * }
+     * 
+ * + * @param out the ObjectOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public void writeExternal(ObjectOutput out) throws IOException { + int version = Version.CURRENT.id; + VersionContext.setVersion(version); + out.writeInt(version); // Write the version + quantizationParams.writeExternal(out); + out.writeInt(thresholds.length); + out.writeInt(thresholds[0].length); + for (float[] row : thresholds) { + for (float value : row) { + out.writeFloat(value); + } + } + } + + /** + * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * + *

The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method. + * This makes the version information available to all classes involved in the deserialization process within the current thread context.

* - * @return a 2D array of threshold values. + *

Classes that are part of the deserialization process can retrieve the version information using the + * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.

+ * + *
+     * {@code
+     * // Example usage in the readExternal method:
+     * int version = in.readInt(); // Read the version
+     * VersionContext.setVersion(version);
+     * quantizationParams = new ScalarQuantizationParams();
+     * quantizationParams.readExternal(in); // Use readExternal of SQParams
+     * int length = in.readInt();
+     * meanThresholds = new float[length];
+     * for (int i = 0; i < length; i++) {
+     *     meanThresholds[i] = in.readFloat();
+     * }
+     * }
+     * 
+ * + * @param in the ObjectInput to read the object from. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ - public float[][] getThresholds() { - return thresholds; + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + int version = in.readInt(); // Read the version + VersionContext.setVersion(version); + quantizationParams = new ScalarQuantizationParams(); + quantizationParams.readExternal(in); // Use readExternal of SQParams + int rows = in.readInt(); + int cols = in.readInt(); + thresholds = new float[rows][cols]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + thresholds[i][j] = in.readFloat(); + } + } + VersionContext.clear(); // Clear the version after use } + /** + * Serializes the current state of this MultiBitScalarQuantizationState object into a byte array. + * This method uses the QuantizationStateSerializer to handle the serialization process. + * + *

The serialized byte array includes all necessary state information, such as the thresholds + * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.

+ * + *
+     * {@code
+     * MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+     * byte[] serializedState = state.toByteArray();
+     * }
+     * 
+ * + * @return a byte array representing the serialized state of this object. + * @throws IOException if an I/O error occurs during serialization. + */ @Override public byte[] toByteArray() throws IOException { return QuantizationStateSerializer.serialize(this, thresholds); } + /** + * Deserializes a MultiBitScalarQuantizationState object from a byte array. + * This method uses the QuantizationStateSerializer to handle the deserialization process. + * + *

The byte array should contain serialized state information, including the thresholds + * and quantization parameters, which are necessary to reconstruct the MultiBitScalarQuantizationState object.

+ * + *
+     * {@code
+     * byte[] serializedState = ...; // obtain the byte array from some source
+     * MultiBitScalarQuantizationState state = MultiBitScalarQuantizationState.fromByteArray(serializedState);
+     * }
+     * 
+ * + * @param bytes the byte array containing the serialized state. + * @return the deserialized MultiBitScalarQuantizationState object. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of a serialized object cannot be found. + */ public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize( bytes, - (parentParams, thresholds) -> new MultiBitScalarQuantizationState((SQParams) parentParams, (float[][]) thresholds) + new MultiBitScalarQuantizationState(), + (parentParams, thresholds) -> new MultiBitScalarQuantizationState( + (ScalarQuantizationParams) parentParams, + (float[][]) thresholds + ) ); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 9b4bad56a..8ab37955e 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -5,53 +5,177 @@ package org.opensearch.knn.quantization.models.quantizationState; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.util.QuantizationStateSerializer; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.Version; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, * including the mean values used for quantization. */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor public final class OneBitScalarQuantizationState implements QuantizationState { - private final SQParams quantizationParams; - private final float[] meanThresholds; - + private ScalarQuantizationParams quantizationParams; /** - * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values. + * Mean thresholds used in the quantization process. + * Each threshold value corresponds to a dimension of the vector being quantized. * - * @param quantizationParams the scalar quantization parameters. - * @param mean the mean values for each dimension. + * Example: + * If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0], + * the quantization process will be: + * - 1.2 < 2.0, so the first bit is 0 + * - 3.4 > 3.0, so the second bit is 1 + * - 5.6 > 4.0, so the third bit is 1 + * The quantized vector will be [0, 1, 1]. */ - public OneBitScalarQuantizationState(final SQParams quantizationParams, final float[] mean) { - this.quantizationParams = quantizationParams; - this.meanThresholds = mean; - } + private float[] meanThresholds; + private static final long serialVersionUID = 1L; // Version ID for serialization @Override - public SQParams getQuantizationParams() { + public ScalarQuantizationParams getQuantizationParams() { return quantizationParams; } /** - * Returns the mean values used in the quantization process. + * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * + *

Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized + * as part of the state to access the version information and implement version-specific logic if needed.

+ * + *

The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable, + * ensuring that the version is available to all classes involved in the serialization process within the current thread context.

* - * @return an array of mean values. + *
+     * {@code
+     * // Example usage in the writeExternal method:
+     * VersionContext.setVersion(version);
+     * out.writeInt(version); // Write the version
+     * quantizationParams.writeExternal(out);
+     * out.writeInt(meanThresholds.length);
+     * for (float mean : meanThresholds) {
+     *     out.writeFloat(mean);
+     * }
+     * }
+     * 
+ * + * @param out the ObjectOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. */ - public float[] getMeanThresholds() { - return meanThresholds; + @Override + public void writeExternal(ObjectOutput out) throws IOException { + int version = Version.CURRENT.id; + VersionContext.setVersion(version); + out.writeInt(version); // Write the version + quantizationParams.writeExternal(out); + out.writeInt(meanThresholds.length); + for (float mean : meanThresholds) { + out.writeFloat(mean); + } + VersionContext.clear(); // Clear the version after use } + /** + * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * + *

The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method. + * This makes the version information available to all classes involved in the deserialization process within the current thread context.

+ * + *

Classes that are part of the deserialization process can retrieve the version information using the + * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.

+ * + *
+     * {@code
+     * // Example usage in the readExternal method:
+     * int version = in.readInt(); // Read the version
+     * VersionContext.setVersion(version);
+     * quantizationParams = new ScalarQuantizationParams();
+     * quantizationParams.readExternal(in); // Use readExternal of SQParams
+     * int length = in.readInt();
+     * meanThresholds = new float[length];
+     * for (int i = 0; i < length; i++) {
+     *     meanThresholds[i] = in.readFloat();
+     * }
+     * }
+     * 
+ * + * @param in the ObjectInput to read the object from. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + int version = in.readInt(); // Read the version + VersionContext.setVersion(version); + quantizationParams = new ScalarQuantizationParams(); + quantizationParams.readExternal(in); // Use readExternal of SQParams + int length = in.readInt(); + meanThresholds = new float[length]; + for (int i = 0; i < length; i++) { + meanThresholds[i] = in.readFloat(); + } + VersionContext.clear(); // Clear the version after use + } + + /** + * Serializes the current state of this OneBitScalarQuantizationState object into a byte array. + * This method uses the QuantizationStateSerializer to handle the serialization process. + * + *

The serialized byte array includes all necessary state information, such as the mean thresholds + * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.

+ * + *
+     * {@code
+     * OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, meanThresholds);
+     * byte[] serializedState = state.toByteArray();
+     * }
+     * 
+ * + * @return a byte array representing the serialized state of this object. + * @throws IOException if an I/O error occurs during serialization. + */ @Override public byte[] toByteArray() throws IOException { return QuantizationStateSerializer.serialize(this, meanThresholds); } + /** + * Deserializes a OneBitScalarQuantizationState object from a byte array. + * This method uses the QuantizationStateSerializer to handle the deserialization process. + * + *

The byte array should contain serialized state information, including the mean thresholds + * and quantization parameters, which are necessary to reconstruct the OneBitScalarQuantizationState object.

+ * + *
+     * {@code
+     * byte[] serializedState = ...; // obtain the byte array from some source
+     * OneBitScalarQuantizationState state = OneBitScalarQuantizationState.fromByteArray(serializedState);
+     * }
+     * 
+ * + * @param bytes the byte array containing the serialized state. + * @return the deserialized OneBitScalarQuantizationState object. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of a serialized object cannot be found. + */ public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize( bytes, - (parentParams, meanThresholds) -> new OneBitScalarQuantizationState((SQParams) parentParams, (float[]) meanThresholds) + new OneBitScalarQuantizationState(), + (parentParams, meanThresholds) -> new OneBitScalarQuantizationState( + (ScalarQuantizationParams) parentParams, + (float[]) meanThresholds + ) ); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java index c17ff0641..d3778fe29 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -7,14 +7,14 @@ import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import java.io.Externalizable; import java.io.IOException; -import java.io.Serializable; /** * QuantizationState interface represents the state of a quantization process, including the parameters used. * This interface provides methods for serializing and deserializing the state. */ -public interface QuantizationState extends Serializable { +public interface QuantizationState extends Externalizable { /** * Returns the quantization parameters associated with this state. * diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java new file mode 100644 index 000000000..5f00d8e0c --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.experimental.UtilityClass; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.io.IOException; +import java.io.ByteArrayInputStream; +import java.io.ObjectInputStream; + +/** + * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing + * QuantizationState objects along with their specific data. + */ +@UtilityClass +class QuantizationStateSerializer { + + /** + * A functional interface for deserializing specific data associated with a QuantizationState. + */ + @FunctionalInterface + interface SerializableDeserializer { + QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); + } + + /** + * Serializes the QuantizationState and specific data into a byte array. + * + * @param state The QuantizationState to serialize. + * @param specificData The specific data related to the state, to be serialized. + * @return A byte array representing the serialized state and specific data. + * @throws IOException If an I/O error occurs during serialization. + */ + static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { + state.writeExternal(out); + out.writeObject(specificData); + out.flush(); + return bos.toByteArray(); + } + } + + /** + * Deserializes a QuantizationState and its specific data from a byte array. + * + * @param bytes The byte array containing the serialized data. + * @param stateInstance An instance of the state to call readExternal on. + * @param specificDataDeserializer The deserializer for the specific data associated with the state. + * @return The deserialized QuantizationState including its specific data. + * @throws IOException If an I/O error occurs during deserialization. + * @throws ClassNotFoundException If the class of the serialized object cannot be found. + */ + static QuantizationState deserialize(byte[] bytes, QuantizationState stateInstance, SerializableDeserializer specificDataDeserializer) + throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { + stateInstance.readExternal(in); + Serializable specificData = (Serializable) in.readObject(); // Read the specific data + return specificDataDeserializer.deserialize(stateInstance.getQuantizationParams(), specificData); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java index 14689ea47..54ebe311c 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -5,64 +5,21 @@ package org.opensearch.knn.quantization.models.requests; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import lombok.AllArgsConstructor; +import lombok.Getter; /** * TrainingRequest represents a request for training a quantizer. * * @param the type of vectors to be trained. */ +@Getter +@AllArgsConstructor public abstract class TrainingRequest { - private final QuantizationParams params; - private final int totalNumberOfVectors; - private int[] sampledIndices; - - /** - * Constructs a TrainingRequest with the given parameters and total number of vectors. - * - * @param params the quantization parameters. - * @param totalNumberOfVectors the total number of vectors. - */ - protected TrainingRequest(final QuantizationParams params, final int totalNumberOfVectors) { - this.params = params; - this.totalNumberOfVectors = totalNumberOfVectors; - } - - /** - * Returns the quantization parameters. - * - * @return the quantization parameters. - */ - public QuantizationParams getParams() { - return params; - } - /** - * Returns the total number of vectors. - * - * @return the total number of vectors. - */ - public int getTotalNumberOfVectors() { - return totalNumberOfVectors; - } - - /** - * Sets the sampled indices for this training request. - * - * @param sampledIndices the sampled indices. + * The total number of vectors in one segment. */ - public void setSampledIndices(int[] sampledIndices) { - this.sampledIndices = sampledIndices; - } - - /** - * Returns the sampled indices for this training request. - * - * @return the sampled indices. - */ - public int[] getSampledIndices() { - return sampledIndices; - } + private final int totalNumberOfVectors; /** * Returns the vector corresponding to the specified document ID. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 0143a6614..c6d366e5b 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -6,19 +6,18 @@ package org.opensearch.knn.quantization.quantizer; -import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; -import org.opensearch.knn.quantization.util.BitPacker; -import org.opensearch.knn.quantization.util.QuantizerHelper; -import java.util.ArrayList; -import java.util.List; +import java.io.IOException; +import java.util.BitSet; /** * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. @@ -41,7 +40,7 @@ public class MultiBitScalarQuantizer implements Quantizer { * @param bitsPerCoordinate the number of bits used per coordinate for quantization. */ public MultiBitScalarQuantizer(final int bitsPerCoordinate) { - this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR)); + this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR)); } /** @@ -70,15 +69,16 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); - int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - - int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; + BitSet sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int dimension = trainingRequest.getVectorByDocId(sampledIndices.nextSetBit(0)).length; float[] meanArray = new float[dimension]; float[] stdDevArray = new float[dimension]; // Calculate sum, mean, and standard deviation in one pass QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension); + ScalarQuantizationParams params = (bitsPerCoordinate == 2) + ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) + : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); return new MultiBitScalarQuantizationState(params, thresholds); } @@ -88,10 +88,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) { * * @param vector the vector to quantize. * @param state the quantization state containing threshold information. - * @return a BinaryQuantizationOutput containing the quantized data. + * @param output the QuantizationOutput object to store the quantized representation of the vector. + * @throws IOException if an I/O error occurs during quantization. */ @Override - public QuantizationOutput quantize(final float[] vector, final QuantizationState state) { + public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) throws IOException { if (vector == null) { throw new IllegalArgumentException("Vector to quantize must not be null."); } @@ -101,17 +102,22 @@ public QuantizationOutput quantize(final float[] vector, final Quantizat if (thresholds == null || thresholds[0].length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - - List bitArrays = new ArrayList<>(); + // Directly pack bits without intermediate array + int totalBits = bitsPerCoordinate * vector.length; + int byteLength = (totalBits + 7) >> 3; // Calculate byte length needed + byte[] packedBits = new byte[byteLength]; for (int i = 0; i < bitsPerCoordinate; i++) { - byte[] bitArray = new byte[vector.length]; for (int j = 0; j < vector.length; j++) { - bitArray[j] = (byte) (vector[j] > thresholds[i][j] ? 1 : 0); + if (vector[j] > thresholds[i][j]) { + int bitPosition = i * vector.length + j; + int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8 + int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8) + packedBits[byteIndex] |= (1 << bitIndex); // Set the bit + } } - bitArrays.add(bitArray); } - return new BinaryQuantizationOutput(BitPacker.packBits(bitArrays)); + output.updateQuantizedVector(packedBits); } /** @@ -145,4 +151,13 @@ private void validateState(final QuantizationState state) { throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState."); } } + + /** + * Returns the number of bits per coordinate used by this quantizer. + * + * @return the number of bits per coordinate. + */ + public int getBitsPerCoordinate() { + return bitsPerCoordinate; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index 2eaa07ce0..eab3d992e 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -5,18 +5,18 @@ package org.opensearch.knn.quantization.quantizer; -import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; -import org.opensearch.knn.quantization.util.BitPacker; -import org.opensearch.knn.quantization.util.QuantizerHelper; -import java.util.Collections; +import java.io.IOException; +import java.util.BitSet; /** * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. @@ -37,7 +37,7 @@ public class OneBitScalarQuantizer implements Quantizer { * Constructs a OneBitScalarQuantizer with a default sampling size of 25000. */ public OneBitScalarQuantizer() { - this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR)); + this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR)); } /** @@ -49,7 +49,6 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { this.samplingSize = samplingSize; this.sampler = sampler; - ; } /** @@ -61,10 +60,9 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); - int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices); - return new OneBitScalarQuantizationState(params, mean); + BitSet sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); + return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); } /** @@ -73,10 +71,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) { * * @param vector the vector to quantize. * @param state the quantization state containing the means for each dimension. + * @param output the QuantizationOutput object to store the quantized representation of the vector. * @return a BinaryQuantizationOutput containing the quantized data. */ @Override - public QuantizationOutput quantize(final float[] vector, final QuantizationState state) { + public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) throws IOException { if (vector == null) { throw new IllegalArgumentException("Vector to quantize must not be null."); } @@ -86,11 +85,19 @@ public QuantizationOutput quantize(final float[] vector, final Quantizat if (thresholds == null || thresholds.length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - byte[] quantizedVector = new byte[vector.length]; + // Directly pack bits without intermediate array + int byteLength = (vector.length + 7) >> 3; // Calculate byte length needed + byte[] packedBits = new byte[byteLength]; + for (int i = 0; i < vector.length; i++) { - quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0); + if (vector[i] > thresholds[i]) { + int byteIndex = i >> 3; // Equivalent to i / 8 + int bitIndex = 7 - (i & 7); // Equivalent to 7 - (i % 8) + packedBits[byteIndex] |= (1 << bitIndex); // Set the bit + } } - return new BinaryQuantizationOutput(BitPacker.packBits(Collections.singletonList(quantizedVector))); + + output.updateQuantizedVector(packedBits); } /** diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index 8231a8aa2..beabd8d73 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -1,14 +1,11 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.knn.quantization.quantizer; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + /** * The Quantizer interface defines the methods required for training and quantizing vectors * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. @@ -34,7 +31,7 @@ public interface Quantizer { * * @param vector the vector to quantize. * @param state the quantization state containing parameters for quantization. - * @return a QuantizationOutput containing the quantized representation of the vector. + * @param output the QuantizationOutput object to store the quantized representation of the vector. */ - QuantizationOutput quantize(T vector, QuantizationState state); + void quantize(T vector, QuantizationState state, QuantizationOutput output) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java similarity index 67% rename from src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java rename to src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index adc4e34c4..e95632606 100644 --- a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -1,40 +1,22 @@ /* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * */ -package org.opensearch.knn.quantization.util; +package org.opensearch.knn.quantization.quantizer; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import lombok.experimental.UtilityClass; +import java.util.BitSet; + /** * Utility class providing common methods for quantizer operations, such as parameter validation and * extraction. This class is designed to be used with various quantizer implementations that require * consistent handling of training requests and sampled indices. */ @UtilityClass -public class QuantizerHelper { - - /** - * Validates the provided training request to ensure it contains non-null quantization parameters. - * Extracts and returns the SQParams from the training request. - * - * @param trainingRequest the training request to validate and extract parameters from. - * @return the extracted SQParams. - * @throws IllegalArgumentException if the SQParams are null. - */ - public static SQParams validateAndExtractParams(TrainingRequest trainingRequest) { - QuantizationParams params = trainingRequest.getParams(); - if (params == null || !(params instanceof SQParams)) { - throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams."); - } - return (SQParams) params; - } - +class QuantizerHelper { /** * Calculates the mean vector from a set of sampled vectors. * @@ -49,13 +31,13 @@ public static SQParams validateAndExtractParams(TrainingRequest trainingReque * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. */ - public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) { - int totalSamples = sampledIndices.length; + static float[] calculateMeanThresholds(TrainingRequest samplingRequest, BitSet sampledIndices) { + int totalSamples = sampledIndices.cardinality(); float[] mean = null; - for (int index : sampledIndices) { - float[] vector = samplingRequest.getVectorByDocId(index); + for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + float[] vector = samplingRequest.getVectorByDocId(docId); if (vector == null) { - throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } if (mean == null) { mean = new float[vector.length]; @@ -81,20 +63,20 @@ public static float[] calculateMean(TrainingRequest samplingRequest, in * @param meanArray the array to store the sum and then the mean of each dimension. * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension. */ - public static void calculateSumMeanAndStdDev( + static void calculateSumMeanAndStdDev( TrainingRequest trainingRequest, - int[] sampledIndices, + BitSet sampledIndices, float[] meanArray, float[] stdDevArray ) { - int totalSamples = sampledIndices.length; + int totalSamples = sampledIndices.cardinality(); int dimension = meanArray.length; // Single pass to calculate sum and sum of squares - for (int index : sampledIndices) { - float[] vector = trainingRequest.getVectorByDocId(index); + for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + float[] vector = trainingRequest.getVectorByDocId(docId); if (vector == null) { - throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } for (int j = 0; j < dimension; j++) { meanArray[j] += vector[j]; diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java index da5327def..22a7bae23 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -5,10 +5,10 @@ package org.opensearch.knn.quantization.sampler; -import java.util.Arrays; -import java.util.Random; +import lombok.NoArgsConstructor; + +import java.util.BitSet; import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.IntStream; /** * ReservoirSampler implements the Sampler interface and provides a method for sampling @@ -16,33 +16,23 @@ * This algorithm is particularly useful for randomly sampling a subset of data from a larger set * when the total size of the dataset is unknown or very large. */ +@NoArgsConstructor final class ReservoirSampler implements Sampler { - - private final Random random; - - /** - * Constructs a ReservoirSampler with a new Random instance. - */ - public ReservoirSampler() { - this(ThreadLocalRandom.current()); - } - /** - * Constructs a ReservoirSampler with a specified random seed for reproducibility. - * - * @param seed the seed for the random number generator. + * Singleton instance holder. */ - public ReservoirSampler(final long seed) { - this(new Random(seed)); - } + private static ReservoirSampler instance; /** - * Constructs a ReservoirSampler with a specified Random instance. + * Provides the singleton instance of ReservoirSampler. * - * @param random the Random instance for generating random numbers. + * @return the singleton instance of ReservoirSampler. */ - public ReservoirSampler(final Random random) { - this.random = random; + public static synchronized ReservoirSampler getInstance() { + if (instance == null) { + instance = new ReservoirSampler(); + } + return instance; } /** @@ -55,9 +45,11 @@ public ReservoirSampler(final Random random) { * @return an array of sampled indices. */ @Override - public int[] sample(final int totalNumberOfVectors, final int sampleSize) { + public BitSet sample(final int totalNumberOfVectors, final int sampleSize) { if (totalNumberOfVectors <= sampleSize) { - return IntStream.range(0, totalNumberOfVectors).toArray(); + BitSet bitSet = new BitSet(totalNumberOfVectors); + bitSet.set(0, totalNumberOfVectors); + return bitSet; } return reservoirSampleIndices(totalNumberOfVectors, sampleSize); } @@ -67,19 +59,30 @@ public int[] sample(final int totalNumberOfVectors, final int sampleSize) { * This method ensures that each index in the range [0, numVectors) has an equal probability * of being included in the sample. * + * Reservoir sampling is particularly useful for selecting a random sample from a large or unknown-sized dataset. + * For more information on the algorithm, see the following link: + * Reservoir Sampling - Wikipedia + * * @param numVectors the total number of vectors. * @param sampleSize the number of indices to sample. - * @return an array of sampled indices. + * @return a BitSet representing the sampled indices. */ - private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) { - int[] indices = IntStream.range(0, sampleSize).toArray(); + private BitSet reservoirSampleIndices(final int numVectors, final int sampleSize) { + int[] indices = new int[sampleSize]; + for (int i = 0; i < sampleSize; i++) { + indices[i] = i; + } for (int i = sampleSize; i < numVectors; i++) { - int j = random.nextInt(i + 1); + int j = ThreadLocalRandom.current().nextInt(i + 1); if (j < sampleSize) { indices[j] = i; } } - Arrays.sort(indices); - return indices; + // Using BitSet to track the presence of indices + BitSet bitSet = new BitSet(numVectors); + for (int i = 0; i < sampleSize; i++) { + bitSet.set(indices[i]); + } + return bitSet; } } diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java index 9021073b4..17834cf04 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -5,6 +5,23 @@ package org.opensearch.knn.quantization.sampler; +import java.util.BitSet; + +/** + * The Sampler interface defines the contract for sampling strategies + * used in various quantization processes. Implementations of this + * interface should provide specific strategies for selecting a sample + * from a given set of vectors. + */ public interface Sampler { - int[] sample(int totalNumberOfVectors, int sampleSize); + + /** + * Samples a subset of indices from the total number of vectors. + * + * @param totalNumberOfVectors the total number of vectors available. + * @param sampleSize the number of vectors to be sampled. + * @return an array of integers representing the indices of the sampled vectors. + * @throws IllegalArgumentException if the sample size is greater than the total number of vectors. + */ + BitSet sample(int totalNumberOfVectors, int sampleSize); } diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java new file mode 100644 index 000000000..cd9b301df --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +/** + * SamplerType is an enumeration of the different types of samplers that can be created by the factory. + */ +public enum SamplerType { + RESERVOIR, // Represents a reservoir sampling strategy + // Add more enum values here for additional sampler types +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java index be228fe6f..80fe5bdae 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java @@ -5,28 +5,16 @@ package org.opensearch.knn.quantization.sampler; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + /** * SamplingFactory is a factory class for creating instances of Sampler. * It uses the factory design pattern to encapsulate the creation logic for different types of samplers. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) public final class SamplingFactory { - /** - * Private constructor to prevent instantiation of this class. - * The class is not meant to be instantiated, as it provides static methods only. - */ - private SamplingFactory() { - - } - - /** - * SamplerType is an enumeration of the different types of samplers that can be created by the factory. - */ - public enum SamplerType { - RESERVOIR, // Represents a reservoir sampling strategy - // Add more enum values here for additional sampler types - } - /** * Creates and returns a Sampler instance based on the specified SamplerType. * @@ -37,7 +25,7 @@ public enum SamplerType { public static Sampler getSampler(final SamplerType samplerType) { switch (samplerType) { case RESERVOIR: - return new ReservoirSampler(); + return ReservoirSampler.getInstance(); // Add more cases for different samplers here default: throw new IllegalArgumentException("Unsupported sampler type: " + samplerType); diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java deleted file mode 100644 index 5d99a892f..000000000 --- a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.knn.quantization.util; - -import lombok.experimental.UtilityClass; - -import java.util.List; - -/** - * Utility class for bit packing operations. - * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission. - */ -@UtilityClass -public class BitPacker { - - /** - * Packs the list of bit arrays into a single byte array. - * Each byte in the resulting array contains up to 8 bits from the bit arrays, packed from left to right. - * - * @param bitArrays the list of bit arrays to be packed. Each bit array should contain only 0s and 1s. - * @return a byte array containing the packed bits. - * @throws IllegalArgumentException if the bitArrays list is empty, if any bit array is null, or if bit arrays have inconsistent lengths. - */ - public static byte[] packBits(List bitArrays) { - if (bitArrays.isEmpty()) { - throw new IllegalArgumentException("The list of bit arrays cannot be empty."); - } - - int bitArrayLength = bitArrays.get(0).length; - int bitLength = bitArrays.size() * bitArrayLength; - int byteLength = (bitLength + 7) / 8; - byte[] packedArray = new byte[byteLength]; - - int bitPosition = 0; - for (byte[] bitArray : bitArrays) { - if (bitArray == null) { - throw new IllegalArgumentException("Bit array cannot be null."); - } - if (bitArray.length != bitArrayLength) { - throw new IllegalArgumentException("All bit arrays must have the same length."); - } - - for (byte bit : bitArray) { - int byteIndex = bitPosition / 8; - int bitIndex = 7 - (bitPosition % 8); - if (bit == 1) { - packedArray[byteIndex] |= (1 << bitIndex); - } - bitPosition++; - } - } - - return packedArray; - } -} diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java deleted file mode 100644 index 89b3b67bd..000000000 --- a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.util; - -import lombok.experimental.UtilityClass; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; - -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; -import java.io.IOException; -import java.io.ByteArrayInputStream; -import java.io.ObjectInputStream; - -/** - * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing - * QuantizationState objects along with their specific data. - */ -@UtilityClass -public class QuantizationStateSerializer { - - /** - * A functional interface for deserializing specific data associated with a QuantizationState. - */ - @FunctionalInterface - public interface SerializableDeserializer { - QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); - } - - /** - * Serializes the QuantizationState and specific data into a byte array. - * - * @param state The QuantizationState to serialize. - * @param specificData The specific data related to the state, to be serialized. - * @return A byte array representing the serialized state and specific data. - * @throws IOException If an I/O error occurs during serialization. - */ - public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { - byte[] parentBytes = serializeParentParams(state.getQuantizationParams()); - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeInt(parentBytes.length); // Write the length of the parent bytes - out.write(parentBytes); // Write the parent bytes - out.writeObject(specificData); // Write the specific data - out.flush(); - return bos.toByteArray(); - } - } - - /** - * Deserializes a QuantizationState and its specific data from a byte array. - * - * @param bytes The byte array containing the serialized data. - * @param specificDataDeserializer The deserializer for the specific data associated with the state. - * @return The deserialized QuantizationState including its specific data. - * @throws IOException If an I/O error occurs during deserialization. - * @throws ClassNotFoundException If the class of the serialized object cannot be found. - */ - public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer) throws IOException, - ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { - int parentLength = in.readInt(); - // Read the length of the parent bytes - byte[] parentBytes = new byte[parentLength]; - in.readFully(parentBytes); // Read the parent bytes - QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params - Serializable specificData = (Serializable) in.readObject(); // Read the specific data - return specificDataDeserializer.deserialize(parentParams, specificData); - } - } - - /** - * Serializes the parent parameters of the QuantizationState into a byte array. - * - * @param params The QuantizationParams to serialize. - * @return A byte array representing the serialized parent parameters. - * @throws IOException If an I/O error occurs during serialization. - */ - private static byte[] serializeParentParams(QuantizationParams params) throws IOException { - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeObject(params); - out.flush(); - return bos.toByteArray(); - } - } - - /** - * Deserializes the parent parameters of the QuantizationState from a byte array. - * - * @param bytes The byte array containing the serialized parent parameters. - * @return The deserialized QuantizationParams. - * @throws IOException If an I/O error occurs during deserialization. - * @throws ClassNotFoundException If the class of the serialized object cannot be found. - */ - private static QuantizationParams deserializeParentParams(byte[] bytes) throws IOException, ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { - return (QuantizationParams) in.readObject(); - } - } -} diff --git a/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java b/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java new file mode 100644 index 000000000..7746305ab --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.util; + +import lombok.experimental.UtilityClass; + +/** + * Utility class to manage version information in a thread-safe manner using ThreadLocal storage. + * This class ensures that version information is available within the current thread context. + */ +@UtilityClass +public class VersionContext { + + /** + * ThreadLocal storage for version information. + * This allows each thread to have its own version information without interference. + */ + private final ThreadLocal versionHolder = new ThreadLocal<>(); + + /** + * Sets the version for the current thread. + * + * @param version the version to be set. + */ + public void setVersion(int version) { + versionHolder.set(version); + } + + /** + * Gets the version for the current thread. + * + * @return the version for the current thread. + */ + public int getVersion() { + return versionHolder.get(); + } + + /** + * Clears the version for the current thread. + */ + public void clear() { + versionHolder.remove(); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java deleted file mode 100644 index 598a07867..000000000 --- a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -import org.opensearch.knn.KNNTestCase; - -public class QuantizationTypeTests extends KNNTestCase { - - public void testQuantizationTypeValues() { - QuantizationType[] expectedValues = { QuantizationType.SPACE, QuantizationType.VALUE }; - assertArrayEquals(expectedValues, QuantizationType.values()); - } - - public void testQuantizationTypeValueOf() { - assertEquals(QuantizationType.SPACE, QuantizationType.valueOf("SPACE")); - assertEquals(QuantizationType.VALUE, QuantizationType.valueOf("VALUE")); - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java similarity index 74% rename from src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java rename to src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java index 3498114a6..815f81071 100644 --- a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java +++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java @@ -7,13 +7,12 @@ import org.opensearch.knn.KNNTestCase; -public class SQTypesTests extends KNNTestCase { +public class ScalarQuantizationTypeTests extends KNNTestCase { public void testSQTypesValues() { ScalarQuantizationType[] expectedValues = { ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.TWO_BIT, - ScalarQuantizationType.FOUR_BIT, - ScalarQuantizationType.UNSUPPORTED_TYPE }; + ScalarQuantizationType.FOUR_BIT }; assertArrayEquals(expectedValues, ScalarQuantizationType.values()); } @@ -21,6 +20,5 @@ public void testSQTypesValueOf() { assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT")); assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT")); assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT")); - assertEquals(ScalarQuantizationType.UNSUPPORTED_TYPE, ScalarQuantizationType.valueOf("UNSUPPORTED_TYPE")); } } diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java deleted file mode 100644 index 3da665630..000000000 --- a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -import org.opensearch.knn.KNNTestCase; - -public class ValueQuantizationTypeTests extends KNNTestCase { - public void testValueQuantizationTypeValues() { - ValueQuantizationType[] expectedValues = { ValueQuantizationType.SCALAR }; - assertArrayEquals(expectedValues, ValueQuantizationType.values()); - } - - public void testValueQuantizationTypeValueOf() { - assertEquals(ValueQuantizationType.SCALAR, ValueQuantizationType.valueOf("SCALAR")); - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java index 42fc18eba..34aa907b8 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -8,7 +8,7 @@ import org.junit.Before; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -27,7 +27,7 @@ public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessEx } public void test_Lazy_Registration() { - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); assertFalse(isRegisteredFieldAccessible()); Quantizer quantizer = QuantizerFactory.getQuantizer(params); assertTrue(quantizer instanceof OneBitScalarQuantizer); @@ -35,33 +35,23 @@ public void test_Lazy_Registration() { } public void testGetQuantizer_withOneBitSQParams() { - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); Quantizer quantizer = QuantizerFactory.getQuantizer(params); assertTrue(quantizer instanceof OneBitScalarQuantizer); } public void testGetQuantizer_withTwoBitSQParams() { - SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); Quantizer quantizer = QuantizerFactory.getQuantizer(params); assertTrue(quantizer instanceof MultiBitScalarQuantizer); } public void testGetQuantizer_withFourBitSQParams() { - SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); Quantizer quantizer = QuantizerFactory.getQuantizer(params); assertTrue(quantizer instanceof MultiBitScalarQuantizer); } - public void testGetQuantizer_withUnsupportedType() { - SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); - try { - QuantizerFactory.getQuantizer(params); - fail("Expected IllegalArgumentException"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("No quantizer registered for type identifier")); - } - } - public void testGetQuantizer_withNullParams() { try { QuantizerFactory.getQuantizer(null); @@ -73,7 +63,7 @@ public void testGetQuantizer_withNullParams() { public void testConcurrentRegistration() throws InterruptedException { Runnable task = () -> { - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); QuantizerFactory.getQuantizer(params); }; diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index 7f53dae8c..2d743f883 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -7,9 +7,8 @@ import org.junit.BeforeClass; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.quantization.enums.QuantizationType; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -18,49 +17,56 @@ public class QuantizerRegistryTests extends KNNTestCase { @BeforeClass public static void setup() { - // Register the quantizers for testing with enums - QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new); QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.TWO_BIT, - () -> new MultiBitScalarQuantizer(2) + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() ); QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.FOUR_BIT, - () -> new MultiBitScalarQuantizer(4) + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), + new MultiBitScalarQuantizer(2) + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), + new MultiBitScalarQuantizer(4) ); } public void testRegisterAndGetQuantizer() { // Test for OneBitScalarQuantizer - SQParams oneBitParams = new SQParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); Quantizer oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer); // Test for MultiBitScalarQuantizer (2-bit) - SQParams twoBitParams = new SQParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); Quantizer twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer); + assertEquals(2, ((MultiBitScalarQuantizer) twoBitQuantizer).getBitsPerCoordinate()); // Test for MultiBitScalarQuantizer (4-bit) - SQParams fourBitParams = new SQParams(ScalarQuantizationType.FOUR_BIT); + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer); + assertEquals(4, ((MultiBitScalarQuantizer) fourBitQuantizer).getBitsPerCoordinate()); } - public void testGetQuantizer_withUnsupportedTypeIdentifier() { - // Create SQParams with an unsupported type identifier - SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered + public void testQuantizerRegistryIsSingleton() { + // Ensure the same instance is returned for the same type identifier + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Quantizer firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + Quantizer secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + assertSame(firstOneBitQuantizer, secondOneBitQuantizer); - // Expect IllegalArgumentException when requesting a quantizer with unsupported params - IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, - () -> { QuantizerRegistry.getQuantizer(params); } - ); + // Ensure the same instance is returned for the same type identifier (2-bit) + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + Quantizer firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + Quantizer secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + assertSame(firstTwoBitQuantizer, secondTwoBitQuantizer); - assertTrue(exception.getMessage().contains("No quantizer registered for type identifier")); + // Ensure the same instance is returned for the same type identifier (4-bit) + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + Quantizer firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + Quantizer secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + assertSame(firstFourBitQuantizer, secondFourBitQuantizer); } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java index ebe6bf6bd..974f58637 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -7,7 +7,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; @@ -16,7 +16,7 @@ public class QuantizationStateSerializerTests extends KNNTestCase { public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException { - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = new float[] { 0.1f, 0.2f, 0.3f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); @@ -28,14 +28,16 @@ public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IO } public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException { - SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } }; MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); byte[] serialized = state.toByteArray(); MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); - assertArrayEquals(thresholds, deserialized.getThresholds()); + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f); + } assertEquals(params, deserialized.getQuantizationParams()); } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 54e304732..5a6b4b1db 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -8,53 +8,76 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.Version; +import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; public class QuantizationStateTests extends KNNTestCase { public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + // Set the version for serialization + VersionContext.setVersion(Version.CURRENT.id); + + // Serialize byte[] serializedState = state.toByteArray(); + + // Deserialize OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); - float delta = 0.0001f; + float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); - assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { - SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } }; MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + // Set the version for serialization + VersionContext.setVersion(Version.CURRENT.id); + + // Serialize byte[] serializedState = state.toByteArray(); + + // Deserialize MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState); - float delta = 0.0001f; + float delta = 0.0001f; for (int i = 0; i < thresholds.length; i++) { assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i], delta); } - assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } - public void testDefaultQuantizationStateSerialization() throws IOException, ClassNotFoundException { - SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); + public void testSerializationWithDifferentVersions() throws IOException, ClassNotFoundException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); - DefaultQuantizationState state = new DefaultQuantizationState(params); + // Simulate an older version + VersionContext.setVersion(Version.V_2_0_0.id); + // Serialize byte[] serializedState = state.toByteArray(); - DefaultQuantizationState deserializedState = DefaultQuantizationState.fromByteArray(serializedState); - assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + // Update to a new version and deserialize + VersionContext.setVersion(Version.CURRENT.id); + OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + + float delta = 0.0001f; + assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java index 231da3dfe..ad6a44686 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -7,12 +7,14 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + public class MultiBitScalarQuantizerTests extends KNNTestCase { public void testTrain_twoBit() { @@ -21,10 +23,8 @@ public void testTrain_twoBit() { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); - int[] sampledIndices = { 0, 1, 2 }; - SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); TrainingRequest request = new MockTrainingRequest(params, vectors); - request.setSampledIndices(sampledIndices); QuantizationState state = twoBitQuantizer.train(request); assertTrue(state instanceof MultiBitScalarQuantizationState); @@ -39,10 +39,8 @@ public void testTrain_fourBit() { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; - int[] sampledIndices = { 0, 1, 2 }; - SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); TrainingRequest request = new MockTrainingRequest(params, vectors); - request.setSampledIndices(sampledIndices); QuantizationState state = fourBitQuantizer.train(request); assertTrue(state instanceof MultiBitScalarQuantizationState); @@ -51,19 +49,19 @@ public void testTrain_fourBit() { assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds } - public void testQuantize_twoBit() { + public void testQuantize_twoBit() throws IOException { MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; float[][] thresholds = { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; - SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); - QuantizationOutput output = twoBitQuantizer.quantize(vector, state); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + twoBitQuantizer.quantize(vector, state, output); assertNotNull(output.getQuantizedVector()); - assertEquals(2, output.getQuantizedVector().length); } - public void testQuantize_fourBit() { + public void testQuantize_fourBit() throws IOException { MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; float[][] thresholds = { @@ -71,38 +69,33 @@ public void testQuantize_fourBit() { { 1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f }, { 1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f }, { 1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f } }; - SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); - QuantizationOutput output = fourBitQuantizer.quantize(vector, state); - assertEquals(4, output.getQuantizedVector().length); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + fourBitQuantizer.quantize(vector, state, output); assertNotNull(output.getQuantizedVector()); } - public void testQuantize_withNullVector() { + public void testQuantize_withNullVector() throws IOException { MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); expectThrows( IllegalArgumentException.class, () -> twoBitQuantizer.quantize( null, - new MultiBitScalarQuantizationState(new SQParams(ScalarQuantizationType.TWO_BIT), new float[2][8]) + new MultiBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT), new float[2][8]), + output ) ); } - public void testQuantize_withInvalidState() { - MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); - float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; - QuantizationState invalidState = new MockInvalidQuantizationState(); - expectThrows(IllegalArgumentException.class, () -> twoBitQuantizer.quantize(vector, invalidState)); - } - // Mock classes for testing private static class MockTrainingRequest extends TrainingRequest { private final float[][] vectors; - public MockTrainingRequest(SQParams params, float[][] vectors) { - super(params, vectors.length); + public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) { + super(vectors.length); this.vectors = vectors; } @@ -111,16 +104,4 @@ public float[] getVectorByDocId(int docId) { return vectors[docId]; } } - - private static class MockInvalidQuantizationState implements QuantizationState { - @Override - public SQParams getQuantizationParams() { - return new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); - } - - @Override - public byte[] toByteArray() { - return new byte[0]; - } - } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index 12e8a43a8..8372ac5d2 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -7,22 +7,27 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; -import org.opensearch.knn.quantization.util.QuantizerHelper; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.BitSet; public class OneBitScalarQuantizerTests extends KNNTestCase { public void testTrain_withTrainingRequired() { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); - TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(vectors.length) { @Override public float[] getVectorByDocId(int docId) { return vectors[docId]; @@ -32,88 +37,107 @@ public float[] getVectorByDocId(int docId) { QuantizationState state = quantizer.train(originalRequest); assertTrue(state instanceof OneBitScalarQuantizationState); - float[] mean = ((OneBitScalarQuantizationState) state).getMeanThresholds(); - assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f); + float[] meanThresholds = ((OneBitScalarQuantizationState) state).getMeanThresholds(); + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); } - public void testQuantize_withState() { + public void testQuantize_withState() throws IOException { float[] vector = { 3.0f, 6.0f, 9.0f }; float[] thresholds = { 4.0f, 5.0f, 6.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + thresholds + ); OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); - QuantizationOutput output = quantizer.quantize(vector, state); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + quantizer.quantize(vector, state, output); assertNotNull(output); byte[] expectedPackedBits = new byte[] { 0b01100000 }; // or 96 in decimal assertArrayEquals(expectedPackedBits, output.getQuantizedVector()); } - public void testQuantize_withNullVector() { + public void testQuantize_withNullVector() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( - new SQParams(ScalarQuantizationType.ONE_BIT), + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), new float[] { 0.0f } ); - expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state)); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state, output)); } - public void testQuantize_withInvalidState() { + public void testQuantize_withInvalidState() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); float[] vector = { 1.0f, 2.0f, 3.0f }; QuantizationState invalidState = new QuantizationState() { @Override - public SQParams getQuantizationParams() { - return new SQParams(ScalarQuantizationType.ONE_BIT); + public ScalarQuantizationParams getQuantizationParams() { + return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); } @Override public byte[] toByteArray() { return new byte[0]; } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + // no-op + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // no-op + } }; - expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState)); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState, output)); } - public void testQuantize_withMismatchedDimensions() { + public void testQuantize_withMismatchedDimensions() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); float[] vector = { 1.0f, 2.0f, 3.0f }; float[] thresholds = { 4.0f, 5.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds); - - expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state)); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + thresholds + ); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output)); } public void testCalculateMean() { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); - TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override public float[] getVectorByDocId(int docId) { return vectors[docId]; } }; - Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); - int[] sampledIndices = sampler.sample(vectors.length, 3); - float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices); - assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f); + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + BitSet sampledIndices = sampler.sample(vectors.length, 3); + float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices); + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); } public void testCalculateMean_withNullVector() { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, null, { 7.0f, 8.0f, 9.0f } }; - SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); - TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override public float[] getVectorByDocId(int docId) { return vectors[docId]; } }; - Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); - int[] sampledIndices = sampler.sample(vectors.length, 3); - expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices)); + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + BitSet sampledIndices = sampler.sample(vectors.length, 3); + expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices)); } } diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java index 4d3345289..e930aef04 100644 --- a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -7,76 +7,55 @@ import org.opensearch.knn.KNNTestCase; -import java.util.Arrays; -import java.util.stream.IntStream; +import java.util.BitSet; public class ReservoirSamplerTests extends KNNTestCase { public void testSampleLessThanSampleSize() { - ReservoirSampler sampler = new ReservoirSampler(); + ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 5; int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); - assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + BitSet expectedIndices = new BitSet(totalNumberOfVectors); + expectedIndices.set(0, totalNumberOfVectors); + assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleEqualToSampleSize() { - ReservoirSampler sampler = new ReservoirSampler(); + ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 10; int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); - assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); - } - - public void testSampleGreaterThanSampleSize() { - ReservoirSampler sampler = new ReservoirSampler(12345); // Fixed seed for reproducibility - int totalNumberOfVectors = 100; - int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(sampleSize, sampledIndices.length); - assertTrue(Arrays.stream(sampledIndices).allMatch(i -> i >= 0 && i < totalNumberOfVectors)); - } - - public void testSampleReproducibility() { - long seed = 12345L; - ReservoirSampler sampler1 = new ReservoirSampler(seed); - ReservoirSampler sampler2 = new ReservoirSampler(seed); - int totalNumberOfVectors = 100; - int sampleSize = 10; - - int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); - int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); - - assertArrayEquals(sampledIndices1, sampledIndices2); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + BitSet expectedIndices = new BitSet(totalNumberOfVectors); + expectedIndices.set(0, totalNumberOfVectors); + assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleRandomness() { - ReservoirSampler sampler1 = new ReservoirSampler(); - ReservoirSampler sampler2 = new ReservoirSampler(); + ReservoirSampler sampler1 = ReservoirSampler.getInstance(); + ReservoirSampler sampler2 = ReservoirSampler.getInstance(); int totalNumberOfVectors = 100; int sampleSize = 10; - int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); - int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + BitSet sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + BitSet sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); - assertNotEquals(Arrays.toString(sampledIndices1), Arrays.toString(sampledIndices2)); + assertNotEquals(sampledIndices1, sampledIndices2); } public void testEdgeCaseZeroVectors() { - ReservoirSampler sampler = new ReservoirSampler(); + ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 0; int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.length); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, sampledIndices.cardinality()); } public void testEdgeCaseZeroSampleSize() { - ReservoirSampler sampler = new ReservoirSampler(); + ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 10; int sampleSize = 0; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.length); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, sampledIndices.cardinality()); } } diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java index ca72c1c5e..db8772b70 100644 --- a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java @@ -9,7 +9,7 @@ public class SamplingFactoryTests extends KNNTestCase { public void testGetSampler_withReservoir() { - Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); assertTrue(sampler instanceof ReservoirSampler); } diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java deleted file mode 100644 index c91c7177b..000000000 --- a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.util; - -import org.opensearch.knn.KNNTestCase; - -import java.util.Arrays; -import java.util.List; - -public class BitPackingUtilsTests extends KNNTestCase { - - public void testPackBits() { - List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, new byte[] { 1, 0, 1, 0, 0, 1, 0, 0 }); - - byte[] expectedPackedArray = new byte[] { (byte) 0b01011011, (byte) 0b10100100 }; - byte[] packedArray = BitPacker.packBits(bitArrays); - - assertArrayEquals(expectedPackedArray, packedArray); - } - - public void testPackBitsEmptyList() { - IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(Arrays.asList()); }); - assertEquals("The list of bit arrays cannot be empty.", exception.getMessage()); - } - - public void testPackBitsNullBitArray() { - List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, null); - - IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); }); - assertEquals("Bit array cannot be null.", exception.getMessage()); - } - - public void testPackBitsInconsistentLength() { - List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, new byte[] { 1, 0, 1 }); - - IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); }); - assertEquals("All bit arrays must have the same length.", exception.getMessage()); - } - - public void testPackBitsEdgeCaseSingleBitArray() { - List bitArrays = Arrays.asList(new byte[] { 1 }); - - byte[] expectedPackedArray = new byte[] { (byte) 0b10000000 }; - byte[] packedArray = BitPacker.packBits(bitArrays); - - assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray); - } - - public void testPackBitsEdgeCaseSingleBit() { - List bitArrays = Arrays.asList(new byte[] { 1, 0, 1, 0, 1, 0, 1, 0 }, new byte[] { 1, 1, 1, 1, 1, 1, 1, 1 }); - - byte[] expectedPackedArray = new byte[] { (byte) 0b10101010, (byte) 0b11111111 }; - byte[] packedArray = BitPacker.packBits(bitArrays); - - assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray); - } - - public void testPackBits_emptyArray() { - List bitArrays = Arrays.asList(); - expectThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); }); - ; - } -}