-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Table, includeSimilarityScore, includeSortVector (#1656)
Co-authored-by: Aaron Morton <aaron.morton@datastax.com> Co-authored-by: maheshrajamani <99678631+maheshrajamani@users.noreply.github.com>
- Loading branch information
1 parent
39e9d50
commit 5df72c3
Showing
16 changed files
with
426 additions
and
42 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorSortable.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package io.stargate.sgv2.jsonapi.api.model.command; | ||
|
||
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; | ||
import java.util.Optional; | ||
|
||
/** Interface for commands that can have vector sort clause */ | ||
public interface VectorSortable extends Sortable { | ||
|
||
default Optional<Boolean> includeSimilarityScore() { | ||
return Optional.empty(); | ||
} | ||
|
||
default Optional<Boolean> includeSortVector() { | ||
return Optional.empty(); | ||
} | ||
|
||
/** | ||
* Returns the first SortExpression that has {@link SortExpression#vector()} not null, if there is | ||
* more than one raises {@link IllegalStateException}. | ||
* | ||
* @return the vector sort expression if it exists. | ||
*/ | ||
default Optional<SortExpression> vectorSortExpression() { | ||
if (sortClause() != null && sortClause().sortExpressions() != null) { | ||
var vectorSorts = | ||
sortClause().sortExpressions().stream() | ||
.filter(expression -> expression.vector() != null) | ||
.toList(); | ||
if (vectorSorts.size() > 1) { | ||
throw new IllegalStateException("Only one vector sort expression is allowed"); | ||
} | ||
return vectorSorts.isEmpty() ? Optional.empty() : Optional.of(vectorSorts.getFirst()); | ||
} | ||
return Optional.empty(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 112 additions & 0 deletions
112
src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableSimilarityFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
package io.stargate.sgv2.jsonapi.service.operation.tables; | ||
|
||
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal; | ||
|
||
import com.datastax.oss.driver.api.core.CqlIdentifier; | ||
import com.datastax.oss.driver.api.core.data.CqlVector; | ||
import com.datastax.oss.driver.api.querybuilder.select.Select; | ||
import com.datastax.oss.driver.api.querybuilder.select.Selector; | ||
import io.stargate.sgv2.jsonapi.api.model.command.Projectable; | ||
import io.stargate.sgv2.jsonapi.api.model.command.VectorSortable; | ||
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; | ||
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName; | ||
import io.stargate.sgv2.jsonapi.util.CqlVectorUtil; | ||
import java.util.function.Function; | ||
|
||
public interface TableSimilarityFunction extends Function<Select, Select> { | ||
|
||
TableSimilarityFunction NO_OP = new TableSimilarityFunctionNoOp(); | ||
|
||
// Make a unique constant string as similarity score function alias in cql statement | ||
// E.G. SELECT id,similarity_euclidean(vector_type,[0.2, 0.15, 0.3]) AS | ||
// similarityScore1699123456789 from xxx; | ||
String SIMILARITY_SCORE_ALIAS = "similarityScore" + System.currentTimeMillis(); | ||
|
||
static <CmdT extends Projectable> TableSimilarityFunction from( | ||
CmdT command, TableSchemaObject table) { | ||
|
||
if (!(command instanceof VectorSortable)) { | ||
return NO_OP; | ||
} | ||
var vectorSortable = (VectorSortable) command; | ||
|
||
var sortExpressionOptional = vectorSortable.vectorSortExpression(); | ||
if (sortExpressionOptional.isEmpty()) { | ||
// nothing to sort on, so nothing to return even if they asked for the similarity score | ||
return NO_OP; | ||
} | ||
var sortExpression = sortExpressionOptional.get(); | ||
|
||
var includeSimilarityScore = vectorSortable.includeSimilarityScore().orElse(false); | ||
if (!includeSimilarityScore) { | ||
// user does not ask for similarityScore | ||
return NO_OP; | ||
} | ||
|
||
var requestedVectorColumnPath = sortExpression.pathAsCqlIdentifier(); | ||
var apiColumnDef = | ||
table | ||
.apiTableDef() | ||
.allColumns() | ||
.filterBy(ApiTypeName.VECTOR) | ||
.get(requestedVectorColumnPath); | ||
if (apiColumnDef == null) { | ||
// column does not exist or is not a vector, ignore because sort will fail | ||
return NO_OP; | ||
} | ||
|
||
var vectorColDefinition = table.vectorConfig().getColumnDefinition(requestedVectorColumnPath); | ||
if (vectorColDefinition.isEmpty()) { | ||
// no requested vector column on the table | ||
return NO_OP; | ||
} | ||
|
||
// similarityFunction is from index, default to cosine. In projection, | ||
// we do not care about if the vector column in indexed or not, capture by vector sort. | ||
var similarityFunction = vectorColDefinition.get().similarityFunction().getFunction(); | ||
|
||
return new TableSimilarityFunctionImpl( | ||
requestedVectorColumnPath, | ||
CqlVectorUtil.floatsToCqlVector(sortExpression.vector()), | ||
similarityFunction); | ||
} | ||
|
||
boolean canProjectSimilarity(); | ||
|
||
class TableSimilarityFunctionImpl implements TableSimilarityFunction { | ||
private final CqlIdentifier requestedVectorColumnPath; | ||
private final CqlVector<Float> vector; | ||
private final String function; | ||
|
||
public TableSimilarityFunctionImpl( | ||
CqlIdentifier requestedVectorColumnPath, CqlVector<Float> vector, String function) { | ||
this.requestedVectorColumnPath = requestedVectorColumnPath; | ||
this.vector = vector; | ||
this.function = function; | ||
} | ||
|
||
@Override | ||
public Select apply(Select select) { | ||
return select | ||
.function(function, Selector.column(requestedVectorColumnPath), literal(vector)) | ||
.as(SIMILARITY_SCORE_ALIAS); | ||
} | ||
|
||
@Override | ||
public boolean canProjectSimilarity() { | ||
return true; | ||
} | ||
} | ||
|
||
public class TableSimilarityFunctionNoOp implements TableSimilarityFunction { | ||
@Override | ||
public Select apply(Select select) { | ||
return select; | ||
} | ||
|
||
@Override | ||
public boolean canProjectSimilarity() { | ||
return false; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.