Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose source_model parameter for vector-enabled collections #1606

Merged
merged 24 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bc05add
remove hardcoded value
Hazel-Datastax Oct 24, 2024
cfefcf3
Add source model support in collection
Hazel-Datastax Oct 24, 2024
dea9077
Merge branch 'refs/heads/main' into hazel/collection_source_model
Hazel-Datastax Oct 24, 2024
489606d
Remove unused method (was moved to VectorConfig)
Hazel-Datastax Oct 25, 2024
9f612f6
Add support for older collections
Hazel-Datastax Oct 25, 2024
fc2342a
Fix typo
Hazel-Datastax Oct 25, 2024
931dd7a
rename
Hazel-Datastax Oct 25, 2024
da96598
Add validation
Hazel-Datastax Oct 28, 2024
af1d2e6
Add ITs
Hazel-Datastax Oct 28, 2024
538bfcc
Merge branch 'main' into hazel/collection_source_model
Hazel-Datastax Oct 28, 2024
c19a8cc
Fix unit tests
Hazel-Datastax Oct 28, 2024
6d1d5a0
Fix ITs
Hazel-Datastax Oct 28, 2024
2311ee6
Merge remote-tracking branch 'refs/remotes/origin/main' into hazel/co…
Hazel-Datastax Oct 30, 2024
11d068b
resolve conflicts
Hazel-Datastax Oct 30, 2024
59367d1
Merge remote-tracking branch 'refs/remotes/origin/main' into hazel/co…
Hazel-Datastax Oct 31, 2024
983a754
update the naming
Hazel-Datastax Oct 31, 2024
5f36442
Merge branch 'main' into hazel/collection_source_model
Hazel-Datastax Oct 31, 2024
04b2327
Add IT
Hazel-Datastax Oct 31, 2024
51e7fc4
Add default
Hazel-Datastax Oct 31, 2024
402a047
move `SUPPORTED_SOURCE_MODELS` map from `VectorConstants` to `SourceM…
Hazel-Datastax Oct 31, 2024
ee7ce7b
Merge branch 'main' into hazel/collection_source_model
Hazel-Datastax Oct 31, 2024
ec06bd7
create a method to get all the source model names
Hazel-Datastax Oct 31, 2024
5f6f21c
fix
Hazel-Datastax Oct 31, 2024
834ec17
Merge remote-tracking branch 'origin/hazel/collection_source_model' i…
Hazel-Datastax Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public boolean handleUnknownProperty(
}
if (typeStr.endsWith("CreateCollectionCommand$Options$VectorSearchConfig")) {
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Unrecognized field \"%s\" for `createCollection.options.vector` (known fields: \"dimension\", \"metric\", \"service\")",
"Unrecognized field \"%s\" for `createCollection.options.vector` (known fields: \"dimension\", \"metric\", \"service\", \"sourceModel\",)",
propertyName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ public record VectorSearchConfig(
@JsonProperty("metric")
@JsonAlias("function") // old name
String metric,
@Nullable
@Pattern(
regexp =
"(ada002|bert|cohere-v3|gecko|nv-qa-4|openai-v3-large|openai-v3-small|other)",
message =
"sourceModel options are 'ada002', 'bert', 'cohere-v3', 'gecko', 'nv-qa-4', 'openai-v3-large', 'openai-v3-small', and 'other'")
@Schema(
description =
"The 'sourceModel' option configures the index with the fastest settings for a given source of embeddings vectors",
defaultValue = "other",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("sourceModel")
String sourceModel,
@Valid
@Nullable
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -105,9 +119,11 @@ public record VectorSearchConfig(
@JsonProperty("service")
VectorizeConfig vectorizeConfig) {

public VectorSearchConfig(Integer dimension, String metric, VectorizeConfig vectorizeConfig) {
public VectorSearchConfig(
Integer dimension, String metric, String sourceModel, VectorizeConfig vectorizeConfig) {
this.dimension = dimension;
this.metric = metric == null ? "cosine" : metric;
tatu-at-datastax marked this conversation as resolved.
Show resolved Hide resolved
this.metric = metric;
this.sourceModel = sourceModel;
this.vectorizeConfig = vectorizeConfig;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ interface Fields {
/** Document field name that will have text value for which vectorize method in called */
String VECTOR_EMBEDDING_TEXT_FIELD = "$vectorize";

/** Key for vector function name definition in cql index. */
String VECTOR_INDEX_FUNCTION_NAME = "similarity_function";

Comment on lines -30 to -32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated with VectorConstants. I think it's better in there, so I remove the one in here

/** Field name used in projection clause to get similarity score in response. */
String VECTOR_FUNCTION_SIMILARITY_FIELD = "$similarity";

Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
package io.stargate.sgv2.jsonapi.config.constants;

import io.smallrye.config.ConfigMapping;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import java.util.Map;

@ConfigMapping(prefix = "stargate.jsonapi.vector")
public interface VectorConstants {
/*
Supported Source Models and suggested function for Vector Index in Cassandra
*/
Map<String, SimilarityFunction> SUPPORTED_SOURCES =
Map.of(
"ada002",
SimilarityFunction.DOT_PRODUCT,
"openai_v3_small",
SimilarityFunction.DOT_PRODUCT,
"openai_v3_large",
SimilarityFunction.DOT_PRODUCT,
"bert",
SimilarityFunction.DOT_PRODUCT,
"gecko",
SimilarityFunction.DOT_PRODUCT,
"nv_qa_4",
SimilarityFunction.DOT_PRODUCT,
"cohere_v3",
SimilarityFunction.DOT_PRODUCT,
"other",
SimilarityFunction.COSINE);

interface VectorColumn {
String DIMENSION = "dimension";
String METRIC = "metric";
String SOURCE_MODEL = "sourceModel";
String SERVICE = "service";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ public enum ErrorCodeV1 {

VECTOR_SEARCH_INVALID_FUNCTION_NAME("Invalid vector search function name"),

VECTOR_SEARCH_UNRECOGNIZED_SOURCE_MODEL_NAME("Unrecognized vector search source model name"),

VECTOR_SEARCH_TOO_BIG_VALUE("Vector embedding property '$vector' length too big"),
VECTOR_SIZE_MISMATCH("Length of vector parameter different from declared '$vector' dimension"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;

/**
* Configuration vector column with the extra info we need for vectors
*
* @param fieldName Is still a string because this is also used by collections
* @param vectorSize
* @param similarityFunction
* @param sourceModel
* @param vectorizeDefinition
*/
public record VectorColumnDefinition(
String fieldName,
int vectorSize,
SimilarityFunction similarityFunction,
SourceModel sourceModel,
VectorizeDefinition vectorizeDefinition) {

/**
Expand All @@ -34,11 +37,16 @@ public static VectorColumnDefinition fromJson(JsonNode jsonNode, ObjectMapper ob
int dimension = jsonNode.get(VectorConstants.VectorColumn.DIMENSION).asInt();
var similarityFunction =
SimilarityFunction.fromString(jsonNode.get(VectorConstants.VectorColumn.METRIC).asText());
// sourceModel doesn't exist if the collection was created before supporting sourceModel; if
// missing, it will be an empty string and sourceModel becomes OTHER.
var sourceModel =
SourceModel.fromString(jsonNode.path(VectorConstants.VectorColumn.SOURCE_MODEL).asText());

return fromJson(
DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD,
dimension,
similarityFunction,
sourceModel,
jsonNode,
objectMapper);
}
Expand All @@ -48,6 +56,7 @@ public static VectorColumnDefinition fromJson(
String fieldName,
int dimension,
SimilarityFunction similarityFunction,
SourceModel sourceModel,
JsonNode jsonNode,
ObjectMapper objectMapper) {

Expand All @@ -58,6 +67,6 @@ public static VectorColumnDefinition fromJson(
: VectorizeDefinition.fromJson(vectorizeServiceNode, objectMapper);

return new VectorColumnDefinition(
fieldName, dimension, similarityFunction, vectorizeDefinition);
fieldName, dimension, similarityFunction, sourceModel, vectorizeDefinition);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -43,6 +44,7 @@ public static VectorConfig fromColumnDefinitions(
return new VectorConfig(vectorColumnDefinitions);
}

/** Get table schema object from table metadata */
public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper objectMapper) {

Map<String, VectorizeDefinition> vectorizeDefs =
Expand All @@ -67,22 +69,32 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object
var indexFunction =
columnIndex.map(
index -> {
String similarityFunction =
String similarityFunctionStr =
index.getOptions().get(VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION);
if (similarityFunction != null) {
return SimilarityFunction.fromString(similarityFunction);
}
// if similarity function is set, use it
SimilarityFunction similarityFunction =
similarityFunctionStr != null
? SimilarityFunction.fromString(similarityFunctionStr)
: null;

String sourceModel =
String sourceModelStr =
index.getOptions().get(VectorConstants.CQLAnnIndex.SOURCE_MODEL);
if (sourceModel != null) {
return VectorConstants.SUPPORTED_SOURCES.get(sourceModel);
SourceModel sourceModel = null;
if (sourceModelStr != null) {
// if similarity function is not set, use the source model to determine it
similarityFunction =
similarityFunction == null
? SourceModel.getSimilarityFunction(sourceModelStr)
: similarityFunction;
sourceModel = SourceModel.fromString(sourceModelStr);
}
return null;
return new AbstractMap.SimpleEntry<>(similarityFunction, sourceModel);
});

// if now index, or we could not work out the function, default
var similarityFunction = indexFunction.orElse(SimilarityFunction.COSINE);
var similarityFunction =
indexFunction.map(Map.Entry::getKey).orElse(SimilarityFunction.COSINE);
var sourModel = indexFunction.map(Map.Entry::getValue).orElse(SourceModel.OTHER);
Hazel-Datastax marked this conversation as resolved.
Show resolved Hide resolved
int dimensions = ((VectorType) column.getType()).getDimensions();

// NOTE: need to keep the column name as a string in the VectorColumnDefinition
Expand All @@ -93,6 +105,7 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object
cqlIdentifierToJsonKey(column.getName()),
dimensions,
similarityFunction,
sourModel,
vectorizeDefs.get(column.getName().asInternal())));
}
return VectorConfig.fromColumnDefinitions(columnDefs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor;
import io.stargate.sgv2.jsonapi.service.operation.Operation;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionTableMatcher;
import java.time.Duration;
Expand All @@ -41,6 +42,7 @@ public record CreateCollectionOperation(
boolean vectorSearch,
int vectorSize,
String vectorFunction,
String sourceModel,
String comment,
int ddlDelayMillis,
boolean tooManyIndexesRollbackEnabled,
Expand All @@ -60,6 +62,7 @@ public static CreateCollectionOperation withVectorSearch(
String name,
int vectorSize,
String vectorFunction,
String sourceModel,
String comment,
int ddlDelayMillis,
boolean tooManyIndexesRollbackEnabled,
Expand All @@ -73,6 +76,7 @@ public static CreateCollectionOperation withVectorSearch(
true,
vectorSize,
vectorFunction,
sourceModel,
comment,
ddlDelayMillis,
tooManyIndexesRollbackEnabled,
Expand All @@ -98,6 +102,7 @@ public static CreateCollectionOperation withoutVectorSearch(
false,
0,
null,
null,
comment,
ddlDelayMillis,
tooManyIndexesRollbackEnabled,
Expand Down Expand Up @@ -142,6 +147,7 @@ public Uni<Supplier<CommandResult>> execute(
vectorSearch,
vectorSize,
SimilarityFunction.fromString(vectorFunction),
SourceModel.fromString(sourceModel),
comment,
objectMapper);
// if table exists we have to choices:
Expand Down Expand Up @@ -492,6 +498,8 @@ public List<SimpleStatement> getIndexStatements(
appender
+ " \"%s_query_vector_value\" ON \"%s\".\"%s\" (query_vector_value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': '"
+ vectorFunction()
+ "', 'source_model': '"
+ sourceModel()
+ "'}";
statements.add(
SimpleStatement.newInstance(String.format(vectorSearch, table, keyspace, table)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexOnTable;
import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexStart;
import com.datastax.oss.driver.internal.querybuilder.schema.DefaultCreateIndex;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.override.ExtendedCreateIndex;
import io.stargate.sgv2.jsonapi.service.operation.SchemaAttempt;
Expand Down Expand Up @@ -90,10 +91,11 @@ public record VectorIndexOptions(SimilarityFunction similarityFunction, String s
public Map<String, Object> getOptions() {
Map<String, Object> options = new HashMap<>();
if (similarityFunction != null) {
options.put("similarity_function", similarityFunction.getMetric());
options.put(
VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION, similarityFunction.getMetric());
}
if (sourceModel != null) {
options.put("source_model", sourceModel);
options.put(VectorConstants.CQLAnnIndex.SOURCE_MODEL, sourceModel);
}
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject;
import io.stargate.sgv2.jsonapi.service.operation.Operation;
import io.stargate.sgv2.jsonapi.service.operation.collections.CreateCollectionOperation;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

Expand Down Expand Up @@ -106,6 +108,7 @@ public Operation resolveKeyspaceCommand(
command.name(),
vector.dimension(),
vector.metric(),
vector.sourceModel(),
comment,
operationsConfig.databaseConfig().ddlDelayMillis(),
operationsConfig.tooManyIndexesRollbackEnabled(),
Expand Down Expand Up @@ -192,13 +195,35 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions

Integer vectorDimension = vector.dimension();
VectorizeConfig service = vector.vectorizeConfig();
String sourceModel = vector.sourceModel();
String metric = vector.metric();

// decide sourceModel and metric value
if (sourceModel != null) {
if (metric == null) {
// (1) sourceModel is provided but metric is not - set metric to cosine or dot_product based
// on the map
metric = SourceModel.getSimilarityFunction(sourceModel).getMetric();
}
// (2) both sourceModel and metric are provided - do nothing
} else {
if (metric != null) {
// (3) sourceModel is not provided but metric is - set sourceModel to 'other'
sourceModel = SourceModel.OTHER.getName();
} else {
// (4) both sourceModel and metric are not provided - set sourceModel to 'other' and metric
// to 'cosine'
sourceModel = SourceModel.DEFAULT_SOURCE_MODEL.getName();
metric = SimilarityFunction.DEFAULT_SIMILARITY_FUNCTION.getMetric();
}
}

if (service != null) {
// Validate service configuration and auto populate vector dimension.
vectorDimension = validateVectorize.validateService(service, vectorDimension);
vector =
new CreateCollectionCommand.Options.VectorSearchConfig(
vectorDimension, vector.metric(), vector.vectorizeConfig());
vectorDimension, metric, sourceModel, vector.vectorizeConfig());
} else {
// Ensure vector dimension is provided when service configuration is absent.
if (vectorDimension == null) {
Expand All @@ -209,6 +234,9 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions
throw ErrorCodeV1.VECTOR_SEARCH_TOO_BIG_VALUE.toApiException(
"%d (max %d)", vectorDimension, documentLimitsConfig.maxVectorEmbeddingLength());
}
vector =
new CreateCollectionCommand.Options.VectorSearchConfig(
vectorDimension, metric, sourceModel, null);
}
return vector;
}
Expand Down
Loading
Loading