diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/configuration/CommandObjectMapperHandler.java b/src/main/java/io/stargate/sgv2/jsonapi/api/configuration/CommandObjectMapperHandler.java index 8bce4ff667..551e121d2f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/configuration/CommandObjectMapperHandler.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/configuration/CommandObjectMapperHandler.java @@ -38,7 +38,7 @@ public boolean handleUnknownProperty( } if (typeStr.endsWith("CreateCollectionCommand$Options$VectorSearchConfig")) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Unrecognized field \"%s\" for `createCollection.options.vector` (known fields: \"dimension\", \"metric\", \"service\")", + "Unrecognized field \"%s\" for `createCollection.options.vector` (known fields: \"dimension\", \"metric\", \"service\", \"sourceModel\",)", propertyName); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java index 64ab3f999a..3cf75593bb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java @@ -95,6 +95,20 @@ public record VectorSearchConfig( @JsonProperty("metric") @JsonAlias("function") // old name String metric, + @Nullable + @Pattern( + regexp = + "(ada002|bert|cohere-v3|gecko|nv-qa-4|openai-v3-large|openai-v3-small|other)", + message = + "sourceModel options are 'ada002', 'bert', 'cohere-v3', 'gecko', 'nv-qa-4', 'openai-v3-large', 'openai-v3-small', and 'other'") + @Schema( + description = + "The 'sourceModel' option configures the index with the fastest settings for a given source of embeddings vectors", + defaultValue = "other", + type = SchemaType.STRING, + implementation = String.class) + @JsonProperty("sourceModel") + String sourceModel, @Valid @Nullable @JsonInclude(JsonInclude.Include.NON_NULL) @@ -105,9 +119,11 @@ public record VectorSearchConfig( @JsonProperty("service") VectorizeConfig vectorizeConfig) { - public VectorSearchConfig(Integer dimension, String metric, VectorizeConfig vectorizeConfig) { + public VectorSearchConfig( + Integer dimension, String metric, String sourceModel, VectorizeConfig vectorizeConfig) { this.dimension = dimension; - this.metric = metric == null ? "cosine" : metric; + this.metric = metric; + this.sourceModel = sourceModel; this.vectorizeConfig = vectorizeConfig; } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java index 037fbfbe05..374a68574e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java @@ -27,9 +27,6 @@ interface Fields { /** Document field name that will have text value for which vectorize method in called */ String VECTOR_EMBEDDING_TEXT_FIELD = "$vectorize"; - /** Key for vector function name definition in cql index. */ - String VECTOR_INDEX_FUNCTION_NAME = "similarity_function"; - /** Field name used in projection clause to get similarity score in response. */ String VECTOR_FUNCTION_SIMILARITY_FIELD = "$similarity"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/VectorConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/VectorConstants.java index cddabe50ba..0f5aa86b13 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/VectorConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/VectorConstants.java @@ -1,36 +1,13 @@ package io.stargate.sgv2.jsonapi.config.constants; import io.smallrye.config.ConfigMapping; -import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; -import java.util.Map; @ConfigMapping(prefix = "stargate.jsonapi.vector") public interface VectorConstants { - /* - Supported Source Models and suggested function for Vector Index in Cassandra - */ - Map SUPPORTED_SOURCES = - Map.of( - "ada002", - SimilarityFunction.DOT_PRODUCT, - "openai_v3_small", - SimilarityFunction.DOT_PRODUCT, - "openai_v3_large", - SimilarityFunction.DOT_PRODUCT, - "bert", - SimilarityFunction.DOT_PRODUCT, - "gecko", - SimilarityFunction.DOT_PRODUCT, - "nv_qa_4", - SimilarityFunction.DOT_PRODUCT, - "cohere_v3", - SimilarityFunction.DOT_PRODUCT, - "other", - SimilarityFunction.COSINE); - interface VectorColumn { String DIMENSION = "dimension"; String METRIC = "metric"; + String SOURCE_MODEL = "sourceModel"; String SERVICE = "service"; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java index 91dc3427ed..954f6d1eea 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java @@ -141,6 +141,8 @@ public enum ErrorCodeV1 { VECTOR_SEARCH_INVALID_FUNCTION_NAME("Invalid vector search function name"), + VECTOR_SEARCH_UNRECOGNIZED_SOURCE_MODEL_NAME("Unrecognized vector search source model name"), + VECTOR_SEARCH_TOO_BIG_VALUE("Vector embedding property '$vector' length too big"), VECTOR_SIZE_MISMATCH("Length of vector parameter different from declared '$vector' dimension"), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorColumnDefinition.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorColumnDefinition.java index 2f13c11dd5..d6a8ee13eb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorColumnDefinition.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorColumnDefinition.java @@ -5,6 +5,7 @@ import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; /** * Configuration vector column with the extra info we need for vectors @@ -12,12 +13,14 @@ * @param fieldName Is still a string because this is also used by collections * @param vectorSize * @param similarityFunction + * @param sourceModel * @param vectorizeDefinition */ public record VectorColumnDefinition( String fieldName, int vectorSize, SimilarityFunction similarityFunction, + SourceModel sourceModel, VectorizeDefinition vectorizeDefinition) { /** @@ -34,11 +37,16 @@ public static VectorColumnDefinition fromJson(JsonNode jsonNode, ObjectMapper ob int dimension = jsonNode.get(VectorConstants.VectorColumn.DIMENSION).asInt(); var similarityFunction = SimilarityFunction.fromString(jsonNode.get(VectorConstants.VectorColumn.METRIC).asText()); + // sourceModel doesn't exist if the collection was created before supporting sourceModel; if + // missing, it will be an empty string and sourceModel becomes OTHER. + var sourceModel = + SourceModel.fromString(jsonNode.path(VectorConstants.VectorColumn.SOURCE_MODEL).asText()); return fromJson( DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, dimension, similarityFunction, + sourceModel, jsonNode, objectMapper); } @@ -48,6 +56,7 @@ public static VectorColumnDefinition fromJson( String fieldName, int dimension, SimilarityFunction similarityFunction, + SourceModel sourceModel, JsonNode jsonNode, ObjectMapper objectMapper) { @@ -58,6 +67,6 @@ public static VectorColumnDefinition fromJson( : VectorizeDefinition.fromJson(vectorizeServiceNode, objectMapper); return new VectorColumnDefinition( - fieldName, dimension, similarityFunction, vectorizeDefinition); + fieldName, dimension, similarityFunction, sourceModel, vectorizeDefinition); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java index 1e00e9a027..f247b81eda 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,6 +44,7 @@ public static VectorConfig fromColumnDefinitions( return new VectorConfig(vectorColumnDefinitions); } + /** Get table schema object from table metadata */ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper objectMapper) { Map vectorizeDefs = @@ -67,22 +69,32 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object var indexFunction = columnIndex.map( index -> { - String similarityFunction = + String similarityFunctionStr = index.getOptions().get(VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION); - if (similarityFunction != null) { - return SimilarityFunction.fromString(similarityFunction); - } + // if similarity function is set, use it + SimilarityFunction similarityFunction = + similarityFunctionStr != null + ? SimilarityFunction.fromString(similarityFunctionStr) + : null; - String sourceModel = + String sourceModelStr = index.getOptions().get(VectorConstants.CQLAnnIndex.SOURCE_MODEL); - if (sourceModel != null) { - return VectorConstants.SUPPORTED_SOURCES.get(sourceModel); + SourceModel sourceModel = null; + if (sourceModelStr != null) { + // if similarity function is not set, use the source model to determine it + similarityFunction = + similarityFunction == null + ? SourceModel.getSimilarityFunction(sourceModelStr) + : similarityFunction; + sourceModel = SourceModel.fromString(sourceModelStr); } - return null; + return new AbstractMap.SimpleEntry<>(similarityFunction, sourceModel); }); // if now index, or we could not work out the function, default - var similarityFunction = indexFunction.orElse(SimilarityFunction.COSINE); + var similarityFunction = + indexFunction.map(Map.Entry::getKey).orElse(SimilarityFunction.COSINE); + var sourceModel = indexFunction.map(Map.Entry::getValue).orElse(SourceModel.OTHER); int dimensions = ((VectorType) column.getType()).getDimensions(); // NOTE: need to keep the column name as a string in the VectorColumnDefinition @@ -93,6 +105,7 @@ public static VectorConfig from(TableMetadata tableMetadata, ObjectMapper object cqlIdentifierToJsonKey(column.getName()), dimensions, similarityFunction, + sourceModel, vectorizeDefs.get(column.getName().asInternal()))); } return VectorConfig.fromColumnDefinitions(columnDefs); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java index 4a4bad95f3..b8742900cb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java @@ -20,6 +20,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.operation.Operation; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionTableMatcher; import java.time.Duration; @@ -41,6 +42,7 @@ public record CreateCollectionOperation( boolean vectorSearch, int vectorSize, String vectorFunction, + String sourceModel, String comment, int ddlDelayMillis, boolean tooManyIndexesRollbackEnabled, @@ -60,6 +62,7 @@ public static CreateCollectionOperation withVectorSearch( String name, int vectorSize, String vectorFunction, + String sourceModel, String comment, int ddlDelayMillis, boolean tooManyIndexesRollbackEnabled, @@ -73,6 +76,7 @@ public static CreateCollectionOperation withVectorSearch( true, vectorSize, vectorFunction, + sourceModel, comment, ddlDelayMillis, tooManyIndexesRollbackEnabled, @@ -98,6 +102,7 @@ public static CreateCollectionOperation withoutVectorSearch( false, 0, null, + null, comment, ddlDelayMillis, tooManyIndexesRollbackEnabled, @@ -142,6 +147,7 @@ public Uni> execute( vectorSearch, vectorSize, SimilarityFunction.fromString(vectorFunction), + SourceModel.fromString(sourceModel), comment, objectMapper); // if table exists we have to choices: @@ -492,6 +498,8 @@ public List getIndexStatements( appender + " \"%s_query_vector_value\" ON \"%s\".\"%s\" (query_vector_value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': '" + vectorFunction() + + "', 'source_model': '" + + sourceModel() + "'}"; statements.add( SimpleStatement.newInstance(String.format(vectorSearch, table, keyspace, table))); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexAttempt.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexAttempt.java index 6ad74ef237..c8f9aa7c52 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexAttempt.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexAttempt.java @@ -13,6 +13,7 @@ import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexOnTable; import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexStart; import com.datastax.oss.driver.internal.querybuilder.schema.DefaultCreateIndex; +import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; import io.stargate.sgv2.jsonapi.service.cqldriver.override.ExtendedCreateIndex; import io.stargate.sgv2.jsonapi.service.operation.SchemaAttempt; @@ -90,10 +91,11 @@ public record VectorIndexOptions(SimilarityFunction similarityFunction, String s public Map getOptions() { Map options = new HashMap<>(); if (similarityFunction != null) { - options.put("similarity_function", similarityFunction.getMetric()); + options.put( + VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION, similarityFunction.getMetric()); } if (sourceModel != null) { - options.put("source_model", sourceModel); + options.put(VectorConstants.CQLAnnIndex.SOURCE_MODEL, sourceModel); } return options; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java index 040f6f524f..4db858ab1d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java @@ -15,6 +15,8 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; import io.stargate.sgv2.jsonapi.service.operation.Operation; import io.stargate.sgv2.jsonapi.service.operation.collections.CreateCollectionOperation; +import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; @@ -106,6 +108,7 @@ public Operation resolveKeyspaceCommand( command.name(), vector.dimension(), vector.metric(), + vector.sourceModel(), comment, operationsConfig.databaseConfig().ddlDelayMillis(), operationsConfig.tooManyIndexesRollbackEnabled(), @@ -192,13 +195,35 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions Integer vectorDimension = vector.dimension(); VectorizeConfig service = vector.vectorizeConfig(); + String sourceModel = vector.sourceModel(); + String metric = vector.metric(); + + // decide sourceModel and metric value + if (sourceModel != null) { + if (metric == null) { + // (1) sourceModel is provided but metric is not - set metric to cosine or dot_product based + // on the map + metric = SourceModel.getSimilarityFunction(sourceModel).getMetric(); + } + // (2) both sourceModel and metric are provided - do nothing + } else { + if (metric != null) { + // (3) sourceModel is not provided but metric is - set sourceModel to 'other' + sourceModel = SourceModel.OTHER.getName(); + } else { + // (4) both sourceModel and metric are not provided - set sourceModel to 'other' and metric + // to 'cosine' + sourceModel = SourceModel.DEFAULT_SOURCE_MODEL.getName(); + metric = SimilarityFunction.DEFAULT_SIMILARITY_FUNCTION.getMetric(); + } + } if (service != null) { // Validate service configuration and auto populate vector dimension. vectorDimension = validateVectorize.validateService(service, vectorDimension); vector = new CreateCollectionCommand.Options.VectorSearchConfig( - vectorDimension, vector.metric(), vector.vectorizeConfig()); + vectorDimension, metric, sourceModel, vector.vectorizeConfig()); } else { // Ensure vector dimension is provided when service configuration is absent. if (vectorDimension == null) { @@ -209,6 +234,9 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions throw ErrorCodeV1.VECTOR_SEARCH_TOO_BIG_VALUE.toApiException( "%d (max %d)", vectorDimension, documentLimitsConfig.maxVectorEmbeddingLength()); } + vector = + new CreateCollectionCommand.Options.VectorSearchConfig( + vectorDimension, metric, sourceModel, null); } return vector; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateVectorIndexCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateVectorIndexCommandResolver.java index a43eab4023..5e60670ddc 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateVectorIndexCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateVectorIndexCommandResolver.java @@ -6,7 +6,6 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateVectorIndexCommand; import io.stargate.sgv2.jsonapi.config.DebugModeConfig; import io.stargate.sgv2.jsonapi.config.OperationsConfig; -import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; import io.stargate.sgv2.jsonapi.service.operation.GenericOperation; @@ -17,11 +16,10 @@ import io.stargate.sgv2.jsonapi.service.operation.tables.CreateIndexAttemptBuilder; import io.stargate.sgv2.jsonapi.service.operation.tables.TableDriverExceptionHandler; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil; import jakarta.enterprise.context.ApplicationScoped; import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -62,15 +60,12 @@ public Operation resolveTableCommand( } if (definitionOptions != null) { - if (sourceModel != null && VectorConstants.SUPPORTED_SOURCES.get(sourceModel) == null) { - List supportedSourceModel = - new ArrayList<>(VectorConstants.SUPPORTED_SOURCES.keySet()); - Collections.sort(supportedSourceModel); + if (sourceModel != null && SourceModel.getSimilarityFunction(sourceModel) == null) { throw SchemaException.Code.INVALID_INDEX_DEFINITION.get( Map.of( "reason", "sourceModel `%s` used in request is invalid. Supported source models are: %s" - .formatted(sourceModel, supportedSourceModel))); + .formatted(sourceModel, SourceModel.getAllSourceModelNames()))); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SimilarityFunction.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SimilarityFunction.java index c120bfaf64..0bc83e9062 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SimilarityFunction.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SimilarityFunction.java @@ -16,6 +16,7 @@ public enum SimilarityFunction { private String metric; private static Map FUNCTIONS_MAP = new HashMap<>(); + public static final SimilarityFunction DEFAULT_SIMILARITY_FUNCTION = COSINE; static { for (SimilarityFunction similarityFunction : SimilarityFunction.values()) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SourceModel.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SourceModel.java new file mode 100644 index 0000000000..8ebcd03d35 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/SourceModel.java @@ -0,0 +1,113 @@ +package io.stargate.sgv2.jsonapi.service.schema; + +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * The source model used for the vector index. This is only applicable if the vector index is + * enabled. + */ +public enum SourceModel { + ADA002("ada002"), + BERT("bert"), + COHERE_V3("cohere-v3"), + GECKO("gecko"), + NV_QA_4("nv-qa-4"), + OPENAI_V3_LARGE("openai-v3-large"), + OPENAI_V3_SMALL("openai-v3-small"), + OTHER("other"), + UNDEFINED("undefined"); + + private final String name; + private static final Map SOURCE_MODEL_NAME_MAP = + Stream.of(SourceModel.values()) + .collect(Collectors.toMap(SourceModel::getName, sourceModel -> sourceModel)); + + /** Supported Source Models and suggested similarity function for Vector Index in Cassandra */ + private static final Map SOURCE_MODEL_METRIC_MAP = + Map.of( + ADA002, + SimilarityFunction.DOT_PRODUCT, + BERT, + SimilarityFunction.DOT_PRODUCT, + COHERE_V3, + SimilarityFunction.DOT_PRODUCT, + GECKO, + SimilarityFunction.DOT_PRODUCT, + NV_QA_4, + SimilarityFunction.DOT_PRODUCT, + OPENAI_V3_LARGE, + SimilarityFunction.DOT_PRODUCT, + OPENAI_V3_SMALL, + SimilarityFunction.DOT_PRODUCT, + OTHER, + SimilarityFunction.COSINE); + + public static final SourceModel DEFAULT_SOURCE_MODEL = OTHER; + + SourceModel(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public static String getAllSourceModelNames() { + return SOURCE_MODEL_NAME_MAP.keySet().stream().sorted().collect(Collectors.joining(", ")); + } + + /** + * Get the recommended similarity function for the given source model. + * + * @param sourceModel The source model + * @return The similarity function + */ + public static SimilarityFunction getSimilarityFunction(SourceModel sourceModel) { + return SOURCE_MODEL_METRIC_MAP.get(sourceModel); + } + + /** + * Get the recommended similarity function for the given source model name. + * + * @param sourceModelName The source model name + * @return The similarity function + */ + public static SimilarityFunction getSimilarityFunction(String sourceModelName) { + return SOURCE_MODEL_NAME_MAP.get(sourceModelName) == null + ? null + : getSimilarityFunction(SOURCE_MODEL_NAME_MAP.get(sourceModelName)); + } + + /** + * Converts a string representation of a source model name to its corresponding {@link + * SourceModel} enum. + * + *

If the provided name is {@code null}, returns {@link SourceModel#UNDEFINED}, used for + * non-vector collections. If the name is an empty string, returns the default {@link + * SourceModel#OTHER}, indicating the collection was created before source models were supported. + * Throws a {@link JsonApiException} if the name is unrecognized. + * + * @param name + * @return + * @throws JsonApiException + */ + public static SourceModel fromString(String name) throws JsonApiException { + if (name == null) return UNDEFINED; + // The string may be empty if the collection was created before supporting source models + if (name.isEmpty()) return OTHER; + SourceModel model = SOURCE_MODEL_NAME_MAP.get(name); + if (model == null) { + String acceptedModels = + SOURCE_MODEL_NAME_MAP.keySet().stream() + .filter(key -> !key.equals(UNDEFINED.getName())) + .collect(Collectors.joining(", ")); + throw ErrorCodeV1.VECTOR_SEARCH_UNRECOGNIZED_SOURCE_MODEL_NAME.toApiException( + "Received: '%s'; Accepted: %s", name, acceptedModels); + } + return model; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java index 7a6425d8c5..8526c5a9fe 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java @@ -15,10 +15,12 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.constants.TableCommentConstants; +import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; import io.stargate.sgv2.jsonapi.service.projection.IndexingProjector; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import java.util.List; import java.util.Map; import java.util.Objects; @@ -159,18 +161,32 @@ public static CollectionSchemaObject getCollectionSettings( break; } } - // default function + // default function and source model SimilarityFunction function = SimilarityFunction.COSINE; + SourceModel sourceModel = SourceModel.OTHER; if (vectorIndex != null) { final String functionName = - vectorIndex.getOptions().get(DocumentConstants.Fields.VECTOR_INDEX_FUNCTION_NAME); + vectorIndex.getOptions().get(VectorConstants.CQLAnnIndex.SIMILARITY_FUNCTION); + final String sourceModelName = + vectorIndex.getOptions().get(VectorConstants.CQLAnnIndex.SOURCE_MODEL); if (functionName != null) { function = SimilarityFunction.fromString(functionName); } + if (sourceModelName != null) { + sourceModel = SourceModel.fromString(sourceModelName); + } } final String comment = (String) table.getOptions().get(CqlIdentifier.fromInternal("comment")); return createCollectionSettings( - keyspaceName, collectionName, table, true, vectorSize, function, comment, objectMapper); + keyspaceName, + collectionName, + table, + true, + vectorSize, + function, + sourceModel, + comment, + objectMapper); } else { // if not vector collection // handling comment so get the indexing config from comment final String comment = (String) table.getOptions().get(CqlIdentifier.fromInternal("comment")); @@ -181,6 +197,7 @@ public static CollectionSchemaObject getCollectionSettings( false, 0, SimilarityFunction.UNDEFINED, + SourceModel.UNDEFINED, comment, objectMapper); } @@ -193,6 +210,7 @@ public static CollectionSchemaObject getCollectionSettings( boolean vectorEnabled, int vectorSize, SimilarityFunction similarityFunction, + SourceModel sourceModel, String comment, ObjectMapper objectMapper) { return createCollectionSettings( @@ -202,6 +220,7 @@ public static CollectionSchemaObject getCollectionSettings( vectorEnabled, vectorSize, similarityFunction, + sourceModel, comment, objectMapper); } @@ -213,6 +232,7 @@ private static CollectionSchemaObject createCollectionSettings( boolean vectorEnabled, int vectorSize, SimilarityFunction function, + SourceModel sourceModel, String comment, ObjectMapper objectMapper) { @@ -229,6 +249,7 @@ private static CollectionSchemaObject createCollectionSettings( DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, vectorSize, function, + sourceModel, null))), null); } else { @@ -276,26 +297,12 @@ private static CollectionSchemaObject createCollectionSettings( tableMetadata, vectorEnabled, vectorSize, - function); + function, + sourceModel); } } } - // convert a vector jsonNode from cql table comment to vectorConfig, used for collection - private static VectorColumnDefinition fromJson(JsonNode jsonNode, ObjectMapper objectMapper) { - // dimension, similarityFunction, must exist - int dimension = jsonNode.get("dimension").asInt(); - SimilarityFunction similarityFunction = - SimilarityFunction.fromString(jsonNode.get("metric").asText()); - - return VectorColumnDefinition.fromJson( - DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, - dimension, - similarityFunction, - jsonNode, - objectMapper); - } - public static CreateCollectionCommand collectionSettingToCreateCollectionCommand( CollectionSchemaObject collectionSetting) { @@ -330,6 +337,7 @@ public static CreateCollectionCommand collectionSettingToCreateCollectionCommand new CreateCollectionCommand.Options.VectorSearchConfig( vectorColumnDefinition.vectorSize(), vectorColumnDefinition.similarityFunction().name().toLowerCase(), + vectorColumnDefinition.sourceModel().getName(), vectorizeConfig); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java index 76448ed917..a3afd2d16f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java @@ -8,6 +8,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import java.util.List; /** @@ -27,7 +28,8 @@ public CollectionSchemaObject readCollectionSettings( TableMetadata tableMetadata, boolean vectorEnabled, int vectorSize, - SimilarityFunction function) { + SimilarityFunction function, + SourceModel sourceModel) { VectorConfig vectorConfig = vectorEnabled @@ -37,6 +39,7 @@ public CollectionSchemaObject readCollectionSettings( DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, vectorSize, function, + sourceModel, null))) : VectorConfig.NOT_ENABLED_CONFIG; CollectionIndexingConfig indexingConfig = null; diff --git a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java index 23690d0c8b..71a2060051 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java @@ -5,6 +5,7 @@ import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import java.util.List; @@ -42,6 +43,7 @@ public final class TestConstants { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, -1, SimilarityFunction.COSINE, + SourceModel.OTHER, null))), null); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java index f19532817a..9090b2582a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java @@ -631,7 +631,8 @@ public void happyPathVectorSearchDefaultFunction() throws Exception { assertThat(createCollection.options()).isNotNull(); assertThat(createCollection.options().vector()).isNotNull(); assertThat(createCollection.options().vector().dimension()).isEqualTo(5); - assertThat(createCollection.options().vector().metric()).isEqualTo("cosine"); + assertThat(createCollection.options().vector().metric()).isNull(); + assertThat(createCollection.options().vector().sourceModel()).isNull(); }); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java index bd9fa8b71d..fe4f86ea14 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java @@ -139,6 +139,7 @@ public void caseSensitive() { .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); + deleteCollection("testcollection"); deleteCollection("testCollection"); } @@ -1751,6 +1752,311 @@ public void failWithWrongModelParameterType() { } } + @Nested + @Order(7) + class CreateCollectionWithSourceModel { + @Test + public void happyWithSourceModelAndMetrics() { + // create a collection with source model and metric + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "createCollection": { + "name": "collection_with_sourceModel_metric", + "options": { + "vector": { + "metric": "cosine", + "sourceModel": "openai-v3-small", + "dimension": 1536, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + } + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.ok", is(1)); + + // verify the collection using FindCollection + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body("status.collections[0].options.vector.metric", is("cosine")) + .body("status.collections[0].options.vector.sourceModel", is("openai-v3-small")); + + deleteCollection("collection_with_sourceModel_metric"); + } + + @Test + public void happyWithSourceModelOnly() { + // create a collection with source model - metric will be auto-populated to 'dot_product' + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "createCollection": { + "name": "collection_with_sourceModel", + "options": { + "vector": { + "sourceModel": "openai-v3-small", + "dimension": 1536, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + } + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.ok", is(1)); + + // verify the collection using FindCollection + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body("status.collections[0].options.vector.metric", is("dot_product")) + .body("status.collections[0].options.vector.sourceModel", is("openai-v3-small")); + + deleteCollection("collection_with_sourceModel"); + } + + @Test + public void happyWithMetricOnly() { + // create a collection with metric - source model will be auto-populated to 'other' + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "createCollection": { + "name": "collection_with_metric", + "options": { + "vector": { + "metric": "cosine", + "dimension": 1536, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + } + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.ok", is(1)); + + // verify the collection using FindCollection + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body("status.collections[0].options.vector.metric", is("cosine")) + .body("status.collections[0].options.vector.sourceModel", is("other")); + + deleteCollection("collection_with_metric"); + } + + @Test + public void happyNoSourceModelAndMetric() { + // create a collection without sourceModel and metric - source model will be auto-populated to + // 'other' and metric to 'cosine' + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "createCollection": { + "name": "collection_with_no_sourceModel_metric", + "options": { + "vector": { + "dimension": 1536, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + } + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.ok", is(1)); + + // verify the collection using FindCollection + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body("status.collections[0].options.vector.metric", is("cosine")) + .body("status.collections[0].options.vector.sourceModel", is("other")); + + deleteCollection("collection_with_no_sourceModel_metric"); + } + + @Test + public void failWithInvalidSourceModel() { + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "createCollection": { + "name": "collection_with_sourceModel", + "options": { + "vector": { + "sourceModel": "invalidName", + "dimension": 1536, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + } + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsError()) + .body("errors[0].exceptionClass", is("JsonApiException")) + .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) + .body( + "errors[0].message", + startsWith( + "Request invalid: field 'command.options.vector.sourceModel' value \"invalidName\" not valid.")); + } + + @Test + public void failWithInvalidSourceModelObject() { + given() + .headers(getHeaders()) + .contentType(ContentType.JSON) + .body( + """ + { + "createCollection": { + "name": "collection_with_sourceModel", + "options": { + "vector": { + "sourceModel": "invalidName", + "dimension": 1536, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + } + } + } + } + """) + .when() + .post(KeyspaceResource.BASE_PATH, keyspaceName) + .then() + .statusCode(200) + .body("$", responseIsError()) + .body("errors[0].exceptionClass", is("JsonApiException")) + .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) + .body( + "errors[0].message", + startsWith( + "Request invalid: field 'command.options.vector.sourceModel' value \"invalidName\" not valid.")); + } + } + private void deleteCollection(String collectionName) { given() .headers(getHeaders()) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java index 253f813306..9a6bb55ee9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java @@ -152,7 +152,8 @@ public void happyPathWithExplain() { "options": { "vector": { "dimension": 5, - "metric": "cosine" + "metric": "cosine", + "sourceModel": "other" },"indexing": { "deny" : ["comment"] } @@ -375,7 +376,7 @@ public void happyPathIndexingWithExplain() { """; String expected3 = """ - {"name":"collection2", "options": {"vector": {"dimension":5, "metric":"cosine"}, "indexing":{"deny":["comment"]}}} + {"name":"collection2", "options": {"vector": {"dimension":5, "metric":"cosine", "sourceModel": "other"}, "indexing":{"deny":["comment"]}}} """; String expected4 = """ diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java index 95b01e6dbd..7ae81a56f0 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java @@ -7,12 +7,9 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; import io.stargate.sgv2.jsonapi.api.v1.util.DataApiCommandSenders; -import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.exception.SchemaException; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Map; import org.junit.jupiter.api.*; @@ -279,7 +276,7 @@ public void createVectorIndexWithSourceModel() { "definition": { "column": "vector_type_2", "options": { - "sourceModel": "openai_v3_small" + "sourceModel": "openai-v3-small" } } } @@ -316,7 +313,7 @@ public void createVectorIndexWithMetricAndSourceModel() { "column": "vector_type_4", "options": { "metric": "cosine", - "sourceModel": "openai_v3_small" + "sourceModel": "openai-v3-small" } } } @@ -403,15 +400,12 @@ public void tryCreateIndexMissingColumn() { @Test public void invalidSourceModel() { - List supportedSourceModel = - new ArrayList<>(VectorConstants.SUPPORTED_SOURCES.keySet()); - Collections.sort(supportedSourceModel); final SchemaException schemaException = SchemaException.Code.INVALID_INDEX_DEFINITION.get( Map.of( "reason", "sourceModel `%s` used in request is invalid. Supported source models are: %s" - .formatted("invalid_source_model", supportedSourceModel))); + .formatted("invalid_source_model", SourceModel.getAllSourceModelNames()))); DataApiCommandSenders.assertTableCommand(keyspaceName, testTableName) .postCreateVectorIndex( """ diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index 4d8c127b33..db5bf4fdd8 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -21,6 +21,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorizeDefinition; import io.stargate.sgv2.jsonapi.service.embedding.DataVectorizer; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import jakarta.inject.Inject; @@ -239,6 +240,7 @@ public void testWithUnmatchedVectorSize() { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, 4, SimilarityFunction.COSINE, + SourceModel.OTHER, new VectorizeDefinition("custom", "custom", null, null)))), null); List documents = new ArrayList<>(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java index b74e15b601..e36f3137c0 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java @@ -9,6 +9,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorizeDefinition; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import java.util.ArrayList; @@ -28,6 +29,7 @@ public class TestEmbeddingProvider extends EmbeddingProvider { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, 3, SimilarityFunction.COSINE, + SourceModel.OTHER, new VectorizeDefinition("custom", "custom", null, null)))), null), new TestEmbeddingProvider(), diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperationTest.java index 28141dc9af..c63f1f1323 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperationTest.java @@ -173,6 +173,7 @@ public void createCollectionVector() { 5, "cosine", "", + "", 10, false, false); @@ -291,6 +292,7 @@ public void denyAllCollectionVector() { 5, "cosine", "", + "", 10, false, true); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java index 679cca1e28..84efe78156 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/FindCollectionOperationTest.java @@ -35,6 +35,7 @@ import io.stargate.sgv2.jsonapi.service.operation.query.DBLogicalExpression; import io.stargate.sgv2.jsonapi.service.projection.DocumentProjector; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import io.stargate.sgv2.jsonapi.service.shredding.collections.DocValueHasher; @@ -89,6 +90,7 @@ public void init() { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, -1, SimilarityFunction.COSINE, + SourceModel.OTHER, null))), null), null, diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java index 8e230f17ee..9610f0587f 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/InsertCollectionOperationTest.java @@ -28,6 +28,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.serializer.CQLBindValues; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import io.stargate.sgv2.jsonapi.service.shredding.collections.DocumentId; @@ -106,6 +107,7 @@ public void init() { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, 4, SimilarityFunction.COSINE, + SourceModel.OTHER, null))), null), null, diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java index 4c539e9cc3..efb5373d30 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/collections/ReadAndUpdateCollectionOperationTest.java @@ -35,6 +35,7 @@ import io.stargate.sgv2.jsonapi.service.operation.query.DBLogicalExpression; import io.stargate.sgv2.jsonapi.service.projection.DocumentProjector; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import io.stargate.sgv2.jsonapi.service.shredding.collections.DocValueHasher; @@ -125,6 +126,7 @@ public void init() { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, -1, SimilarityFunction.COSINE, + SourceModel.OTHER, null))), null), null, diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java index b3b591cfe6..cd151fded1 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolverWithVectorizerTest.java @@ -35,6 +35,7 @@ import io.stargate.sgv2.jsonapi.service.operation.filters.collection.TextCollectionFilter; import io.stargate.sgv2.jsonapi.service.projection.DocumentProjector; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; +import io.stargate.sgv2.jsonapi.service.schema.SourceModel; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import io.stargate.sgv2.jsonapi.service.shredding.collections.DocumentShredder; @@ -92,6 +93,7 @@ class Resolve { DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, -1, SimilarityFunction.COSINE, + SourceModel.OTHER, null))), null), null, diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolverTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolverTest.java index 5c75587a61..46ed71c7de 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolverTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolverTest.java @@ -128,7 +128,7 @@ public void happyPathVectorizeSearch() throws Exception { assertThat(op.vectorFunction()).isEqualTo("cosine"); assertThat(op.comment()) .isEqualTo( - "{\"collection\":{\"name\":\"my_collection\",\"schema_version\":1,\"options\":{\"vector\":{\"dimension\":768,\"metric\":\"cosine\",\"service\":{\"provider\":\"azureOpenAI\",\"modelName\":\"text-embedding-3-small\",\"parameters\":{\"resourceName\":\"test\",\"deploymentId\":\"test\"}}},\"defaultId\":{\"type\":\"\"}}}}", + "{\"collection\":{\"name\":\"my_collection\",\"schema_version\":1,\"options\":{\"vector\":{\"dimension\":768,\"metric\":\"cosine\",\"sourceModel\":\"other\",\"service\":{\"provider\":\"azureOpenAI\",\"modelName\":\"text-embedding-3-small\",\"parameters\":{\"resourceName\":\"test\",\"deploymentId\":\"test\"}}},\"defaultId\":{\"type\":\"\"}}}}", TableCommentConstants.SCHEMA_VERSION_VALUE); }); } @@ -167,7 +167,7 @@ public void happyPathIndexing() throws Exception { assertThat(op.vectorFunction()).isEqualTo("cosine"); assertThat(op.comment()) .isEqualTo( - "{\"collection\":{\"name\":\"my_collection\",\"schema_version\":%s,\"options\":{\"indexing\":{\"deny\":[\"comment\"]},\"vector\":{\"dimension\":4,\"metric\":\"cosine\"},\"defaultId\":{\"type\":\"\"}}}}", + "{\"collection\":{\"name\":\"my_collection\",\"schema_version\":%s,\"options\":{\"indexing\":{\"deny\":[\"comment\"]},\"vector\":{\"dimension\":4,\"metric\":\"cosine\",\"sourceModel\":\"other\"},\"defaultId\":{\"type\":\"\"}}}}", TableCommentConstants.SCHEMA_VERSION_VALUE); }); }