From bc2879b5bb9c5c4e4c7e05911b185bc6f9d457b1 Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Thu, 8 Aug 2024 13:18:10 -0700 Subject: [PATCH] Implemented Serlization using Writable Signed-off-by: VIKASH TIWARI --- .../factory/QuantizerRegistrar.java | 2 +- .../factory/QuantizerRegistry.java | 7 +- .../BinaryQuantizationOutput.java | 62 ++++++-- .../QuantizationOutput.java | 8 +- .../QuantizationParams.java | 4 +- .../ScalarQuantizationParams.java | 75 +++------ .../DefaultQuantizationState.java | 55 ++----- .../MultiBitScalarQuantizationState.java | 83 ++++------ .../OneBitScalarQuantizationState.java | 58 ++----- .../quantizationState/QuantizationState.java | 4 +- .../QuantizationStateSerializer.java | 34 ++--- .../knn/quantization/quantizer/BitPacker.java | 143 ++++++++++++++++++ .../quantizer/MultiBitScalarQuantizer.java | 73 ++++++--- .../quantizer/OneBitScalarQuantizer.java | 20 +-- .../quantizer/QuantizerHelper.java | 29 ++-- .../sampler/ReservoirSampler.java | 27 ++-- .../knn/quantization/sampler/Sampler.java | 2 +- .../knn/quantization/util/VersionContext.java | 47 ------ .../enums/ScalarQuantizationTypeTests.java | 11 ++ .../factory/QuantizerFactoryTests.java | 41 +---- .../factory/QuantizerRegistryTests.java | 12 ++ .../QuantizationStateSerializerTests.java | 17 ++- .../QuantizationStateTests.java | 35 ++--- .../quantizer/OneBitScalarQuantizerTests.java | 16 +- .../sampler/ReservoirSamplerTests.java | 36 ++--- 25 files changed, 459 insertions(+), 442 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java delete mode 100644 src/main/java/org/opensearch/knn/quantization/util/VersionContext.java diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java index 889130c9a..7b542aea0 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java @@ -20,7 +20,7 @@ final class QuantizerRegistrar { /** - * Registers default quantizers if not already registered. + * Registers default quantizers *

* This method is synchronized to ensure that registration occurs only once, * even in a multi-threaded environment. diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java index 50e1158c2..ac266f547 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -30,8 +30,11 @@ final class QuantizerRegistry { * @param quantizer an instance of the quantizer */ static void register(final String paramIdentifier, final Quantizer quantizer) { - // Ensure that the quantizer for this identifier is registered only once - registry.putIfAbsent(paramIdentifier, quantizer); + // Check if the quantizer is already registered for the given identifier + if (registry.putIfAbsent(paramIdentifier, quantizer) != null) { + // Throw an exception if a quantizer is already registered + throw new IllegalArgumentException("Quantizer already registered for identifier: " + paramIdentifier); + } } /** diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java index dbf8e5bf9..206f4e951 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -5,11 +5,10 @@ package org.opensearch.knn.quantization.models.quantizationOutput; -import lombok.NoArgsConstructor; import lombok.Getter; +import lombok.NoArgsConstructor; -import java.io.ByteArrayOutputStream; -import java.io.IOException; +import java.util.Arrays; /** * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. @@ -18,23 +17,62 @@ @NoArgsConstructor public class BinaryQuantizationOutput implements QuantizationOutput { @Getter - private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + private byte[] quantizedVector; /** - * Updates the quantized vector with a new byte array. + * Prepares the quantized vector array based on the provided parameters and returns it for direct modification. + * This method ensures that the internal byte array is appropriately sized and cleared before being used. + * + *

+ * The method accepts two parameters: + *

+ *

+ * + *

+ * If the existing quantized vector is either null or not the same size as the required byte array, + * a new byte array is allocated. Otherwise, the existing array is cleared (i.e., all bytes are set to zero). + *

* - * @param newQuantizedVector the new quantized vector represented as a byte array. + *

+ * This method is designed to be used in conjunction with a bit-packing utility that writes quantized values directly + * into the returned byte array. + *

+ * + * @param params an array of parameters, where the first parameter is the number of bits per coordinate (int), + * and the second parameter is the length of the vector (int). + * @return the prepared and writable quantized vector as a byte array. + * @throws IllegalArgumentException if the parameters are not as expected (e.g., missing or not integers). */ - public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException { - if (newQuantizedVector == null || newQuantizedVector.length == 0) { - throw new IllegalArgumentException("Quantized vector cannot be null or empty"); + @Override + public byte[] prepareAndGetWritableQuantizedVector(Object... params) { + if (params.length != 2 || !(params[0] instanceof Integer) || !(params[1] instanceof Integer)) { + throw new IllegalArgumentException("Expected two integer parameters: bitsPerCoordinate and vectorLength"); } - byteArrayOutputStream.reset(); - byteArrayOutputStream.write(newQuantizedVector); + int bitsPerCoordinate = (int) params[0]; + int vectorLength = (int) params[1]; + int totalBits = bitsPerCoordinate * vectorLength; + int byteLength = (totalBits + 7) >> 3; + + if (this.quantizedVector == null || this.quantizedVector.length != byteLength) { + this.quantizedVector = new byte[byteLength]; + } else { + Arrays.fill(this.quantizedVector, (byte) 0); + } + + return this.quantizedVector; } + + /** + * Returns the quantized vector. + * + * @return the quantized vector byte array. + */ @Override public byte[] getQuantizedVector() { - return byteArrayOutputStream.toByteArray(); + return quantizedVector; } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java index 8f01a0594..c2dd12cc8 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -21,10 +21,10 @@ public interface QuantizationOutput { T getQuantizedVector(); /** - * Updates the quantized vector with new data. + * Prepares and returns the writable quantized vector for direct modification. * - * @param newQuantizedVector the new quantized vector data. - * @throws IOException if an I/O error occurs during the update. + * @param params the parameters needed for preparing the quantized vector. + * @return the prepared and writable quantized vector. */ - void updateQuantizedVector(T newQuantizedVector) throws IOException; + T prepareAndGetWritableQuantizedVector(Object... params); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java index 88b22c4fd..4f2ee36c5 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -5,7 +5,7 @@ package org.opensearch.knn.quantization.models.quantizationParams; -import java.io.Externalizable; +import org.opensearch.core.common.io.stream.Writeable; /** * Interface for quantization parameters. @@ -14,7 +14,7 @@ * Implementations of this interface are expected to provide specific configurations * for various quantization strategies. */ -public interface QuantizationParams extends Externalizable { +public interface QuantizationParams extends Writeable { /** * Provides a unique identifier for the quantization parameters. * This identifier is typically a combination of the quantization type diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java index c7c24062d..d602fb577 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java @@ -9,15 +9,15 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; import java.util.Locale; /** - * The SQParams class represents the parameters specific to scalar quantization (SQ). + * The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ). * This class implements the QuantizationParams interface and includes the type of scalar quantization. */ @Getter @@ -39,67 +39,40 @@ public static String generateTypeIdentifier(ScalarQuantizationType sqType) { } /** - * Serializes the SQParams object to an external output. - * This method writes the scalar quantization type to the output stream. + * Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type. + * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. * - * @param out the ObjectOutput to write the object to. - * @throws IOException if an I/O error occurs during serialization. + * @return A string representing the unique type identifier. */ @Override - public void writeExternal(ObjectOutput out) throws IOException { - // The version is already written by the parent state class, no need to write it here again - // Retrieve the current version from VersionContext - // This context will be used by other classes involved in the serialization process. - // Example: - // int version = VersionContext.getVersion(); // Get the current version from VersionContext - // Any Version Specific logic can be wriiten based on Version - out.writeObject(sqType); + public String getTypeIdentifier() { + return generateIdentifier(sqType.getId()); + } + + private static String generateIdentifier(int id) { + return "ScalarQuantizationParams_" + id; } /** - * Deserializes the SQParams object from an external input with versioning. - * This method reads the scalar quantization type and new field from the input stream based on the version. + * Writes the object to the output stream. + * This method is part of the Writeable interface and is used to serialize the object. * - * @param in the ObjectInput to read the object from. - * @throws IOException if an I/O error occurs during deserialization. - * @throws ClassNotFoundException if the class of the serialized object cannot be found. + * @param out the output stream to write the object to. + * @throws IOException if an I/O error occurs. */ @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - // The version is already read by the parent state class and set in VersionContext - // Retrieve the current version from VersionContext to handle version-specific deserialization logic - // int versionId = VersionContext.getVersion(); - // Version version = Version.fromId(versionId); - - sqType = (ScalarQuantizationType) in.readObject(); - - // Add version-specific deserialization logic - // For example, if new fields are added in a future version, handle them here - // This section contains conditional logic to handle different versions appropriately. - // Example: - // if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) { - // // Handle logic for versions between 1.0.0 and 2.0.0 - // // Example: Read additional fields introduced in version 1.0.0 - // // newField = in.readInt(); - // } else if (version.onOrAfter(Version.V_2_0_0)) { - // // Handle logic for versions 2.0.0 and above - // // Example: Read additional fields introduced in version 2.0.0 - // // anotherNewField = in.readFloat(); - // } + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(sqType); } /** - * Provides a unique type identifier for the SQParams, combining the SQ type. - * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. + * Reads the object from the input stream. + * This method is part of the Writeable interface and is used to deserialize the object. * - * @return A string representing the unique type identifier. + * @param in the input stream to read the object from. + * @throws IOException if an I/O error occurs. */ - @Override - public String getTypeIdentifier() { - return generateIdentifier(sqType.getId()); - } - - private static String generateIdentifier(int id) { - return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id); + public ScalarQuantizationParams(StreamInput in, int version) throws IOException { + this.sqType = in.readEnum(ScalarQuantizationType.class); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java index 3e3249c6f..33e775cad 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -9,12 +9,12 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; /** * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. @@ -25,18 +25,23 @@ @AllArgsConstructor public class DefaultQuantizationState implements QuantizationState { private QuantizationParams params; - private static final long serialVersionUID = 1L; // Version ID for serialization - /** - * Returns the quantization parameters associated with this state. - * - * @return the quantization parameters. - */ @Override public QuantizationParams getQuantizationParams() { return params; } + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + params.writeTo(out); + } + + public DefaultQuantizationState(StreamInput in) throws IOException { + int version = in.readInt(); // Read the version + this.params = new ScalarQuantizationParams(in, version); + } + /** * Serializes the quantization state to a byte array. * @@ -45,7 +50,7 @@ public QuantizationParams getQuantizationParams() { */ @Override public byte[] toByteArray() throws IOException { - return QuantizationStateSerializer.serialize(this, null); + return QuantizationStateSerializer.serialize(this); } /** @@ -57,36 +62,6 @@ public byte[] toByteArray() throws IOException { * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { - return (DefaultQuantizationState) QuantizationStateSerializer.deserialize( - bytes, - new DefaultQuantizationState(), - (parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams) - ); - } - - /** - * Writes the object to the output stream. - * This method is part of the Externalizable interface and is used to serialize the object. - * - * @param out the output stream to write the object to. - * @throws IOException if an I/O error occurs. - */ - @Override - public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(Version.CURRENT.id); // Write the version - out.writeObject(params); - } - - /** - * Reads the object from the input stream. - * This method is part of the Externalizable interface and is used to deserialize the object. - * - * @param in the input stream to read the object from. - * @throws IOException if an I/O error occurs. - * @throws ClassNotFoundException if the class of the serialized object cannot be found. - */ - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.params = (QuantizationParams) in.readObject(); + return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 095d245f2..935b4875f 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -9,12 +9,11 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; -import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; /** * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, @@ -42,7 +41,6 @@ public final class MultiBitScalarQuantizationState implements QuantizationState * Each column represents the threshold for a specific dimension in the vector space. */ private float[][] thresholds; - private static final long serialVersionUID = 1L; // Version ID for serialization @Override public ScalarQuantizationParams getQuantizationParams() { @@ -50,37 +48,31 @@ public ScalarQuantizationParams getQuantizationParams() { } /** - * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. + * This method is responsible for writing the state of the MultiBitScalarQuantizationState object to an external output. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized - * as part of the state to access the version information and implement version-specific logic if needed.

- * - *

The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable, - * ensuring that the version is available to all classes involved in the serialization process within the current thread context.

- * *
      * {@code
      * // Example usage in the writeExternal method:
-     * VersionContext.setVersion(version);
      * out.writeInt(version); // Write the version
-     * quantizationParams.writeExternal(out);
-     * out.writeInt(meanThresholds.length);
-     * for (float mean : meanThresholds) {
-     *     out.writeFloat(mean);
+     * quantizationParams.writeTo(out);
+     * out.writeInt(thresholds.length);
+     * out.writeInt(thresholds[0].length);
+     * for (float[] row : thresholds) {
+     *     for (float value : row) {
+     *         out.writeFloat(value);
+     *     }
      * }
      * }
      * 
* - * @param out the ObjectOutput to write the object to. + * @param out the StreamOutput to write the object to. * @throws IOException if an I/O error occurs during serialization. */ @Override - public void writeExternal(ObjectOutput out) throws IOException { - int version = Version.CURRENT.id; - VersionContext.setVersion(version); - out.writeInt(version); // Write the version - quantizationParams.writeExternal(out); + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); out.writeInt(thresholds.length); out.writeInt(thresholds[0].length); for (float[] row : thresholds) { @@ -91,40 +83,32 @@ public void writeExternal(ObjectOutput out) throws IOException { } /** - * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. + * This method is responsible for reading the state of the MultiBitScalarQuantizationState object from an external input. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method. - * This makes the version information available to all classes involved in the deserialization process within the current thread context.

- * - *

Classes that are part of the deserialization process can retrieve the version information using the - * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.

- * *
      * {@code
      * // Example usage in the readExternal method:
      * int version = in.readInt(); // Read the version
-     * VersionContext.setVersion(version);
-     * quantizationParams = new ScalarQuantizationParams();
-     * quantizationParams.readExternal(in); // Use readExternal of SQParams
-     * int length = in.readInt();
-     * meanThresholds = new float[length];
-     * for (int i = 0; i < length; i++) {
-     *     meanThresholds[i] = in.readFloat();
+     * quantizationParams = new ScalarQuantizationParams(in, version);
+     * int rows = in.readInt();
+     * int cols = in.readInt();
+     * thresholds = new float[rows][cols];
+     * for (int i = 0; i < rows; i++) {
+     *     for (int j = 0; j < cols; j++) {
+     *         thresholds[i][j] = in.readFloat();
+     *     }
      * }
      * }
      * 
* - * @param in the ObjectInput to read the object from. + * @param in the StreamInput to read the object from. * @throws IOException if an I/O error occurs during deserialization. * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public MultiBitScalarQuantizationState(StreamInput in) throws IOException { int version = in.readInt(); // Read the version - VersionContext.setVersion(version); - quantizationParams = new ScalarQuantizationParams(); - quantizationParams.readExternal(in); // Use readExternal of SQParams + this.quantizationParams = new ScalarQuantizationParams(in, version); int rows = in.readInt(); int cols = in.readInt(); thresholds = new float[rows][cols]; @@ -133,7 +117,6 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept thresholds[i][j] = in.readFloat(); } } - VersionContext.clear(); // Clear the version after use } /** @@ -155,7 +138,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept */ @Override public byte[] toByteArray() throws IOException { - return QuantizationStateSerializer.serialize(this, thresholds); + return QuantizationStateSerializer.serialize(this); } /** @@ -175,16 +158,8 @@ public byte[] toByteArray() throws IOException { * @param bytes the byte array containing the serialized state. * @return the deserialized MultiBitScalarQuantizationState object. * @throws IOException if an I/O error occurs during deserialization. - * @throws ClassNotFoundException if the class of a serialized object cannot be found. */ - public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { - return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize( - bytes, - new MultiBitScalarQuantizationState(), - (parentParams, thresholds) -> new MultiBitScalarQuantizationState( - (ScalarQuantizationParams) parentParams, - (float[][]) thresholds - ) - ); + public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { + return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 8ab37955e..13ea4aaba 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -9,12 +9,11 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; -import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; /** * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, @@ -38,7 +37,6 @@ public final class OneBitScalarQuantizationState implements QuantizationState { * The quantized vector will be [0, 1, 1]. */ private float[] meanThresholds; - private static final long serialVersionUID = 1L; // Version ID for serialization @Override public ScalarQuantizationParams getQuantizationParams() { @@ -49,18 +47,12 @@ public ScalarQuantizationParams getQuantizationParams() { * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized - * as part of the state to access the version information and implement version-specific logic if needed.

- * - *

The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable, - * ensuring that the version is available to all classes involved in the serialization process within the current thread context.

* *
      * {@code
      * // Example usage in the writeExternal method:
-     * VersionContext.setVersion(version);
      * out.writeInt(version); // Write the version
-     * quantizationParams.writeExternal(out);
+     * quantizationParams.writeTo(out);
      * out.writeInt(meanThresholds.length);
      * for (float mean : meanThresholds) {
      *     out.writeFloat(mean);
@@ -68,39 +60,29 @@ public ScalarQuantizationParams getQuantizationParams() {
      * }
      * 
* - * @param out the ObjectOutput to write the object to. + * @param out the StreamOutput to write the object to. * @throws IOException if an I/O error occurs during serialization. */ @Override - public void writeExternal(ObjectOutput out) throws IOException { - int version = Version.CURRENT.id; - VersionContext.setVersion(version); - out.writeInt(version); // Write the version - quantizationParams.writeExternal(out); + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); out.writeInt(meanThresholds.length); for (float mean : meanThresholds) { out.writeFloat(mean); } - VersionContext.clear(); // Clear the version after use } /** * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method. - * This makes the version information available to all classes involved in the deserialization process within the current thread context.

- * - *

Classes that are part of the deserialization process can retrieve the version information using the - * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.

* *
      * {@code
      * // Example usage in the readExternal method:
      * int version = in.readInt(); // Read the version
-     * VersionContext.setVersion(version);
-     * quantizationParams = new ScalarQuantizationParams();
-     * quantizationParams.readExternal(in); // Use readExternal of SQParams
+     * quantizationParams = new ScalarQuantizationParams(in, version);
      * int length = in.readInt();
      * meanThresholds = new float[length];
      * for (int i = 0; i < length; i++) {
@@ -109,22 +91,18 @@ public void writeExternal(ObjectOutput out) throws IOException {
      * }
      * 
* - * @param in the ObjectInput to read the object from. + * @param in the StreamInput to read the object from. * @throws IOException if an I/O error occurs during deserialization. * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public OneBitScalarQuantizationState(StreamInput in) throws IOException { int version = in.readInt(); // Read the version - VersionContext.setVersion(version); - quantizationParams = new ScalarQuantizationParams(); - quantizationParams.readExternal(in); // Use readExternal of SQParams + this.quantizationParams = new ScalarQuantizationParams(in, version); int length = in.readInt(); meanThresholds = new float[length]; for (int i = 0; i < length; i++) { meanThresholds[i] = in.readFloat(); } - VersionContext.clear(); // Clear the version after use } /** @@ -146,7 +124,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept */ @Override public byte[] toByteArray() throws IOException { - return QuantizationStateSerializer.serialize(this, meanThresholds); + return QuantizationStateSerializer.serialize(this); } /** @@ -166,16 +144,8 @@ public byte[] toByteArray() throws IOException { * @param bytes the byte array containing the serialized state. * @return the deserialized OneBitScalarQuantizationState object. * @throws IOException if an I/O error occurs during deserialization. - * @throws ClassNotFoundException if the class of a serialized object cannot be found. */ - public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { - return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize( - bytes, - new OneBitScalarQuantizationState(), - (parentParams, meanThresholds) -> new OneBitScalarQuantizationState( - (ScalarQuantizationParams) parentParams, - (float[]) meanThresholds - ) - ); + public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { + return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java index d3778fe29..e32df8bc3 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -5,16 +5,16 @@ package org.opensearch.knn.quantization.models.quantizationState; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import java.io.Externalizable; import java.io.IOException; /** * QuantizationState interface represents the state of a quantization process, including the parameters used. * This interface provides methods for serializing and deserializing the state. */ -public interface QuantizationState extends Externalizable { +public interface QuantizationState extends Writeable { /** * Returns the quantization parameters associated with this state. * diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java index 5f00d8e0c..1f378e0dc 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java @@ -6,14 +6,10 @@ package org.opensearch.knn.quantization.models.quantizationState; import lombok.experimental.UtilityClass; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; import java.io.IOException; -import java.io.ByteArrayInputStream; -import java.io.ObjectInputStream; /** * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing @@ -27,23 +23,20 @@ class QuantizationStateSerializer { */ @FunctionalInterface interface SerializableDeserializer { - QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); + QuantizationState deserialize(StreamInput in) throws IOException; } /** * Serializes the QuantizationState and specific data into a byte array. * * @param state The QuantizationState to serialize. - * @param specificData The specific data related to the state, to be serialized. * @return A byte array representing the serialized state and specific data. * @throws IOException If an I/O error occurs during serialization. */ - static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { - state.writeExternal(out); - out.writeObject(specificData); - out.flush(); - return bos.toByteArray(); + static byte[] serialize(QuantizationState state) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + state.writeTo(out); + return out.bytes().toBytesRef().bytes; } } @@ -51,18 +44,13 @@ static byte[] serialize(QuantizationState state, Serializable specificData) thro * Deserializes a QuantizationState and its specific data from a byte array. * * @param bytes The byte array containing the serialized data. - * @param stateInstance An instance of the state to call readExternal on. - * @param specificDataDeserializer The deserializer for the specific data associated with the state. + * @param deserializer The deserializer for the specific data associated with the state. * @return The deserialized QuantizationState including its specific data. * @throws IOException If an I/O error occurs during deserialization. - * @throws ClassNotFoundException If the class of the serialized object cannot be found. */ - static QuantizationState deserialize(byte[] bytes, QuantizationState stateInstance, SerializableDeserializer specificDataDeserializer) - throws IOException, ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { - stateInstance.readExternal(in); - Serializable specificData = (Serializable) in.readObject(); // Read the specific data - return specificDataDeserializer.deserialize(stateInstance.getQuantizationParams(), specificData); + static QuantizationState deserialize(byte[] bytes, SerializableDeserializer deserializer) throws IOException { + try (StreamInput in = StreamInput.wrap(bytes)) { + return deserializer.deserialize(in); } } } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java new file mode 100644 index 000000000..c54e4bf71 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import lombok.experimental.UtilityClass; + +/** + * The BitPacker class provides utility methods for quantizing floating-point vectors and packing the resulting bits + * into a pre-allocated byte array. This class supports both single-bit and multi-bit quantization scenarios, + * enabling efficient storage and transmission of quantized vectors. + * + *

+ * The methods in this class are designed to be used by quantizers that need to convert floating-point vectors + * into compact binary representations by comparing them against quantization thresholds. + *

+ * + *

+ * This class is marked as a utility class using Lombok's {@link lombok.experimental.UtilityClass} annotation, + * making it a singleton and preventing instantiation. + *

+ */ +@UtilityClass +class BitPacker { + + /** + * Quantizes a given floating-point vector and packs the resulting quantized bits into a provided byte array. + * This method operates by comparing each element of the input vector against corresponding thresholds + * and encoding the results into a compact binary format using the specified number of bits per coordinate. + * + *

+ * The method supports multi-bit quantization where each coordinate of the input vector can be represented + * by multiple bits. For example, with 2-bit quantization, each coordinate is encoded into 2 bits, allowing + * for four distinct levels of quantization per coordinate. + *

+ * + *

+ * Example: + *

+ *

+ * Consider a vector with 3 coordinates: [1.2, 3.4, 5.6] and thresholds: + *

+ *
+     * thresholds = {
+     *     {1.0, 3.0, 5.0},  // First bit thresholds
+     *     {1.5, 3.5, 5.5}   // Second bit thresholds
+     * };
+     * 
+ *

+ * If the number of bits per coordinate is 2, the quantization process will proceed as follows: + *

+ *
    + *
  • First bit comparison: + *
      + *
    • 1.2 > 1.0 -> 1
    • + *
    • 3.4 > 3.0 -> 1
    • + *
    • 5.6 > 5.0 -> 1
    • + *
    + *
  • + *
  • Second bit comparison: + *
      + *
    • 1.2 <= 1.5 -> 0
    • + *
    • 3.4 <= 3.5 -> 0
    • + *
    • 5.6 > 5.5 -> 1
    • + *
    + *
  • + *
+ *

+ * The resulting quantized bits will be 11 10 11, which is packed into the provided byte array. + * If there are fewer than 8 bits, the remaining bits in the byte are set to 0. + *

+ * + *

+ * Packing Process: + * The quantized bits are packed into the byte array. The first coordinate's bits are stored in the most + * significant positions of the first byte, followed by the second coordinate, and so on. In the example + * above, the resulting byte array will have the following binary representation: + *

+ *
+     * packedBits = [11011000] // Only the first 6 bits are used, and the last two are set to 0.
+     * 
+ * + *

Bitwise Operations Explanation:

+ *
    + *
  • byteIndex: This is calculated using byteIndex = bitPosition >> 3, which is equivalent to bitPosition / 8. It determines which byte in the byte array the current bit should be placed in.
  • + *
  • bitIndex: This is calculated using bitIndex = 7 - (bitPosition & 7), which is equivalent to 7 - (bitPosition % 8). It determines the exact bit position within the byte.
  • + *
  • Setting the bit: The bit is set using packedBits[byteIndex] |= (1 << bitIndex). This shifts a 1 into the correct bit position and ORs it with the existing byte value to set the bit.
  • + *
+ * + * @param vector the floating-point vector to be quantized. + * @param thresholds a 2D array representing the quantization thresholds. The first dimension corresponds to the number of bits per coordinate, and the second dimension corresponds to the vector's length. + * @param bitsPerCoordinate the number of bits used per coordinate, determining the granularity of the quantization. + * @param packedBits the byte array where the quantized bits will be packed. + */ + void quantizeAndPackBits(final float[] vector, final float[][] thresholds, final int bitsPerCoordinate, byte[] packedBits) { + int vectorLength = vector.length; + + for (int i = 0; i < bitsPerCoordinate; i++) { + for (int j = 0; j < vectorLength; j++) { + if (vector[j] > thresholds[i][j]) { + int bitPosition = i * vectorLength + j; + // Calculate the index of the byte in the packedBits array. + int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8 + // Calculate the bit index within the byte. + int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8) + // Set the bit at the calculated position. + packedBits[byteIndex] |= (1 << bitIndex); // Set the bit at bitIndex + } + } + } + } + + /** + * Overloaded method to quantize a vector using single-bit quantization and pack the results into a provided byte array. + * + *

+ * This method is specifically designed for one-bit quantization scenarios, where each coordinate of the + * vector is represented by a single bit indicating whether the value is above or below the threshold. + *

+ * + *

Example:

+ *

+ * If we have a vector [1.2, 3.4, 5.6] and thresholds [2.0, 3.0, 4.0], the quantization process will be: + *

+ *
    + *
  • 1.2 < 2.0 -> 0
  • + *
  • 3.4 > 3.0 -> 1
  • + *
  • 5.6 > 4.0 -> 1
  • + *
+ *

+ * The quantized vector will be [0, 1, 1]. + *

+ * + * @param vector the vector to quantize. + * @param thresholds the thresholds for quantization, where each element represents the threshold for a corresponding coordinate. + * @param packedBits the byte array where the quantized bits will be packed. + */ + void quantizeAndPackBits(final float[] vector, final float[] thresholds, byte[] packedBits) { + quantizeAndPackBits(vector, new float[][] { thresholds }, 1, packedBits); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index c6d366e5b..9de58c249 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -17,11 +17,50 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import java.io.IOException; -import java.util.BitSet; /** * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. - * It supports multiple bits per coordinate, allowing for finer granularity in quantization. + * Unlike the OneBitScalarQuantizer, which uses a single bit per dimension to represent whether a value is above + * or below a mean threshold, the MultiBitScalarQuantizer allows for multiple bits per dimension, enabling more + * granular and precise quantization. + * + *

+ * In a OneBitScalarQuantizer, each dimension of a vector is compared to a single threshold (the mean), and a single + * bit is used to indicate whether the value is above or below that threshold. This results in a very coarse + * representation where each dimension is either "on" or "off." + *

+ * + *

+ * The MultiBitScalarQuantizer, on the other hand, uses multiple thresholds per dimension. For example, in a 2-bit + * quantization scheme, three thresholds are used to divide each dimension into four possible regions. Each region + * is represented by a unique 2-bit value. This allows for a much finer representation of the data, capturing more + * nuances in the variation of each dimension. + *

+ * + *

+ * The thresholds in MultiBitScalarQuantizer are calculated based on the mean and standard deviation of the sampled + * vectors for each dimension. Here's how it works: + *

+ * + *
    + *
  • First, the mean and standard deviation are computed for each dimension across the sampled vectors.
  • + *
  • For each bit used in the quantization (e.g., 2 bits per coordinate), the thresholds are calculated + * using a linear combination of the mean and the standard deviation. The combination coefficients are + * determined by the number of bits, allowing the thresholds to split the data into equal probability regions. + *
  • + *
  • For example, in a 2-bit quantization (which divides data into four regions), the thresholds might be + * set at points corresponding to -1 standard deviation, 0 standard deviations (mean), and +1 standard deviation. + * This ensures that the data is evenly split into four regions, each represented by a 2-bit value. + *
  • + *
+ * + *

+ * The number of bits per coordinate is determined by the type of scalar quantization being applied, such as 2-bit + * or 4-bit quantization. The increased number of bits per coordinate in MultiBitScalarQuantizer allows for better + * preservation of information during the quantization process, making it more suitable for tasks where precision + * is crucial. However, this comes at the cost of increased storage and computational complexity compared to the + * simpler OneBitScalarQuantizer. + *

*/ public class MultiBitScalarQuantizer implements Quantizer { private final int bitsPerCoordinate; // Number of bits used to quantize each dimension @@ -69,19 +108,20 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - BitSet sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - int dimension = trainingRequest.getVectorByDocId(sampledIndices.nextSetBit(0)).length; + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; float[] meanArray = new float[dimension]; float[] stdDevArray = new float[dimension]; // Calculate sum, mean, and standard deviation in one pass - QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); + QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension); ScalarQuantizationParams params = (bitsPerCoordinate == 2) - ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) - : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) + : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); return new MultiBitScalarQuantizationState(params, thresholds); } + /** * Quantizes the provided vector using the provided quantization state, producing a quantized output. * The vector is quantized based on the thresholds in the quantization state. @@ -102,22 +142,9 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds[0].length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - // Directly pack bits without intermediate array - int totalBits = bitsPerCoordinate * vector.length; - int byteLength = (totalBits + 7) >> 3; // Calculate byte length needed - byte[] packedBits = new byte[byteLength]; - for (int i = 0; i < bitsPerCoordinate; i++) { - for (int j = 0; j < vector.length; j++) { - if (vector[j] > thresholds[i][j]) { - int bitPosition = i * vector.length + j; - int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8 - int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8) - packedBits[byteIndex] |= (1 << bitIndex); // Set the bit - } - } - } - - output.updateQuantizedVector(packedBits); + // Prepare and get the writable array + byte[] writableArray = output.prepareAndGetWritableQuantizedVector(bitsPerCoordinate, vector.length); + BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, writableArray); } /** diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index eab3d992e..2d39b192f 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -16,7 +16,6 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import java.io.IOException; -import java.util.BitSet; /** * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. @@ -60,11 +59,12 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - BitSet sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); } + /** * Quantizes the provided vector using the given quantization state. * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. @@ -85,19 +85,9 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds.length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - // Directly pack bits without intermediate array - int byteLength = (vector.length + 7) >> 3; // Calculate byte length needed - byte[] packedBits = new byte[byteLength]; - - for (int i = 0; i < vector.length; i++) { - if (vector[i] > thresholds[i]) { - int byteIndex = i >> 3; // Equivalent to i / 8 - int bitIndex = 7 - (i & 7); // Equivalent to 7 - (i % 8) - packedBits[byteIndex] |= (1 << bitIndex); // Set the bit - } - } - - output.updateQuantizedVector(packedBits); + // Prepare and get the writable array + byte[] writableArray = output.prepareAndGetWritableQuantizedVector(1, vector.length); + BitPacker.quantizeAndPackBits(vector, thresholds, writableArray); } /** diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index e95632606..3a8ea0c8d 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -20,21 +20,16 @@ class QuantizerHelper { /** * Calculates the mean vector from a set of sampled vectors. * - *

This method takes a {@link TrainingRequest} object and an array of sampled indices, - * retrieves the vectors corresponding to these indices, and calculates the mean vector. - * Each element of the mean vector is computed as the average of the corresponding elements - * of the sampled vectors.

- * * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. * @return A float array representing the mean vector of the sampled vectors. * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. */ - static float[] calculateMeanThresholds(TrainingRequest samplingRequest, BitSet sampledIndices) { - int totalSamples = sampledIndices.cardinality(); + static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) { + int totalSamples = sampledIndices.length; float[] mean = null; - for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + for (int docId : sampledIndices) { float[] vector = samplingRequest.getVectorByDocId(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); @@ -56,24 +51,22 @@ static float[] calculateMeanThresholds(TrainingRequest samplingRequest, } /** - * Calculates the sum, sum of squares, mean, and standard deviation for each dimension in a single pass. + * Calculates the mean and StdDev per dimension for sampled vectors. * * @param trainingRequest the request containing the data and parameters for training. * @param sampledIndices the indices of the sampled vectors. * @param meanArray the array to store the sum and then the mean of each dimension. * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension. */ - static void calculateSumMeanAndStdDev( - TrainingRequest trainingRequest, - BitSet sampledIndices, - float[] meanArray, - float[] stdDevArray + static void calculateMeanAndStdDev( + TrainingRequest trainingRequest, + int[] sampledIndices, + float[] meanArray, + float[] stdDevArray ) { - int totalSamples = sampledIndices.cardinality(); + int totalSamples = sampledIndices.length; int dimension = meanArray.length; - - // Single pass to calculate sum and sum of squares - for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + for (int docId : sampledIndices) { float[] vector = trainingRequest.getVectorByDocId(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java index 22a7bae23..4322cedd9 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -7,8 +7,10 @@ import lombok.NoArgsConstructor; +import java.util.Arrays; import java.util.BitSet; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; /** * ReservoirSampler implements the Sampler interface and provides a method for sampling @@ -45,11 +47,9 @@ public static synchronized ReservoirSampler getInstance() { * @return an array of sampled indices. */ @Override - public BitSet sample(final int totalNumberOfVectors, final int sampleSize) { + public int[] sample(final int totalNumberOfVectors, final int sampleSize) { if (totalNumberOfVectors <= sampleSize) { - BitSet bitSet = new BitSet(totalNumberOfVectors); - bitSet.set(0, totalNumberOfVectors); - return bitSet; + return IntStream.range(0, totalNumberOfVectors).toArray(); } return reservoirSampleIndices(totalNumberOfVectors, sampleSize); } @@ -65,24 +65,27 @@ public BitSet sample(final int totalNumberOfVectors, final int sampleSize) { * * @param numVectors the total number of vectors. * @param sampleSize the number of indices to sample. - * @return a BitSet representing the sampled indices. + * @return an array of sampled indices. */ - private BitSet reservoirSampleIndices(final int numVectors, final int sampleSize) { + private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) { int[] indices = new int[sampleSize]; + + // Initialize the reservoir with the first sampleSize elements for (int i = 0; i < sampleSize; i++) { indices[i] = i; } + + // Replace elements with gradually decreasing probability for (int i = sampleSize; i < numVectors; i++) { int j = ThreadLocalRandom.current().nextInt(i + 1); if (j < sampleSize) { indices[j] = i; } } - // Using BitSet to track the presence of indices - BitSet bitSet = new BitSet(numVectors); - for (int i = 0; i < sampleSize; i++) { - bitSet.set(indices[i]); - } - return bitSet; + + // Sort the sampled indices + Arrays.sort(indices); + + return indices; } } diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java index 17834cf04..828d87801 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -23,5 +23,5 @@ public interface Sampler { * @return an array of integers representing the indices of the sampled vectors. * @throws IllegalArgumentException if the sample size is greater than the total number of vectors. */ - BitSet sample(int totalNumberOfVectors, int sampleSize); + int[] sample(int totalNumberOfVectors, int sampleSize); } diff --git a/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java b/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java deleted file mode 100644 index 7746305ab..000000000 --- a/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.util; - -import lombok.experimental.UtilityClass; - -/** - * Utility class to manage version information in a thread-safe manner using ThreadLocal storage. - * This class ensures that version information is available within the current thread context. - */ -@UtilityClass -public class VersionContext { - - /** - * ThreadLocal storage for version information. - * This allows each thread to have its own version information without interference. - */ - private final ThreadLocal versionHolder = new ThreadLocal<>(); - - /** - * Sets the version for the current thread. - * - * @param version the version to be set. - */ - public void setVersion(int version) { - versionHolder.set(version); - } - - /** - * Gets the version for the current thread. - * - * @return the version for the current thread. - */ - public int getVersion() { - return versionHolder.get(); - } - - /** - * Clears the version for the current thread. - */ - public void clear() { - versionHolder.remove(); - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java index 815f81071..99621a0e5 100644 --- a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java +++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java @@ -7,6 +7,9 @@ import org.opensearch.knn.KNNTestCase; +import java.util.HashSet; +import java.util.Set; + public class ScalarQuantizationTypeTests extends KNNTestCase { public void testSQTypesValues() { ScalarQuantizationType[] expectedValues = { @@ -21,4 +24,12 @@ public void testSQTypesValueOf() { assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT")); assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT")); } + + public void testUniqueSQTypeValues() { + Set uniqueIds = new HashSet<>(); + for (ScalarQuantizationType type : ScalarQuantizationType.values()) { + boolean added = uniqueIds.add(type.getId()); + assertTrue("Duplicate value found: " + type.getId(), added); + } + } } diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java index 34aa907b8..27fc9e901 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -28,30 +28,17 @@ public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessEx public void test_Lazy_Registration() { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); assertFalse(isRegisteredFieldAccessible()); Quantizer quantizer = QuantizerFactory.getQuantizer(params); + Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); + Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); + assertTrue(quantizerFourBit instanceof MultiBitScalarQuantizer); + assertTrue(quantizerTwoBit instanceof MultiBitScalarQuantizer); assertTrue(quantizer instanceof OneBitScalarQuantizer); assertTrue(isRegisteredFieldAccessible()); } - - public void testGetQuantizer_withOneBitSQParams() { - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - Quantizer quantizer = QuantizerFactory.getQuantizer(params); - assertTrue(quantizer instanceof OneBitScalarQuantizer); - } - - public void testGetQuantizer_withTwoBitSQParams() { - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - Quantizer quantizer = QuantizerFactory.getQuantizer(params); - assertTrue(quantizer instanceof MultiBitScalarQuantizer); - } - - public void testGetQuantizer_withFourBitSQParams() { - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - Quantizer quantizer = QuantizerFactory.getQuantizer(params); - assertTrue(quantizer instanceof MultiBitScalarQuantizer); - } - public void testGetQuantizer_withNullParams() { try { QuantizerFactory.getQuantizer(null); @@ -61,20 +48,6 @@ public void testGetQuantizer_withNullParams() { } } - public void testConcurrentRegistration() throws InterruptedException { - Runnable task = () -> { - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizerFactory.getQuantizer(params); - }; - - Thread thread1 = new Thread(task); - Thread thread2 = new Thread(task); - thread1.start(); - thread2.start(); - thread1.join(); - thread2.join(); - assertTrue(isRegisteredFieldAccessible()); - } private boolean isRegisteredFieldAccessible() { try { @@ -87,4 +60,4 @@ private boolean isRegisteredFieldAccessible() { return false; } } -} +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index 2d743f883..d7c196f77 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -69,4 +69,16 @@ public void testQuantizerRegistryIsSingleton() { Quantizer secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); assertSame(firstFourBitQuantizer, secondFourBitQuantizer); } + + public void testRegisterQuantizerThrowsExceptionWhenAlreadyRegistered() { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + + // Attempt to register the same quantizer again should throw an exception + assertThrows(IllegalArgumentException.class, () -> { + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); + }); + } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java index 974f58637..c6bcfb3a2 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizationState; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -15,25 +16,33 @@ public class QuantizationStateSerializerTests extends KNNTestCase { - public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException { + public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = new float[] { 0.1f, 0.2f, 0.3f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + // Serialize byte[] serialized = state.toByteArray(); - OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized); + + // Deserialize + StreamInput in = StreamInput.wrap(serialized); + OneBitScalarQuantizationState deserialized = new OneBitScalarQuantizationState(in); assertArrayEquals(mean, deserialized.getMeanThresholds(), 0.0f); assertEquals(params, deserialized.getQuantizationParams()); } - public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException { + public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } }; MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + // Serialize byte[] serialized = state.toByteArray(); - MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); + + // Deserialize + StreamInput in = StreamInput.wrap(serialized); + MultiBitScalarQuantizationState deserialized = new MultiBitScalarQuantizationState(in); for (int i = 0; i < thresholds.length; i++) { assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f); diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 5a6b4b1db..834440bca 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -6,52 +6,44 @@ package org.opensearch.knn.quantization.quantizationState; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; -import org.opensearch.Version; -import org.opensearch.knn.quantization.util.VersionContext; - import java.io.IOException; public class QuantizationStateTests extends KNNTestCase { - public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + public void testOneBitScalarQuantizationStateSerialization() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); - // Set the version for serialization - VersionContext.setVersion(Version.CURRENT.id); - // Serialize byte[] serializedState = state.toByteArray(); // Deserialize - OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } - public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + public void testMultiBitScalarQuantizationStateSerialization() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } }; MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); - - // Set the version for serialization - VersionContext.setVersion(Version.CURRENT.id); - - // Serialize byte[] serializedState = state.toByteArray(); // Deserialize - MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState); + StreamInput in = StreamInput.wrap(serializedState); + MultiBitScalarQuantizationState deserializedState = new MultiBitScalarQuantizationState(in); float delta = 0.0001f; for (int i = 0; i < thresholds.length; i++) { @@ -60,21 +52,14 @@ public void testMultiBitScalarQuantizationStateSerialization() throws IOExceptio assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } - public void testSerializationWithDifferentVersions() throws IOException, ClassNotFoundException { + public void testSerializationWithDifferentVersions() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); - - // Simulate an older version - VersionContext.setVersion(Version.V_2_0_0.id); - - // Serialize byte[] serializedState = state.toByteArray(); - - // Update to a new version and deserialize - VersionContext.setVersion(Version.CURRENT.id); - OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index 8372ac5d2..6c81ae69c 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; @@ -17,8 +18,6 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; import java.util.BitSet; public class OneBitScalarQuantizerTests extends KNNTestCase { @@ -83,13 +82,8 @@ public byte[] toByteArray() { } @Override - public void writeExternal(ObjectOutput out) throws IOException { - // no-op - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - // no-op + public void writeTo(StreamOutput out) throws IOException { + // Empty implementation for test } }; BinaryQuantizationOutput output = new BinaryQuantizationOutput(); @@ -120,7 +114,7 @@ public float[] getVectorByDocId(int docId) { }; Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); - BitSet sampledIndices = sampler.sample(vectors.length, 3); + int[] sampledIndices = sampler.sample(vectors.length, 3); float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices); assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); } @@ -137,7 +131,7 @@ public float[] getVectorByDocId(int docId) { }; Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); - BitSet sampledIndices = sampler.sample(vectors.length, 3); + int[] sampledIndices = sampler.sample(vectors.length, 3); expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices)); } } diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java index e930aef04..88d668980 100644 --- a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -7,7 +7,8 @@ import org.opensearch.knn.KNNTestCase; -import java.util.BitSet; +import java.util.Arrays; +import java.util.stream.IntStream; public class ReservoirSamplerTests extends KNNTestCase { @@ -15,20 +16,18 @@ public void testSampleLessThanSampleSize() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 5; int sampleSize = 10; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - BitSet expectedIndices = new BitSet(totalNumberOfVectors); - expectedIndices.set(0, totalNumberOfVectors); - assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleEqualToSampleSize() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 10; int sampleSize = 10; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - BitSet expectedIndices = new BitSet(totalNumberOfVectors); - expectedIndices.set(0, totalNumberOfVectors); - assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleRandomness() { @@ -37,25 +36,28 @@ public void testSampleRandomness() { int totalNumberOfVectors = 100; int sampleSize = 10; - BitSet sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); - BitSet sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); - assertNotEquals(sampledIndices1, sampledIndices2); + // It's unlikely but possible for the two samples to be equal, so we just check they are sorted correctly + Arrays.sort(sampledIndices1); + Arrays.sort(sampledIndices2); + assertFalse("Sampled indices should be different", Arrays.equals(sampledIndices1, sampledIndices2)); } public void testEdgeCaseZeroVectors() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 0; int sampleSize = 10; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.cardinality()); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals("Sampled indices should be empty when there are zero vectors.", 0, sampledIndices.length); } public void testEdgeCaseZeroSampleSize() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 10; int sampleSize = 0; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.cardinality()); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals("Sampled indices should be empty when sample size is zero.", 0, sampledIndices.length); } -} +} \ No newline at end of file