Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #633: prevent use of "$similarity" in projection #635

Merged
merged 13 commits into from
Nov 13, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ public enum ErrorCode {

VECTOR_SEARCH_INVALID_FUNCTION_NAME("Invalid vector search function name: "),

VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED(
"$similarity projection is not supported for this command"),

VECTOR_SEARCH_FIELD_TOO_BIG("Vector embedding field '$vector' length too big"),
VECTORIZE_SERVICE_NOT_REGISTERED("Vectorize service name provided is not registered : "),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,14 @@ private static class PathCollector {
private Boolean idInclusion = null;

/** Whether similarity score is needed. */
private boolean includeSimilarityScore;
private final boolean includeSimilarityScore;

private PathCollector() {}
private PathCollector(boolean includeSimilarityScore) {
this.includeSimilarityScore = includeSimilarityScore;
}

static PathCollector collectPaths(JsonNode def, boolean includeSimilarity) {
return new PathCollector().collectFromObject(def, null, includeSimilarity);
return new PathCollector(includeSimilarity).collectFromObject(def, null);
}

public DocumentProjector buildProjector() {
Expand Down Expand Up @@ -171,7 +173,7 @@ boolean isIdentityProjection() {
return paths.isEmpty() && slices.isEmpty() && !Boolean.FALSE.equals(idInclusion);
}

PathCollector collectFromObject(JsonNode ob, String parentPath, boolean includeSimilarity) {
PathCollector collectFromObject(JsonNode ob, String parentPath) {
var it = ob.fields();
while (it.hasNext()) {
var entry = it.next();
Expand All @@ -185,8 +187,7 @@ PathCollector collectFromObject(JsonNode ob, String parentPath, boolean includeS
}
if (path.charAt(0) == '$'
&& !(path.equals(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD)
|| DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD.equals(path)
|| DocumentConstants.Fields.VECTOR_FUNCTION_PROJECTION_FIELD.equals(path))) {
|| DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD.equals(path))) {
// First: no operators allowed at root level
if (parentPath == null) {
throw new JsonApiException(
Expand All @@ -208,20 +209,6 @@ PathCollector collectFromObject(JsonNode ob, String parentPath, boolean includeS
addSlice(parentPath, entry.getValue());
continue;
}
// This `or` is needed because the method is called in loop
includeSimilarityScore = includeSimilarityScore || includeSimilarity;
if (parentPath == null
&& DocumentConstants.Fields.VECTOR_FUNCTION_PROJECTION_FIELD.equals(path)) {
JsonNode value = entry.getValue();
if (BigDecimal.ZERO.equals(value.decimalValue()) && includeSimilarity) {
throw new JsonApiException(
ErrorCode.UNSUPPORTED_PROJECTION_PARAM,
ErrorCode.UNSUPPORTED_PROJECTION_PARAM.getMessage()
+ ": Cannot exclude $similarity when `includeSimilarity` option is set `true`");
} else {
includeSimilarityScore = true;
}
}

if (parentPath != null) {
path = parentPath + "." + path;
Expand All @@ -242,7 +229,7 @@ PathCollector collectFromObject(JsonNode ob, String parentPath, boolean includeS
}
} else if (value.isObject()) {
// Nested definitions allowed, too
collectFromObject(value, path, includeSimilarity);
collectFromObject(value, path);
} else {
// Unknown JSON node type; error
throw new JsonApiException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.api.model.command.impl.FindOneAndDeleteCommand;
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.operation.model.Operation;
import io.stargate.sgv2.jsonapi.service.operation.model.ReadType;
import io.stargate.sgv2.jsonapi.service.operation.model.impl.DeleteOperation;
Expand Down Expand Up @@ -47,11 +45,6 @@ public Class<FindOneAndDeleteCommand> getCommandClass() {
public Operation resolveCommand(CommandContext commandContext, FindOneAndDeleteCommand command) {
FindOperation findOperation = getFindOperation(commandContext, command);
final DocumentProjector documentProjector = command.buildProjector();
if (documentProjector.doIncludeSimilarityScore()) {
throw new JsonApiException(
ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED,
ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED.getMessage());
}
// return
return DeleteOperation.deleteOneAndReturn(
commandContext, findOperation, operationsConfig.lwt().retries(), documentProjector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.api.model.command.impl.FindOneAndReplaceCommand;
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.operation.model.Operation;
import io.stargate.sgv2.jsonapi.service.operation.model.ReadType;
import io.stargate.sgv2.jsonapi.service.operation.model.impl.FindOperation;
Expand Down Expand Up @@ -49,11 +47,6 @@ public Operation resolveCommand(CommandContext commandContext, FindOneAndReplace
FindOperation findOperation = getFindOperation(commandContext, command);

final DocumentProjector documentProjector = command.buildProjector();
if (documentProjector.doIncludeSimilarityScore()) {
throw new JsonApiException(
ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED,
ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED.getMessage());
}
// Vectorize replacement document
commandContext.tryVectorize(
objectMapper.getNodeFactory(), List.of(command.replacementDocument()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.api.model.command.impl.FindOneAndUpdateCommand;
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.operation.model.Operation;
import io.stargate.sgv2.jsonapi.service.operation.model.ReadType;
import io.stargate.sgv2.jsonapi.service.operation.model.impl.FindOperation;
Expand Down Expand Up @@ -49,12 +47,6 @@ public Operation resolveCommand(CommandContext commandContext, FindOneAndUpdateC
FindOperation findOperation = getFindOperation(commandContext, command);

final DocumentProjector documentProjector = command.buildProjector();
if (documentProjector.doIncludeSimilarityScore()) {
throw new JsonApiException(
ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED,
ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED.getMessage());
}

// Vectorize update clause
commandContext.tryVectorize(objectMapper.getNodeFactory(), command.updateClause());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1525,35 +1525,9 @@ public void findVectorWithUnmatchedSize() {
@Order(7)
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class VectorSearchSimilarityProjection {
@Test
@Order(1)
public void findOne() {
insertVectorDocuments();
String json =
"""
{
"findOne": {
"sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]},
"projection" : {"_id" : 1, "$similarity" : 1}
}
}
""";

given()
.header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken())
.contentType(ContentType.JSON)
.body(json)
.when()
.post(CollectionResource.BASE_PATH, namespaceName, collectionName)
.then()
.statusCode(200)
.body("data.document._id", is("3"))
.body("data.document.$similarity", notNullValue())
.body("errors", is(nullValue()));
}

@Test
@Order(2)
@Order(1)
public void findOneSimilarityOption() {
insertVectorDocuments();
String json =
Expand All @@ -1580,39 +1554,7 @@ public void findOneSimilarityOption() {
}

@Test
@Order(3)
public void find() {
insertVectorDocuments();
String json =
"""
{
"find": {
"sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]},
"projection" : {"_id" : 1, "$similarity" : 1}
}
}
""";

given()
.header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken())
.contentType(ContentType.JSON)
.body(json)
.when()
.post(CollectionResource.BASE_PATH, namespaceName, collectionName)
.then()
.statusCode(200)
.body("errors", is(nullValue()))
.body("data.documents", hasSize(3))
.body("data.documents[0]._id", is("3"))
.body("data.documents[0].$similarity", notNullValue())
.body("data.documents[1]._id", is("2"))
.body("data.documents[1].$similarity", notNullValue())
.body("data.documents[2]._id", is("1"))
.body("data.documents[2].$similarity", notNullValue());
}

@Test
@Order(4)
@Order(2)
public void findSimilarityOption() {
insertVectorDocuments();
String json =
Expand Down Expand Up @@ -1642,131 +1584,6 @@ public void findSimilarityOption() {
.body("data.documents[2]._id", is("1"))
.body("data.documents[2].$similarity", notNullValue());
}

@Test
@Order(5)
public void findSimilarityInvalidProjection() {
insertVectorDocuments();
String json =
"""
{
"find": {
"sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]},
"projection" : {"_id" : 1, "$similarity" : 0},
"options" : {"includeSimilarity" : true}
}
}
""";

given()
.header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken())
.contentType(ContentType.JSON)
.body(json)
.when()
.post(CollectionResource.BASE_PATH, namespaceName, collectionName)
.then()
.statusCode(200)
.body("errors", is(notNullValue()))
.body("errors[0].exceptionClass", is("JsonApiException"))
.body("errors[0].errorCode", is("UNSUPPORTED_PROJECTION_PARAM"))
.body(
"errors[0].message",
is(
ErrorCode.UNSUPPORTED_PROJECTION_PARAM.getMessage()
+ ": Cannot exclude $similarity when `includeSimilarity` option is set `true`"));
}

@Test
@Order(6)
public void findOneAndUpdate() {
String json =
"""
{
"findOneAndUpdate": {
"filter" : {"_id" : "1"},
"sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]},
"update" : {"$set" : {"name" : "Vision Vector Frame"}},
"projection" : {"_id" : 1, "$similarity" : 1}
}
}
""";

given()
.header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken())
.contentType(ContentType.JSON)
.body(json)
.when()
.post(CollectionResource.BASE_PATH, namespaceName, collectionName)
.then()
.statusCode(200)
.body("errors", is(notNullValue()))
.body("errors[0].exceptionClass", is("JsonApiException"))
.body("errors[0].errorCode", is("VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED"))
.body(
"errors[0].message",
is(ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED.getMessage()));
}

@Test
@Order(7)
public void findOneAndDelete() {
String json =
"""
{
"findOneAndDelete": {
"filter" : {"_id" : "1"},
"sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]},
"projection" : {"_id" : 1, "$similarity" : 1}
}
}
""";

given()
.header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken())
.contentType(ContentType.JSON)
.body(json)
.when()
.post(CollectionResource.BASE_PATH, namespaceName, collectionName)
.then()
.statusCode(200)
.body("errors", is(notNullValue()))
.body("errors[0].exceptionClass", is("JsonApiException"))
.body("errors[0].errorCode", is("VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED"))
.body(
"errors[0].message",
is(ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED.getMessage()));
}

@Test
@Order(8)
public void findOneAndReplace() {
String json =
"""
{
"findOneAndReplace": {
"filter" : {"_id" : "1"},
"sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]},
"replacement" : {"_id" : "1", "name" : "Vision Vector Frame"},
"projection" : {"_id" : 1, "$similarity" : 1}
}
}
""";

given()
.header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken())
.contentType(ContentType.JSON)
.body(json)
.when()
.post(CollectionResource.BASE_PATH, namespaceName, collectionName)
.then()
.statusCode(200)
.body("errors", is(notNullValue()))
.body("errors[0].exceptionClass", is("JsonApiException"))
.body("errors[0].errorCode", is("VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED"))
.body(
"errors[0].message",
is(ErrorCode.VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED.getMessage()));
}
}

private static void createVectorCollection(
Expand Down Expand Up @@ -1862,21 +1679,12 @@ public void checkInsertOneMetrics() {
"FindOneCommand", JsonApiMetricsConfig.SortType.SIMILARITY_SORT.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneCommand", JsonApiMetricsConfig.SortType.SIMILARITY_SORT_WITH_FILTERS.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneAndUpdateCommand",
JsonApiMetricsConfig.SortType.SIMILARITY_SORT_WITH_FILTERS.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneAndUpdateCommand", JsonApiMetricsConfig.SortType.NONE.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneAndUpdateCommand", JsonApiMetricsConfig.SortType.SIMILARITY_SORT.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneAndDeleteCommand",
JsonApiMetricsConfig.SortType.SIMILARITY_SORT_WITH_FILTERS.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneAndDeleteCommand", JsonApiMetricsConfig.SortType.SIMILARITY_SORT.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"FindOneAndReplaceCommand",
JsonApiMetricsConfig.SortType.SIMILARITY_SORT_WITH_FILTERS.name());
VectorSearchIntegrationTest.checkVectorMetrics(
"UpdateOneCommand", JsonApiMetricsConfig.SortType.SIMILARITY_SORT.name());
}
Expand Down
Loading