Skip to content

Commit

Permalink
Refactor method structure and definitions (opensearch-project#1920)
Browse files Browse the repository at this point in the history
* Remove getMethod from KNNLibrary interface

Removes the getMethod from KNNLibrary interface. Users should not
directly retrieve methods from the engines as this creates additional
complexity. Instead, they should get everything they need about a method
from the KNNLibrary itself. This change will allow us to better maintain
the different methods.

* Move EngineSpecificMethodContext into KNNMethod

Moving the EngineSpecificMethodContext into the KNNMethod class.
EngineSpecificMethodContext is mainly used during search to provide
parameter validation/configure dynamic updates. Moving into KNNMethod
will encapsulate the structure of the method into one class.

* Change KNNMethod to interface

Changes KNNMethod to interface and creates methods per engine/algo
combination. For one, this will make code much more maintainable as we
wont have to deal with big maps and builders. For the other, we can
implement more complex logic between engines.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Add simple encoder interface

Adds simple encoder interface to clean up definitions of the encoders.

---------

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored Aug 2, 2024
1 parent ec6451c commit adaf150
Show file tree
Hide file tree
Showing 36 changed files with 1,031 additions and 795 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897)
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.engine.KNNMethod;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -56,7 +55,6 @@ public class FaissSQIT extends AbstractRestartUpgradeTestCase {

public void testHNSWSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Exception {
if (!isRunningAgainstOldCluster()) {
KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT };
Random random = new Random();
SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)];
Expand Down Expand Up @@ -97,7 +95,7 @@ public void testHNSWSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws E
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.NAME, KNNConstants.METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(KNNConstants.PARAMETERS)
Expand Down Expand Up @@ -133,8 +131,6 @@ public void testHNSWSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws E

public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenSucceed() throws Exception {
if (!isRunningAgainstOldCluster()) {
KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
new Random();

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
Expand Down Expand Up @@ -175,7 +171,7 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.NAME, KNNConstants.METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,56 +20,49 @@
public abstract class AbstractKNNLibrary implements KNNLibrary {

protected final Map<String, KNNMethod> methods;
protected final Map<String, EngineSpecificMethodContext> engineMethods;
@Getter
protected final String version;

@Override
public KNNMethod getMethod(String methodName) {
KNNMethod method = methods.get(methodName);
if (method == null) {
throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName));
}
return method;
}

@Override
public EngineSpecificMethodContext getMethodContext(String methodName) {
EngineSpecificMethodContext method = engineMethods.get(methodName);
if (method == null) {
throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName));
}
return method;
validateMethodExists(methodName);
KNNMethod method = methods.get(methodName);
return method.getEngineSpecificMethodContext();
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
return getMethod(methodName).validate(knnMethodContext);
validateMethodExists(methodName);
return methods.get(methodName).validate(knnMethodContext);
}

@Override
public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
return getMethod(methodName).validateWithData(knnMethodContext, vectorSpaceInfo);
validateMethodExists(methodName);
return methods.get(methodName).validateWithData(knnMethodContext, vectorSpaceInfo);
}

@Override
public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
return getMethod(methodName).isTrainingRequired(knnMethodContext);
validateMethodExists(methodName);
return methods.get(methodName).isTrainingRequired(knnMethodContext);
}

@Override
public Map<String, Object> getMethodAsMap(KNNMethodContext knnMethodContext) {
KNNMethod knnMethod = methods.get(knnMethodContext.getMethodComponentContext().getName());
String method = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(method);
KNNMethod knnMethod = methods.get(method);
return knnMethod.getAsMap(knnMethodContext);
}

if (knnMethod == null) {
throw new IllegalArgumentException(
String.format("Invalid method name: %s", knnMethodContext.getMethodComponentContext().getName())
);
private void validateMethodExists(String methodName) {
KNNMethod method = methods.get(methodName);
if (method == null) {
throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName));
}

return knnMethod.getAsMap(knnMethodContext);
}
}
119 changes: 119 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

