diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java index f7ade7e014..dab6126efb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java @@ -165,7 +165,13 @@ public enum ErrorCode { "Collection creation failure (unable to create table). Recommend re-creating the collection"), INVALID_SCHEMA_VERSION( "Collection has invalid schema version. Recommend re-creating the collection"), - INVALID_ID_TYPE("Invalid Id type"); + INVALID_ID_TYPE("Invalid Id type"), + + UNSUPPORTED_CQL_QUERY_TYPE("Unsupported cql query type"), + + MISSING_VECTOR_VALUE("Missing the vector value when building cql"), + + INVALID_LOGIC_OPERATOR("Invalid logical operator"); private final String message; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cql/ColumnUtils.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/ColumnUtils.java new file mode 100644 index 0000000000..e7452aa1d8 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/ColumnUtils.java @@ -0,0 +1,44 @@ +/* + * Copyright DataStax, Inc. and/or The Stargate Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.stargate.sgv2.jsonapi.service.cql; + +import io.stargate.sgv2.api.common.cql.ReservedKeywords; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class ColumnUtils { + + private static final Pattern PATTERN_DOUBLE_QUOTE = Pattern.compile("\"", Pattern.LITERAL); + private static final String ESCAPED_DOUBLE_QUOTE = Matcher.quoteReplacement("\"\""); + /** + * Updated regex pattern to support selecting collection entry lime map_column['entry_key'], + * set_column['set_value'] + */ + private static final Pattern UNQUOTED_IDENTIFIER = + Pattern.compile("[a-z][a-z0-9_]*(\\['.*'\\])?"); + + /** + * Given the raw (as stored internally) text of an identifier, return its CQL representation. That + * is, unless the text is full lowercase and use only characters allowed in unquoted identifiers, + * the result is double-quoted. + */ + public static String maybeQuote(String text) { + if (UNQUOTED_IDENTIFIER.matcher(text).matches() && !ReservedKeywords.isReserved(text)) { + return text; + } + return '"' + PATTERN_DOUBLE_QUOTE.matcher(text).replaceAll(ESCAPED_DOUBLE_QUOTE) + '"'; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cql/ExpressionUtils.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/ExpressionUtils.java new file mode 100644 index 0000000000..f811ba50c5 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/ExpressionUtils.java @@ -0,0 +1,71 @@ +/* + * Copyright The Stargate Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.stargate.sgv2.jsonapi.service.cql; + +import static io.stargate.sgv2.jsonapi.exception.ErrorCode.INVALID_LOGIC_OPERATOR; + +import com.bpodgursky.jbool_expressions.And; +import com.bpodgursky.jbool_expressions.Expression; +import com.bpodgursky.jbool_expressions.Or; +import java.util.List; + +/** + * Convenience expression builder + * + *

when construct jbool expression without specifying a comparator, it will use hashComparator by + * default which will cause the order of expression indeterminate, and cause JSONAPI unit tests + * failure By using this ExpressionUtils class, we pass a default comparator to keep expression + * order as it is + */ +public class ExpressionUtils { + + public static And andOf(Expression... expressions) { + // expression as creation order + return And.of(expressions, (e1, e2) -> 1); + } + + public static And andOf(List> expressions) { + // expression as creation order + return And.of(expressions.toArray(new Expression[expressions.size()]), (e1, e2) -> 1); + } + + public static Or orOf(List> expressions) { + // expression as creation order + return Or.of(expressions.toArray(new Expression[expressions.size()]), (e1, e2) -> 1); + } + + public static Or orOf(Expression... expressions) { + // expression as creation order + return Or.of(expressions, (e1, e2) -> 1); + } + + public static Expression buildExpression( + List> expressions, String logicOperator) { + switch (logicOperator) { + case "$and" -> { + return andOf(expressions); + } + case "$or" -> { + return orOf(expressions); + } + default -> throw INVALID_LOGIC_OPERATOR.toApiException(); + } + } + + public static Expression[] getAsArray(Expression... expressions) { + return expressions; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/BuiltCondition.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/BuiltCondition.java new file mode 100644 index 0000000000..5ac8c6e6b5 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/BuiltCondition.java @@ -0,0 +1,137 @@ +package io.stargate.sgv2.jsonapi.service.cql.builder; + +import io.stargate.sgv2.api.common.cql.ColumnUtils; +import io.stargate.sgv2.jsonapi.service.operation.model.impl.JsonTerm; +import java.util.Objects; + +public final class BuiltCondition { + + public LHS lhs; + + public Predicate predicate; + + public JsonTerm jsonTerm; + + public BuiltCondition(LHS lhs, Predicate predicate, JsonTerm jsonTerm) { + this.lhs = lhs; + this.predicate = predicate; + this.jsonTerm = jsonTerm; + } + + public static BuiltCondition of(LHS lhs, Predicate predicate, JsonTerm jsonTerm) { + return new BuiltCondition(lhs, predicate, jsonTerm); + } + + public static BuiltCondition of(String columnName, Predicate predicate, JsonTerm jsonTerm) { + return of(LHS.column(columnName), predicate, jsonTerm); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + // Append the LHS part of the condition + if (lhs != null) { + lhs.appendToBuilder(builder); + } else { + builder.append("null"); + } + // Append the predicate part of the condition + if (predicate != null) { + builder.append(" ").append(predicate); + } else { + builder.append(" null"); + } + // Append the JSON term part of the condition + if (jsonTerm != null) { + builder.append(" ").append(jsonTerm); + } else { + builder.append(" null"); + } + return builder.toString(); + } + + /** + * Represents the left hand side of a condition. + * + *

This is usually a column name, but technically can be: + * + *

+ */ + public abstract static class LHS { + public static LHS column(String columnName) { + return new ColumnName(columnName); + } + + public static LHS mapAccess(String columnName, String key) { + return new MapElement(columnName, key); + } + + abstract void appendToBuilder(StringBuilder builder); + + static final class ColumnName extends LHS { + private final String columnName; + + private ColumnName(String columnName) { + this.columnName = columnName; + } + + void appendToBuilder(StringBuilder builder) { + builder.append(ColumnUtils.maybeQuote(columnName)); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof ColumnName) { + ColumnName that = (ColumnName) other; + return Objects.equals(this.columnName, that.columnName); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(columnName); + } + } + + static final class MapElement extends LHS { + private final String columnName; + private final String key; + + MapElement(String columnName, String key) { + this.columnName = columnName; + this.key = key; + } + + void appendToBuilder(StringBuilder builder) { + builder.append(ColumnUtils.maybeQuote(columnName)).append("[?]"); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof MapElement) { + MapElement that = (MapElement) other; + return Objects.equals(this.columnName, that.columnName) + && Objects.equals(this.key, that.key); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(columnName, key); + } + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/Predicate.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/Predicate.java new file mode 100644 index 0000000000..31c08fc890 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/Predicate.java @@ -0,0 +1,26 @@ +package io.stargate.sgv2.jsonapi.service.cql.builder; + +public enum Predicate { + EQ("="), + NEQ("!="), + LT("<"), + GT(">"), + LTE("<="), + GTE(">="), + IN("IN"), + CONTAINS("CONTAINS"), + NOT_CONTAINS("NOT CONTAINS"), + CONTAINS_KEY("CONTAINS KEY"), + ; + + private final String cql; + + Predicate(String cql) { + this.cql = cql; + } + + @Override + public String toString() { + return cql; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/Query.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/Query.java new file mode 100644 index 0000000000..d694421406 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/Query.java @@ -0,0 +1,17 @@ +package io.stargate.sgv2.jsonapi.service.cql.builder; + +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import java.util.List; + +/** + * @param cql The query string. It can contain anonymous placeholders identified by a question mark + * (?), or named placeholders prefixed by a column (:name). + * @param values The values to fill the placeholders in the query string. + */ +public record Query(String cql, List values) { + + public SimpleStatement queryToStatement() { + SimpleStatement simpleStatement = SimpleStatement.newInstance(cql); + return simpleStatement.setPositionalValues(values); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/QueryBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/QueryBuilder.java new file mode 100644 index 0000000000..28098ea78a --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cql/builder/QueryBuilder.java @@ -0,0 +1,331 @@ +package io.stargate.sgv2.jsonapi.service.cql.builder; + +import com.bpodgursky.jbool_expressions.Expression; +import com.bpodgursky.jbool_expressions.Variable; +import com.datastax.oss.driver.api.core.data.CqlVector; +import io.stargate.sgv2.jsonapi.exception.ErrorCode; +import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.cql.ColumnUtils; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.CollectionSettings; +import io.stargate.sgv2.jsonapi.service.cqldriver.serializer.CQLBindValues; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class QueryBuilder { + private String keyspaceName; + private String tableName; + private boolean isInsert; + private boolean isUpdate; + private boolean isDelete; + private boolean isSelect; + private Integer limitInt; + private String orderByAnn; + private final List functionCalls = new ArrayList<>(); + + /** The vectorValue used to compute similarityScore or process an ANN search */ + private CqlVector vectorValue; + + /** Column names for a SELECT or DELETE. */ + private final List selection = new ArrayList<>(); + + /** The where expression which contains conditions and logic operation for a SELECT or UPDATE. */ + private Expression whereExpression = null; + + private static final String COUNT_FUNCTION_NAME = "COUNT"; + + public void keyspace(String keyspace) { + this.keyspaceName = keyspace; + } + + public void table(String table) { + this.tableName = table; + } + + public QueryBuilder from(String keyspace, String table) { + this.keyspaceName = keyspace; + table(table); + return this; + } + + public QueryBuilder select() { + isSelect = true; + return this; + } + + public QueryBuilder column(String... columns) { + for (String c : columns) { + column(c); + } + return this; + } + + public QueryBuilder column(String column) { + if (isSelect || isDelete) { + selection.add(column); + } + return this; + } + + public QueryBuilder count() { + count(null); + return this; + } + + public QueryBuilder count(String columnName) { + functionCalls.add(FunctionCall.count(columnName)); + return this; + } + + public QueryBuilder as(String alias) { + if (functionCalls.isEmpty()) { + throw new IllegalStateException( + "The as() method cannot be called without a preceding function call."); + } + // the alias is set for the last function call + FunctionCall functionCall = functionCalls.get(functionCalls.size() - 1); + functionCall.setAlias(alias); + return this; + } + + private void setWhereExpression(Expression whereExpression) { + this.whereExpression = whereExpression; + } + + public QueryBuilder where(Expression whereExpression) { + if (whereExpression != null) { + setWhereExpression(whereExpression); + } + return this; + } + + public QueryBuilder limit(Integer limit) { + this.limitInt = limit; + return this; + } + + public QueryBuilder limit() { + this.limitInt = -1; + return this; + } + + public Query build() { + if (isSelect) { + return selectQuery(); + } + throw ErrorCode.UNSUPPORTED_CQL_QUERY_TYPE.toApiException(); + } + + private Query selectQuery() { + List values = new ArrayList<>(); + StringBuilder builder = new StringBuilder("SELECT "); + // Data API has 3 sets of selection columns: DOCUMENT, SORTED_DOCUMENT, KEY + if (selection.isEmpty() && functionCalls.isEmpty()) { + builder.append('*'); + } else { + builder.append( + Stream.concat( + selection.stream().map(QueryBuilder::cqlName), + functionCalls.stream() + .map(functionCall -> formatFunctionCall(functionCall, values))) + .collect(Collectors.joining(", "))); + } + builder.append(" FROM ").append(maybeQualify(tableName)); + + appendWheres(builder, values); + + if (orderByAnn != null) { + if (vectorValue == null) { + throw ErrorCode.MISSING_VECTOR_VALUE.toApiException(); + } + builder.append(" ORDER BY ").append(orderByAnn).append(" ANN OF ?"); + values.add(vectorValue); + } + + if (limitInt != null) { + builder.append(" LIMIT ").append(limitInt == -1 ? "?" : limitInt); + } + + return new Query(builder.toString(), values); + } + + private void appendWheres(StringBuilder builder, List values) { + // Data API fully rely on Expression instead of List + if (this.whereExpression != null) { + appendConditions(this.whereExpression, " WHERE ", builder, values); + } + } + + private void appendConditions( + Expression whereExpression, + String initialPrefix, + StringBuilder builder, + List values) { + builder.append(initialPrefix); + addExpressionCql(builder, whereExpression, values); + } + + private void addExpressionCql( + StringBuilder sb, Expression outerExpression, List values) { + List> innerExpressions = outerExpression.getChildren(); + switch (outerExpression.getExprType()) { + case "and" -> { + // have parenthesis only when having more than one innerExpression + if (innerExpressions.size() > 1) { + sb.append("("); + } + for (int i = 0; i < innerExpressions.size(); i++) { + addExpressionCql(sb, innerExpressions.get(i), values); + if (i == innerExpressions.size() - 1) { + break; + } + sb.append(" AND "); + } + if (innerExpressions.size() > 1) { + sb.append(")"); + } + } + case "or" -> { + // have parenthesis only when having more than one innerExpression + if (innerExpressions.size() > 1) { + sb.append("("); + } + for (int i = 0; i < innerExpressions.size(); i++) { + addExpressionCql(sb, innerExpressions.get(i), values); + if (i == innerExpressions.size() - 1) { + break; + } + sb.append(" OR "); + } + if (innerExpressions.size() > 1) { + sb.append(")"); + } + } + case "variable" -> { + Variable variable = (Variable) outerExpression; + BuiltCondition condition = variable.getValue(); + condition.lhs.appendToBuilder(sb); + condition.jsonTerm.addToCqlValues(values); + sb.append(" ").append(condition.predicate.toString()).append(" ?"); + } + default -> throw new IllegalArgumentException( + String.format("Unsupported expression type %s", outerExpression.getExprType())); + } + } + + private static String cqlName(String name) { + return ColumnUtils.maybeQuote(name); + } + + private String maybeQualify(String elementName) { + if (keyspaceName == null) { + return cqlName(elementName); + } else { + return cqlName(keyspaceName) + '.' + cqlName(elementName); + } + } + + /** + * @param functionCall functionCall such as similarityScore + * @param values values list to be populated + * @return + */ + private String formatFunctionCall(QueryBuilder.FunctionCall functionCall, List values) { + StringBuilder builder = new StringBuilder(); + if (functionCall.getColumnName() == null + && COUNT_FUNCTION_NAME.equals(functionCall.getFunctionName())) { + // count function call and no column name + builder.append(functionCall.getFunctionName()).append("(1)"); + } else { + builder + .append(functionCall.getFunctionName()) + .append('(') + .append(cqlName(functionCall.getColumnName())); + if (functionCall.isSimilarityFunction) { + if (vectorValue == null) { + throw ErrorCode.MISSING_VECTOR_VALUE.toApiException(); + } + builder.append(", ").append('?'); + values.add(vectorValue); + } + builder.append(')'); + } + if (functionCall.getAlias() != null) { + builder.append(" AS ").append(cqlName(functionCall.getAlias())); + } + + return builder.toString(); + } + + public QueryBuilder similarityFunction( + String columnName, CollectionSettings.SimilarityFunction similarityFunction) { + switch (similarityFunction) { + case COSINE, UNDEFINED -> functionCalls.add( + FunctionCall.similarityFunctionCall(columnName, "SIMILARITY_COSINE")); + case EUCLIDEAN -> functionCalls.add( + FunctionCall.similarityFunctionCall(columnName, "SIMILARITY_EUCLIDEAN")); + case DOT_PRODUCT -> functionCalls.add( + FunctionCall.similarityFunctionCall(columnName, "SIMILARITY_DOT_PRODUCT")); + default -> throw new JsonApiException( + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME, + ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME.getMessage() + similarityFunction); + } + return this; + } + + public QueryBuilder vsearch(String column, float[] vectorValue) { + this.orderByAnn = column; + this.vectorValue = CQLBindValues.getVectorValue(vectorValue); + return this; + } + + public static class FunctionCall { + final String columnName; + String alias; + final String functionName; + + boolean isSimilarityFunction; + + private FunctionCall( + String columnName, String alias, String functionName, boolean isSimilarityFunction) { + this.columnName = columnName; + this.alias = alias; + this.functionName = functionName; + this.isSimilarityFunction = isSimilarityFunction; + } + + public static FunctionCall function(String name, String alias, String functionName) { + return new FunctionCall(name, alias, functionName, false); + } + + public static FunctionCall similarityFunctionCall( + String columnName, String similarityFunction) { + return new FunctionCall(columnName, null, similarityFunction, true); + } + + public static FunctionCall count(String columnName) { + return count(columnName, null); + } + + public static FunctionCall count(String columnName, String alias) { + return function(columnName, alias, COUNT_FUNCTION_NAME); + } + + public String getColumnName() { + return columnName; + } + + public String getFunctionName() { + return functionName; + } + + public String getAlias() { + return alias; + } + + public void setAlias(String alias) { + this.alias = alias; + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/CountOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/CountOperation.java index 912eae48c3..9322bb0c44 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/CountOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/CountOperation.java @@ -3,16 +3,15 @@ import com.bpodgursky.jbool_expressions.Expression; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import io.smallrye.mutiny.Uni; -import io.stargate.bridge.proto.QueryOuterClass; -import io.stargate.sgv2.api.common.cql.builder.BuiltCondition; -import io.stargate.sgv2.api.common.cql.builder.QueryBuilder; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.LogicalExpression; +import io.stargate.sgv2.jsonapi.service.cql.builder.BuiltCondition; +import io.stargate.sgv2.jsonapi.service.cql.builder.Query; +import io.stargate.sgv2.jsonapi.service.cql.builder.QueryBuilder; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.operation.model.impl.CountOperationPage; import io.stargate.sgv2.jsonapi.service.operation.model.impl.ExpressionBuilder; -import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; @@ -48,11 +47,7 @@ public Uni> execute(QueryExecutor queryExecutor) { private SimpleStatement buildSelectQuery() { final List> expressions = ExpressionBuilder.buildExpressions(logicalExpression, null); - List collect = new ArrayList<>(); - if (expressions != null && !expressions.isEmpty() && expressions.get(0) != null) { - collect = ExpressionBuilder.getExpressionValuesInOrder(expressions.get(0)); - } - QueryOuterClass.Query query = null; + Query query = null; if (limit == -1) { query = new QueryBuilder() @@ -72,9 +67,8 @@ private SimpleStatement buildSelectQuery() { .limit(limit + 1) .build(); } - - final SimpleStatement simpleStatement = SimpleStatement.newInstance(query.getCql()); + SimpleStatement simpleStatement = query.queryToStatement(); simpleStatement.setPageSize(pageSize()); - return simpleStatement.setPositionalValues(collect); + return simpleStatement; } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java index b6e1c12833..4b8db81756 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/DBFilterBase.java @@ -6,11 +6,10 @@ import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; -import io.stargate.bridge.grpc.Values; -import io.stargate.sgv2.api.common.cql.builder.BuiltCondition; -import io.stargate.sgv2.api.common.cql.builder.Predicate; import io.stargate.sgv2.jsonapi.exception.ErrorCode; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.cql.builder.BuiltCondition; +import io.stargate.sgv2.jsonapi.service.cql.builder.Predicate; import io.stargate.sgv2.jsonapi.service.cqldriver.serializer.CQLBindValues; import io.stargate.sgv2.jsonapi.service.shredding.model.DocValueHasher; import io.stargate.sgv2.jsonapi.service.shredding.model.DocumentId; @@ -135,42 +134,42 @@ public BuiltCondition get() { switch (operator) { case EQ: return BuiltCondition.of( - DATA_CONTAINS, + BuiltCondition.LHS.column(DATA_CONTAINS), Predicate.CONTAINS, new JsonTerm(getHashValue(new DocValueHasher(), key, value))); case NE: return BuiltCondition.of( - DATA_CONTAINS, + BuiltCondition.LHS.column(DATA_CONTAINS), Predicate.NOT_CONTAINS, new JsonTerm(getHashValue(new DocValueHasher(), key, value))); case MAP_EQUALS: return BuiltCondition.of( - BuiltCondition.LHS.mapAccess(columnName, Values.NULL), + BuiltCondition.LHS.mapAccess(columnName, key), Predicate.EQ, new JsonTerm(key, value)); case MAP_NOT_EQUALS: return BuiltCondition.of( - BuiltCondition.LHS.mapAccess(columnName, Values.NULL), + BuiltCondition.LHS.mapAccess(columnName, key), Predicate.NEQ, new JsonTerm(key, value)); case GT: return BuiltCondition.of( - BuiltCondition.LHS.mapAccess(columnName, Values.NULL), + BuiltCondition.LHS.mapAccess(columnName, key), Predicate.GT, new JsonTerm(key, value)); case GTE: return BuiltCondition.of( - BuiltCondition.LHS.mapAccess(columnName, Values.NULL), + BuiltCondition.LHS.mapAccess(columnName, key), Predicate.GTE, new JsonTerm(key, value)); case LT: return BuiltCondition.of( - BuiltCondition.LHS.mapAccess(columnName, Values.NULL), + BuiltCondition.LHS.mapAccess(columnName, key), Predicate.LT, new JsonTerm(key, value)); case LTE: return BuiltCondition.of( - BuiltCondition.LHS.mapAccess(columnName, Values.NULL), + BuiltCondition.LHS.mapAccess(columnName, key), Predicate.LTE, new JsonTerm(key, value)); default: @@ -314,13 +313,13 @@ public List getAll() { if (documentId.value() instanceof BigDecimal numberId) { return List.of( BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_dbl_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_dbl_values", DOC_ID), Predicate.NEQ, new JsonTerm(DOC_ID, numberId))); } else if (documentId.value() instanceof String strId) { return List.of( BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_text_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_text_values", DOC_ID), Predicate.NEQ, new JsonTerm(DOC_ID, strId))); } else { @@ -415,20 +414,20 @@ public List getAll() { // array element is sub_doc inResult.add( BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_text_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_text_values", this.getPath()), Predicate.EQ, new JsonTerm(this.getPath(), getHash(new DocValueHasher(), value)))); } else if (value instanceof List) { // array element is array inResult.add( BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_text_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_text_values", this.getPath()), Predicate.EQ, new JsonTerm(this.getPath(), getHash(new DocValueHasher(), value)))); } else { inResult.add( BuiltCondition.of( - DATA_CONTAINS, + BuiltCondition.LHS.column(DATA_CONTAINS), Predicate.CONTAINS, new JsonTerm(getHashValue(new DocValueHasher(), getPath(), value)))); } @@ -443,20 +442,20 @@ public List getAll() { // array element is sub_doc ninResults.add( BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_text_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_text_values", this.getPath()), Predicate.NEQ, new JsonTerm(this.getPath(), getHash(new DocValueHasher(), value)))); } else if (value instanceof List) { // array element is array ninResults.add( BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_text_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_text_values", this.getPath()), Predicate.NEQ, new JsonTerm(this.getPath(), getHash(new DocValueHasher(), value)))); } else { ninResults.add( BuiltCondition.of( - DATA_CONTAINS, + BuiltCondition.LHS.column(DATA_CONTAINS), Predicate.NOT_CONTAINS, new JsonTerm(getHashValue(new DocValueHasher(), getPath(), value)))); } @@ -471,14 +470,14 @@ public List getAll() { if (docIdValue instanceof BigDecimal numberId) { BuiltCondition condition = BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_dbl_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_dbl_values", DOC_ID), Predicate.NEQ, new JsonTerm(DOC_ID, numberId)); conditions.add(condition); } else if (docIdValue instanceof String strId) { BuiltCondition condition = BuiltCondition.of( - BuiltCondition.LHS.mapAccess("query_text_values", Values.NULL), + BuiltCondition.LHS.mapAccess("query_text_values", DOC_ID), Predicate.NEQ, new JsonTerm(DOC_ID, strId)); conditions.add(condition); @@ -537,9 +536,11 @@ public int hashCode() { public BuiltCondition get() { switch (operator) { case CONTAINS: - return BuiltCondition.of(columnName, Predicate.CONTAINS, new JsonTerm(value)); + return BuiltCondition.of( + BuiltCondition.LHS.column(columnName), Predicate.CONTAINS, new JsonTerm(value)); case NOT_CONTAINS: - return BuiltCondition.of(columnName, Predicate.NOT_CONTAINS, new JsonTerm(value)); + return BuiltCondition.of( + BuiltCondition.LHS.column(columnName), Predicate.NOT_CONTAINS, new JsonTerm(value)); default: throw new JsonApiException( ErrorCode.UNSUPPORTED_FILTER_OPERATION, @@ -620,7 +621,7 @@ public List getAll() { for (Object value : arrayValue) { result.add( BuiltCondition.of( - DATA_CONTAINS, + BuiltCondition.LHS.column(DATA_CONTAINS), negation ? Predicate.NOT_CONTAINS : Predicate.CONTAINS, new JsonTerm(getHashValue(new DocValueHasher(), getPath(), value)))); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/ExpressionBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/ExpressionBuilder.java index 92ae3806e6..f44fec3f70 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/ExpressionBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/ExpressionBuilder.java @@ -2,11 +2,11 @@ import com.bpodgursky.jbool_expressions.Expression; import com.bpodgursky.jbool_expressions.Variable; -import io.stargate.sgv2.api.common.cql.ExpressionUtils; -import io.stargate.sgv2.api.common.cql.builder.BuiltCondition; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.ComparisonExpression; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.LogicalExpression; import io.stargate.sgv2.jsonapi.exception.ErrorCode; +import io.stargate.sgv2.jsonapi.service.cql.ExpressionUtils; +import io.stargate.sgv2.jsonapi.service.cql.builder.BuiltCondition; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -191,35 +191,4 @@ private static Expression buildExpressionRecursive( return ExpressionUtils.buildExpression( conditionExpressions, logicalExpression.getLogicalRelation().getOperator()); } - - /** - * Get all positional cql values from express recursively. Result order is in consistent of the - * expression structure - */ - public static List getExpressionValuesInOrder(Expression expression) { - List values = new ArrayList<>(); - if (expression != null) { - populateValuesRecursive(values, expression); - } - return values; - } - - private static void populateValuesRecursive( - List values, Expression outerExpression) { - if (outerExpression.getExprType().equals("variable")) { - Variable var = (Variable) outerExpression; - JsonTerm term = ((JsonTerm) var.getValue().value()); - if (term.getKey() != null) { - values.add(term.getKey()); - } - values.add(term.getValue()); - return; - } - if (outerExpression.getExprType().equals("and") || outerExpression.getExprType().equals("or")) { - List> innerExpressions = outerExpression.getChildren(); - for (Expression innerExpression : innerExpressions) { - populateValuesRecursive(values, innerExpression); - } - } - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java index 1ea42a2f53..6a2c34e7dd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperation.java @@ -7,10 +7,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.Lists; import io.smallrye.mutiny.Uni; -import io.stargate.bridge.grpc.Values; -import io.stargate.bridge.proto.QueryOuterClass; -import io.stargate.sgv2.api.common.cql.builder.BuiltCondition; -import io.stargate.sgv2.api.common.cql.builder.QueryBuilder; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.ComparisonExpression; @@ -19,8 +15,10 @@ import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCode; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.cql.builder.BuiltCondition; +import io.stargate.sgv2.jsonapi.service.cql.builder.Query; +import io.stargate.sgv2.jsonapi.service.cql.builder.QueryBuilder; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; -import io.stargate.sgv2.jsonapi.service.cqldriver.serializer.CQLBindValues; import io.stargate.sgv2.jsonapi.service.operation.model.ChainedComparator; import io.stargate.sgv2.jsonapi.service.operation.model.ReadOperation; import io.stargate.sgv2.jsonapi.service.operation.model.ReadType; @@ -414,9 +412,9 @@ private List buildSelectQueries(DBFilterBase.IDFilter additiona List queries = new ArrayList<>(expressions.size()); expressions.forEach( expression -> { - List collect = ExpressionBuilder.getExpressionValuesInOrder(expression); + final Query query; if (vector() == null) { - final QueryOuterClass.Query query = + query = new QueryBuilder() .select() .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) @@ -424,20 +422,10 @@ private List buildSelectQueries(DBFilterBase.IDFilter additiona .where(expression) .limit(limit) .build(); - final SimpleStatement simpleStatement = SimpleStatement.newInstance(query.getCql()); - queries.add(simpleStatement.setPositionalValues(collect)); } else { - QueryOuterClass.Query query = getVectorSearchQueryByExpression(expression); - collect.add(CQLBindValues.getVectorValue(vector())); - final SimpleStatement simpleStatement = SimpleStatement.newInstance(query.getCql()); - if (projection().doIncludeSimilarityScore()) { - List appendedCollect = new ArrayList<>(); - appendedCollect.add(collect.get(collect.size() - 1)); - appendedCollect.addAll(collect); - collect = appendedCollect; - } - queries.add(simpleStatement.setPositionalValues(collect)); + query = getVectorSearchQueryByExpression(expression); } + queries.add(query.queryToStatement()); }); return queries; @@ -447,54 +435,19 @@ private List buildSelectQueries(DBFilterBase.IDFilter additiona * A separate method to build vector search query by using expression, expression can contain * logic operations like 'or','and'.. */ - private QueryOuterClass.Query getVectorSearchQueryByExpression( - Expression expression) { - QueryOuterClass.Query builtQuery = null; + private Query getVectorSearchQueryByExpression(Expression expression) { if (projection().doIncludeSimilarityScore()) { - switch (commandContext().similarityFunction()) { - case COSINE, UNDEFINED -> { - return new QueryBuilder() - .select() - .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) - .similarityCosine( - DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, Values.NULL) - .from(commandContext.namespace(), commandContext.collection()) - .where(expression) - .limit(limit) - .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) - .build(); - } - case EUCLIDEAN -> { - return new QueryBuilder() - .select() - .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) - .similarityEuclidean( - DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, Values.NULL) - .from(commandContext.namespace(), commandContext.collection()) - .where(expression) - .limit(limit) - .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) - .build(); - } - case DOT_PRODUCT -> { - return new QueryBuilder() - .select() - .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) - .similarityDotProduct( - DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, Values.NULL) - .from(commandContext.namespace(), commandContext.collection()) - .where(expression) - .limit(limit) - .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) - .build(); - } - default -> { - throw new JsonApiException( - ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME, - ErrorCode.VECTOR_SEARCH_INVALID_FUNCTION_NAME.getMessage() - + commandContext().similarityFunction()); - } - } + return new QueryBuilder() + .select() + .column(ReadType.DOCUMENT == readType ? documentColumns : documentKeyColumns) + .similarityFunction( + DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, + commandContext().similarityFunction()) + .from(commandContext.namespace(), commandContext.collection()) + .where(expression) + .limit(limit) + .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, vector()) + .build(); } else { return new QueryBuilder() .select() @@ -502,7 +455,7 @@ private QueryOuterClass.Query getVectorSearchQueryByExpression( .from(commandContext.namespace(), commandContext.collection()) .where(expression) .limit(limit) - .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME) + .vsearch(DocumentConstants.Fields.VECTOR_SEARCH_INDEX_COLUMN_NAME, vector()) .build(); } } @@ -531,8 +484,7 @@ private List buildSortedSelectQueries(DBFilterBase.IDFilter add List queries = new ArrayList<>(expressions.size()); expressions.forEach( expression -> { - List collect = ExpressionBuilder.getExpressionValuesInOrder(expression); - final QueryOuterClass.Query query = + final Query query = new QueryBuilder() .select() .column(columnsToAdd) @@ -540,8 +492,7 @@ private List buildSortedSelectQueries(DBFilterBase.IDFilter add .where(expression) .limit(maxSortReadLimit()) .build(); - final SimpleStatement simpleStatement = SimpleStatement.newInstance(query.getCql()); - queries.add(simpleStatement.setPositionalValues(collect)); + queries.add(query.queryToStatement()); }); return queries; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/JsonTerm.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/JsonTerm.java index 9d043f5254..3136e54746 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/JsonTerm.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/JsonTerm.java @@ -1,7 +1,6 @@ package io.stargate.sgv2.jsonapi.service.operation.model.impl; -import io.stargate.bridge.grpc.Values; -import io.stargate.sgv2.api.common.cql.builder.Literal; +import java.util.List; import java.util.Objects; /** @@ -10,7 +9,7 @@ * required as a placeholder to set values in query builder and extracted out to set the value in * SimpleStatement positional values */ -public class JsonTerm extends Literal { +public class JsonTerm { static final String NULL_ERROR_MESSAGE = "Use Values.NULL to bind a null CQL value"; private final Object key; private final Object value; @@ -20,7 +19,6 @@ public JsonTerm(Object value) { } public JsonTerm(Object key, Object value) { - super(Values.NULL); this.key = key; this.value = value; } @@ -47,4 +45,18 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(new Object[] {this.value, this.key}); } + + /** + * This method is used for populate positional cql value list e.g. select * from table where + * map[?] = ? limit 1; For this case, we populate as key and value + * + *

e.g. select * from table where array_contains contains ? limit 1; * For this case, we + * populate positional cql value + */ + public void addToCqlValues(List values) { + if (this.key != null) { + values.add(this.key); + } + values.add(this.value); + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cql/builder/QueryBuilderTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cql/builder/QueryBuilderTest.java new file mode 100644 index 0000000000..d3e876e02d --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cql/builder/QueryBuilderTest.java @@ -0,0 +1,219 @@ +package io.stargate.sgv2.jsonapi.service.cql.builder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +import com.bpodgursky.jbool_expressions.Expression; +import com.bpodgursky.jbool_expressions.Variable; +import com.datastax.oss.driver.api.core.data.CqlVector; +import io.stargate.sgv2.jsonapi.service.cql.ExpressionUtils; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.CollectionSettings; +import io.stargate.sgv2.jsonapi.service.cqldriver.serializer.CQLBindValues; +import io.stargate.sgv2.jsonapi.service.operation.model.impl.JsonTerm; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class QueryBuilderTest { + + public static final String VECTOR_COLUMN = "query_vector_value"; + public static final float[] TEST_VECTOR = new float[] {0.1f, 0.2f, 0.3f}; + + public static final CqlVector TEST_CQL_VECTOR = CQLBindValues.getVectorValue(TEST_VECTOR); + + public static final List EMPTY_VALUES = new ArrayList<>(); + + @ParameterizedTest + @MethodSource("sampleQueries") + @DisplayName("Should generate expected CQL string and values") + public void generateQuery(Query query, String expectedCql, List expectedValues) { + assertThat(query.cql()).isEqualTo(expectedCql); + assertThat(query.values()).isEqualTo(expectedValues); + } + + public static Arguments[] sampleQueries() { + return new Arguments[] { + arguments( + new QueryBuilder().select().from("ks", "tbl").build(), + "SELECT * FROM ks.tbl", + EMPTY_VALUES), + arguments( + new QueryBuilder().select().column("a", "b", "c").from("ks", "tbl").build(), + "SELECT a, b, c FROM ks.tbl", + EMPTY_VALUES), + arguments( + new QueryBuilder().select().count("a").from("ks", "tbl").build(), + "SELECT COUNT(a) FROM ks.tbl", + EMPTY_VALUES), + arguments( + new QueryBuilder().select().count().from("ks", "tbl").build(), + "SELECT COUNT(1) FROM ks.tbl", + EMPTY_VALUES), + arguments( + new QueryBuilder().select().count("a").from("ks", "tbl").limit(1).build(), + "SELECT COUNT(a) FROM ks.tbl LIMIT 1", + EMPTY_VALUES), + arguments( + new QueryBuilder().select().count("a").from("ks", "tbl").limit().build(), + "SELECT COUNT(a) FROM ks.tbl LIMIT ?", + EMPTY_VALUES), + arguments( + new QueryBuilder() + .select() + .column("FirstName", "b", "c") + .from("ks", "tbl") + .limit(1) + .build(), + "SELECT \"FirstName\", b, c FROM ks.tbl LIMIT 1", + EMPTY_VALUES, + arguments( + new QueryBuilder() + .select() + .column("a", "b", "c") + .from("ks", "tbl") + .limit(1) + .vsearch(VECTOR_COLUMN, TEST_VECTOR) + .build(), + "SELECT a, b, c FROM ks.tbl ORDER BY query_vector_value ANN OF ? LIMIT 1", + List.of(TEST_CQL_VECTOR)), + arguments( + new QueryBuilder() + .select() + .column("a", "b", "c") + .similarityFunction( + "query_vector_value", CollectionSettings.SimilarityFunction.COSINE) + .from("ks", "tbl") + .limit(1) + .vsearch(VECTOR_COLUMN, TEST_VECTOR) + .build(), + "SELECT a, b, c, SIMILARITY_COSINE(query_vector_value, ?) FROM ks.tbl ORDER BY query_vector_value ANN OF ? LIMIT 1", + List.of(TEST_CQL_VECTOR, TEST_CQL_VECTOR)), + arguments( + new QueryBuilder() + .select() + .column("a", "b", "c") + .similarityFunction( + "query_vector_value", CollectionSettings.SimilarityFunction.DOT_PRODUCT) + .from("ks", "tbl") + .limit(1) + .vsearch(VECTOR_COLUMN, TEST_VECTOR) + .build(), + "SELECT a, b, c, SIMILARITY_DOT_PRODUCT(query_vector_value, ?) FROM ks.tbl ORDER BY query_vector_value ANN OF ? LIMIT 1", + List.of(TEST_CQL_VECTOR, TEST_CQL_VECTOR)), + arguments( + new QueryBuilder() + .select() + .column("a", "b", "c") + .similarityFunction( + VECTOR_COLUMN, CollectionSettings.SimilarityFunction.EUCLIDEAN) + .from("ks", "tbl") + .limit(1) + .vsearch("query_vector_value", TEST_VECTOR) + .build(), + "SELECT a, b, c, SIMILARITY_EUCLIDEAN(query_vector_value, ?) FROM ks.tbl ORDER BY query_vector_value ANN OF ? LIMIT 1", + List.of(TEST_CQL_VECTOR, TEST_CQL_VECTOR))) + }; + } + + @Nested + public class expressionToCqlBuilderTest { + @Test + public void simpleAnd() { + Expression expression = + ExpressionUtils.andOf( + Variable.of(BuiltCondition.of("Name", Predicate.EQ, new JsonTerm("testName"))), + Variable.of(BuiltCondition.of("age", Predicate.EQ, new JsonTerm("testAge")))); + Query query = new QueryBuilder().select().from("ks", "tbl").where(expression).build(); + assertThat(query.cql()).isEqualTo("SELECT * FROM ks.tbl WHERE (\"Name\" = ? AND age = ?)"); + assertThat(query.values()).contains("testName", "testAge"); + } + + @Test + public void vsearch() { + Expression expression = null; + Query query = + new QueryBuilder() + .select() + .from("ks", "tbl") + .vsearch(VECTOR_COLUMN, TEST_VECTOR) + .where(expression) + .build(); + assertThat(query.cql()) + .isEqualTo("SELECT * FROM ks.tbl ORDER BY query_vector_value ANN OF ?"); + assertThat(query.values()).contains(TEST_CQL_VECTOR); + } + + @Test + public void vsearchWithFilter() { + Expression expression = + ExpressionUtils.andOf( + Variable.of(BuiltCondition.of("name", Predicate.EQ, new JsonTerm("testName")))); + Query query = + new QueryBuilder() + .select() + .column("a", "b") + .from("ks", "tbl") + .similarityFunction(VECTOR_COLUMN, CollectionSettings.SimilarityFunction.EUCLIDEAN) + .vsearch(VECTOR_COLUMN, TEST_VECTOR) + .where(expression) + .limit(10) + .build(); + assertThat(query.cql()) + .isEqualTo( + "SELECT a, b, SIMILARITY_EUCLIDEAN(query_vector_value, ?) FROM ks.tbl WHERE name = ? ORDER BY query_vector_value ANN OF ? LIMIT 10"); + assertThat(query.values()).isEqualTo(List.of(TEST_CQL_VECTOR, "testName", TEST_CQL_VECTOR)); + } + + @Test + public void simpleOr() { + Expression expression = + ExpressionUtils.orOf( + Variable.of(BuiltCondition.of("name", Predicate.EQ, new JsonTerm("testName"))), + Variable.of(BuiltCondition.of("age", Predicate.EQ, new JsonTerm("testAge")))); + Query query = new QueryBuilder().select().from("ks", "tbl").where(expression).build(); + assertThat(query.cql()).isEqualTo("SELECT * FROM ks.tbl WHERE (name = ? OR age = ?)"); + assertThat(query.values()).contains("testName", "testAge"); + } + + @Test + public void singleVariableWithoutParenthesis() { + Expression expression1 = + ExpressionUtils.andOf( + Variable.of(BuiltCondition.of("name", Predicate.EQ, new JsonTerm("testName")))); + Query query1 = new QueryBuilder().select().from("ks", "tbl").where(expression1).build(); + assertThat(query1.cql()).isEqualTo("SELECT * FROM ks.tbl WHERE name = ?"); + assertThat(query1.values()).containsOnly("testName"); + } + + @Test + public void nestedAndOr() { + Expression expression2 = + ExpressionUtils.orOf( + Variable.of(BuiltCondition.of("address", Predicate.EQ, new JsonTerm("testAddress"))), + ExpressionUtils.andOf( + Variable.of(BuiltCondition.of("name", Predicate.EQ, new JsonTerm("testName"))), + Variable.of(BuiltCondition.of("age", Predicate.EQ, new JsonTerm("testAge"))))); + Query query2 = new QueryBuilder().select().from("ks", "tbl").where(expression2).build(); + assertThat(query2.cql()) + .isEqualTo("SELECT * FROM ks.tbl WHERE (address = ? OR (name = ? AND age = ?))"); + assertThat(query2.values()).contains("testName", "testAge", "testAddress"); + } + + @Test + public void singleVariableExpression() { + Expression expression2 = + ExpressionUtils.orOf( + Variable.of(BuiltCondition.of("address", Predicate.EQ, new JsonTerm("testAddress"))), + ExpressionUtils.andOf( + Variable.of(BuiltCondition.of("age", Predicate.EQ, new JsonTerm("testAge"))))); + Query query2 = new QueryBuilder().select().from("ks", "tbl").where(expression2).build(); + assertThat(query2.cql()).isEqualTo("SELECT * FROM ks.tbl WHERE (address = ? OR age = ?)"); + assertThat(query2.values()).contains("testAge", "testAddress"); + } + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperationTest.java index 4a012863bb..703ca68de0 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindOperationTest.java @@ -22,13 +22,13 @@ import io.quarkus.test.junit.TestProfile; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; -import io.stargate.sgv2.api.common.cql.builder.BuiltCondition; import io.stargate.sgv2.common.testprofiles.NoGlobalResourcesTestProfile; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.ComparisonExpression; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.LogicalExpression; import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableToErrorMapper; +import io.stargate.sgv2.jsonapi.service.cql.builder.BuiltCondition; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.CollectionSettings; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.operation.model.ReadType;