diff --git a/pom.xml b/pom.xml index 2a37ae8cdc..99a35a8697 100644 --- a/pom.xml +++ b/pom.xml @@ -73,6 +73,11 @@ org.apache.commons commons-lang3 + + com.bpodgursky + jbool_expressions + 1.24 + io.quarkus quarkus-rest-client-reactive-jackson diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializer.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializer.java index 6a7e853f4b..653bad81b8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializer.java @@ -65,7 +65,6 @@ public FilterClause deserialize( entry.getKey(), jsonNodeValue(entry.getKey(), entry.getValue()))); } } - validate(expressionList); return new FilterClause(expressionList); } @@ -90,11 +89,6 @@ private void validate(String path, FilterOperation filterOperation) { if (filterOperation.operator() instanceof ValueComparisonOperator valueComparisonOperator) { switch (valueComparisonOperator) { case IN -> { - if (!path.equals(DocumentConstants.Fields.DOC_ID)) { - throw new JsonApiException( - ErrorCode.INVALID_FILTER_EXPRESSION, "Can use $in operator only on _id field"); - } - if (filterOperation.operand().value() instanceof List list) { if (list.size() > operationsConfig.defaultPageSize()) { throw new JsonApiException( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java index 985fa16b94..449b1c71c0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java @@ -9,6 +9,12 @@ interface Fields { /** Primary key for Documents stored; has special handling for many operations. */ String DOC_ID = "_id"; + /** + * Atomic values are added to the array_contains field to support $eq on both atomic value and + * array element + */ + String DATA_CONTAINS = "array_contains"; + /** Physical table column name that stores the vector field. */ String VECTOR_SEARCH_INDEX_COLUMN_NAME = "query_vector_value"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java index d342d2e921..b34f0564fb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java @@ -97,7 +97,7 @@ public enum ErrorCode { VECTOR_SEARCH_NOT_SUPPORTED("Vector search is not enabled for the collection "), - VECTOR_SEARCH_INVALID_FUCTION_NAME("Invalid vector search function name "), + VECTOR_SEARCH_INVALID_FUNCTION_NAME("Invalid vector search function name: "), VECTOR_SEARCH_SIMILARITY_PROJECTION_NOT_SUPPORTED( "$similarity projection is not supported for this command"), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/bridge/executor/NamespaceCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/bridge/executor/NamespaceCache.java index b67118c4db..8349c8daf6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/bridge/executor/NamespaceCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/bridge/executor/NamespaceCache.java @@ -164,8 +164,8 @@ public static SimilarityFunction fromString(String similarityFunction) { case "euclidean" -> EUCLIDEAN; case "dot_product" -> DOT_PRODUCT; default -> throw new JsonApiException( - ErrorCode.VECTOR_SEARCH_INVALID_FUCTION_NAME, - ErrorCode.VECTOR_SEARCH_INVALID_FUCTION_NAME.getMessage() + similarityFunction); + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME, + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME.getMessage() + similarityFunction); }; } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java index d4fb6a141d..8c3a5d6911 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java @@ -1,5 +1,7 @@ package io.stargate.sgv2.jsonapi.service.operation.model.impl; +import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.DATA_CONTAINS; + import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; @@ -78,12 +80,6 @@ public enum Operator { protected final DBFilterBase.MapFilterBase.Operator operator; private final T value; - /** - * Atomic values are added to the array_contains field to support $eq on both atomic value and - * array element - */ - private static final String DATA_CONTAINS = "array_contains"; - protected MapFilterBase( String columnName, String key, MapFilterBase.Operator operator, T value) { super(key); @@ -288,10 +284,75 @@ boolean canAddField() { } } } + /** - * DB filter / condition for testing a set value Note: we can only do CONTAINS until SAI indexes - * are updated + * based on values of fields other than document id: for filtering on non-id field use InFilter. */ + public static class InFilter extends DBFilterBase { + private final List arrayValue; + protected final InFilter.Operator operator; + + @Override + JsonNode asJson(JsonNodeFactory nodeFactory) { + return DBFilterBase.getJsonNode(nodeFactory, arrayValue); + } + + @Override + boolean canAddField() { + return false; + } + // IN operator for non-id field filtering + public enum Operator { + IN; + } + + public InFilter(InFilter.Operator operator, String path, List arrayValue) { + super(path); + this.arrayValue = arrayValue; + this.operator = operator; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InFilter inFilter = (InFilter) o; + return operator == inFilter.operator && Objects.equals(arrayValue, inFilter.arrayValue); + } + + @Override + public int hashCode() { + return Objects.hash(arrayValue, operator); + } + + @Override + public BuiltCondition get() { + throw new UnsupportedOperationException("For IN filter we always use getALL() method"); + } + + public List getAll() { + List values = arrayValue; + switch (operator) { + case IN: + if (values.isEmpty()) return List.of(); + return values.stream() + .map( + v -> + BuiltCondition.of( + DATA_CONTAINS, + Predicate.CONTAINS, + getGrpcValue(getHashValue(new DocValueHasher(), getPath(), v)))) + .collect(Collectors.toList()); + + default: + throw new JsonApiException( + ErrorCode.UNSUPPORTED_FILTER_OPERATION, + String.format("Unsupported %s column operation %s", getPath(), operator)); + } + } + } + + /** DB filter / condition for testing a set value */ public abstract static class SetFilterBase extends DBFilterBase { public enum Operator { CONTAINS; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java index d302bc8fa0..ebfee0a6fb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java @@ -1,5 +1,9 @@ package io.stargate.sgv2.jsonapi.service.operation.model.impl; +import com.bpodgursky.jbool_expressions.And; +import com.bpodgursky.jbool_expressions.Expression; +import com.bpodgursky.jbool_expressions.Or; +import com.bpodgursky.jbool_expressions.Variable; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; @@ -23,6 +27,7 @@ import io.stargate.sgv2.jsonapi.service.shredding.model.DocumentId; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -371,8 +376,6 @@ public ReadDocument getNewDocument() { return ReadDocument.from(documentId, null, rootNode); } - // builds select query - /** * Builds select query based on filters and additionalIdFilter overrides. * @@ -381,34 +384,75 @@ public ReadDocument getNewDocument() { * buildConditions method. */ private List buildSelectQueries(DBFilterBase.IDFilter additionalIdFilter) { - List> conditions = buildConditions(additionalIdFilter); - if (conditions == null) { - return List.of(); - } - List queries = new ArrayList<>(conditions.size()); - conditions.forEach( - condition -> { - if (vector() == null) { - queries.add( - new QueryBuilder() - .select() - .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) - .from(commandContext.namespace(), commandContext.collection()) - .where(condition) - .limit(limit) - .build()); - } else { - QueryOuterClass.Query builtQuery = getVectorSearchQuery(condition); - final List valuesList = - builtQuery.getValuesOrBuilder().getValuesList(); - final QueryOuterClass.Values.Builder builder = QueryOuterClass.Values.newBuilder(); - valuesList.forEach(builder::addValues); - builder.addValues(CustomValueSerializers.getVectorValue(vector())); - queries.add(QueryOuterClass.Query.newBuilder(builtQuery).setValues(builder).build()); - } - }); + // if the query has "$in" operator for non-id field, buildCondition should return List of + // Expression + // that is the reason having this boolean + // TODO queryBuilder change for where(Expression> instead of + // TODO List> + final Optional inFilter = + filters.stream().filter(filter -> filter instanceof DBFilterBase.InFilter).findFirst(); + if (inFilter.isPresent()) { + // This if block handles filter with "$in" for non-id field + List> expressions = buildConditionExpressions(additionalIdFilter); + if (expressions == null) { + return List.of(); + } + List queries = new ArrayList<>(expressions.size()); + expressions.forEach( + expression -> { + if (vector() == null) { + queries.add( + new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(limit) + .build()); + } else { + QueryOuterClass.Query builtQuery = getVectorSearchQueryByExpression(expression); + final List valuesList = + builtQuery.getValuesOrBuilder().getValuesList(); + final QueryOuterClass.Values.Builder builder = QueryOuterClass.Values.newBuilder(); + valuesList.forEach(builder::addValues); + builder.addValues(CustomValueSerializers.getVectorValue(vector())); + queries.add(QueryOuterClass.Query.newBuilder(builtQuery).setValues(builder).build()); + } + }); + + return queries; + } else { + // This if block handles filter with no "$in" for non-id field + List> conditions = buildConditions(additionalIdFilter); + if (conditions == null) { + return List.of(); + } + List queries = new ArrayList<>(conditions.size()); + conditions.forEach( + condition -> { + if (vector() == null) { + queries.add( + new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .from(commandContext.namespace(), commandContext.collection()) + .where(condition) + .limit(limit) + .build()); + } else { + QueryOuterClass.Query builtQuery = getVectorSearchQuery(condition); + final List valuesList = + builtQuery.getValuesOrBuilder().getValuesList(); + final QueryOuterClass.Values.Builder builder = QueryOuterClass.Values.newBuilder(); + valuesList.forEach(builder::addValues); + builder.addValues(CustomValueSerializers.getVectorValue(vector())); + queries.add(QueryOuterClass.Query.newBuilder(builtQuery).setValues(builder).build()); + } + }); - return queries; + return queries; + } } /** Making it a separate method to build vector search query as there are many options */ @@ -457,8 +501,9 @@ private QueryOuterClass.Query getVectorSearchQuery(List conditio } default -> { throw new JsonApiException( - ErrorCode.VECTOR_SEARCH_INVALID_FUCTION_NAME, - ErrorCode.VECTOR_SEARCH_INVALID_FUCTION_NAME.getMessage()); + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME, + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME.getMessage() + + commandContext().similarityFunction()); } } } else { @@ -472,7 +517,72 @@ private QueryOuterClass.Query getVectorSearchQuery(List conditio .build(); } } - + /** + * A separate method to build vector search query by using expression, expression can contain + * logic operations like 'or','and'.. + */ + private QueryOuterClass.Query getVectorSearchQueryByExpression( + Expression expression) { + QueryOuterClass.Query builtQuery = null; + if (projection().doIncludeSimilarityScore()) { + switch (commandContext().similarityFunction()) { + case COSINE, UNDEFINED -> { + return new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .similarityCosine( + DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, + CustomValueSerializers.getVectorValue(vector())) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(limit) + .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) + .build(); + } + case EUCLIDEAN -> { + return new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .similarityEuclidean( + DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, + CustomValueSerializers.getVectorValue(vector())) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(limit) + .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) + .build(); + } + case DOT_PRODUCT -> { + return new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .similarityDotProduct( + DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, + CustomValueSerializers.getVectorValue(vector())) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(limit) + .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) + .build(); + } + default -> { + throw new JsonApiException( + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME, + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME.getMessage() + + commandContext().similarityFunction()); + } + } + } else { + return new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(limit) + .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) + .build(); + } + } /** * Builds select query based on filters, sort fields and additionalIdFilter overrides. * @@ -482,30 +592,69 @@ private QueryOuterClass.Query getVectorSearchQuery(List conditio */ private List buildSortedSelectQueries( DBFilterBase.IDFilter additionalIdFilter) { - List> conditions = buildConditions(additionalIdFilter); - if (conditions == null) { - return List.of(); - } - String[] columns = sortedDataColumns; - if (orderBy() != null) { - List sortColumns = Lists.newArrayList(columns); - orderBy().forEach(order -> sortColumns.addAll(order.getOrderingColumns())); - columns = new String[sortColumns.size()]; - sortColumns.toArray(columns); + // if the query has "$in" operator for non-id field, buildCondition should return List of + // Expression + // that is the reason having this boolean + // TODO queryBuilder change for where(Expression> instead of + // TODO List> + + final Optional inFilter = + filters.stream().filter(filter -> filter instanceof DBFilterBase.InFilter).findFirst(); + if (inFilter.isPresent()) { + // This if block handles filter with "$in" for non-id field + List> expressions = buildConditionExpressions(additionalIdFilter); + if (expressions == null) { + return List.of(); + } + String[] columns = sortedDataColumns; + if (orderBy() != null) { + List sortColumns = Lists.newArrayList(columns); + orderBy().forEach(order -> sortColumns.addAll(order.getOrderingColumns())); + columns = new String[sortColumns.size()]; + sortColumns.toArray(columns); + } + final String[] columnsToAdd = columns; + List queries = new ArrayList<>(expressions.size()); + expressions.forEach( + expression -> + queries.add( + new QueryBuilder() + .select() + .column(columnsToAdd) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(maxSortReadLimit()) + .build())); + return queries; + + } else { + // This if block handles filter with "$in" for non-id field + List> conditions = buildConditions(additionalIdFilter); + if (conditions == null) { + return List.of(); + } + String[] columns = sortedDataColumns; + if (orderBy() != null) { + List sortColumns = Lists.newArrayList(columns); + orderBy().forEach(order -> sortColumns.addAll(order.getOrderingColumns())); + columns = new String[sortColumns.size()]; + sortColumns.toArray(columns); + } + final String[] columnsToAdd = columns; + List queries = new ArrayList<>(conditions.size()); + conditions.forEach( + condition -> + queries.add( + new QueryBuilder() + .select() + .column(columnsToAdd) + .from(commandContext.namespace(), commandContext.collection()) + .where(condition) + .limit(maxSortReadLimit()) + .build())); + return queries; } - final String[] columnsToAdd = columns; - List queries = new ArrayList<>(conditions.size()); - conditions.forEach( - condition -> - queries.add( - new QueryBuilder() - .select() - .column(columnsToAdd) - .from(commandContext.namespace(), commandContext.collection()) - .where(condition) - .limit(maxSortReadLimit()) - .build())); - return queries; } /** @@ -552,6 +701,69 @@ private List> buildConditions(DBFilterBase.IDFilter additio } } + /** + * Builds select query based on filters and additionalIdFilter overrides. return expression to + * pass logic operation information, eg 'or' + */ + private List> buildConditionExpressions( + DBFilterBase.IDFilter additionalIdFilter) { + Expression conditionExpression = null; + // for (DBFilterBase filter : filters) { + // // all filters will be in And.of + // // if the filter is DBFilterBase.INFilter + // // inside of this filter, it has getAll method to return a list of BuiltCondition, + // these + // // BuiltCondition will be in Or.of + // } + DBFilterBase.IDFilter idFilterToUse = additionalIdFilter; + // if we have id filter overwrite ignore existing IDFilter + boolean idFilterOverwrite = additionalIdFilter != null; + for (DBFilterBase filter : filters) { + if (filter instanceof DBFilterBase.InFilter inFilter) { + List conditions = inFilter.getAll(); + if (!conditions.isEmpty()) { + List> variableConditions = + conditions.stream().map(Variable::of).toList(); + conditionExpression = + conditionExpression == null + ? Or.of(variableConditions) + : And.of(Or.of(variableConditions), conditionExpression); + } + } else if (filter instanceof DBFilterBase.IDFilter idFilter) { + if (!idFilterOverwrite) { + idFilterToUse = idFilter; + } + } else { + conditionExpression = + conditionExpression == null + ? Variable.of(filter.get()) + : And.of(Variable.of(filter.get()), conditionExpression); + } + } + + if (idFilterToUse != null) { + final List inSplit = idFilterToUse.getAll(); + if (inSplit.isEmpty()) { + return null; + } else { + // split n queries by id + Expression tempExpression = conditionExpression; + return inSplit.stream() + .map( + idCondition -> { + Expression newExpression = + tempExpression == null + ? Variable.of(idCondition) + : And.of(Variable.of((idCondition)), tempExpression); + return newExpression; + }) + .collect(Collectors.toList()); + } + } else { + return conditionExpression == null ? null : List.of(conditionExpression); // only one query + } + } + /** * Represents sort field name and option to be sorted ascending/descending. * diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterableResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterableResolver.java index 668ac00514..83470a5067 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterableResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterableResolver.java @@ -34,7 +34,7 @@ public abstract class FilterableResolver { private static final Object ID_GROUP = new Object(); private static final Object ID_GROUP_IN = new Object(); - + private static final Object DYNAMIC_GROUP_IN = new Object(); private static final Object DYNAMIC_TEXT_GROUP = new Object(); private static final Object DYNAMIC_NUMBER_GROUP = new Object(); private static final Object DYNAMIC_BOOL_GROUP = new Object(); @@ -62,6 +62,12 @@ public FilterableResolver() { .capture(ID_GROUP_IN) .compareValues("_id", EnumSet.of(ValueComparisonOperator.IN), JsonType.ARRAY); + // matchRules + // .addMatchRule(this::findDynamic, FilterMatcher.MatchStrategy.STRICT) + // .matcher() + // .capture(DYNAMIC_GROUP_IN) + // .compareValues("*", EnumSet.of(ValueComparisonOperator.IN), JsonType.ARRAY); + // NOTE - can only do eq ops on fields until SAI changes matchRules .addMatchRule(this::findDynamic, FilterMatcher.MatchStrategy.GREEDY) @@ -70,6 +76,8 @@ public FilterableResolver() { .compareValues("_id", EnumSet.of(ValueComparisonOperator.EQ), JsonType.DOCUMENT_ID) .capture(ID_GROUP_IN) .compareValues("_id", EnumSet.of(ValueComparisonOperator.IN), JsonType.ARRAY) + .capture(DYNAMIC_GROUP_IN) + .compareValues("*", EnumSet.of(ValueComparisonOperator.IN), JsonType.ARRAY) .capture(DYNAMIC_NUMBER_GROUP) .compareValues("*", EnumSet.of(ValueComparisonOperator.EQ), JsonType.NUMBER) .capture(DYNAMIC_TEXT_GROUP) @@ -148,6 +156,18 @@ private List findDynamic(CommandContext commandContext, CaptureGro DBFilterBase.IDFilter.Operator.IN, expression.value()))); } + final CaptureGroup> dynamicGroups = + (CaptureGroup>) captures.getGroupIfPresent(DYNAMIC_GROUP_IN); + if (dynamicGroups != null) { + dynamicGroups.consumeAllCaptures( + expression -> { + final DocValueHasher docValueHasher = new DocValueHasher(); + filters.add( + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, expression.path(), expression.value())); + }); + } + final CaptureGroup textGroup = (CaptureGroup) captures.getGroupIfPresent(DYNAMIC_TEXT_GROUP); if (textGroup != null) { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializerTest.java index 3fa1e15af8..b425a75432 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/FilterClauseDeserializerTest.java @@ -290,7 +290,7 @@ public void mustHandleSubDocEq() throws Exception { } @Test - public void mustHandleIn() throws Exception { + public void mustHandleIdFieldIn() throws Exception { String json = """ {"_id" : {"$in": ["2", "3"]}} """; @@ -308,32 +308,34 @@ public void mustHandleIn() throws Exception { } @Test - public void mustHandleInArrayNonEmpty() throws Exception { + public void mustHandleNonIdFieldIn() throws Exception { String json = """ - {"_id" : {"$in": []}} + {"name" : {"$in": ["name1", "name2"]}} """; final ComparisonExpression expectedResult = new ComparisonExpression( - "_id", + "name", List.of( new ValueComparisonOperation( - ValueComparisonOperator.IN, new JsonLiteral(List.of(), JsonType.ARRAY)))); + ValueComparisonOperator.IN, + new JsonLiteral(List.of("name1", "name2"), JsonType.ARRAY)))); FilterClause filterClause = objectMapper.readValue(json, FilterClause.class); assertThat(filterClause.comparisonExpressions()).hasSize(1).contains(expectedResult); } @Test - public void mustHandleInIdFieldOnly() throws Exception { + public void mustHandleInArrayNonEmpty() throws Exception { String json = """ - {"name" : {"$in": ["aaa"]}} + {"_id" : {"$in": []}} """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, FilterClause.class)); - assertThat(throwable) - .isInstanceOf(JsonApiException.class) - .satisfies( - t -> { - assertThat(t.getMessage()).isEqualTo("Can use $in operator only on _id field"); - }); + final ComparisonExpression expectedResult = + new ComparisonExpression( + "_id", + List.of( + new ValueComparisonOperation( + ValueComparisonOperator.IN, new JsonLiteral(List.of(), JsonType.ARRAY)))); + FilterClause filterClause = objectMapper.readValue(json, FilterClause.class); + assertThat(filterClause.comparisonExpressions()).hasSize(1).contains(expectedResult); } @Test diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java index d8312b57c4..e9c5a50cd8 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java @@ -87,6 +87,7 @@ public void setUp() { "insertOne": { "document": { "_id": "doc4", + "username" : "user4", "indexedObject" : { "0": "value_0", "1": "value_1" } } } @@ -263,7 +264,7 @@ public void inCondition() { """; String expected2 = """ - {"_id":"doc4", "indexedObject":{"0":"value_0","1":"value_1"}} + {"_id":"doc4", "username":"user4", "indexedObject":{"0":"value_0","1":"value_1"}} """; given() @@ -387,11 +388,48 @@ public void inConditionNonIdField() { """ { "find": { - "filter" : {"non_id" : {"$in": ["a", "b", "c"]}} - } + "filter" : { + "username" : {"$in" : ["user1", "user10"]} + } + } } """; + String expected1 = + "{\"_id\":\"doc1\", \"username\":\"user1\", \"active_user\":true, \"date\" : {\"$date\": 1672531200000}}"; + given() + .header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken()) + .contentType(ContentType.JSON) + .body(json) + .when() + .post(CollectionResource.BASE_PATH, namespaceName, collectionName) + .then() + .statusCode(200) + .body("data.documents", hasSize(1)) + .body("status", is(nullValue())) + .body("errors", is(nullValue())) + .body("data.documents[0]", jsonEquals(expected1)); + } + @Test + public void inConditionNonIdFieldMulti() { + String json = + """ + { + "find": { + "filter" : { + "username" : {"$in" : ["user1", "user4"]} + } + } + } + """; + String expected1 = + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}} + """; + String expected2 = + """ + {"_id":"doc4", "username":"user4", "indexedObject":{"0":"value_0","1":"value_1"}} + """; given() .header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken()) .contentType(ContentType.JSON) @@ -400,12 +438,69 @@ public void inConditionNonIdField() { .post(CollectionResource.BASE_PATH, namespaceName, collectionName) .then() .statusCode(200) + .body("data.documents", hasSize(2)) .body("status", is(nullValue())) - .body("data", is(nullValue())) - .body("errors", is(notNullValue())) - .body("errors[1].message", is("Can use $in operator only on _id field")) - .body("errors[1].exceptionClass", is("JsonApiException")) - .body("errors[1].errorCode", is("INVALID_FILTER_EXPRESSION")); + .body("errors", is(nullValue())) + .body("data.documents", containsInAnyOrder(jsonEquals(expected1), jsonEquals(expected2))); + } + + @Test + public void inConditionNonIdFieldIdField() { + String json = + """ + { + "find": { + "filter" : { + "username" : {"$in" : ["user1", "user10"]}, + "_id" : {"$in" : ["doc1", "???"]} + } + } + } + """; + String expected1 = + "{\"_id\":\"doc1\", \"username\":\"user1\", \"active_user\":true, \"date\" : {\"$date\": 1672531200000}}"; + given() + .header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken()) + .contentType(ContentType.JSON) + .body(json) + .when() + .post(CollectionResource.BASE_PATH, namespaceName, collectionName) + .then() + .statusCode(200) + .body("data.documents", hasSize(1)) + .body("status", is(nullValue())) + .body("errors", is(nullValue())) + .body("data.documents[0]", jsonEquals(expected1)); + } + + @Test + public void inConditionNonIdFieldIdFieldSort() { + String json = + """ + { + "find": { + "filter" : { + "username" : {"$in" : ["user1", "user10"]}, + "_id" : {"$in" : ["doc1", "???"]} + }, + "sort": { "username": -1 } + } + } + """; + String expected1 = + "{\"_id\":\"doc1\", \"username\":\"user1\", \"active_user\":true, \"date\" : {\"$date\": 1672531200000}}"; + given() + .header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken()) + .contentType(ContentType.JSON) + .body(json) + .when() + .post(CollectionResource.BASE_PATH, namespaceName, collectionName) + .then() + .statusCode(200) + .body("data.documents", hasSize(1)) + .body("status", is(nullValue())) + .body("errors", is(nullValue())) + .body("data.documents[0]", jsonEquals(expected1)); } @Test @@ -571,6 +666,7 @@ public void withEqSubDocWithIndex() { """ { "_id": "doc4", + "username":"user4", "indexedObject" : { "0": "value_0", "1": "value_1" } } """; diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java index 1ced23d030..be811883ff 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java @@ -347,10 +347,7 @@ public void inConditionNonIdField() { .post(CollectionResource.BASE_PATH, namespaceName, collectionName) .then() .statusCode(200) - .body("errors", is(notNullValue())) - .body("errors[1].message", is("Can use $in operator only on _id field")) - .body("errors[1].exceptionClass", is("JsonApiException")) - .body("errors[1].errorCode", is("INVALID_FILTER_EXPRESSION")); + .body("errors", is(nullValue())); } @Test diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java index d0bebd6dee..4d405b1b73 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java @@ -572,6 +572,64 @@ public void happyPathWithFilter() { .body("errors", is(nullValue())); } + @Test + @Order(3) + public void happyPathWithInFilter() { + String json = + """ + { + "insertOne": { + "document": { + "_id": "xx", + "name": "Logic Layers", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.25, 0.25, 0.25, 0.25, 0.25] + } + } + } + """; + + given() + .header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken()) + .contentType(ContentType.JSON) + .body(json) + .when() + .post(CollectionResource.BASE_PATH, namespaceName, collectionName) + .then() + .statusCode(200) + .body("status.insertedIds[0]", is("xx")) + .body("data", is(nullValue())) + .body("errors", is(nullValue())); + json = + """ + { + "find": { + "filter" : { + "_id" : {"$in" : ["??", "xx"]}, + "name": {"$in" : ["Logic Layers","???"]} + }, + "projection" : {"_id" : 1, "$vector" : 0}, + "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, + "options" : { + "limit" : 5 + } + } + } + """; + + given() + .header(HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, getAuthToken()) + .contentType(ContentType.JSON) + .body(json) + .when() + .post(CollectionResource.BASE_PATH, namespaceName, collectionName) + .then() + .statusCode(200) + .body("data.documents[0]._id", is("xx")) + .body("data.documents[0].$vector", is(nullValue())) + .body("errors", is(nullValue())); + } + @Test @Order(4) public void happyPathWithEmptyVector() { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/FindCommandResolverTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/FindCommandResolverTest.java index 28e046e80e..c1d8812556 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/FindCommandResolverTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/FindCommandResolverTest.java @@ -721,4 +721,304 @@ public void noFilterConditionWithProjection() throws Exception { }); } } + + @Nested + class FindCommandResolveWithINOperator { + CommandContext commandContext = CommandContext.empty(); + + @Test + public void NonIdIn() throws Exception { + String json = + """ + { + "find": { + "filter" : {"name" : { "$in" : ["test1", "test2"]}} + } + } + """; + + FindCommand findCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findCommand); + + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + DBFilterBase filter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.limit()).isEqualTo(Integer.MAX_VALUE); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()).isZero(); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.orderBy()).isNull(); + assertThat(find.filters()).containsOnly(filter); + }); + + final FindOperation operation1 = (FindOperation) operation; + } + + @Test + public void NonIdInIdEq() throws Exception { + String json = + """ + { + "find": { + "filter" : { + "_id" : "id1", + "name" : { "$in" : ["test1", "test2"]} + } + } + } + """; + FindCommand findCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findCommand); + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + DBFilterBase inFilter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + DBFilterBase idFilter = + new DBFilterBase.IDFilter( + DBFilterBase.IDFilter.Operator.EQ, DocumentId.fromString("id1")); + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.limit()).isEqualTo(Integer.MAX_VALUE); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()).isZero(); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.orderBy()).isNull(); + assertThat(find.filters()).containsOnly(inFilter, idFilter); + }); + } + + @Test + public void NonIdInIdIn() throws Exception { + String json = + """ + { + "find": { + "filter" : { + "_id" : { "$in" : ["id1", "id2"]}, + "name" : { "$in" : ["test1", "test2"]} + } + } + } + """; + FindCommand findCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findCommand); + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + DBFilterBase inFilter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + DBFilterBase idFilter = + new DBFilterBase.IDFilter( + DBFilterBase.IDFilter.Operator.IN, + List.of(DocumentId.fromString("id1"), DocumentId.fromString("id2"))); + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.limit()).isEqualTo(Integer.MAX_VALUE); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()).isZero(); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.orderBy()).isNull(); + assertThat(find.filters()).containsOnly(inFilter, idFilter); + }); + } + + @Test + public void NonIdInVSearch() throws Exception { + String json = + """ + { + "find": { + "filter" : { + "name" : { "$in" : ["test1", "test2"]} + }, + "sort" : {"$vector" : [0.15, 0.1, 0.1]} + } + } + """; + + FindCommand findCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findCommand); + + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + DBFilterBase inFilter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + float[] vector = new float[] {0.15f, 0.1f, 0.1f}; + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.limit()).isEqualTo(operationsConfig.maxVectorSearchLimit()); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()).isZero(); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.vector()).containsExactly(vector); + assertThat(find.filters()).containsOnly(inFilter); + }); + } + + @Test + public void NonIdInIdInVSearch() throws Exception { + String json = + """ + { + "find": { + "filter" : { + "_id" : { "$in" : ["id1", "id2"]}, + "name" : { "$in" : ["test1", "test2"]} + }, + "sort" : {"$vector" : [0.15, 0.1, 0.1]} + } + } + """; + + FindCommand findCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findCommand); + + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + DBFilterBase inFilter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + DBFilterBase idFilter = + new DBFilterBase.IDFilter( + DBFilterBase.IDFilter.Operator.IN, + List.of(DocumentId.fromString("id1"), DocumentId.fromString("id2"))); + float[] vector = new float[] {0.15f, 0.1f, 0.1f}; + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.limit()).isEqualTo(operationsConfig.maxVectorSearchLimit()); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()).isZero(); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.vector()).containsExactly(vector); + assertThat(find.filters()).containsOnly(inFilter, idFilter); + }); + } + + @Test + public void descendingSortNonIdIn() throws Exception { + String json = + """ + { + "find": { + "sort": { + "name": -1 + }, + "filter" : { + "name" : {"$in" : ["test1", "test2"]} + } + } + } + """; + + FindCommand findOneCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findOneCommand); + + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + FindOperation.OrderBy orderBy = new FindOperation.OrderBy("name", false); + DBFilterBase inFilter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultSortPageSize()); + assertThat(find.limit()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.SORTED_DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()) + .isEqualTo(operationsConfig.maxDocumentSortCount()); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.orderBy()).containsOnly(orderBy); + assertThat(find.filters()).containsOnly(inFilter); + }); + } + + @Test + public void ascendingSortNonIdInIdIn() throws Exception { + String json = + """ + { + "find": { + "sort": { + "name": 1 + }, + "filter" : { + "name" : {"$in" : ["test1", "test2"]}, + "_id" : {"$in" : ["id1","id2"]} + } + } + } + """; + + FindCommand findOneCommand = objectMapper.readValue(json, FindCommand.class); + Operation operation = resolver.resolveCommand(commandContext, findOneCommand); + + assertThat(operation) + .isInstanceOfSatisfying( + FindOperation.class, + find -> { + FindOperation.OrderBy orderBy = new FindOperation.OrderBy("name", true); + DBFilterBase inFilter = + new DBFilterBase.InFilter( + DBFilterBase.InFilter.Operator.IN, "name", List.of("test1", "test2")); + DBFilterBase idFilter = + new DBFilterBase.IDFilter( + DBFilterBase.IDFilter.Operator.IN, + List.of(DocumentId.fromString("id1"), DocumentId.fromString("id2"))); + assertThat(find.objectMapper()).isEqualTo(objectMapper); + assertThat(find.commandContext()).isEqualTo(commandContext); + assertThat(find.projection()).isEqualTo(DocumentProjector.identityProjector()); + assertThat(find.pageSize()).isEqualTo(operationsConfig.defaultSortPageSize()); + assertThat(find.limit()).isEqualTo(operationsConfig.defaultPageSize()); + assertThat(find.pagingState()).isNull(); + assertThat(find.readType()).isEqualTo(ReadType.SORTED_DOCUMENT); + assertThat(find.skip()).isZero(); + assertThat(find.maxSortReadLimit()) + .isEqualTo(operationsConfig.maxDocumentSortCount()); + assertThat(find.singleResponse()).isFalse(); + assertThat(find.orderBy()).containsOnly(orderBy); + assertThat(find.filters()).containsOnly(inFilter, idFilter); + }); + } + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterMatchRuleTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterMatchRuleTest.java index 77d556a4d3..3637d0abd7 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterMatchRuleTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/matcher/FilterMatchRuleTest.java @@ -66,5 +66,39 @@ public void apply() throws Exception { filterMatchRule.apply(new CommandContext("namespace", "collection"), findOneCommand); assertThat(response).isEmpty(); } + + @Test + public void testDynamicIn() throws Exception { + String json = + """ + { + "findOne": { + "filter" : {"name" : {"$in" : ["testname1", "testname2"]}} + } + } + """; + FindOneCommand findOneCommand = objectMapper.readValue(json, FindOneCommand.class); + FilterMatcher matcher = + new FilterMatcher<>(FilterMatcher.MatchStrategy.GREEDY); + + // matcher.capture("capture marker") + // .compareValues("*", EnumSet.of(ValueComparisonOperator.IN), JsonType.ARRAY); + + BiFunction, List> + resolveFunction = (commandContext, captures) -> filters; + + FilterMatchRule filterMatchRule = + new FilterMatchRule(matcher, resolveFunction); + + filterMatchRule + .matcher() + .capture("capture marker") + .compareValues("*", EnumSet.of(ValueComparisonOperator.IN), JsonType.ARRAY); + + Optional> response = + filterMatchRule.apply( + new CommandContext("testNamespace", "testCollection"), findOneCommand); + assertThat(response).isPresent(); + } } }