Skip to content

Commit

Permalink
Table, includeSimilarityScore, includeSortVector (#1656)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron Morton <aaron.morton@datastax.com>
Co-authored-by: maheshrajamani <99678631+maheshrajamani@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 7, 2024
1 parent 39e9d50 commit 5df72c3
Show file tree
Hide file tree
Showing 16 changed files with 426 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.stargate.sgv2.jsonapi.api.model.command;

import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression;
import java.util.Optional;

/** Interface for commands that can have vector sort clause */
public interface VectorSortable extends Sortable {

default Optional<Boolean> includeSimilarityScore() {
return Optional.empty();
}

default Optional<Boolean> includeSortVector() {
return Optional.empty();
}

/**
* Returns the first SortExpression that has {@link SortExpression#vector()} not null, if there is
* more than one raises {@link IllegalStateException}.
*
* @return the vector sort expression if it exists.
*/
default Optional<SortExpression> vectorSortExpression() {
if (sortClause() != null && sortClause().sortExpressions() != null) {
var vectorSorts =
sortClause().sortExpressions().stream()
.filter(expression -> expression.vector() != null)
.toList();
if (vectorSorts.size() > 1) {
throw new IllegalStateException("Only one vector sort expression is allowed");
}
return vectorSorts.isEmpty() ? Optional.empty() : Optional.of(vectorSorts.getFirst());
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.fasterxml.jackson.databind.JsonNode;
import io.stargate.sgv2.jsonapi.api.model.command.Filterable;
import io.stargate.sgv2.jsonapi.api.model.command.Projectable;
import io.stargate.sgv2.jsonapi.api.model.command.ReadCommand;
import io.stargate.sgv2.jsonapi.api.model.command.Sortable;
import io.stargate.sgv2.jsonapi.api.model.command.Windowable;
import io.stargate.sgv2.jsonapi.api.model.command.*;
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterClause;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.api.model.command.validation.CheckFindOption;
Expand All @@ -28,7 +24,7 @@ public record FindCommand(
@JsonProperty("projection") JsonNode projectionDefinition,
@Valid @JsonProperty("sort") SortClause sortClause,
@Valid @Nullable Options options)
implements ReadCommand, Filterable, Projectable, Sortable, Windowable {
implements ReadCommand, Filterable, Projectable, Sortable, Windowable, VectorSortable {

public record Options(

Expand Down Expand Up @@ -85,4 +81,14 @@ public Optional<Integer> limit() {
public Optional<Integer> skip() {
return Optional.ofNullable(options()).map(Options::skip).filter(skip -> skip > 0);
}

@Override
public Optional<Boolean> includeSimilarityScore() {
return options() == null ? Optional.empty() : Optional.of(options().includeSimilarity);
}

@Override
public Optional<Boolean> includeSortVector() {
return options() == null ? Optional.empty() : Optional.of(options().includeSortVector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.fasterxml.jackson.databind.JsonNode;
import io.stargate.sgv2.jsonapi.api.model.command.Filterable;
import io.stargate.sgv2.jsonapi.api.model.command.Projectable;
import io.stargate.sgv2.jsonapi.api.model.command.ReadCommand;
import io.stargate.sgv2.jsonapi.api.model.command.Sortable;
import io.stargate.sgv2.jsonapi.api.model.command.Windowable;
import io.stargate.sgv2.jsonapi.api.model.command.*;
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterClause;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import jakarta.validation.Valid;
Expand All @@ -23,7 +19,7 @@ public record FindOneCommand(
@JsonProperty("projection") JsonNode projectionDefinition,
@Valid @JsonProperty("sort") SortClause sortClause,
@Valid @Nullable Options options)
implements ReadCommand, Filterable, Projectable, Sortable, Windowable {
implements ReadCommand, Filterable, Projectable, Sortable, Windowable, VectorSortable {

public record Options(

Expand All @@ -49,4 +45,14 @@ public CommandName commandName() {
public Optional<Integer> limit() {
return Optional.of(1);
}

@Override
public Optional<Boolean> includeSimilarityScore() {
return options() == null ? Optional.empty() : Optional.of(options().includeSimilarity);
}

@Override
public Optional<Boolean> includeSortVector() {
return options() == null ? Optional.empty() : Optional.of(options().includeSortVector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object
return new AbstractMap.SimpleEntry<>(similarityFunction, sourceModel);
});

// if now index, or we could not work out the function, default
// if no index, or we could not work out the function, default
var similarityFunction =
indexFunction.map(Map.Entry::getKey).orElse(SimilarityFunction.COSINE);
var sourceModel = indexFunction.map(Map.Entry::getValue).orElse(SourceModel.OTHER);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import io.stargate.sgv2.jsonapi.api.model.command.CommandResult;
import io.stargate.sgv2.jsonapi.api.model.command.CommandResultBuilder;
import io.stargate.sgv2.jsonapi.api.model.command.CommandStatus;
import io.stargate.sgv2.jsonapi.api.model.command.VectorSortable;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.CqlPagingState;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import java.util.*;
Expand Down Expand Up @@ -92,16 +94,29 @@ public Builder<SchemaT> singleResponse(boolean singleResponse) {
return this;
}

public Builder<SchemaT> includeSortVector(boolean includeSortVector) {
private Builder<SchemaT> includeSortVector(boolean includeSortVector) {
this.includeSortVector = includeSortVector;
return this;
}

public Builder<SchemaT> sortVector(float[] sortVector) {
private Builder<SchemaT> sortVector(float[] sortVector) {
this.sortVector = sortVector;
return this;
}

public <CmdT extends VectorSortable> Builder<SchemaT> mayReturnVector(CmdT command) {
var includeVector = command.includeSortVector().orElse(false);
if (includeVector) {
var requestedVector =
command.vectorSortExpression().map(SortExpression::vector).orElse(null);
if (requestedVector != null) {
this.includeSortVector = true;
this.sortVector = requestedVector;
}
}
return this;
}

@Override
public ReadAttemptPage<SchemaT> getOperationPage() {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.stargate.sgv2.jsonapi.service.operation.tables;

import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_FUNCTION_SIMILARITY_FIELD;
import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errFmtApiColumnDef;

import com.datastax.oss.driver.api.core.CqlIdentifier;
Expand All @@ -10,6 +11,7 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.stargate.sgv2.jsonapi.api.model.command.*;
import io.stargate.sgv2.jsonapi.api.model.command.table.definition.ColumnsDescContainer;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.exception.checked.MissingJSONCodecException;
Expand All @@ -18,7 +20,6 @@
import io.stargate.sgv2.jsonapi.service.operation.OperationProjection;
import io.stargate.sgv2.jsonapi.service.operation.filters.table.codecs.*;
import io.stargate.sgv2.jsonapi.service.operation.query.SelectCQLClause;
import io.stargate.sgv2.jsonapi.service.projection.TableProjectionDefinition;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -39,26 +40,28 @@ public class TableProjection implements SelectCQLClause, OperationProjection {
private TableSchemaObject table;
private List<ColumnMetadata> columns;
private ColumnsDescContainer columnsDesc;
private TableSimilarityFunction tableSimilarityFunction;

private TableProjection(
ObjectMapper objectMapper,
TableSchemaObject table,
List<ColumnMetadata> columns,
ColumnsDescContainer columnsDesc) {
ColumnsDescContainer columnsDesc,
TableSimilarityFunction tableSimilarityFunction) {

this.objectMapper = objectMapper;
this.table = table;
this.columns = columns;
this.columnsDesc = columnsDesc;
this.tableSimilarityFunction = tableSimilarityFunction;
}

/**
* Factory method for construction projection instance, given a projection definition and table
* schema.
*/
public static TableProjection fromDefinition(
ObjectMapper objectMapper,
TableProjectionDefinition projectionDefinition,
TableSchemaObject table) {
public static <CmdT extends Projectable> TableProjection fromDefinition(
ObjectMapper objectMapper, CmdT command, TableSchemaObject table) {

Map<String, ColumnMetadata> columnsByName = new HashMap<>();
// TODO: This can also be cached as part of TableSchemaObject than resolving it for every query.
Expand All @@ -67,7 +70,8 @@ public static TableProjection fromDefinition(
.getColumns()
.forEach((id, column) -> columnsByName.put(id.asInternal(), column));

List<ColumnMetadata> columns = projectionDefinition.extractSelectedColumns(columnsByName);
List<ColumnMetadata> columns =
command.tableProjectionDefinition().extractSelectedColumns(columnsByName);

// TODO: A table can't be with empty columns. Think a redundant check.
if (columns.isEmpty()) {
Expand All @@ -88,14 +92,22 @@ public static TableProjection fromDefinition(
.formatted(errFmtApiColumnDef(readApiColumns.filterByUnsupported())));
}

return new TableProjection(objectMapper, table, columns, readApiColumns.toColumnsDesc());
return new TableProjection(
objectMapper,
table,
columns,
readApiColumns.toColumnsDesc(),
TableSimilarityFunction.from(command, table));
}

@Override
public Select apply(OngoingSelection ongoingSelection) {
Set<CqlIdentifier> readColumns = new LinkedHashSet<>();
readColumns.addAll(columns.stream().map(ColumnMetadata::getName).toList());
return ongoingSelection.columnsIds(readColumns);
Select select = ongoingSelection.columnsIds(readColumns);

// may apply similarity score function
return tableSimilarityFunction.apply(select);
}

@Override
Expand Down Expand Up @@ -147,6 +159,17 @@ public JsonNode projectRow(Row row) {
nonNullCount,
skippedNullCount);
}

// If user specify includeSimilarity, but no ANN sort clause, then we won't generate
// similarity_score function in the cql statement
if (tableSimilarityFunction.canProjectSimilarity()) {
try {
final float aFloat = row.getFloat(TableSimilarityFunction.SIMILARITY_SCORE_ALIAS);
result.put(VECTOR_FUNCTION_SIMILARITY_FIELD, aFloat);
// Should not happen, but keep it caught, in case it breaks the query
} catch (IllegalArgumentException ignored) {
}
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package io.stargate.sgv2.jsonapi.service.operation.tables;

import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.api.querybuilder.select.Selector;
import io.stargate.sgv2.jsonapi.api.model.command.Projectable;
import io.stargate.sgv2.jsonapi.api.model.command.VectorSortable;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName;
import io.stargate.sgv2.jsonapi.util.CqlVectorUtil;
import java.util.function.Function;

public interface TableSimilarityFunction extends Function<Select, Select> {

TableSimilarityFunction NO_OP = new TableSimilarityFunctionNoOp();

// Make a unique constant string as similarity score function alias in cql statement
// E.G. SELECT id,similarity_euclidean(vector_type,[0.2, 0.15, 0.3]) AS
// similarityScore1699123456789 from xxx;
String SIMILARITY_SCORE_ALIAS = "similarityScore" + System.currentTimeMillis();

static <CmdT extends Projectable> TableSimilarityFunction from(
CmdT command, TableSchemaObject table) {

if (!(command instanceof VectorSortable)) {
return NO_OP;
}
var vectorSortable = (VectorSortable) command;

var sortExpressionOptional = vectorSortable.vectorSortExpression();
if (sortExpressionOptional.isEmpty()) {
// nothing to sort on, so nothing to return even if they asked for the similarity score
return NO_OP;
}
var sortExpression = sortExpressionOptional.get();

var includeSimilarityScore = vectorSortable.includeSimilarityScore().orElse(false);
if (!includeSimilarityScore) {
// user does not ask for similarityScore
return NO_OP;
}

var requestedVectorColumnPath = sortExpression.pathAsCqlIdentifier();
var apiColumnDef =
table
.apiTableDef()
.allColumns()
.filterBy(ApiTypeName.VECTOR)
.get(requestedVectorColumnPath);
if (apiColumnDef == null) {
// column does not exist or is not a vector, ignore because sort will fail
return NO_OP;
}

var vectorColDefinition = table.vectorConfig().getColumnDefinition(requestedVectorColumnPath);
if (vectorColDefinition.isEmpty()) {
// no requested vector column on the table
return NO_OP;
}

// similarityFunction is from index, default to cosine. In projection,
// we do not care about if the vector column in indexed or not, capture by vector sort.
var similarityFunction = vectorColDefinition.get().similarityFunction().getFunction();

return new TableSimilarityFunctionImpl(
requestedVectorColumnPath,
CqlVectorUtil.floatsToCqlVector(sortExpression.vector()),
similarityFunction);
}

boolean canProjectSimilarity();

class TableSimilarityFunctionImpl implements TableSimilarityFunction {
private final CqlIdentifier requestedVectorColumnPath;
private final CqlVector<Float> vector;
private final String function;

public TableSimilarityFunctionImpl(
CqlIdentifier requestedVectorColumnPath, CqlVector<Float> vector, String function) {
this.requestedVectorColumnPath = requestedVectorColumnPath;
this.vector = vector;
this.function = function;
}

@Override
public Select apply(Select select) {
return select
.function(function, Selector.column(requestedVectorColumnPath), literal(vector))
.as(SIMILARITY_SCORE_ALIAS);
}

@Override
public boolean canProjectSimilarity() {
return true;
}
}

public class TableSimilarityFunctionNoOp implements TableSimilarityFunction {
@Override
public Select apply(Select select) {
return select;
}

@Override
public boolean canProjectSimilarity() {
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public Operation resolveTableCommand(CommandContext<TableSchemaObject> ctx, Find
: CqlPagingState.from(command.options().pageState());

var pageBuilder =
ReadAttemptPage.<TableSchemaObject>builder().singleResponse(false).includeSortVector(false);
ReadAttemptPage.<TableSchemaObject>builder().singleResponse(false).mayReturnVector(command);

return readCommandResolver.buildReadOperation(ctx, command, cqlPageState, pageBuilder);

Expand Down
Loading

0 comments on commit 5df72c3

Please sign in to comment.