Skip to content

Commit

Permalink
Implemented Serlization using Writable
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 12, 2024
1 parent 0ebbef0 commit b71c94e
Show file tree
Hide file tree
Showing 26 changed files with 462 additions and 508 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,20 @@ public enum ScalarQuantizationType {
ScalarQuantizationType(int id) {
this.id = id;
}

/**
* Returns the ScalarQuantizationType associated with the given ID.
*
* @param id the ID of the quantization type.
* @return the corresponding ScalarQuantizationType.
* @throws IllegalArgumentException if the ID does not correspond to any ScalarQuantizationType.
*/
public static ScalarQuantizationType fromId(int id) {
for (ScalarQuantizationType type : ScalarQuantizationType.values()) {
if (type.getId() == id) {
return type;
}
}
throw new IllegalArgumentException("Unknown ScalarQuantizationType ID: " + id);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
final class QuantizerRegistrar {

/**
* Registers default quantizers if not already registered.
* Registers default quantizers
* <p>
* This method is synchronized to ensure that registration occurs only once,
* even in a multi-threaded environment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ final class QuantizerRegistry {
* @param quantizer an instance of the quantizer
*/
static void register(final String paramIdentifier, final Quantizer<?, ?> quantizer) {
// Ensure that the quantizer for this identifier is registered only once
registry.putIfAbsent(paramIdentifier, quantizer);
// Check if the quantizer is already registered for the given identifier
if (registry.putIfAbsent(paramIdentifier, quantizer) != null) {
// Throw an exception if a quantizer is already registered
throw new IllegalArgumentException("Quantizer already registered for identifier: " + paramIdentifier);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

package org.opensearch.knn.quantization.models.quantizationOutput;

import lombok.NoArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;

/**
* The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
Expand All @@ -18,23 +17,62 @@
@NoArgsConstructor
public class BinaryQuantizationOutput implements QuantizationOutput<byte[]> {
@Getter
private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
private byte[] quantizedVector;

/**
* Updates the quantized vector with a new byte array.
* Prepares the quantized vector array based on the provided parameters and returns it for direct modification.
* This method ensures that the internal byte array is appropriately sized and cleared before being used.
*
* <p>
* The method accepts two parameters:
* <ul>
* <li><b>bitsPerCoordinate:</b> The number of bits used per coordinate. This determines the granularity of the quantization.</li>
* <li><b>vectorLength:</b> The length of the original vector that needs to be quantized. This helps in calculating the required byte array size.</li>
* </ul>
* </p>
*
* <p>
* If the existing quantized vector is either null or not the same size as the required byte array,
* a new byte array is allocated. Otherwise, the existing array is cleared (i.e., all bytes are set to zero).
* </p>
*
* @param newQuantizedVector the new quantized vector represented as a byte array.
* <p>
* This method is designed to be used in conjunction with a bit-packing utility that writes quantized values directly
* into the returned byte array.
* </p>
*
* @param params an array of parameters, where the first parameter is the number of bits per coordinate (int),
* and the second parameter is the length of the vector (int).
* @return the prepared and writable quantized vector as a byte array.
* @throws IllegalArgumentException if the parameters are not as expected (e.g., missing or not integers).
*/
public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException {
if (newQuantizedVector == null || newQuantizedVector.length == 0) {
throw new IllegalArgumentException("Quantized vector cannot be null or empty");
@Override
public byte[] prepareAndGetWritableQuantizedVector(Object... params) {
if (params.length != 2 || !(params[0] instanceof Integer) || !(params[1] instanceof Integer)) {
throw new IllegalArgumentException("Expected two integer parameters: bitsPerCoordinate and vectorLength");
}
byteArrayOutputStream.reset();
byteArrayOutputStream.write(newQuantizedVector);
int bitsPerCoordinate = (int) params[0];
int vectorLength = (int) params[1];
int totalBits = bitsPerCoordinate * vectorLength;
int byteLength = (totalBits + 7) >> 3;

if (this.quantizedVector == null || this.quantizedVector.length != byteLength) {
this.quantizedVector = new byte[byteLength];
} else {
Arrays.fill(this.quantizedVector, (byte) 0);
}

return this.quantizedVector;
}


/**
* Returns the quantized vector.
*
* @return the quantized vector byte array.
*/
@Override
public byte[] getQuantizedVector() {
return byteArrayOutputStream.toByteArray();
return quantizedVector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public interface QuantizationOutput<T> {
T getQuantizedVector();

/**
* Updates the quantized vector with new data.
* Prepares and returns the writable quantized vector for direct modification.
*
* @param newQuantizedVector the new quantized vector data.
* @throws IOException if an I/O error occurs during the update.
* @param params the parameters needed for preparing the quantized vector.
* @return the prepared and writable quantized vector.
*/
void updateQuantizedVector(T newQuantizedVector) throws IOException;
T prepareAndGetWritableQuantizedVector(Object... params);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.quantization.models.quantizationParams;

import java.io.Externalizable;
import org.opensearch.core.common.io.stream.Writeable;

/**
* Interface for quantization parameters.
Expand All @@ -14,7 +14,7 @@
* Implementations of this interface are expected to provide specific configurations
* for various quantization strategies.
*/
public interface QuantizationParams extends Externalizable {
public interface QuantizationParams extends Writeable {
/**
* Provides a unique identifier for the quantization parameters.
* This identifier is typically a combination of the quantization type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Locale;

/**
* The SQParams class represents the parameters specific to scalar quantization (SQ).
* The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ).
* This class implements the QuantizationParams interface and includes the type of scalar quantization.
*/
@Getter
Expand All @@ -26,7 +25,6 @@
@EqualsAndHashCode
public class ScalarQuantizationParams implements QuantizationParams {
private ScalarQuantizationType sqType;
private static final long serialVersionUID = 1L; // Version ID for serialization

/**
* Static method to generate type identifier based on ScalarQuantizationType.
Expand All @@ -39,67 +37,41 @@ public static String generateTypeIdentifier(ScalarQuantizationType sqType) {
}

/**
* Serializes the SQParams object to an external output.
* This method writes the scalar quantization type to the output stream.
* Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type.
* This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
*
* @param out the ObjectOutput to write the object to.
* @throws IOException if an I/O error occurs during serialization.
* @return A string representing the unique type identifier.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
// The version is already written by the parent state class, no need to write it here again
// Retrieve the current version from VersionContext
// This context will be used by other classes involved in the serialization process.
// Example:
// int version = VersionContext.getVersion(); // Get the current version from VersionContext
// Any Version Specific logic can be wriiten based on Version
out.writeObject(sqType);
public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return "ScalarQuantizationParams_" + id;
}

/**
* Deserializes the SQParams object from an external input with versioning.
* This method reads the scalar quantization type and new field from the input stream based on the version.
* Writes the object to the output stream.
* This method is part of the Writeable interface and is used to serialize the object.
*
* @param in the ObjectInput to read the object from.
* @throws IOException if an I/O error occurs during deserialization.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
* @param out the output stream to write the object to.
* @throws IOException if an I/O error occurs.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// The version is already read by the parent state class and set in VersionContext
// Retrieve the current version from VersionContext to handle version-specific deserialization logic
// int versionId = VersionContext.getVersion();
// Version version = Version.fromId(versionId);

sqType = (ScalarQuantizationType) in.readObject();

// Add version-specific deserialization logic
// For example, if new fields are added in a future version, handle them here
// This section contains conditional logic to handle different versions appropriately.
// Example:
// if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) {
// // Handle logic for versions between 1.0.0 and 2.0.0
// // Example: Read additional fields introduced in version 1.0.0
// // newField = in.readInt();
// } else if (version.onOrAfter(Version.V_2_0_0)) {
// // Handle logic for versions 2.0.0 and above
// // Example: Read additional fields introduced in version 2.0.0
// // anotherNewField = in.readFloat();
// }
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(sqType.getId());
}

/**
* Provides a unique type identifier for the SQParams, combining the SQ type.
* This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
* Reads the object from the input stream.
* This method is part of the Writeable interface and is used to deserialize the object.
*
* @return A string representing the unique type identifier.
* @param in the input stream to read the object from.
* @throws IOException if an I/O error occurs.
*/
@Override
public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id);
public ScalarQuantizationParams(StreamInput in, int version) throws IOException {
int typeId = in.readVInt();
this.sqType = ScalarQuantizationType.fromId(typeId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/**
* DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
Expand All @@ -25,18 +25,23 @@
@AllArgsConstructor
public class DefaultQuantizationState implements QuantizationState {
private QuantizationParams params;
private static final long serialVersionUID = 1L; // Version ID for serialization

/**
* Returns the quantization parameters associated with this state.
*
* @return the quantization parameters.
*/
@Override
public QuantizationParams getQuantizationParams() {
return params;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
params.writeTo(out);
}

public DefaultQuantizationState(StreamInput in) throws IOException {
int version = in.readInt(); // Read the version
this.params = new ScalarQuantizationParams(in, version);
}

/**
* Serializes the quantization state to a byte array.
*
Expand All @@ -45,7 +50,7 @@ public QuantizationParams getQuantizationParams() {
*/
@Override
public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, null);
return QuantizationStateSerializer.serialize(this);
}

/**
Expand All @@ -57,36 +62,6 @@ public byte[] toByteArray() throws IOException {
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
new DefaultQuantizationState(),
(parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams)
);
}

/**
* Writes the object to the output stream.
* This method is part of the Externalizable interface and is used to serialize the object.
*
* @param out the output stream to write the object to.
* @throws IOException if an I/O error occurs.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
out.writeObject(params);
}

/**
* Reads the object from the input stream.
* This method is part of the Externalizable interface and is used to deserialize the object.
*
* @param in the input stream to read the object from.
* @throws IOException if an I/O error occurs.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.params = (QuantizationParams) in.readObject();
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new);
}
}
Loading

0 comments on commit b71c94e

Please sign in to comment.