Skip to content

Commit

Permalink
Expose source_model parameter for vector-enabled collections (#1606)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hazel-Datastax authored Oct 31, 2024
1 parent 45986a1 commit 69cb961
Show file tree
Hide file tree
Showing 28 changed files with 576 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public boolean handleUnknownProperty(
}
if (typeStr.endsWith("CreateCollectionCommand$Options$VectorSearchConfig")) {
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Unrecognized field \"%s\" for `createCollection.options.vector` (known fields: \"dimension\", \"metric\", \"service\")",
"Unrecognized field \"%s\" for `createCollection.options.vector` (known fields: \"dimension\", \"metric\", \"service\", \"sourceModel\",)",
propertyName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ public record VectorSearchConfig(
@JsonProperty("metric")
@JsonAlias("function") // old name
String metric,
@Nullable
@Pattern(
regexp =
"(ada002|bert|cohere-v3|gecko|nv-qa-4|openai-v3-large|openai-v3-small|other)",
message =
"sourceModel options are 'ada002', 'bert', 'cohere-v3', 'gecko', 'nv-qa-4', 'openai-v3-large', 'openai-v3-small', and 'other'")
@Schema(
description =
"The 'sourceModel' option configures the index with the fastest settings for a given source of embeddings vectors",
defaultValue = "other",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("sourceModel")
String sourceModel,
@Valid
@Nullable
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -105,9 +119,11 @@ public record VectorSearchConfig(
@JsonProperty("service")
VectorizeConfig vectorizeConfig) {

public VectorSearchConfig(Integer dimension, String metric, VectorizeConfig vectorizeConfig) {
public VectorSearchConfig(
Integer dimension, String metric, String sourceModel, VectorizeConfig vectorizeConfig) {
this.dimension = dimension;
this.metric = metric == null ? "cosine" : metric;
this.metric = metric;
this.sourceModel = sourceModel;
this.vectorizeConfig = vectorizeConfig;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ interface Fields {
/** Document field name that will have text value for which vectorize method in called */
String VECTOR_EMBEDDING_TEXT_FIELD = "$vectorize";

/** Key for vector function name definition in cql index. */
String VECTOR_INDEX_FUNCTION_NAME = "similarity_function";

/** Field name used in projection clause to get similarity score in response. */
String VECTOR_FUNCTION_SIMILARITY_FIELD = "$similarity";

Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
package io.stargate.sgv2.jsonapi.config.constants;

import io.smallrye.config.ConfigMapping;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import java.util.Map;

@ConfigMapping(prefix = "stargate.jsonapi.vector")
public interface VectorConstants {
/*
Supported Source Models and suggested function for Vector Index in Cassandra
*/
Map<String, SimilarityFunction> SUPPORTED_SOURCES =
Map.of(
"ada002",
SimilarityFunction.DOT_PRODUCT,
"openai_v3_small",
SimilarityFunction.DOT_PRODUCT,
"openai_v3_large",
SimilarityFunction.DOT_PRODUCT,
"bert",
SimilarityFunction.DOT_PRODUCT,
"gecko",
SimilarityFunction.DOT_PRODUCT,
"nv_qa_4",
SimilarityFunction.DOT_PRODUCT,
"cohere_v3",
SimilarityFunction.DOT_PRODUCT,
"other",
SimilarityFunction.COSINE);

interface VectorColumn {
String DIMENSION = "dimension";
String METRIC = "metric";
String SOURCE_MODEL = "sourceModel";
String SERVICE = "service";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ public enum ErrorCodeV1 {

VECTOR_SEARCH_INVALID_FUNCTION_NAME("Invalid vector search function name"),

VECTOR_SEARCH_UNRECOGNIZED_SOURCE_MODEL_NAME("Unrecognized vector search source model name"),

VECTOR_SEARCH_TOO_BIG_VALUE("Vector embedding property '$vector' length too big"),
VECTOR_SIZE_MISMATCH("Length of vector parameter different from declared '$vector' dimension"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;

/**
* Configuration vector column with the extra info we need for vectors
*
* @param fieldName Is still a string because this is also used by collections
* @param vectorSize
* @param similarityFunction
* @param sourceModel
* @param vectorizeDefinition
*/
public record VectorColumnDefinition(
String fieldName,
int vectorSize,
SimilarityFunction similarityFunction,
SourceModel sourceModel,
VectorizeDefinition vectorizeDefinition) {

/**
Expand All @@ -34,11 +37,16 @@ public static VectorColumnDefinition fromJson(JsonNode jsonNode, ObjectMapper ob
int dimension = jsonNode.get(VectorConstants.VectorColumn.DIMENSION).asInt();
var similarityFunction =
SimilarityFunction.fromString(jsonNode.get(VectorConstants.VectorColumn.METRIC).asText());
// sourceModel doesn't exist if the collection was created before supporting sourceModel; if
// missing, it will be an empty string and sourceModel becomes OTHER.
var sourceModel =
SourceModel.fromString(jsonNode.path(VectorConstants.VectorColumn.SOURCE_MODEL).asText());

return fromJson(
DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD,
dimension,
similarityFunction,
sourceModel,
jsonNode,
objectMapper);
}
Expand All @@ -48,6 +56,7 @@ public static VectorColumnDefinition fromJson(
String fieldName,
int dimension,
SimilarityFunction similarityFunction,
SourceModel sourceModel,
JsonNode jsonNode,
ObjectMapper objectMapper) {

Expand All @@ -58,6 +67,6 @@ public static VectorColumnDefinition fromJson(
: VectorizeDefinition.fromJson(vectorizeServiceNode, objectMapper);

return new VectorColumnDefinition(
fieldName, dimension, similarityFunction, vectorizeDefinition);
fieldName, dimension, similarityFunction, sourceModel, vectorizeDefinition);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -43,6 +44,7 @@ public static VectorConfig fromColumnDefinitions(
return new VectorConfig(vectorColumnDefinitions);
}

/** Get table schema object from table metadata */
public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper objectMapper) {

Map<String, VectorizeDefinition> vectorizeDefs =
Expand All @@ -67,22 +69,32 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object
var indexFunction =
columnIndex.map(
index -> {
String similarityFunction =
String similarityFunctionStr =
index.getOptions().get(VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION);
if (similarityFunction != null) {
return SimilarityFunction.fromString(similarityFunction);
}
// if similarity function is set, use it
SimilarityFunction similarityFunction =
similarityFunctionStr != null
? SimilarityFunction.fromString(similarityFunctionStr)
: null;

String sourceModel =
String sourceModelStr =
index.getOptions().get(VectorConstants.CQLAnnIndex.SOURCE_MODEL);
if (sourceModel != null) {
return VectorConstants.SUPPORTED_SOURCES.get(sourceModel);
SourceModel sourceModel = null;
if (sourceModelStr != null) {
// if similarity function is not set, use the source model to determine it
similarityFunction =
similarityFunction == null
? SourceModel.getSimilarityFunction(sourceModelStr)
: similarityFunction;
sourceModel = SourceModel.fromString(sourceModelStr);
}
return null;
return new AbstractMap.SimpleEntry<>(similarityFunction, sourceModel);
});

// if now index, or we could not work out the function, default
var similarityFunction = indexFunction.orElse(SimilarityFunction.COSINE);
var similarityFunction =
indexFunction.map(Map.Entry::getKey).orElse(SimilarityFunction.COSINE);
var sourceModel = indexFunction.map(Map.Entry::getValue).orElse(SourceModel.OTHER);
int dimensions = ((VectorType) column.getType()).getDimensions();

// NOTE: need to keep the column name as a string in the VectorColumnDefinition
Expand All @@ -93,6 +105,7 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object
cqlIdentifierToJsonKey(column.getName()),
dimensions,
similarityFunction,
sourceModel,
vectorizeDefs.get(column.getName().asInternal())));
}
return VectorConfig.fromColumnDefinitions(columnDefs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor;
import io.stargate.sgv2.jsonapi.service.operation.Operation;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionTableMatcher;
import java.time.Duration;
Expand All @@ -41,6 +42,7 @@ public record CreateCollectionOperation(
boolean vectorSearch,
int vectorSize,
String vectorFunction,
String sourceModel,
String comment,
int ddlDelayMillis,
boolean tooManyIndexesRollbackEnabled,
Expand All @@ -60,6 +62,7 @@ public static CreateCollectionOperation withVectorSearch(
String name,
int vectorSize,
String vectorFunction,
String sourceModel,
String comment,
int ddlDelayMillis,
boolean tooManyIndexesRollbackEnabled,
Expand All @@ -73,6 +76,7 @@ public static CreateCollectionOperation withVectorSearch(
true,
vectorSize,
vectorFunction,
sourceModel,
comment,
ddlDelayMillis,
tooManyIndexesRollbackEnabled,
Expand All @@ -98,6 +102,7 @@ public static CreateCollectionOperation withoutVectorSearch(
false,
0,
null,
null,
comment,
ddlDelayMillis,
tooManyIndexesRollbackEnabled,
Expand Down Expand Up @@ -142,6 +147,7 @@ public Uni<Supplier<CommandResult>> execute(
vectorSearch,
vectorSize,
SimilarityFunction.fromString(vectorFunction),
SourceModel.fromString(sourceModel),
comment,
objectMapper);
// if table exists we have to choices:
Expand Down Expand Up @@ -492,6 +498,8 @@ public List<SimpleStatement> getIndexStatements(
appender
+ " \"%s_query_vector_value\" ON \"%s\".\"%s\" (query_vector_value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': '"
+ vectorFunction()
+ "', 'source_model': '"
+ sourceModel()
+ "'}";
statements.add(
SimpleStatement.newInstance(String.format(vectorSearch, table, keyspace, table)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexOnTable;
import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexStart;
import com.datastax.oss.driver.internal.querybuilder.schema.DefaultCreateIndex;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.override.ExtendedCreateIndex;
import io.stargate.sgv2.jsonapi.service.operation.SchemaAttempt;
Expand Down Expand Up @@ -90,10 +91,11 @@ public record VectorIndexOptions(SimilarityFunction similarityFunction, String s
public Map<String, Object> getOptions() {
Map<String, Object> options = new HashMap<>();
if (similarityFunction != null) {
options.put("similarity_function", similarityFunction.getMetric());
options.put(
VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION, similarityFunction.getMetric());
}
if (sourceModel != null) {
options.put("source_model", sourceModel);
options.put(VectorConstants.CQLAnnIndex.SOURCE_MODEL, sourceModel);
}
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject;
import io.stargate.sgv2.jsonapi.service.operation.Operation;
import io.stargate.sgv2.jsonapi.service.operation.collections.CreateCollectionOperation;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

Expand Down Expand Up @@ -106,6 +108,7 @@ public Operation resolveKeyspaceCommand(
command.name(),
vector.dimension(),
vector.metric(),
vector.sourceModel(),
comment,
operationsConfig.databaseConfig().ddlDelayMillis(),
operationsConfig.tooManyIndexesRollbackEnabled(),
Expand Down Expand Up @@ -192,13 +195,35 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions

Integer vectorDimension = vector.dimension();
VectorizeConfig service = vector.vectorizeConfig();
String sourceModel = vector.sourceModel();
String metric = vector.metric();

// decide sourceModel and metric value
if (sourceModel != null) {
if (metric == null) {
// (1) sourceModel is provided but metric is not - set metric to cosine or dot_product based
// on the map
metric = SourceModel.getSimilarityFunction(sourceModel).getMetric();
}
// (2) both sourceModel and metric are provided - do nothing
} else {
if (metric != null) {
// (3) sourceModel is not provided but metric is - set sourceModel to 'other'
sourceModel = SourceModel.OTHER.getName();
} else {
// (4) both sourceModel and metric are not provided - set sourceModel to 'other' and metric
// to 'cosine'
sourceModel = SourceModel.DEFAULT_SOURCE_MODEL.getName();
metric = SimilarityFunction.DEFAULT_SIMILARITY_FUNCTION.getMetric();
}
}

if (service != null) {
// Validate service configuration and auto populate vector dimension.
vectorDimension = validateVectorize.validateService(service, vectorDimension);
vector =
new CreateCollectionCommand.Options.VectorSearchConfig(
vectorDimension, vector.metric(), vector.vectorizeConfig());
vectorDimension, metric, sourceModel, vector.vectorizeConfig());
} else {
// Ensure vector dimension is provided when service configuration is absent.
if (vectorDimension == null) {
Expand All @@ -209,6 +234,9 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions
throw ErrorCodeV1.VECTOR_SEARCH_TOO_BIG_VALUE.toApiException(
"%d (max %d)", vectorDimension, documentLimitsConfig.maxVectorEmbeddingLength());
}
vector =
new CreateCollectionCommand.Options.VectorSearchConfig(
vectorDimension, metric, sourceModel, null);
}
return vector;
}
Expand Down
Loading

0 comments on commit 69cb961

Please sign in to comment.