> registry = new ConcurrentHashMap<>();
/**
* Registers a quantizer with the registry.
*
- * @param paramClass the class of the quantization parameters
- * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION)
- * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT)
- * @param quantizerSupplier a supplier that provides instances of the quantizer
- * @param the type of quantization parameters
+ * @param paramIdentifier the unique identifier for the quantization parameters
+ * @param quantizer an instance of the quantizer
*/
- public static
void register(
- final Class
paramClass,
- final QuantizationType quantizationType,
- final ScalarQuantizationType sqType,
- final Supplier extends Quantizer, ?>> quantizerSupplier
- ) {
- String identifier = createIdentifier(quantizationType, sqType);
+ static void register(final String paramIdentifier, final Quantizer, ?> quantizer) {
// Ensure that the quantizer for this identifier is registered only once
- registry.computeIfAbsent(identifier, key -> quantizerSupplier);
+ registry.putIfAbsent(paramIdentifier, quantizer);
}
/**
@@ -56,27 +43,14 @@ public static
void register(
* @return an instance of {@link Quantizer} corresponding to the provided parameters
* @throws IllegalArgumentException if no quantizer is registered for the given parameters
*/
- public static
Quantizer
getQuantizer(final P params) {
+ static
Quantizer
getQuantizer(final P params) {
String identifier = params.getTypeIdentifier();
- Supplier extends Quantizer, ?>> supplier = registry.get(identifier);
- if (supplier == null) {
- throw new IllegalArgumentException(
- "No quantizer registered for type identifier: " + identifier + ". Available quantizers: " + registry.keySet()
- );
+ Quantizer, ?> quantizer = registry.get(identifier);
+ if (quantizer == null) {
+ throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier);
}
@SuppressWarnings("unchecked")
- Quantizer
quantizer = (Quantizer
) supplier.get();
- return quantizer;
- }
-
- /**
- * Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype.
- *
- * @param quantizationType the quantization type
- * @param sqType the specific quantization subtype
- * @return a string identifier
- */
- private static String createIdentifier(final QuantizationType quantizationType, final ScalarQuantizationType sqType) {
- return quantizationType.name() + "_" + sqType.name();
+ Quantizer
typedQuantizer = (Quantizer
) quantizer;
+ return typedQuantizer;
}
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
index 18077182f..dbf8e5bf9 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
@@ -5,27 +5,36 @@
package org.opensearch.knn.quantization.models.quantizationOutput;
+import lombok.NoArgsConstructor;
+import lombok.Getter;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+
/**
* The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
* It implements the QuantizationOutput interface to handle byte arrays specifically.
*/
+@NoArgsConstructor
public class BinaryQuantizationOutput implements QuantizationOutput {
- private final byte[] quantizedVector;
+ @Getter
+ private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
/**
- * Constructs a BinaryQuantizationOutput instance with the specified quantized vector.
+ * Updates the quantized vector with a new byte array.
*
- * @param quantizedVector the quantized vector represented as a byte array.
+ * @param newQuantizedVector the new quantized vector represented as a byte array.
*/
- public BinaryQuantizationOutput(final byte[] quantizedVector) {
- if (quantizedVector == null) {
- throw new IllegalArgumentException("Quantized vector cannot be null");
+ public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException {
+ if (newQuantizedVector == null || newQuantizedVector.length == 0) {
+ throw new IllegalArgumentException("Quantized vector cannot be null or empty");
}
- this.quantizedVector = quantizedVector;
+ byteArrayOutputStream.reset();
+ byteArrayOutputStream.write(newQuantizedVector);
}
@Override
public byte[] getQuantizedVector() {
- return quantizedVector;
+ return byteArrayOutputStream.toByteArray();
}
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
index c5c5fd21f..8f01a0594 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
@@ -5,6 +5,8 @@
package org.opensearch.knn.quantization.models.quantizationOutput;
+import java.io.IOException;
+
/**
* The QuantizationOutput interface defines the contract for quantization output data.
*
@@ -17,4 +19,12 @@ public interface QuantizationOutput {
* @return the quantized data.
*/
T getQuantizedVector();
+
+ /**
+ * Updates the quantized vector with new data.
+ *
+ * @param newQuantizedVector the new quantized vector data.
+ * @throws IOException if an I/O error occurs during the update.
+ */
+ void updateQuantizedVector(T newQuantizedVector) throws IOException;
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
index 2c982a306..88b22c4fd 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
@@ -5,9 +5,7 @@
package org.opensearch.knn.quantization.models.quantizationParams;
-import org.opensearch.knn.quantization.enums.QuantizationType;
-
-import java.io.Serializable;
+import java.io.Externalizable;
/**
* Interface for quantization parameters.
@@ -16,17 +14,7 @@
* Implementations of this interface are expected to provide specific configurations
* for various quantization strategies.
*/
-public interface QuantizationParams extends Serializable {
-
- /**
- * Gets the quantization type associated with the parameters.
- * The quantization type defines the overall strategy or method used
- * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION.
- *
- * @return the {@link QuantizationType} indicating the quantization method.
- */
- QuantizationType getQuantizationType();
-
+public interface QuantizationParams extends Externalizable {
/**
* Provides a unique identifier for the quantization parameters.
* This identifier is typically a combination of the quantization type
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
deleted file mode 100644
index 0b6bbc988..000000000
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package org.opensearch.knn.quantization.models.quantizationParams;
-
-import org.opensearch.knn.quantization.enums.QuantizationType;
-import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-
-import java.util.Objects;
-
-/**
- * The SQParams class represents the parameters specific to scalar quantization (SQ).
- * This class implements the QuantizationParams interface and includes the type of scalar quantization.
- */
-public class SQParams implements QuantizationParams {
- private final ScalarQuantizationType sqType;
-
- /**
- * Constructs an SQParams instance with the specified scalar quantization type.
- *
- * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT).
- */
- public SQParams(final ScalarQuantizationType sqType) {
- this.sqType = sqType;
- }
-
- /**
- * Returns the quantization type associated with these parameters.
- *
- * @return The quantization type, always VALUE_QUANTIZATION for SQParams.
- */
- @Override
- public QuantizationType getQuantizationType() {
- return QuantizationType.VALUE;
- }
-
- /**
- * Returns the scalar quantization type.
- *
- * @return The specific scalar quantization type.
- */
- public ScalarQuantizationType getSqType() {
- return sqType;
- }
-
- /**
- * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type.
- * This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
- *
- * @return A string representing the unique type identifier.
- */
- @Override
- public String getTypeIdentifier() {
- return getQuantizationType().name() + "_" + sqType.name();
- }
-
- /**
- * Compares this object to the specified object. The result is true if and only if the argument is not null and is
- * an SQParams object that represents the same scalar quantization type.
- *
- * @param o The object to compare this SQParams against.
- * @return true if the given object represents an SQParams equivalent to this instance, false otherwise.
- */
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- SQParams sqParams = (SQParams) o;
- return sqType == sqParams.sqType;
- }
-
- /**
- * Returns a hash code value for this SQParams instance.
- *
- * @return A hash code value for this SQParams instance.
- */
- @Override
- public int hashCode() {
- return Objects.hash(sqType);
- }
-}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java
new file mode 100644
index 000000000..c7c24062d
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import lombok.AllArgsConstructor;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Locale;
+
+/**
+ * The SQParams class represents the parameters specific to scalar quantization (SQ).
+ * This class implements the QuantizationParams interface and includes the type of scalar quantization.
+ */
+@Getter
+@AllArgsConstructor
+@NoArgsConstructor // No-argument constructor for deserialization
+@EqualsAndHashCode
+public class ScalarQuantizationParams implements QuantizationParams {
+ private ScalarQuantizationType sqType;
+ private static final long serialVersionUID = 1L; // Version ID for serialization
+
+ /**
+ * Static method to generate type identifier based on ScalarQuantizationType.
+ *
+ * @param sqType the scalar quantization type.
+ * @return A string representing the unique type identifier.
+ */
+ public static String generateTypeIdentifier(ScalarQuantizationType sqType) {
+ return generateIdentifier(sqType.getId());
+ }
+
+ /**
+ * Serializes the SQParams object to an external output.
+ * This method writes the scalar quantization type to the output stream.
+ *
+ * @param out the ObjectOutput to write the object to.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // The version is already written by the parent state class, no need to write it here again
+ // Retrieve the current version from VersionContext
+ // This context will be used by other classes involved in the serialization process.
+ // Example:
+ // int version = VersionContext.getVersion(); // Get the current version from VersionContext
+ // Any Version Specific logic can be wriiten based on Version
+ out.writeObject(sqType);
+ }
+
+ /**
+ * Deserializes the SQParams object from an external input with versioning.
+ * This method reads the scalar quantization type and new field from the input stream based on the version.
+ *
+ * @param in the ObjectInput to read the object from.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of the serialized object cannot be found.
+ */
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // The version is already read by the parent state class and set in VersionContext
+ // Retrieve the current version from VersionContext to handle version-specific deserialization logic
+ // int versionId = VersionContext.getVersion();
+ // Version version = Version.fromId(versionId);
+
+ sqType = (ScalarQuantizationType) in.readObject();
+
+ // Add version-specific deserialization logic
+ // For example, if new fields are added in a future version, handle them here
+ // This section contains conditional logic to handle different versions appropriately.
+ // Example:
+ // if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) {
+ // // Handle logic for versions between 1.0.0 and 2.0.0
+ // // Example: Read additional fields introduced in version 1.0.0
+ // // newField = in.readInt();
+ // } else if (version.onOrAfter(Version.V_2_0_0)) {
+ // // Handle logic for versions 2.0.0 and above
+ // // Example: Read additional fields introduced in version 2.0.0
+ // // anotherNewField = in.readFloat();
+ // }
+ }
+
+ /**
+ * Provides a unique type identifier for the SQParams, combining the SQ type.
+ * This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
+ *
+ * @return A string representing the unique type identifier.
+ */
+ @Override
+ public String getTypeIdentifier() {
+ return generateIdentifier(sqType.getId());
+ }
+
+ private static String generateIdentifier(int id) {
+ return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
index acc8c2f00..3e3249c6f 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
@@ -5,28 +5,27 @@
package org.opensearch.knn.quantization.models.quantizationState;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.Version;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
-import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
/**
* DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
* It can be utilized by any quantizer to represent a default state.
*/
+@Getter
+@NoArgsConstructor // No-argument constructor for deserialization
+@AllArgsConstructor
public class DefaultQuantizationState implements QuantizationState {
-
- private final QuantizationParams params;
-
- /**
- * Constructs a DefaultQuantizationState with the given quantization parameters.
- *
- * @param params the quantization parameters.
- */
- public DefaultQuantizationState(final QuantizationParams params) {
- this.params = params;
- }
+ private QuantizationParams params;
+ private static final long serialVersionUID = 1L; // Version ID for serialization
/**
* Returns the quantization parameters associated with this state.
@@ -60,7 +59,34 @@ public byte[] toByteArray() throws IOException {
public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
- (parentParams, specificData) -> new DefaultQuantizationState((SQParams) parentParams)
+ new DefaultQuantizationState(),
+ (parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams)
);
}
+
+ /**
+ * Writes the object to the output stream.
+ * This method is part of the Externalizable interface and is used to serialize the object.
+ *
+ * @param out the output stream to write the object to.
+ * @throws IOException if an I/O error occurs.
+ */
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(Version.CURRENT.id); // Write the version
+ out.writeObject(params);
+ }
+
+ /**
+ * Reads the object from the input stream.
+ * This method is part of the Externalizable interface and is used to deserialize the object.
+ *
+ * @param in the input stream to read the object from.
+ * @throws IOException if an I/O error occurs.
+ * @throws ClassNotFoundException if the class of the serialized object cannot be found.
+ */
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this.params = (QuantizationParams) in.readObject();
+ }
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
index 58834dd2c..095d245f2 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
@@ -5,54 +5,186 @@
package org.opensearch.knn.quantization.models.quantizationState;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
-import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.Version;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.util.VersionContext;
import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
/**
* MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization,
* including the thresholds used for quantization.
*/
+@Getter
+@NoArgsConstructor // No-argument constructor for deserialization
+@AllArgsConstructor
public final class MultiBitScalarQuantizationState implements QuantizationState {
- private final SQParams quantizationParams;
- private final float[][] thresholds;
-
+ private ScalarQuantizationParams quantizationParams;
/**
- * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds.
+ * The threshold values for multi-bit quantization, organized as a 2D array
+ * where each row corresponds to a different bit level.
+ *
+ * For example:
+ * - For 2-bit quantization:
+ * thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level
+ * thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level
+ * - For 4-bit quantization:
+ * thresholds[0] -> {0.1f, 0.2f, 0.3f}
+ * thresholds[1] -> {0.4f, 0.5f, 0.6f}
+ * thresholds[2] -> {0.7f, 0.8f, 0.9f}
+ * thresholds[3] -> {1.0f, 1.1f, 1.2f}
*
- * @param quantizationParams the scalar quantization parameters.
- * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array
- * where each row corresponds to a different bit level.
+ * Each column represents the threshold for a specific dimension in the vector space.
*/
- public MultiBitScalarQuantizationState(final SQParams quantizationParams, final float[][] thresholds) {
- this.quantizationParams = quantizationParams;
- this.thresholds = thresholds;
- }
+ private float[][] thresholds;
+ private static final long serialVersionUID = 1L; // Version ID for serialization
@Override
- public SQParams getQuantizationParams() {
+ public ScalarQuantizationParams getQuantizationParams() {
return quantizationParams;
}
/**
- * Returns the thresholds used in the quantization process.
+ * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ *
+ * Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized
+ * as part of the state to access the version information and implement version-specific logic if needed.
+ *
+ * The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable,
+ * ensuring that the version is available to all classes involved in the serialization process within the current thread context.
+ *
+ *
+ * {@code
+ * // Example usage in the writeExternal method:
+ * VersionContext.setVersion(version);
+ * out.writeInt(version); // Write the version
+ * quantizationParams.writeExternal(out);
+ * out.writeInt(meanThresholds.length);
+ * for (float mean : meanThresholds) {
+ * out.writeFloat(mean);
+ * }
+ * }
+ *
+ *
+ * @param out the ObjectOutput to write the object to.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ int version = Version.CURRENT.id;
+ VersionContext.setVersion(version);
+ out.writeInt(version); // Write the version
+ quantizationParams.writeExternal(out);
+ out.writeInt(thresholds.length);
+ out.writeInt(thresholds[0].length);
+ for (float[] row : thresholds) {
+ for (float value : row) {
+ out.writeFloat(value);
+ }
+ }
+ }
+
+ /**
+ * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ *
+ * The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method.
+ * This makes the version information available to all classes involved in the deserialization process within the current thread context.
*
- * @return a 2D array of threshold values.
+ * Classes that are part of the deserialization process can retrieve the version information using the
+ * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.
+ *
+ *
+ * {@code
+ * // Example usage in the readExternal method:
+ * int version = in.readInt(); // Read the version
+ * VersionContext.setVersion(version);
+ * quantizationParams = new ScalarQuantizationParams();
+ * quantizationParams.readExternal(in); // Use readExternal of SQParams
+ * int length = in.readInt();
+ * meanThresholds = new float[length];
+ * for (int i = 0; i < length; i++) {
+ * meanThresholds[i] = in.readFloat();
+ * }
+ * }
+ *
+ *
+ * @param in the ObjectInput to read the object from.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
- public float[][] getThresholds() {
- return thresholds;
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ int version = in.readInt(); // Read the version
+ VersionContext.setVersion(version);
+ quantizationParams = new ScalarQuantizationParams();
+ quantizationParams.readExternal(in); // Use readExternal of SQParams
+ int rows = in.readInt();
+ int cols = in.readInt();
+ thresholds = new float[rows][cols];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < cols; j++) {
+ thresholds[i][j] = in.readFloat();
+ }
+ }
+ VersionContext.clear(); // Clear the version after use
}
+ /**
+ * Serializes the current state of this MultiBitScalarQuantizationState object into a byte array.
+ * This method uses the QuantizationStateSerializer to handle the serialization process.
+ *
+ * The serialized byte array includes all necessary state information, such as the thresholds
+ * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.
+ *
+ *
+ * {@code
+ * MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+ * byte[] serializedState = state.toByteArray();
+ * }
+ *
+ *
+ * @return a byte array representing the serialized state of this object.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
@Override
public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, thresholds);
}
+ /**
+ * Deserializes a MultiBitScalarQuantizationState object from a byte array.
+ * This method uses the QuantizationStateSerializer to handle the deserialization process.
+ *
+ * The byte array should contain serialized state information, including the thresholds
+ * and quantization parameters, which are necessary to reconstruct the MultiBitScalarQuantizationState object.
+ *
+ *
+ * {@code
+ * byte[] serializedState = ...; // obtain the byte array from some source
+ * MultiBitScalarQuantizationState state = MultiBitScalarQuantizationState.fromByteArray(serializedState);
+ * }
+ *
+ *
+ * @param bytes the byte array containing the serialized state.
+ * @return the deserialized MultiBitScalarQuantizationState object.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of a serialized object cannot be found.
+ */
public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
- (parentParams, thresholds) -> new MultiBitScalarQuantizationState((SQParams) parentParams, (float[][]) thresholds)
+ new MultiBitScalarQuantizationState(),
+ (parentParams, thresholds) -> new MultiBitScalarQuantizationState(
+ (ScalarQuantizationParams) parentParams,
+ (float[][]) thresholds
+ )
);
}
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
index 9b4bad56a..8ab37955e 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
@@ -5,53 +5,177 @@
package org.opensearch.knn.quantization.models.quantizationState;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
-import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.Version;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.util.VersionContext;
import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
/**
* OneBitScalarQuantizationState represents the state of one-bit scalar quantization,
* including the mean values used for quantization.
*/
+@Getter
+@NoArgsConstructor // No-argument constructor for deserialization
+@AllArgsConstructor
public final class OneBitScalarQuantizationState implements QuantizationState {
- private final SQParams quantizationParams;
- private final float[] meanThresholds;
-
+ private ScalarQuantizationParams quantizationParams;
/**
- * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values.
+ * Mean thresholds used in the quantization process.
+ * Each threshold value corresponds to a dimension of the vector being quantized.
*
- * @param quantizationParams the scalar quantization parameters.
- * @param mean the mean values for each dimension.
+ * Example:
+ * If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0],
+ * the quantization process will be:
+ * - 1.2 < 2.0, so the first bit is 0
+ * - 3.4 > 3.0, so the second bit is 1
+ * - 5.6 > 4.0, so the third bit is 1
+ * The quantized vector will be [0, 1, 1].
*/
- public OneBitScalarQuantizationState(final SQParams quantizationParams, final float[] mean) {
- this.quantizationParams = quantizationParams;
- this.meanThresholds = mean;
- }
+ private float[] meanThresholds;
+ private static final long serialVersionUID = 1L; // Version ID for serialization
@Override
- public SQParams getQuantizationParams() {
+ public ScalarQuantizationParams getQuantizationParams() {
return quantizationParams;
}
/**
- * Returns the mean values used in the quantization process.
+ * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ *
+ * Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized
+ * as part of the state to access the version information and implement version-specific logic if needed.
+ *
+ * The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable,
+ * ensuring that the version is available to all classes involved in the serialization process within the current thread context.
*
- * @return an array of mean values.
+ *
+ * {@code
+ * // Example usage in the writeExternal method:
+ * VersionContext.setVersion(version);
+ * out.writeInt(version); // Write the version
+ * quantizationParams.writeExternal(out);
+ * out.writeInt(meanThresholds.length);
+ * for (float mean : meanThresholds) {
+ * out.writeFloat(mean);
+ * }
+ * }
+ *
+ *
+ * @param out the ObjectOutput to write the object to.
+ * @throws IOException if an I/O error occurs during serialization.
*/
- public float[] getMeanThresholds() {
- return meanThresholds;
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ int version = Version.CURRENT.id;
+ VersionContext.setVersion(version);
+ out.writeInt(version); // Write the version
+ quantizationParams.writeExternal(out);
+ out.writeInt(meanThresholds.length);
+ for (float mean : meanThresholds) {
+ out.writeFloat(mean);
+ }
+ VersionContext.clear(); // Clear the version after use
}
+ /**
+ * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ *
+ * The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method.
+ * This makes the version information available to all classes involved in the deserialization process within the current thread context.
+ *
+ * Classes that are part of the deserialization process can retrieve the version information using the
+ * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.
+ *
+ *
+ * {@code
+ * // Example usage in the readExternal method:
+ * int version = in.readInt(); // Read the version
+ * VersionContext.setVersion(version);
+ * quantizationParams = new ScalarQuantizationParams();
+ * quantizationParams.readExternal(in); // Use readExternal of SQParams
+ * int length = in.readInt();
+ * meanThresholds = new float[length];
+ * for (int i = 0; i < length; i++) {
+ * meanThresholds[i] = in.readFloat();
+ * }
+ * }
+ *
+ *
+ * @param in the ObjectInput to read the object from.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of the serialized object cannot be found.
+ */
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ int version = in.readInt(); // Read the version
+ VersionContext.setVersion(version);
+ quantizationParams = new ScalarQuantizationParams();
+ quantizationParams.readExternal(in); // Use readExternal of SQParams
+ int length = in.readInt();
+ meanThresholds = new float[length];
+ for (int i = 0; i < length; i++) {
+ meanThresholds[i] = in.readFloat();
+ }
+ VersionContext.clear(); // Clear the version after use
+ }
+
+ /**
+ * Serializes the current state of this OneBitScalarQuantizationState object into a byte array.
+ * This method uses the QuantizationStateSerializer to handle the serialization process.
+ *
+ * The serialized byte array includes all necessary state information, such as the mean thresholds
+ * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.
+ *
+ *
+ * {@code
+ * OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, meanThresholds);
+ * byte[] serializedState = state.toByteArray();
+ * }
+ *
+ *
+ * @return a byte array representing the serialized state of this object.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
@Override
public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, meanThresholds);
}
+ /**
+ * Deserializes a OneBitScalarQuantizationState object from a byte array.
+ * This method uses the QuantizationStateSerializer to handle the deserialization process.
+ *
+ * The byte array should contain serialized state information, including the mean thresholds
+ * and quantization parameters, which are necessary to reconstruct the OneBitScalarQuantizationState object.
+ *
+ *
+ * {@code
+ * byte[] serializedState = ...; // obtain the byte array from some source
+ * OneBitScalarQuantizationState state = OneBitScalarQuantizationState.fromByteArray(serializedState);
+ * }
+ *
+ *
+ * @param bytes the byte array containing the serialized state.
+ * @return the deserialized OneBitScalarQuantizationState object.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of a serialized object cannot be found.
+ */
public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
- (parentParams, meanThresholds) -> new OneBitScalarQuantizationState((SQParams) parentParams, (float[]) meanThresholds)
+ new OneBitScalarQuantizationState(),
+ (parentParams, meanThresholds) -> new OneBitScalarQuantizationState(
+ (ScalarQuantizationParams) parentParams,
+ (float[]) meanThresholds
+ )
);
}
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
index c17ff0641..d3778fe29 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
@@ -7,14 +7,14 @@
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import java.io.Externalizable;
import java.io.IOException;
-import java.io.Serializable;
/**
* QuantizationState interface represents the state of a quantization process, including the parameters used.
* This interface provides methods for serializing and deserializing the state.
*/
-public interface QuantizationState extends Serializable {
+public interface QuantizationState extends Externalizable {
/**
* Returns the quantization parameters associated with this state.
*
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java
new file mode 100644
index 000000000..5f00d8e0c
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import lombok.experimental.UtilityClass;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.io.IOException;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
+
+/**
+ * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing
+ * QuantizationState objects along with their specific data.
+ */
+@UtilityClass
+class QuantizationStateSerializer {
+
+ /**
+ * A functional interface for deserializing specific data associated with a QuantizationState.
+ */
+ @FunctionalInterface
+ interface SerializableDeserializer {
+ QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData);
+ }
+
+ /**
+ * Serializes the QuantizationState and specific data into a byte array.
+ *
+ * @param state The QuantizationState to serialize.
+ * @param specificData The specific data related to the state, to be serialized.
+ * @return A byte array representing the serialized state and specific data.
+ * @throws IOException If an I/O error occurs during serialization.
+ */
+ static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException {
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ state.writeExternal(out);
+ out.writeObject(specificData);
+ out.flush();
+ return bos.toByteArray();
+ }
+ }
+
+ /**
+ * Deserializes a QuantizationState and its specific data from a byte array.
+ *
+ * @param bytes The byte array containing the serialized data.
+ * @param stateInstance An instance of the state to call readExternal on.
+ * @param specificDataDeserializer The deserializer for the specific data associated with the state.
+ * @return The deserialized QuantizationState including its specific data.
+ * @throws IOException If an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException If the class of the serialized object cannot be found.
+ */
+ static QuantizationState deserialize(byte[] bytes, QuantizationState stateInstance, SerializableDeserializer specificDataDeserializer)
+ throws IOException, ClassNotFoundException {
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) {
+ stateInstance.readExternal(in);
+ Serializable specificData = (Serializable) in.readObject(); // Read the specific data
+ return specificDataDeserializer.deserialize(stateInstance.getQuantizationParams(), specificData);
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
index 14689ea47..54ebe311c 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
@@ -5,64 +5,21 @@
package org.opensearch.knn.quantization.models.requests;
-import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
/**
* TrainingRequest represents a request for training a quantizer.
*
* @param the type of vectors to be trained.
*/
+@Getter
+@AllArgsConstructor
public abstract class TrainingRequest {
- private final QuantizationParams params;
- private final int totalNumberOfVectors;
- private int[] sampledIndices;
-
- /**
- * Constructs a TrainingRequest with the given parameters and total number of vectors.
- *
- * @param params the quantization parameters.
- * @param totalNumberOfVectors the total number of vectors.
- */
- protected TrainingRequest(final QuantizationParams params, final int totalNumberOfVectors) {
- this.params = params;
- this.totalNumberOfVectors = totalNumberOfVectors;
- }
-
- /**
- * Returns the quantization parameters.
- *
- * @return the quantization parameters.
- */
- public QuantizationParams getParams() {
- return params;
- }
-
/**
- * Returns the total number of vectors.
- *
- * @return the total number of vectors.
- */
- public int getTotalNumberOfVectors() {
- return totalNumberOfVectors;
- }
-
- /**
- * Sets the sampled indices for this training request.
- *
- * @param sampledIndices the sampled indices.
+ * The total number of vectors in one segment.
*/
- public void setSampledIndices(int[] sampledIndices) {
- this.sampledIndices = sampledIndices;
- }
-
- /**
- * Returns the sampled indices for this training request.
- *
- * @return the sampled indices.
- */
- public int[] getSampledIndices() {
- return sampledIndices;
- }
+ private final int totalNumberOfVectors;
/**
* Returns the vector corresponding to the specified document ID.
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
index 0143a6614..c6d366e5b 100644
--- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
@@ -6,19 +6,18 @@
package org.opensearch.knn.quantization.quantizer;
-import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplerType;
import org.opensearch.knn.quantization.sampler.SamplingFactory;
-import org.opensearch.knn.quantization.util.BitPacker;
-import org.opensearch.knn.quantization.util.QuantizerHelper;
-import java.util.ArrayList;
-import java.util.List;
+import java.io.IOException;
+import java.util.BitSet;
/**
* MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension.
@@ -41,7 +40,7 @@ public class MultiBitScalarQuantizer implements Quantizer {
* @param bitsPerCoordinate the number of bits used per coordinate for quantization.
*/
public MultiBitScalarQuantizer(final int bitsPerCoordinate) {
- this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR));
+ this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR));
}
/**
@@ -70,15 +69,16 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi
*/
@Override
public QuantizationState train(final TrainingRequest trainingRequest) {
- SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest);
- int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
-
- int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length;
+ BitSet sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+ int dimension = trainingRequest.getVectorByDocId(sampledIndices.nextSetBit(0)).length;
float[] meanArray = new float[dimension];
float[] stdDevArray = new float[dimension];
// Calculate sum, mean, and standard deviation in one pass
QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray);
float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension);
+ ScalarQuantizationParams params = (bitsPerCoordinate == 2)
+ ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT)
+ : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
return new MultiBitScalarQuantizationState(params, thresholds);
}
@@ -88,10 +88,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) {
*
* @param vector the vector to quantize.
* @param state the quantization state containing threshold information.
- * @return a BinaryQuantizationOutput containing the quantized data.
+ * @param output the QuantizationOutput object to store the quantized representation of the vector.
+ * @throws IOException if an I/O error occurs during quantization.
*/
@Override
- public QuantizationOutput quantize(final float[] vector, final QuantizationState state) {
+ public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) throws IOException {
if (vector == null) {
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
@@ -101,17 +102,22 @@ public QuantizationOutput quantize(final float[] vector, final Quantizat
if (thresholds == null || thresholds[0].length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
-
- List bitArrays = new ArrayList<>();
+ // Directly pack bits without intermediate array
+ int totalBits = bitsPerCoordinate * vector.length;
+ int byteLength = (totalBits + 7) >> 3; // Calculate byte length needed
+ byte[] packedBits = new byte[byteLength];
for (int i = 0; i < bitsPerCoordinate; i++) {
- byte[] bitArray = new byte[vector.length];
for (int j = 0; j < vector.length; j++) {
- bitArray[j] = (byte) (vector[j] > thresholds[i][j] ? 1 : 0);
+ if (vector[j] > thresholds[i][j]) {
+ int bitPosition = i * vector.length + j;
+ int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8
+ int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8)
+ packedBits[byteIndex] |= (1 << bitIndex); // Set the bit
+ }
}
- bitArrays.add(bitArray);
}
- return new BinaryQuantizationOutput(BitPacker.packBits(bitArrays));
+ output.updateQuantizedVector(packedBits);
}
/**
@@ -145,4 +151,13 @@ private void validateState(final QuantizationState state) {
throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState.");
}
}
+
+ /**
+ * Returns the number of bits per coordinate used by this quantizer.
+ *
+ * @return the number of bits per coordinate.
+ */
+ public int getBitsPerCoordinate() {
+ return bitsPerCoordinate;
+ }
}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
index 2eaa07ce0..eab3d992e 100644
--- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
@@ -5,18 +5,18 @@
package org.opensearch.knn.quantization.quantizer;
-import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplerType;
import org.opensearch.knn.quantization.sampler.SamplingFactory;
-import org.opensearch.knn.quantization.util.BitPacker;
-import org.opensearch.knn.quantization.util.QuantizerHelper;
-import java.util.Collections;
+import java.io.IOException;
+import java.util.BitSet;
/**
* OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension.
@@ -37,7 +37,7 @@ public class OneBitScalarQuantizer implements Quantizer {
* Constructs a OneBitScalarQuantizer with a default sampling size of 25000.
*/
public OneBitScalarQuantizer() {
- this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR));
+ this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR));
}
/**
@@ -49,7 +49,6 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) {
this.samplingSize = samplingSize;
this.sampler = sampler;
- ;
}
/**
@@ -61,10 +60,9 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) {
*/
@Override
public QuantizationState train(final TrainingRequest trainingRequest) {
- SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest);
- int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
- float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices);
- return new OneBitScalarQuantizationState(params, mean);
+ BitSet sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+ float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds);
+ return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds);
}
/**
@@ -73,10 +71,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) {
*
* @param vector the vector to quantize.
* @param state the quantization state containing the means for each dimension.
+ * @param output the QuantizationOutput object to store the quantized representation of the vector.
* @return a BinaryQuantizationOutput containing the quantized data.
*/
@Override
- public QuantizationOutput quantize(final float[] vector, final QuantizationState state) {
+ public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) throws IOException {
if (vector == null) {
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
@@ -86,11 +85,19 @@ public QuantizationOutput quantize(final float[] vector, final Quantizat
if (thresholds == null || thresholds.length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
- byte[] quantizedVector = new byte[vector.length];
+ // Directly pack bits without intermediate array
+ int byteLength = (vector.length + 7) >> 3; // Calculate byte length needed
+ byte[] packedBits = new byte[byteLength];
+
for (int i = 0; i < vector.length; i++) {
- quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0);
+ if (vector[i] > thresholds[i]) {
+ int byteIndex = i >> 3; // Equivalent to i / 8
+ int bitIndex = 7 - (i & 7); // Equivalent to 7 - (i % 8)
+ packedBits[byteIndex] |= (1 << bitIndex); // Set the bit
+ }
}
- return new BinaryQuantizationOutput(BitPacker.packBits(Collections.singletonList(quantizedVector)));
+
+ output.updateQuantizedVector(packedBits);
}
/**
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
index 8231a8aa2..beabd8d73 100644
--- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
@@ -1,14 +1,11 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
package org.opensearch.knn.quantization.quantizer;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import java.io.IOException;
+
/**
* The Quantizer interface defines the methods required for training and quantizing vectors
* in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks.
@@ -34,7 +31,7 @@ public interface Quantizer {
*
* @param vector the vector to quantize.
* @param state the quantization state containing parameters for quantization.
- * @return a QuantizationOutput containing the quantized representation of the vector.
+ * @param output the QuantizationOutput object to store the quantized representation of the vector.
*/
- QuantizationOutput quantize(T vector, QuantizationState state);
+ void quantize(T vector, QuantizationState state, QuantizationOutput output) throws IOException;
}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java
similarity index 67%
rename from src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java
rename to src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java
index adc4e34c4..e95632606 100644
--- a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java
@@ -1,40 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
- *
*/
-package org.opensearch.knn.quantization.util;
+package org.opensearch.knn.quantization.quantizer;
-import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import lombok.experimental.UtilityClass;
+import java.util.BitSet;
+
/**
* Utility class providing common methods for quantizer operations, such as parameter validation and
* extraction. This class is designed to be used with various quantizer implementations that require
* consistent handling of training requests and sampled indices.
*/
@UtilityClass
-public class QuantizerHelper {
-
- /**
- * Validates the provided training request to ensure it contains non-null quantization parameters.
- * Extracts and returns the SQParams from the training request.
- *
- * @param trainingRequest the training request to validate and extract parameters from.
- * @return the extracted SQParams.
- * @throws IllegalArgumentException if the SQParams are null.
- */
- public static SQParams validateAndExtractParams(TrainingRequest> trainingRequest) {
- QuantizationParams params = trainingRequest.getParams();
- if (params == null || !(params instanceof SQParams)) {
- throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams.");
- }
- return (SQParams) params;
- }
-
+class QuantizerHelper {
/**
* Calculates the mean vector from a set of sampled vectors.
*
@@ -49,13 +31,13 @@ public static SQParams validateAndExtractParams(TrainingRequest> trainingReque
* @throws IllegalArgumentException If any of the vectors at the sampled indices are null.
* @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors.
*/
- public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) {
- int totalSamples = sampledIndices.length;
+ static float[] calculateMeanThresholds(TrainingRequest samplingRequest, BitSet sampledIndices) {
+ int totalSamples = sampledIndices.cardinality();
float[] mean = null;
- for (int index : sampledIndices) {
- float[] vector = samplingRequest.getVectorByDocId(index);
+ for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) {
+ float[] vector = samplingRequest.getVectorByDocId(docId);
if (vector == null) {
- throw new IllegalArgumentException("Vector at sampled index " + index + " is null.");
+ throw new IllegalArgumentException("Vector at sampled index " + docId + " is null.");
}
if (mean == null) {
mean = new float[vector.length];
@@ -81,20 +63,20 @@ public static float[] calculateMean(TrainingRequest samplingRequest, in
* @param meanArray the array to store the sum and then the mean of each dimension.
* @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension.
*/
- public static void calculateSumMeanAndStdDev(
+ static void calculateSumMeanAndStdDev(
TrainingRequest trainingRequest,
- int[] sampledIndices,
+ BitSet sampledIndices,
float[] meanArray,
float[] stdDevArray
) {
- int totalSamples = sampledIndices.length;
+ int totalSamples = sampledIndices.cardinality();
int dimension = meanArray.length;
// Single pass to calculate sum and sum of squares
- for (int index : sampledIndices) {
- float[] vector = trainingRequest.getVectorByDocId(index);
+ for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) {
+ float[] vector = trainingRequest.getVectorByDocId(docId);
if (vector == null) {
- throw new IllegalArgumentException("Vector at sampled index " + index + " is null.");
+ throw new IllegalArgumentException("Vector at sampled index " + docId + " is null.");
}
for (int j = 0; j < dimension; j++) {
meanArray[j] += vector[j];
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
index da5327def..22a7bae23 100644
--- a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
@@ -5,10 +5,10 @@
package org.opensearch.knn.quantization.sampler;
-import java.util.Arrays;
-import java.util.Random;
+import lombok.NoArgsConstructor;
+
+import java.util.BitSet;
import java.util.concurrent.ThreadLocalRandom;
-import java.util.stream.IntStream;
/**
* ReservoirSampler implements the Sampler interface and provides a method for sampling
@@ -16,33 +16,23 @@
* This algorithm is particularly useful for randomly sampling a subset of data from a larger set
* when the total size of the dataset is unknown or very large.
*/
+@NoArgsConstructor
final class ReservoirSampler implements Sampler {
-
- private final Random random;
-
- /**
- * Constructs a ReservoirSampler with a new Random instance.
- */
- public ReservoirSampler() {
- this(ThreadLocalRandom.current());
- }
-
/**
- * Constructs a ReservoirSampler with a specified random seed for reproducibility.
- *
- * @param seed the seed for the random number generator.
+ * Singleton instance holder.
*/
- public ReservoirSampler(final long seed) {
- this(new Random(seed));
- }
+ private static ReservoirSampler instance;
/**
- * Constructs a ReservoirSampler with a specified Random instance.
+ * Provides the singleton instance of ReservoirSampler.
*
- * @param random the Random instance for generating random numbers.
+ * @return the singleton instance of ReservoirSampler.
*/
- public ReservoirSampler(final Random random) {
- this.random = random;
+ public static synchronized ReservoirSampler getInstance() {
+ if (instance == null) {
+ instance = new ReservoirSampler();
+ }
+ return instance;
}
/**
@@ -55,9 +45,11 @@ public ReservoirSampler(final Random random) {
* @return an array of sampled indices.
*/
@Override
- public int[] sample(final int totalNumberOfVectors, final int sampleSize) {
+ public BitSet sample(final int totalNumberOfVectors, final int sampleSize) {
if (totalNumberOfVectors <= sampleSize) {
- return IntStream.range(0, totalNumberOfVectors).toArray();
+ BitSet bitSet = new BitSet(totalNumberOfVectors);
+ bitSet.set(0, totalNumberOfVectors);
+ return bitSet;
}
return reservoirSampleIndices(totalNumberOfVectors, sampleSize);
}
@@ -67,19 +59,30 @@ public int[] sample(final int totalNumberOfVectors, final int sampleSize) {
* This method ensures that each index in the range [0, numVectors) has an equal probability
* of being included in the sample.
*
+ * Reservoir sampling is particularly useful for selecting a random sample from a large or unknown-sized dataset.
+ * For more information on the algorithm, see the following link:
+ * Reservoir Sampling - Wikipedia
+ *
* @param numVectors the total number of vectors.
* @param sampleSize the number of indices to sample.
- * @return an array of sampled indices.
+ * @return a BitSet representing the sampled indices.
*/
- private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) {
- int[] indices = IntStream.range(0, sampleSize).toArray();
+ private BitSet reservoirSampleIndices(final int numVectors, final int sampleSize) {
+ int[] indices = new int[sampleSize];
+ for (int i = 0; i < sampleSize; i++) {
+ indices[i] = i;
+ }
for (int i = sampleSize; i < numVectors; i++) {
- int j = random.nextInt(i + 1);
+ int j = ThreadLocalRandom.current().nextInt(i + 1);
if (j < sampleSize) {
indices[j] = i;
}
}
- Arrays.sort(indices);
- return indices;
+ // Using BitSet to track the presence of indices
+ BitSet bitSet = new BitSet(numVectors);
+ for (int i = 0; i < sampleSize; i++) {
+ bitSet.set(indices[i]);
+ }
+ return bitSet;
}
}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
index 9021073b4..17834cf04 100644
--- a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
@@ -5,6 +5,23 @@
package org.opensearch.knn.quantization.sampler;
+import java.util.BitSet;
+
+/**
+ * The Sampler interface defines the contract for sampling strategies
+ * used in various quantization processes. Implementations of this
+ * interface should provide specific strategies for selecting a sample
+ * from a given set of vectors.
+ */
public interface Sampler {
- int[] sample(int totalNumberOfVectors, int sampleSize);
+
+ /**
+ * Samples a subset of indices from the total number of vectors.
+ *
+ * @param totalNumberOfVectors the total number of vectors available.
+ * @param sampleSize the number of vectors to be sampled.
+ * @return an array of integers representing the indices of the sampled vectors.
+ * @throws IllegalArgumentException if the sample size is greater than the total number of vectors.
+ */
+ BitSet sample(int totalNumberOfVectors, int sampleSize);
}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java
new file mode 100644
index 000000000..cd9b301df
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+/**
+ * SamplerType is an enumeration of the different types of samplers that can be created by the factory.
+ */
+public enum SamplerType {
+ RESERVOIR, // Represents a reservoir sampling strategy
+ // Add more enum values here for additional sampler types
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
index be228fe6f..80fe5bdae 100644
--- a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
@@ -5,28 +5,16 @@
package org.opensearch.knn.quantization.sampler;
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+
/**
* SamplingFactory is a factory class for creating instances of Sampler.
* It uses the factory design pattern to encapsulate the creation logic for different types of samplers.
*/
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class SamplingFactory {
- /**
- * Private constructor to prevent instantiation of this class.
- * The class is not meant to be instantiated, as it provides static methods only.
- */
- private SamplingFactory() {
-
- }
-
- /**
- * SamplerType is an enumeration of the different types of samplers that can be created by the factory.
- */
- public enum SamplerType {
- RESERVOIR, // Represents a reservoir sampling strategy
- // Add more enum values here for additional sampler types
- }
-
/**
* Creates and returns a Sampler instance based on the specified SamplerType.
*
@@ -37,7 +25,7 @@ public enum SamplerType {
public static Sampler getSampler(final SamplerType samplerType) {
switch (samplerType) {
case RESERVOIR:
- return new ReservoirSampler();
+ return ReservoirSampler.getInstance();
// Add more cases for different samplers here
default:
throw new IllegalArgumentException("Unsupported sampler type: " + samplerType);
diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java
deleted file mode 100644
index 5d99a892f..000000000
--- a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- *
- */
-
-package org.opensearch.knn.quantization.util;
-
-import lombok.experimental.UtilityClass;
-
-import java.util.List;
-
-/**
- * Utility class for bit packing operations.
- * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission.
- */
-@UtilityClass
-public class BitPacker {
-
- /**
- * Packs the list of bit arrays into a single byte array.
- * Each byte in the resulting array contains up to 8 bits from the bit arrays, packed from left to right.
- *
- * @param bitArrays the list of bit arrays to be packed. Each bit array should contain only 0s and 1s.
- * @return a byte array containing the packed bits.
- * @throws IllegalArgumentException if the bitArrays list is empty, if any bit array is null, or if bit arrays have inconsistent lengths.
- */
- public static byte[] packBits(List bitArrays) {
- if (bitArrays.isEmpty()) {
- throw new IllegalArgumentException("The list of bit arrays cannot be empty.");
- }
-
- int bitArrayLength = bitArrays.get(0).length;
- int bitLength = bitArrays.size() * bitArrayLength;
- int byteLength = (bitLength + 7) / 8;
- byte[] packedArray = new byte[byteLength];
-
- int bitPosition = 0;
- for (byte[] bitArray : bitArrays) {
- if (bitArray == null) {
- throw new IllegalArgumentException("Bit array cannot be null.");
- }
- if (bitArray.length != bitArrayLength) {
- throw new IllegalArgumentException("All bit arrays must have the same length.");
- }
-
- for (byte bit : bitArray) {
- int byteIndex = bitPosition / 8;
- int bitIndex = 7 - (bitPosition % 8);
- if (bit == 1) {
- packedArray[byteIndex] |= (1 << bitIndex);
- }
- bitPosition++;
- }
- }
-
- return packedArray;
- }
-}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java
deleted file mode 100644
index 89b3b67bd..000000000
--- a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package org.opensearch.knn.quantization.util;
-
-import lombok.experimental.UtilityClass;
-import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
-import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
-
-import java.io.ByteArrayOutputStream;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
-import java.io.IOException;
-import java.io.ByteArrayInputStream;
-import java.io.ObjectInputStream;
-
-/**
- * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing
- * QuantizationState objects along with their specific data.
- */
-@UtilityClass
-public class QuantizationStateSerializer {
-
- /**
- * A functional interface for deserializing specific data associated with a QuantizationState.
- */
- @FunctionalInterface
- public interface SerializableDeserializer {
- QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData);
- }
-
- /**
- * Serializes the QuantizationState and specific data into a byte array.
- *
- * @param state The QuantizationState to serialize.
- * @param specificData The specific data related to the state, to be serialized.
- * @return A byte array representing the serialized state and specific data.
- * @throws IOException If an I/O error occurs during serialization.
- */
- public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException {
- byte[] parentBytes = serializeParentParams(state.getQuantizationParams());
- try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) {
- out.writeInt(parentBytes.length); // Write the length of the parent bytes
- out.write(parentBytes); // Write the parent bytes
- out.writeObject(specificData); // Write the specific data
- out.flush();
- return bos.toByteArray();
- }
- }
-
- /**
- * Deserializes a QuantizationState and its specific data from a byte array.
- *
- * @param bytes The byte array containing the serialized data.
- * @param specificDataDeserializer The deserializer for the specific data associated with the state.
- * @return The deserialized QuantizationState including its specific data.
- * @throws IOException If an I/O error occurs during deserialization.
- * @throws ClassNotFoundException If the class of the serialized object cannot be found.
- */
- public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer) throws IOException,
- ClassNotFoundException {
- try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) {
- int parentLength = in.readInt();
- // Read the length of the parent bytes
- byte[] parentBytes = new byte[parentLength];
- in.readFully(parentBytes); // Read the parent bytes
- QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params
- Serializable specificData = (Serializable) in.readObject(); // Read the specific data
- return specificDataDeserializer.deserialize(parentParams, specificData);
- }
- }
-
- /**
- * Serializes the parent parameters of the QuantizationState into a byte array.
- *
- * @param params The QuantizationParams to serialize.
- * @return A byte array representing the serialized parent parameters.
- * @throws IOException If an I/O error occurs during serialization.
- */
- private static byte[] serializeParentParams(QuantizationParams params) throws IOException {
- try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) {
- out.writeObject(params);
- out.flush();
- return bos.toByteArray();
- }
- }
-
- /**
- * Deserializes the parent parameters of the QuantizationState from a byte array.
- *
- * @param bytes The byte array containing the serialized parent parameters.
- * @return The deserialized QuantizationParams.
- * @throws IOException If an I/O error occurs during deserialization.
- * @throws ClassNotFoundException If the class of the serialized object cannot be found.
- */
- private static QuantizationParams deserializeParentParams(byte[] bytes) throws IOException, ClassNotFoundException {
- try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) {
- return (QuantizationParams) in.readObject();
- }
- }
-}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java b/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java
new file mode 100644
index 000000000..7746305ab
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import lombok.experimental.UtilityClass;
+
+/**
+ * Utility class to manage version information in a thread-safe manner using ThreadLocal storage.
+ * This class ensures that version information is available within the current thread context.
+ */
+@UtilityClass
+public class VersionContext {
+
+ /**
+ * ThreadLocal storage for version information.
+ * This allows each thread to have its own version information without interference.
+ */
+ private final ThreadLocal versionHolder = new ThreadLocal<>();
+
+ /**
+ * Sets the version for the current thread.
+ *
+ * @param version the version to be set.
+ */
+ public void setVersion(int version) {
+ versionHolder.set(version);
+ }
+
+ /**
+ * Gets the version for the current thread.
+ *
+ * @return the version for the current thread.
+ */
+ public int getVersion() {
+ return versionHolder.get();
+ }
+
+ /**
+ * Clears the version for the current thread.
+ */
+ public void clear() {
+ versionHolder.remove();
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
deleted file mode 100644
index 598a07867..000000000
--- a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
+++ /dev/null
@@ -1,21 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package org.opensearch.knn.quantization.enums;
-
-import org.opensearch.knn.KNNTestCase;
-
-public class QuantizationTypeTests extends KNNTestCase {
-
- public void testQuantizationTypeValues() {
- QuantizationType[] expectedValues = { QuantizationType.SPACE, QuantizationType.VALUE };
- assertArrayEquals(expectedValues, QuantizationType.values());
- }
-
- public void testQuantizationTypeValueOf() {
- assertEquals(QuantizationType.SPACE, QuantizationType.valueOf("SPACE"));
- assertEquals(QuantizationType.VALUE, QuantizationType.valueOf("VALUE"));
- }
-}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java
similarity index 74%
rename from src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
rename to src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java
index 3498114a6..815f81071 100644
--- a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java
@@ -7,13 +7,12 @@
import org.opensearch.knn.KNNTestCase;
-public class SQTypesTests extends KNNTestCase {
+public class ScalarQuantizationTypeTests extends KNNTestCase {
public void testSQTypesValues() {
ScalarQuantizationType[] expectedValues = {
ScalarQuantizationType.ONE_BIT,
ScalarQuantizationType.TWO_BIT,
- ScalarQuantizationType.FOUR_BIT,
- ScalarQuantizationType.UNSUPPORTED_TYPE };
+ ScalarQuantizationType.FOUR_BIT };
assertArrayEquals(expectedValues, ScalarQuantizationType.values());
}
@@ -21,6 +20,5 @@ public void testSQTypesValueOf() {
assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT"));
assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT"));
assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT"));
- assertEquals(ScalarQuantizationType.UNSUPPORTED_TYPE, ScalarQuantizationType.valueOf("UNSUPPORTED_TYPE"));
}
}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
deleted file mode 100644
index 3da665630..000000000
--- a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
+++ /dev/null
@@ -1,19 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package org.opensearch.knn.quantization.enums;
-
-import org.opensearch.knn.KNNTestCase;
-
-public class ValueQuantizationTypeTests extends KNNTestCase {
- public void testValueQuantizationTypeValues() {
- ValueQuantizationType[] expectedValues = { ValueQuantizationType.SCALAR };
- assertArrayEquals(expectedValues, ValueQuantizationType.values());
- }
-
- public void testValueQuantizationTypeValueOf() {
- assertEquals(ValueQuantizationType.SCALAR, ValueQuantizationType.valueOf("SCALAR"));
- }
-}
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
index 42fc18eba..34aa907b8 100644
--- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
@@ -8,7 +8,7 @@
import org.junit.Before;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;
@@ -27,7 +27,7 @@ public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessEx
}
public void test_Lazy_Registration() {
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
assertFalse(isRegisteredFieldAccessible());
Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
assertTrue(quantizer instanceof OneBitScalarQuantizer);
@@ -35,33 +35,23 @@ public void test_Lazy_Registration() {
}
public void testGetQuantizer_withOneBitSQParams() {
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
assertTrue(quantizer instanceof OneBitScalarQuantizer);
}
public void testGetQuantizer_withTwoBitSQParams() {
- SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
assertTrue(quantizer instanceof MultiBitScalarQuantizer);
}
public void testGetQuantizer_withFourBitSQParams() {
- SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
assertTrue(quantizer instanceof MultiBitScalarQuantizer);
}
- public void testGetQuantizer_withUnsupportedType() {
- SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE);
- try {
- QuantizerFactory.getQuantizer(params);
- fail("Expected IllegalArgumentException");
- } catch (IllegalArgumentException e) {
- assertTrue(e.getMessage().contains("No quantizer registered for type identifier"));
- }
- }
-
public void testGetQuantizer_withNullParams() {
try {
QuantizerFactory.getQuantizer(null);
@@ -73,7 +63,7 @@ public void testGetQuantizer_withNullParams() {
public void testConcurrentRegistration() throws InterruptedException {
Runnable task = () -> {
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizerFactory.getQuantizer(params);
};
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
index 7f53dae8c..2d743f883 100644
--- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
@@ -7,9 +7,8 @@
import org.junit.BeforeClass;
import org.opensearch.knn.KNNTestCase;
-import org.opensearch.knn.quantization.enums.QuantizationType;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;
@@ -18,49 +17,56 @@ public class QuantizerRegistryTests extends KNNTestCase {
@BeforeClass
public static void setup() {
- // Register the quantizers for testing with enums
- QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new);
QuantizerRegistry.register(
- SQParams.class,
- QuantizationType.VALUE,
- ScalarQuantizationType.TWO_BIT,
- () -> new MultiBitScalarQuantizer(2)
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT),
+ new OneBitScalarQuantizer()
);
QuantizerRegistry.register(
- SQParams.class,
- QuantizationType.VALUE,
- ScalarQuantizationType.FOUR_BIT,
- () -> new MultiBitScalarQuantizer(4)
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT),
+ new MultiBitScalarQuantizer(2)
+ );
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT),
+ new MultiBitScalarQuantizer(4)
);
}
public void testRegisterAndGetQuantizer() {
// Test for OneBitScalarQuantizer
- SQParams oneBitParams = new SQParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
Quantizer, ?> oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer);
// Test for MultiBitScalarQuantizer (2-bit)
- SQParams twoBitParams = new SQParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
Quantizer, ?> twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer);
+ assertEquals(2, ((MultiBitScalarQuantizer) twoBitQuantizer).getBitsPerCoordinate());
// Test for MultiBitScalarQuantizer (4-bit)
- SQParams fourBitParams = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
Quantizer, ?> fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer);
+ assertEquals(4, ((MultiBitScalarQuantizer) fourBitQuantizer).getBitsPerCoordinate());
}
- public void testGetQuantizer_withUnsupportedTypeIdentifier() {
- // Create SQParams with an unsupported type identifier
- SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered
+ public void testQuantizerRegistryIsSingleton() {
+ // Ensure the same instance is returned for the same type identifier
+ ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ Quantizer, ?> firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ Quantizer, ?> secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ assertSame(firstOneBitQuantizer, secondOneBitQuantizer);
- // Expect IllegalArgumentException when requesting a quantizer with unsupported params
- IllegalArgumentException exception = assertThrows(
- IllegalArgumentException.class,
- () -> { QuantizerRegistry.getQuantizer(params); }
- );
+ // Ensure the same instance is returned for the same type identifier (2-bit)
+ ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ Quantizer, ?> firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ Quantizer, ?> secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ assertSame(firstTwoBitQuantizer, secondTwoBitQuantizer);
- assertTrue(exception.getMessage().contains("No quantizer registered for type identifier"));
+ // Ensure the same instance is returned for the same type identifier (4-bit)
+ ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ Quantizer, ?> firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ Quantizer, ?> secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ assertSame(firstFourBitQuantizer, secondFourBitQuantizer);
}
}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
index ebe6bf6bd..974f58637 100644
--- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
@@ -7,7 +7,7 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
@@ -16,7 +16,7 @@
public class QuantizationStateSerializerTests extends KNNTestCase {
public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException {
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
float[] mean = new float[] { 0.1f, 0.2f, 0.3f };
OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
@@ -28,14 +28,16 @@ public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IO
}
public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException {
- SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } };
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
byte[] serialized = state.toByteArray();
MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized);
- assertArrayEquals(thresholds, deserialized.getThresholds());
+ for (int i = 0; i < thresholds.length; i++) {
+ assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f);
+ }
assertEquals(params, deserialized.getQuantizationParams());
}
}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
index 54e304732..5a6b4b1db 100644
--- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
@@ -8,53 +8,76 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
-import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.Version;
+import org.opensearch.knn.quantization.util.VersionContext;
import java.io.IOException;
public class QuantizationStateTests extends KNNTestCase {
public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException {
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
float[] mean = { 1.0f, 2.0f, 3.0f };
OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+ // Set the version for serialization
+ VersionContext.setVersion(Version.CURRENT.id);
+
+ // Serialize
byte[] serializedState = state.toByteArray();
+
+ // Deserialize
OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState);
- float delta = 0.0001f;
+ float delta = 0.0001f;
assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta);
- assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType());
}
public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException {
- SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } };
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+ // Set the version for serialization
+ VersionContext.setVersion(Version.CURRENT.id);
+
+ // Serialize
byte[] serializedState = state.toByteArray();
+
+ // Deserialize
MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState);
- float delta = 0.0001f;
+ float delta = 0.0001f;
for (int i = 0; i < thresholds.length; i++) {
assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i], delta);
}
- assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType());
}
- public void testDefaultQuantizationStateSerialization() throws IOException, ClassNotFoundException {
- SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE);
+ public void testSerializationWithDifferentVersions() throws IOException, ClassNotFoundException {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ float[] mean = { 1.0f, 2.0f, 3.0f };
+
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
- DefaultQuantizationState state = new DefaultQuantizationState(params);
+ // Simulate an older version
+ VersionContext.setVersion(Version.V_2_0_0.id);
+ // Serialize
byte[] serializedState = state.toByteArray();
- DefaultQuantizationState deserializedState = DefaultQuantizationState.fromByteArray(serializedState);
- assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ // Update to a new version and deserialize
+ VersionContext.setVersion(Version.CURRENT.id);
+ OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState);
+
+ float delta = 0.0001f;
+ assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta);
+ assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType());
}
}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
index 231da3dfe..ad6a44686 100644
--- a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
@@ -7,12 +7,14 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import java.io.IOException;
+
public class MultiBitScalarQuantizerTests extends KNNTestCase {
public void testTrain_twoBit() {
@@ -21,10 +23,8 @@ public void testTrain_twoBit() {
{ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
{ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
- int[] sampledIndices = { 0, 1, 2 };
- SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
TrainingRequest request = new MockTrainingRequest(params, vectors);
- request.setSampledIndices(sampledIndices);
QuantizationState state = twoBitQuantizer.train(request);
assertTrue(state instanceof MultiBitScalarQuantizationState);
@@ -39,10 +39,8 @@ public void testTrain_fourBit() {
{ 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
{ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
{ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
- int[] sampledIndices = { 0, 1, 2 };
- SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
TrainingRequest request = new MockTrainingRequest(params, vectors);
- request.setSampledIndices(sampledIndices);
QuantizationState state = fourBitQuantizer.train(request);
assertTrue(state instanceof MultiBitScalarQuantizationState);
@@ -51,19 +49,19 @@ public void testTrain_fourBit() {
assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds
}
- public void testQuantize_twoBit() {
+ public void testQuantize_twoBit() throws IOException {
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
float[][] thresholds = { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
- SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
- QuantizationOutput output = twoBitQuantizer.quantize(vector, state);
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ twoBitQuantizer.quantize(vector, state, output);
assertNotNull(output.getQuantizedVector());
- assertEquals(2, output.getQuantizedVector().length);
}
- public void testQuantize_fourBit() {
+ public void testQuantize_fourBit() throws IOException {
MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
float[][] thresholds = {
@@ -71,38 +69,33 @@ public void testQuantize_fourBit() {
{ 1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f },
{ 1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f },
{ 1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f } };
- SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
- QuantizationOutput output = fourBitQuantizer.quantize(vector, state);
- assertEquals(4, output.getQuantizedVector().length);
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ fourBitQuantizer.quantize(vector, state, output);
assertNotNull(output.getQuantizedVector());
}
- public void testQuantize_withNullVector() {
+ public void testQuantize_withNullVector() throws IOException {
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
expectThrows(
IllegalArgumentException.class,
() -> twoBitQuantizer.quantize(
null,
- new MultiBitScalarQuantizationState(new SQParams(ScalarQuantizationType.TWO_BIT), new float[2][8])
+ new MultiBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT), new float[2][8]),
+ output
)
);
}
- public void testQuantize_withInvalidState() {
- MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
- float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
- QuantizationState invalidState = new MockInvalidQuantizationState();
- expectThrows(IllegalArgumentException.class, () -> twoBitQuantizer.quantize(vector, invalidState));
- }
-
// Mock classes for testing
private static class MockTrainingRequest extends TrainingRequest {
private final float[][] vectors;
- public MockTrainingRequest(SQParams params, float[][] vectors) {
- super(params, vectors.length);
+ public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) {
+ super(vectors.length);
this.vectors = vectors;
}
@@ -111,16 +104,4 @@ public float[] getVectorByDocId(int docId) {
return vectors[docId];
}
}
-
- private static class MockInvalidQuantizationState implements QuantizationState {
- @Override
- public SQParams getQuantizationParams() {
- return new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE);
- }
-
- @Override
- public byte[] toByteArray() {
- return new byte[0];
- }
- }
}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
index 12e8a43a8..8372ac5d2 100644
--- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
@@ -7,22 +7,27 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
-import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
-import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplerType;
import org.opensearch.knn.quantization.sampler.SamplingFactory;
-import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.BitSet;
public class OneBitScalarQuantizerTests extends KNNTestCase {
public void testTrain_withTrainingRequired() {
float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } };
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
- TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest originalRequest = new TrainingRequest(vectors.length) {
@Override
public float[] getVectorByDocId(int docId) {
return vectors[docId];
@@ -32,88 +37,107 @@ public float[] getVectorByDocId(int docId) {
QuantizationState state = quantizer.train(originalRequest);
assertTrue(state instanceof OneBitScalarQuantizationState);
- float[] mean = ((OneBitScalarQuantizationState) state).getMeanThresholds();
- assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f);
+ float[] meanThresholds = ((OneBitScalarQuantizationState) state).getMeanThresholds();
+ assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f);
}
- public void testQuantize_withState() {
+ public void testQuantize_withState() throws IOException {
float[] vector = { 3.0f, 6.0f, 9.0f };
float[] thresholds = { 4.0f, 5.0f, 6.0f };
- OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds);
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
+ new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT),
+ thresholds
+ );
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
- QuantizationOutput output = quantizer.quantize(vector, state);
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ quantizer.quantize(vector, state, output);
assertNotNull(output);
byte[] expectedPackedBits = new byte[] { 0b01100000 }; // or 96 in decimal
assertArrayEquals(expectedPackedBits, output.getQuantizedVector());
}
- public void testQuantize_withNullVector() {
+ public void testQuantize_withNullVector() throws IOException {
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
- new SQParams(ScalarQuantizationType.ONE_BIT),
+ new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT),
new float[] { 0.0f }
);
- expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state));
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state, output));
}
- public void testQuantize_withInvalidState() {
+ public void testQuantize_withInvalidState() throws IOException {
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
float[] vector = { 1.0f, 2.0f, 3.0f };
QuantizationState invalidState = new QuantizationState() {
@Override
- public SQParams getQuantizationParams() {
- return new SQParams(ScalarQuantizationType.ONE_BIT);
+ public ScalarQuantizationParams getQuantizationParams() {
+ return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
}
@Override
public byte[] toByteArray() {
return new byte[0];
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // no-op
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // no-op
+ }
};
- expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState));
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState, output));
}
- public void testQuantize_withMismatchedDimensions() {
+ public void testQuantize_withMismatchedDimensions() throws IOException {
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
float[] vector = { 1.0f, 2.0f, 3.0f };
float[] thresholds = { 4.0f, 5.0f };
- OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds);
-
- expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state));
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
+ new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT),
+ thresholds
+ );
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output));
}
public void testCalculateMean() {
float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } };
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
- TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(vectors.length) {
@Override
public float[] getVectorByDocId(int docId) {
return vectors[docId];
}
};
- Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
- int[] sampledIndices = sampler.sample(vectors.length, 3);
- float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices);
- assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f);
+ Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
+ BitSet sampledIndices = sampler.sample(vectors.length, 3);
+ float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices);
+ assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f);
}
public void testCalculateMean_withNullVector() {
float[][] vectors = { { 1.0f, 2.0f, 3.0f }, null, { 7.0f, 8.0f, 9.0f } };
- SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
- TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(vectors.length) {
@Override
public float[] getVectorByDocId(int docId) {
return vectors[docId];
}
};
- Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
- int[] sampledIndices = sampler.sample(vectors.length, 3);
- expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices));
+ Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
+ BitSet sampledIndices = sampler.sample(vectors.length, 3);
+ expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices));
}
}
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
index 4d3345289..e930aef04 100644
--- a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
@@ -7,76 +7,55 @@
import org.opensearch.knn.KNNTestCase;
-import java.util.Arrays;
-import java.util.stream.IntStream;
+import java.util.BitSet;
public class ReservoirSamplerTests extends KNNTestCase {
public void testSampleLessThanSampleSize() {
- ReservoirSampler sampler = new ReservoirSampler();
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
int totalNumberOfVectors = 5;
int sampleSize = 10;
- int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
- int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray();
- assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
+ BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ BitSet expectedIndices = new BitSet(totalNumberOfVectors);
+ expectedIndices.set(0, totalNumberOfVectors);
+ assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
}
public void testSampleEqualToSampleSize() {
- ReservoirSampler sampler = new ReservoirSampler();
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
int totalNumberOfVectors = 10;
int sampleSize = 10;
- int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
- int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray();
- assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
- }
-
- public void testSampleGreaterThanSampleSize() {
- ReservoirSampler sampler = new ReservoirSampler(12345); // Fixed seed for reproducibility
- int totalNumberOfVectors = 100;
- int sampleSize = 10;
- int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
- assertEquals(sampleSize, sampledIndices.length);
- assertTrue(Arrays.stream(sampledIndices).allMatch(i -> i >= 0 && i < totalNumberOfVectors));
- }
-
- public void testSampleReproducibility() {
- long seed = 12345L;
- ReservoirSampler sampler1 = new ReservoirSampler(seed);
- ReservoirSampler sampler2 = new ReservoirSampler(seed);
- int totalNumberOfVectors = 100;
- int sampleSize = 10;
-
- int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize);
- int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize);
-
- assertArrayEquals(sampledIndices1, sampledIndices2);
+ BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ BitSet expectedIndices = new BitSet(totalNumberOfVectors);
+ expectedIndices.set(0, totalNumberOfVectors);
+ assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
}
public void testSampleRandomness() {
- ReservoirSampler sampler1 = new ReservoirSampler();
- ReservoirSampler sampler2 = new ReservoirSampler();
+ ReservoirSampler sampler1 = ReservoirSampler.getInstance();
+ ReservoirSampler sampler2 = ReservoirSampler.getInstance();
int totalNumberOfVectors = 100;
int sampleSize = 10;
- int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize);
- int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize);
+ BitSet sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize);
+ BitSet sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize);
- assertNotEquals(Arrays.toString(sampledIndices1), Arrays.toString(sampledIndices2));
+ assertNotEquals(sampledIndices1, sampledIndices2);
}
public void testEdgeCaseZeroVectors() {
- ReservoirSampler sampler = new ReservoirSampler();
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
int totalNumberOfVectors = 0;
int sampleSize = 10;
- int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
- assertEquals(0, sampledIndices.length);
+ BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(0, sampledIndices.cardinality());
}
public void testEdgeCaseZeroSampleSize() {
- ReservoirSampler sampler = new ReservoirSampler();
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
int totalNumberOfVectors = 10;
int sampleSize = 0;
- int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
- assertEquals(0, sampledIndices.length);
+ BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(0, sampledIndices.cardinality());
}
}
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
index ca72c1c5e..db8772b70 100644
--- a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
@@ -9,7 +9,7 @@
public class SamplingFactoryTests extends KNNTestCase {
public void testGetSampler_withReservoir() {
- Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
assertTrue(sampler instanceof ReservoirSampler);
}
diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java
deleted file mode 100644
index c91c7177b..000000000
--- a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package org.opensearch.knn.quantization.util;
-
-import org.opensearch.knn.KNNTestCase;
-
-import java.util.Arrays;
-import java.util.List;
-
-public class BitPackingUtilsTests extends KNNTestCase {
-
- public void testPackBits() {
- List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, new byte[] { 1, 0, 1, 0, 0, 1, 0, 0 });
-
- byte[] expectedPackedArray = new byte[] { (byte) 0b01011011, (byte) 0b10100100 };
- byte[] packedArray = BitPacker.packBits(bitArrays);
-
- assertArrayEquals(expectedPackedArray, packedArray);
- }
-
- public void testPackBitsEmptyList() {
- IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(Arrays.asList()); });
- assertEquals("The list of bit arrays cannot be empty.", exception.getMessage());
- }
-
- public void testPackBitsNullBitArray() {
- List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, null);
-
- IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); });
- assertEquals("Bit array cannot be null.", exception.getMessage());
- }
-
- public void testPackBitsInconsistentLength() {
- List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, new byte[] { 1, 0, 1 });
-
- IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); });
- assertEquals("All bit arrays must have the same length.", exception.getMessage());
- }
-
- public void testPackBitsEdgeCaseSingleBitArray() {
- List bitArrays = Arrays.asList(new byte[] { 1 });
-
- byte[] expectedPackedArray = new byte[] { (byte) 0b10000000 };
- byte[] packedArray = BitPacker.packBits(bitArrays);
-
- assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray);
- }
-
- public void testPackBitsEdgeCaseSingleBit() {
- List bitArrays = Arrays.asList(new byte[] { 1, 0, 1, 0, 1, 0, 1, 0 }, new byte[] { 1, 1, 1, 1, 1, 1, 1, 1 });
-
- byte[] expectedPackedArray = new byte[] { (byte) 0b10101010, (byte) 0b11111111 };
- byte[] packedArray = BitPacker.packBits(bitArrays);
-
- assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray);
- }
-
- public void testPackBits_emptyArray() {
- List bitArrays = Arrays.asList();
- expectThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); });
- ;
- }
-}