/**
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
* It defines the common attributes and methods that all KNN methods should implement.
*/
@AllArgsConstructor
public abstract class AbstractKNNMethod implements KNNMethod {

protected final MethodComponent methodComponent;
protected final Set<SpaceType> spaces;
protected final EngineSpecificMethodContext engineSpecificMethodContext;

@Override
public boolean isSpaceTypeSupported(SpaceType space) {
return spaces.contains(space);
}

@Override
public ValidationException validate(KNNMethodContext knnMethodContext) {
List<String> errorMessages = new ArrayList<>();
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
Locale.ROOT,
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
knnMethodContext.getSpaceType().getValue()
)
);
}

ValidationException methodValidation = methodComponent.validate(knnMethodContext.getMethodComponentContext());
if (methodValidation != null) {
errorMessages.addAll(methodValidation.validationErrors());
}

if (errorMessages.isEmpty()) {
return null;
}

ValidationException validationException = new ValidationException();
validationException.addValidationErrors(errorMessages);
return validationException;
}

@Override
public ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) {
List<String> errorMessages = new ArrayList<>();
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
Locale.ROOT,
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
knnMethodContext.getSpaceType().getValue()
)
);
}

ValidationException methodValidation = methodComponent.validateWithData(
knnMethodContext.getMethodComponentContext(),
vectorSpaceInfo
);
if (methodValidation != null) {
errorMessages.addAll(methodValidation.validationErrors());
}

if (errorMessages.isEmpty()) {
return null;
}

ValidationException validationException = new ValidationException();
validationException.addValidationErrors(errorMessages);
return validationException;
}

@Override
public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
return methodComponent.isTrainingRequired(knnMethodContext.getMethodComponentContext());
}

@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) {
return methodComponent.estimateOverheadInKB(knnMethodContext.getMethodComponentContext(), dimension);
}

@Override
public Map<String, Object> getAsMap(KNNMethodContext knnMethodContext) {
Map<String, Object> parameterMap = new HashMap<>(methodComponent.getAsMap(knnMethodContext.getMethodComponentContext()));
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
return parameterMap;
}

@Override
public EngineSpecificMethodContext getEngineSpecificMethodContext() {
return engineSpecificMethodContext;
}
}
27 changes: 27 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/Encoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

/**
* Interface representing an encoder. An encoder generally refers to a vector quantizer.
*/
public interface Encoder {
/**
* The name of the encoder does not have to be unique. Howevwer, when using within a method, there cannot be
* 2 encoders with the same name.
*
* @return Name of the encoder
*/
default String getName() {
return getMethodComponent().getName();
}

/**
*
* @return Method component associated with the encoder
*/
MethodComponent getMethodComponent();
}
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/engine/JVMLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public abstract class JVMLibrary extends AbstractKNNLibrary {
* @param methods Map of k-NN methods that the library supports
* @param version String representing version of library
*/
public JVMLibrary(Map<String, KNNMethod> methods, Map<String, EngineSpecificMethodContext> engineMethodMetadataMap, String version) {
super(methods, engineMethodMetadataMap, version);
public JVMLibrary(Map<String, KNNMethod> methods, String version) {
super(methods, version);
}

@Override
Expand Down
5 changes: 0 additions & 5 deletions src/main/java/org/opensearch/knn/index/engine/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,6 @@ public String getCompoundExtension() {
return knnLibrary.getCompoundExtension();
}

@Override
public KNNMethod getMethod(String methodName) {
return knnLibrary.getMethod(methodName);
}

@Override
public EngineSpecificMethodContext getMethodContext(String methodName) {
return knnLibrary.getMethodContext(methodName);
Expand Down
9 changes: 0 additions & 9 deletions src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ public interface KNNLibrary {
*/
String getCompoundExtension();

/**
* Gets a particular KNN method that the library supports. This should throw an exception if the method is not
* supported by the library.
*
* @param methodName name of the method to be looked up
* @return KNNMethod in the library corresponding to the method name
*/
KNNMethod getMethod(String methodName);

/**
* Gets metadata related to methods supported by the library
* @param methodName
Expand Down
Loading

0 comments on commit adaf150

Please sign in to comment.