Skip to content

Commit

Permalink
Implemented Serlization using Writable
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 committed Aug 12, 2024
1 parent 0ebbef0 commit 12f3d72
Show file tree
Hide file tree
Showing 20 changed files with 385 additions and 393 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@

package org.opensearch.knn.quantization.models.quantizationOutput;

import lombok.NoArgsConstructor;
import lombok.Getter;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import lombok.NoArgsConstructor;

/**
* The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
Expand All @@ -18,23 +15,27 @@
@NoArgsConstructor
public class BinaryQuantizationOutput implements QuantizationOutput<byte[]> {
@Getter
private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
private byte[] quantizedVector;

/**
* Updates the quantized vector with a new byte array.
*
* @param newQuantizedVector the new quantized vector represented as a byte array.
*/
public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException {
public void updateQuantizedVector(final byte[] newQuantizedVector) {
if (newQuantizedVector == null || newQuantizedVector.length == 0) {
throw new IllegalArgumentException("Quantized vector cannot be null or empty");
}
byteArrayOutputStream.reset();
byteArrayOutputStream.write(newQuantizedVector);
this.quantizedVector = newQuantizedVector;
}

/**
* Returns the quantized vector.
*
* @return the quantized vector byte array.
*/
@Override
public byte[] getQuantizedVector() {
return byteArrayOutputStream.toByteArray();
return quantizedVector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.quantization.models.quantizationParams;

import java.io.Externalizable;
import org.opensearch.core.common.io.stream.Writeable;

/**
* Interface for quantization parameters.
Expand All @@ -14,7 +14,7 @@
* Implementations of this interface are expected to provide specific configurations
* for various quantization strategies.
*/
public interface QuantizationParams extends Externalizable {
public interface QuantizationParams extends Writeable {
/**
* Provides a unique identifier for the quantization parameters.
* This identifier is typically a combination of the quantization type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Locale;

/**
* The SQParams class represents the parameters specific to scalar quantization (SQ).
* The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ).
* This class implements the QuantizationParams interface and includes the type of scalar quantization.
*/
@Getter
Expand All @@ -39,67 +39,40 @@ public static String generateTypeIdentifier(ScalarQuantizationType sqType) {
}

/**
* Serializes the SQParams object to an external output.
* This method writes the scalar quantization type to the output stream.
* Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type.
* This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
*
* @param out the ObjectOutput to write the object to.
* @throws IOException if an I/O error occurs during serialization.
* @return A string representing the unique type identifier.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
// The version is already written by the parent state class, no need to write it here again
// Retrieve the current version from VersionContext
// This context will be used by other classes involved in the serialization process.
// Example:
// int version = VersionContext.getVersion(); // Get the current version from VersionContext
// Any Version Specific logic can be wriiten based on Version
out.writeObject(sqType);
public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id);
}

/**
* Deserializes the SQParams object from an external input with versioning.
* This method reads the scalar quantization type and new field from the input stream based on the version.
* Writes the object to the output stream.
* This method is part of the Writeable interface and is used to serialize the object.
*
* @param in the ObjectInput to read the object from.
* @throws IOException if an I/O error occurs during deserialization.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
* @param out the output stream to write the object to.
* @throws IOException if an I/O error occurs.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// The version is already read by the parent state class and set in VersionContext
// Retrieve the current version from VersionContext to handle version-specific deserialization logic
// int versionId = VersionContext.getVersion();
// Version version = Version.fromId(versionId);

sqType = (ScalarQuantizationType) in.readObject();

// Add version-specific deserialization logic
// For example, if new fields are added in a future version, handle them here
// This section contains conditional logic to handle different versions appropriately.
// Example:
// if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) {
// // Handle logic for versions between 1.0.0 and 2.0.0
// // Example: Read additional fields introduced in version 1.0.0
// // newField = in.readInt();
// } else if (version.onOrAfter(Version.V_2_0_0)) {
// // Handle logic for versions 2.0.0 and above
// // Example: Read additional fields introduced in version 2.0.0
// // anotherNewField = in.readFloat();
// }
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(sqType);
}

/**
* Provides a unique type identifier for the SQParams, combining the SQ type.
* This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
* Reads the object from the input stream.
* This method is part of the Writeable interface and is used to deserialize the object.
*
* @return A string representing the unique type identifier.
* @param in the input stream to read the object from.
* @throws IOException if an I/O error occurs.
*/
@Override
public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id);
public ScalarQuantizationParams(StreamInput in, int version) throws IOException {
this.sqType = in.readEnum(ScalarQuantizationType.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/**
* DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
Expand All @@ -27,16 +27,22 @@ public class DefaultQuantizationState implements QuantizationState {
private QuantizationParams params;
private static final long serialVersionUID = 1L; // Version ID for serialization

/**
* Returns the quantization parameters associated with this state.
*
* @return the quantization parameters.
*/
@Override
public QuantizationParams getQuantizationParams() {
return params;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
params.writeTo(out);
}

public DefaultQuantizationState(StreamInput in) throws IOException {
int version = in.readInt(); // Read the version
this.params = new ScalarQuantizationParams(in, version);
}

/**
* Serializes the quantization state to a byte array.
*
Expand All @@ -45,7 +51,7 @@ public QuantizationParams getQuantizationParams() {
*/
@Override
public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, null);
return QuantizationStateSerializer.serialize(this);
}

/**
Expand All @@ -57,36 +63,6 @@ public byte[] toByteArray() throws IOException {
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
new DefaultQuantizationState(),
(parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams)
);
}

/**
* Writes the object to the output stream.
* This method is part of the Externalizable interface and is used to serialize the object.
*
* @param out the output stream to write the object to.
* @throws IOException if an I/O error occurs.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
out.writeObject(params);
}

/**
* Reads the object from the input stream.
* This method is part of the Externalizable interface and is used to deserialize the object.
*
* @param in the input stream to read the object from.
* @throws IOException if an I/O error occurs.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.params = (QuantizationParams) in.readObject();
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new);
}
}
Loading

0 comments on commit 12f3d72

Please sign in to comment.