Skip to content

Commit

Permalink
Use CQL Order By for non ANN sort (#1635)
Browse files Browse the repository at this point in the history
  • Loading branch information
amorton authored Nov 4, 2024
1 parent a9e7484 commit 0a1db86
Show file tree
Hide file tree
Showing 42 changed files with 1,300 additions and 167 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.stargate.sgv2.jsonapi.api.model.command.clause.sort;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.stargate.sgv2.jsonapi.api.model.command.deserializers.SortClauseDeserializer;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
Expand Down Expand Up @@ -32,12 +33,14 @@ public boolean isEmpty() {
return sortExpressions == null || sortExpressions.isEmpty();
}

/** Get the sort expressions that are trying to vector sort columns on a table */
public List<SortExpression> tableVectorSorts() {
return sortExpressions == null
? List.of()
: sortExpressions.stream().filter(SortExpression::isTableVectorSort).toList();
}

/** Get the sort expressions that are not trying to vector sort columns on a table */
public List<SortExpression> nonTableVectorSorts() {
return sortExpressions == null
? List.of()
Expand All @@ -46,6 +49,12 @@ public List<SortExpression> nonTableVectorSorts() {
.toList();
}

public List<CqlIdentifier> sortColumnIdentifiers() {
return sortExpressions == null
? List.of()
: sortExpressions.stream().map(SortExpression::pathAsCqlIdentifier).toList();
}

public boolean hasVsearchClause() {
return sortExpressions != null
&& !sortExpressions.isEmpty()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD;
import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil;
import jakarta.validation.constraints.NotBlank;
import java.util.Objects;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -42,6 +44,14 @@ public static SortExpression tableVectorSort(String path, float[] vector) {
return new SortExpression(path, false, vector, null);
}

public CqlIdentifier pathAsCqlIdentifier() {
return CqlIdentifierUtil.cqlIdentifierFromUserInput(path);
}

/**
* Check if the sort expression is trying to vector sort columns on a table, the sort is trying to
* do this if it is not using $vector or $vectorize and it has a vector array to sort on
*/
public boolean isTableVectorSort() {
return !pathIs$VectorNames() && vector != null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,29 @@ public SortClause deserialize(JsonParser parser, DeserializationContext ctxt) th
}
if (!inner.getValue().isArray()) {
throw ErrorCodeV1.SHRED_BAD_VECTOR_VALUE.toApiException();
} else {
ArrayNode arrayNode = (ArrayNode) inner.getValue();
float[] arrayVals = new float[arrayNode.size()];
if (arrayNode.size() == 0) {
throw ErrorCodeV1.SHRED_BAD_VECTOR_SIZE.toApiException();
}
for (int i = 0; i < arrayNode.size(); i++) {
JsonNode element = arrayNode.get(i);
if (!element.isNumber()) {
throw ErrorCodeV1.SHRED_BAD_VECTOR_VALUE.toApiException();
}
arrayVals[i] = element.floatValue();
}

SortExpression exp = SortExpression.vsearch(arrayVals);
sortExpressions.clear();
sortExpressions.add(exp);
// TODO: aaron 17-oct-2024 - this break seems unneeded as above it checks if there is only
// 1
// field, leaving for now
break;
}

SortExpression exp =
SortExpression.vsearch(arrayNodeToVector((ArrayNode) inner.getValue()));
sortExpressions.clear();
sortExpressions.add(exp);
// TODO: aaron 17-oct-2024 - this break seems unneeded as above it checks if there is only 1
// field, leaving for now
break;

} else if (DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD.equals(path)) {
// Vector search can't be used with other sort clause
if (totalFields > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ public SortException(ErrorInstance errorInstance) {
}

public enum Code implements ErrorCode<SortException> {
CANNOT_MIX_VECTOR_AND_NON_VECTOR_SORT,
CANNOT_SORT_UNKNOWN_COLUMNS,
CANNOT_VECTOR_SORT_NON_VECTOR_COLUMNS,
CANNOT_VECTOR_SORT_NON_INDEXED_VECTOR_COLUMNS,
CANNOT_MIX_VECTOR_AND_NON_VECTOR_SORT,
CANNOT_VECTOR_SORT_NON_VECTOR_COLUMNS,
MORE_THAN_ONE_VECTOR_SORT,
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public enum Code implements ErrorCode<WarningException> {
COMPARISON_FILTER_UNSUPPORTED_BY_INDEXING,
DEPRECATED_COMMAND,
INCOMPLETE_PRIMARY_KEY_FILTER,
IN_MEMORY_SORTING_DUE_TO_MISSING_PARTITION_SORTING,
IN_MEMORY_SORTING_DUE_TO_NON_PARTITION_SORTING,
IN_MEMORY_SORTING_DUE_TO_OUT_OF_ORDER_PARTITION_SORTING,
MISSING_INDEX,
NOT_EQUALS_UNSUPPORTED_BY_INDEXING,
NOT_IN_FILTER_UNSUPPORTED_BY_INDEXING,
Expand Down
92 changes: 92 additions & 0 deletions src/main/java/io/stargate/sgv2/jsonapi/exception/WithWarnings.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package io.stargate.sgv2.jsonapi.exception;

import com.google.common.base.Preconditions;
import io.stargate.sgv2.jsonapi.service.operation.OperationAttempt;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

/**
* Re-usable class for holding an object and the {@link WarningException}'s that have been generated
* for it.
*
* <p>This is usually used when analysing part of a query and generating warnings for the user.
*
* <p>Is a {@link Consumer} of {@link OperationAttempt} so that it can add the warnings to the
* attempt, and so multiple instances can be chained together using {@link
* Consumer#andThen(Consumer)}
*
* @param <T> Type of the target object the warnings are about.
*/
public class WithWarnings<T> implements Consumer<OperationAttempt<?, ?>> {

private final T target;
private final List<WarningException> warnings;

public WithWarnings(T target, List<WarningException> warnings) {
Preconditions.checkNotNull(target, "target must not be null");
this.target = target;
this.warnings = warnings == null ? new ArrayList<>() : warnings;
}

/**
* The target object the warnings are about.
*
* @return The target object.
*/
public T target() {
return target;
}

/**
* The warnings generated for the target object.
*
* <p>This is a mutable, so you can add more warnings to it.
*
* @return The list of warnings, never null.
*/
public List<WarningException> warnings() {
return warnings;
}

/** Returns true if there are no warnings. */
public boolean isEmpty() {
return warnings.isEmpty();
}

/*
* Constructor an instance with no warnings.
* @param target the target object that has no warnings
* @return an instance with no warnings
*/
public static <T> WithWarnings<T> of(T target) {
return new WithWarnings<>(target, new ArrayList<>());
}

/**
* Constructor an instance with a single warning.
*
* @param target the target object that has the warning
* @param warning the warning to add
* @return An instance with the warning
* @param <T> Type of the target object the warnings are about.
*/
public static <T> WithWarnings<T> of(T target, WarningException warning) {
Objects.requireNonNull(warning, "warning is required");
var warnings = new ArrayList<WarningException>();
warnings.add(warning);
return new WithWarnings<>(target, warnings);
}

/**
* Adds all the warnings to the {@link OperationAttempt}
*
* @param operationAttempt the {@link OperationAttempt} to add the warnings to
*/
@Override
public void accept(OperationAttempt<?, ?> operationAttempt) {
Objects.requireNonNull(operationAttempt, "operationAttempt must not be null");
warnings.forEach(operationAttempt::addWarning);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.datastax.oss.driver.api.querybuilder.update.Update;
import io.stargate.sgv2.jsonapi.exception.FilterException;
import io.stargate.sgv2.jsonapi.exception.WithWarnings;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.operation.query.UpdateValuesCQLClause;
import io.stargate.sgv2.jsonapi.service.operation.query.WhereCQLClause;
Expand Down Expand Up @@ -31,34 +32,40 @@ public UpdateAttemptBuilder(SchemaT tableBasedSchema) {
}

public UpdateAttempt<SchemaT> build(
WhereCQLClause<Update> whereCQLClause, UpdateValuesCQLClause updateCQLClause) {
WhereCQLClause<Update> whereCQLClause, WithWarnings<UpdateValuesCQLClause> updateCQLClause) {

readPosition += 1;

// TODO: this may be common for creating a read / delete / where attempt will look at how to
// refactor once all done
WhereCQLClauseAnalyzer.WhereClauseAnalysis analyzedResult = null;
WhereCQLClauseAnalyzer.WhereClauseWithWarnings whereClauseWithWarnings = null;
Exception exception = null;
try {
analyzedResult = whereCQLClauseAnalyzer.analyse(whereCQLClause);
whereClauseWithWarnings = whereCQLClauseAnalyzer.analyse(whereCQLClause);
} catch (FilterException filterException) {
exception = filterException;
}

var attempt =
new UpdateAttempt<>(readPosition, tableBasedSchema, updateCQLClause, whereCQLClause);
new UpdateAttempt<>(
readPosition, tableBasedSchema, updateCQLClause.target(), whereCQLClause);

// ok to pass null exception, will be ignored
attempt.maybeAddFailure(exception);

// There should not be any warnings, we cannot turn on allow filtering for delete
// and we should not be turning on allow filtering for delete
// sanity check
if (analyzedResult != null && !analyzedResult.isEmpty()) {
if (whereClauseWithWarnings != null
&& (whereClauseWithWarnings.requiresAllowFiltering()
|| !whereClauseWithWarnings.isEmpty())) {
throw new IllegalStateException(
"Where clause analysis for update was not empty, analysis:%s".formatted(analyzedResult));
"Where clause analysis for update was not empty, analysis:%s"
.formatted(whereClauseWithWarnings));
}

// add warnings from the CQL clauses to the attempt
updateCQLClause.accept(attempt);
return attempt;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private CreateTable addColumnsAndKeys(CreateTableStart createTableStart) {
: createTable.withPartitionKey(partitionDef.name(), partitionDef.type().cqlType());
}

for (var clusteringDef : tableDef.clusteringKeys()) {
for (var clusteringDef : tableDef.clusteringDefs()) {
createTable =
createTable.withClusteringColumn(
clusteringDef.columnDef().name(), clusteringDef.columnDef().type().cqlType());
Expand All @@ -92,7 +92,7 @@ private CreateTable addColumnsAndKeys(CreateTableStart createTableStart) {

private CreateTableWithOptions addClusteringOrder(CreateTableWithOptions createTableWithOptions) {

for (var clusteringDef : tableDef.clusteringKeys()) {
for (var clusteringDef : tableDef.clusteringDefs()) {
createTableWithOptions =
createTableWithOptions.withClusteringOrder(
clusteringDef.columnDef().name(), clusteringDef.order().getCqlOrder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public DeleteAttempt<SchemaT> build(WhereCQLClause<Delete> whereCQLClause) {

// TODO: this may be common for creating a read / delete / where attempt will look at how to
// refactor once all done
WhereCQLClauseAnalyzer.WhereClauseAnalysis analyzedResult = null;
WhereCQLClauseAnalyzer.WhereClauseWithWarnings analyzedResult = null;
Exception exception = null;
try {
analyzedResult = whereCQLClauseAnalyzer.analyse(whereCQLClause);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.stargate.sgv2.jsonapi.service.operation.tables;

import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import io.stargate.sgv2.jsonapi.service.operation.query.OrderByCqlClause;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName;
import java.util.Objects;

/**
* A CQL clause that adds an ORDER BY clause to a SELECT statement to ANN sort.
*
* <p>Note: Only supports sorting on vector columns a single column, if there is a secondary sort
* that would be in memory sorting.
*/
public class TableOrderByANNCqlClause implements OrderByCqlClause {

private final ApiColumnDef apiColumnDef;
private final CqlVector<Float> vector;

public TableOrderByANNCqlClause(ApiColumnDef apiColumnDef, CqlVector<Float> vector) {
this.apiColumnDef = Objects.requireNonNull(apiColumnDef, "apiColumnDef must not be null");
this.vector = Objects.requireNonNull(vector, "vector must not be null");

// sanity check
if (apiColumnDef.type().typeName() != ApiTypeName.VECTOR) {
throw new IllegalArgumentException(
"ApiColumnDef must be a vector type, got: %s".formatted(apiColumnDef));
}
}

@Override
public Select apply(Select select) {
return select.orderByAnnOf(apiColumnDef.name(), vector);
}

@Override
public boolean inMemorySortNeeded() {
return false;
}
}
Loading

0 comments on commit 0a1db86

Please sign in to comment.