diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/ColumnDefinitionDeserializer.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/ColumnDefinitionDeserializer.java index db7c13a68..fff555cb0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/ColumnDefinitionDeserializer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/ColumnDefinitionDeserializer.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype.ColumnType; import io.stargate.sgv2.jsonapi.exception.SchemaException; import java.io.IOException; @@ -25,13 +26,14 @@ public ColumnType deserialize( throws IOException, JacksonException { JsonNode definition = deserializationContext.readTree(jsonParser); if (definition.isTextual()) { - return ColumnType.fromString(definition.asText(), null, null, -1); + return ColumnType.fromString(definition.asText(), null, null, -1, null); } if (definition.isObject() && definition.has("type")) { String type = definition.path("type").asText(); String keyType = null; String valueType = null; int dimension = -1; + VectorizeConfig vectorConfig = null; if (definition.has("keyType")) { keyType = definition.path("keyType").asText(); } @@ -41,7 +43,15 @@ public ColumnType deserialize( if (definition.has("dimension")) { dimension = definition.path("dimension").asInt(); } - return ColumnType.fromString(type, keyType, valueType, dimension); + if (definition.has("service")) { + JsonNode service = definition.path("service"); + try { + vectorConfig = deserializationContext.readTreeAsValue(service, VectorizeConfig.class); + } catch (JacksonException je) { + throw SchemaException.Code.VECTOR_TYPE_INCORRECT_DEFINITION.get(); + } + } + return ColumnType.fromString(type, keyType, valueType, dimension, vectorConfig); } throw SchemaException.Code.COLUMN_TYPE_INCORRECT.get(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java index 604cee7c2..64ab3f999 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java @@ -7,7 +7,6 @@ import io.stargate.sgv2.jsonapi.api.model.command.CollectionOnlyCommand; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import jakarta.validation.Valid; import jakarta.validation.constraints.*; import java.util.*; @@ -111,84 +110,6 @@ public VectorSearchConfig(Integer dimension, String metric, VectorizeConfig vect this.metric = metric == null ? "cosine" : metric; this.vectorizeConfig = vectorizeConfig; } - - public record VectorizeConfig( - @NotNull - @Schema( - description = "Registered Embedding service provider", - type = SchemaType.STRING, - implementation = String.class) - @JsonProperty("provider") - String provider, - @Schema( - description = "Registered Embedding service model", - type = SchemaType.STRING, - implementation = String.class) - @JsonProperty("modelName") - String modelName, - @Valid - @Nullable - @Schema( - description = "Authentication config for chosen embedding service", - type = SchemaType.OBJECT) - @JsonProperty("authentication") - @JsonInclude(JsonInclude.Include.NON_NULL) - Map authentication, - @Nullable - @Schema( - description = - "Optional parameters that match the messageTemplate provided for the provider", - type = SchemaType.OBJECT) - @JsonProperty("parameters") - @JsonInclude(JsonInclude.Include.NON_NULL) - Map parameters) { - - public VectorizeConfig( - String provider, - String modelName, - Map authentication, - Map parameters) { - this.provider = provider; - // HuggingfaceDedicated does not need user to specify model explicitly - // If user specifies modelName other than endpoint-defined-model, will error out - // By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder - if (provider.equals(ProviderConstants.HUGGINGFACE_DEDICATED)) { - if (modelName != null - && !modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'modelName' is not needed for provider %s explicitly, only '%s' is accepted", - ProviderConstants.HUGGINGFACE_DEDICATED, - ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL); - } - this.modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL; - } else { - this.modelName = modelName; - } - if (authentication != null && !authentication.isEmpty()) { - Map updatedAuth = new HashMap<>(); - for (Map.Entry userAuth : authentication.entrySet()) { - // Determine the full credential name based on the sharedKeyValue pair - // If the sharedKeyValue does not contain a dot (e.g. myKey) or the part after the dot - // does not match the key (e.g. myKey.test), append the key to the sharedKeyValue with - // a dot (e.g. myKey.providerKey or myKey.test.providerKey). Otherwise, use the - // sharedKeyValue (e.g. myKey.providerKey) as is. - String sharedKeyValue = userAuth.getValue(); - String credentialName = - sharedKeyValue.lastIndexOf('.') <= 0 - || !sharedKeyValue - .substring(sharedKeyValue.lastIndexOf('.') + 1) - .equals(userAuth.getKey()) - ? sharedKeyValue + "." + userAuth.getKey() - : sharedKeyValue; - updatedAuth.put(userAuth.getKey(), credentialName); - } - this.authentication = updatedAuth; - } else { - this.authentication = authentication; - } - this.parameters = parameters; - } - } } public record IndexingConfig( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java new file mode 100644 index 000000000..abe0e2952 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java @@ -0,0 +1,90 @@ +package io.stargate.sgv2.jsonapi.api.model.command.impl; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import jakarta.validation.Valid; +import jakarta.validation.constraints.*; +import java.util.*; +import javax.annotation.Nullable; +import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; +import org.eclipse.microprofile.openapi.annotations.media.Schema; + +public record VectorizeConfig( + @NotNull + @Schema( + description = "Registered Embedding service provider", + type = SchemaType.STRING, + implementation = String.class) + @JsonProperty("provider") + String provider, + @Schema( + description = "Registered Embedding service model", + type = SchemaType.STRING, + implementation = String.class) + @JsonProperty("modelName") + String modelName, + @Valid + @Nullable + @Schema( + description = "Authentication config for chosen embedding service", + type = SchemaType.OBJECT) + @JsonProperty("authentication") + @JsonInclude(JsonInclude.Include.NON_NULL) + Map authentication, + @Nullable + @Schema( + description = + "Optional parameters that match the messageTemplate provided for the provider", + type = SchemaType.OBJECT) + @JsonProperty("parameters") + @JsonInclude(JsonInclude.Include.NON_NULL) + Map parameters) { + + public VectorizeConfig( + String provider, + String modelName, + Map authentication, + Map parameters) { + this.provider = provider; + // HuggingfaceDedicated does not need user to specify model explicitly + // If user specifies modelName other than endpoint-defined-model, will error out + // By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder + if (provider.equals(ProviderConstants.HUGGINGFACE_DEDICATED)) { + if (modelName != null + && !modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "'modelName' is not needed for provider %s explicitly, only '%s' is accepted", + ProviderConstants.HUGGINGFACE_DEDICATED, + ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL); + } + this.modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL; + } else { + this.modelName = modelName; + } + if (authentication != null && !authentication.isEmpty()) { + Map updatedAuth = new HashMap<>(); + for (Map.Entry userAuth : authentication.entrySet()) { + // Determine the full credential name based on the sharedKeyValue pair + // If the sharedKeyValue does not contain a dot (e.g. myKey) or the part after the dot + // does not match the key (e.g. myKey.test), append the key to the sharedKeyValue with + // a dot (e.g. myKey.providerKey or myKey.test.providerKey). Otherwise, use the + // sharedKeyValue (e.g. myKey.providerKey) as is. + String sharedKeyValue = userAuth.getValue(); + String credentialName = + sharedKeyValue.lastIndexOf('.') <= 0 + || !sharedKeyValue + .substring(sharedKeyValue.lastIndexOf('.') + 1) + .equals(userAuth.getKey()) + ? sharedKeyValue + "." + userAuth.getKey() + : sharedKeyValue; + updatedAuth.put(userAuth.getKey(), credentialName); + } + this.authentication = updatedAuth; + } else { + this.authentication = authentication; + } + this.parameters = parameters; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ColumnType.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ColumnType.java index dfaf0c25c..d2da91411 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ColumnType.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ColumnType.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.stargate.sgv2.jsonapi.api.model.command.deserializers.ColumnDefinitionDeserializer; +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiDataType; import java.util.List; @@ -40,7 +41,8 @@ static List getSupportedTypes() { } // Returns the column type from the string. - static ColumnType fromString(String type, String keyType, String valueType, int dimension) { + static ColumnType fromString( + String type, String keyType, String valueType, int dimension, VectorizeConfig vectorConfig) { // TODO: the name of the type should be a part of the ColumnType interface, and use a map for // the lookup switch (type) { @@ -87,8 +89,8 @@ static ColumnType fromString(String type, String keyType, String valueType, int } try { return new ComplexTypes.MapType( - fromString(keyType, null, null, dimension), - fromString(valueType, null, null, dimension)); + fromString(keyType, null, null, dimension, vectorConfig), + fromString(valueType, null, null, dimension, vectorConfig)); } catch (SchemaException se) { throw SchemaException.Code.MAP_TYPE_INCORRECT_DEFINITION.get(); } @@ -99,7 +101,8 @@ static ColumnType fromString(String type, String keyType, String valueType, int throw SchemaException.Code.LIST_TYPE_INCORRECT_DEFINITION.get(); } try { - return new ComplexTypes.ListType(fromString(valueType, null, null, dimension)); + return new ComplexTypes.ListType( + fromString(valueType, null, null, dimension, vectorConfig)); } catch (SchemaException se) { throw SchemaException.Code.LIST_TYPE_INCORRECT_DEFINITION.get(); } @@ -111,7 +114,8 @@ static ColumnType fromString(String type, String keyType, String valueType, int throw SchemaException.Code.SET_TYPE_INCORRECT_DEFINITION.get(); } try { - return new ComplexTypes.SetType(fromString(valueType, null, null, dimension)); + return new ComplexTypes.SetType( + fromString(valueType, null, null, dimension, vectorConfig)); } catch (SchemaException se) { throw SchemaException.Code.SET_TYPE_INCORRECT_DEFINITION.get(); } @@ -123,7 +127,7 @@ static ColumnType fromString(String type, String keyType, String valueType, int throw SchemaException.Code.VECTOR_TYPE_INCORRECT_DEFINITION.get(); } try { - return new ComplexTypes.VectorType(PrimitiveTypes.FLOAT, dimension); + return new ComplexTypes.VectorType(PrimitiveTypes.FLOAT, dimension, vectorConfig); } catch (SchemaException se) { throw SchemaException.Code.VECTOR_TYPE_INCORRECT_DEFINITION.get(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ComplexTypes.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ComplexTypes.java index cc3e474f8..4bb640961 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ComplexTypes.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/datatype/ComplexTypes.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype; +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiDataType; import io.stargate.sgv2.jsonapi.service.schema.tables.ComplexApiDataType; import io.stargate.sgv2.jsonapi.service.schema.tables.PrimitiveApiDataType; @@ -58,10 +59,12 @@ public static class VectorType implements ColumnType { // Float will be default type for vector private final ColumnType valueType; private final int vectorSize; + private final VectorizeConfig vectorConfig; - public VectorType(ColumnType valueType, int vectorSize) { + public VectorType(ColumnType valueType, int vectorSize, VectorizeConfig vectorConfig) { this.valueType = valueType; this.vectorSize = vectorSize; + this.vectorConfig = vectorConfig; } @Override @@ -69,5 +72,13 @@ public ApiDataType getApiDataType() { return new ComplexApiDataType.VectorType( (PrimitiveApiDataType) valueType.getApiDataType(), vectorSize); } + + public VectorizeConfig getVectorConfig() { + return vectorConfig; + } + + public int getDimension() { + return vectorSize; + } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttempt.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttempt.java index 7d04ba5a8..2ef81900e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttempt.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttempt.java @@ -5,6 +5,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.data.ByteUtils; import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder; import com.datastax.oss.driver.api.core.type.DataType; import com.datastax.oss.driver.api.querybuilder.schema.CreateTable; @@ -23,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; public class CreateTableAttempt extends SchemaAttempt { @@ -30,7 +32,7 @@ public class CreateTableAttempt extends SchemaAttempt { private final Map columnTypes; private final List partitionKeys; private final List clusteringKeys; - private final String comment; + private final Map customProperties; private final boolean ifNotExists; protected CreateTableAttempt( @@ -43,7 +45,7 @@ protected CreateTableAttempt( List partitionKeys, List clusteringKeys, boolean ifNotExists, - String comment) { + Map customProperties) { super( position, schemaObject, @@ -54,7 +56,7 @@ protected CreateTableAttempt( this.partitionKeys = partitionKeys; this.clusteringKeys = clusteringKeys; this.ifNotExists = ifNotExists; - this.comment = comment; + this.customProperties = customProperties; setStatus(OperationStatus.READY); } @@ -73,8 +75,17 @@ protected SimpleStatement buildStatement() { // Add all primary keys and colunms CreateTable createTable = addColumnsAndKeys(create); - // Add comment which has table properties for vectorize - CreateTableWithOptions createWithOptions = createTable.withComment(comment); + // Add customProperties which has table properties for vectorize + // Convert value to hex string using the ByteUtils.toHexString + // This needs to use `createTable.withExtensions()` method in driver when PR + // (https://github.com/apache/cassandra-java-driver/pull/1964) is released + final Map extensions = + customProperties.entrySet().stream() + .collect( + Collectors.toMap( + e -> e.getKey(), e -> ByteUtils.toHexString(e.getValue().getBytes()))); + + CreateTableWithOptions createWithOptions = createTable.withOption("extensions", extensions); // Add the clustering key order createWithOptions = addClusteringOrder(createWithOptions); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttemptBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttemptBuilder.java index 275d1d06c..f6a63cb20 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttemptBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateTableAttemptBuilder.java @@ -16,7 +16,7 @@ public class CreateTableAttemptBuilder { private Map columnTypes; private List partitionKeys; private List clusteringKeys; - private String comment; + private Map customProperties; private boolean ifNotExists; public CreateTableAttemptBuilder(int position, KeyspaceSchemaObject schemaObject) { @@ -55,8 +55,8 @@ public CreateTableAttemptBuilder clusteringKeys(List clu return this; } - public CreateTableAttemptBuilder comment(String comment) { - this.comment = comment; + public CreateTableAttemptBuilder customProperties(Map customProperties) { + this.customProperties = customProperties; return this; } @@ -77,6 +77,6 @@ public CreateTableAttempt build() { partitionKeys, clusteringKeys, ifNotExists, - comment); + customProperties); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java index 97e58e3a8..040f6f524 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateCollectionCommand; +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.config.DatabaseLimitsConfig; import io.stargate.sgv2.jsonapi.config.DocumentLimitsConfig; import io.stargate.sgv2.jsonapi.config.OperationsConfig; @@ -12,14 +13,10 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.operation.Operation; import io.stargate.sgv2.jsonapi.service.operation.collections.CreateCollectionOperation; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import java.util.*; -import java.util.stream.Collectors; @ApplicationScoped public class CreateCollectionCommandResolver implements CommandResolver { @@ -29,8 +26,7 @@ public class CreateCollectionCommandResolver implements CommandResolver - *
  • Validate that all keys (member names) in the authentication stanza (e.g. providerKey) are - * listed in the configuration for the provider as accepted keys. - *
  • For each key-value member of the authentication stanza: - *
      - *
    1. If the value does not contain the period character "." it assumes the value is the - * name of the credential without specifying the key. - *
        - *
      1. The credential name is appended with .<key> and the secret service - * called to validate that a credential with that name exists and it has the - * named key. - *
      - *
    2. If the value does contain a period character "." it assumes the first part is the - * name of the credential and the second the name of the key within it. - *
        - *
      1. The secret service called to validate that a credential with that name exists - * and it has the named key. - *
      - *
    - * - * - * @param userConfig The vectorize configuration provided by the user. - * @param providerConfig The embedding provider configuration. - * @throws JsonApiException If the user authentication is invalid. - */ - private void validateAuthentication( - CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - // Get all the accepted keys in SHARED_SECRET - Set acceptedKeys = - providerConfig.supportedAuthentications().entrySet().stream() - .filter( - config -> - config - .getKey() - .equals( - EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType - .SHARED_SECRET)) - .filter(config -> config.getValue().enabled() && config.getValue().tokens() != null) - .flatMap(config -> config.getValue().tokens().stream()) - .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.TokenConfig::accepted) - .collect(Collectors.toSet()); - - // If the user hasn't provided authentication details, verify that either the 'NONE' or 'HEADER' - // authentication type is enabled. - if (userConfig.authentication() == null || userConfig.authentication().isEmpty()) { - // Check if 'NONE' authentication type is enabled - boolean noneEnabled = - Optional.ofNullable( - providerConfig - .supportedAuthentications() - .get( - EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE)) - .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig::enabled) - .orElse(false); - - // Check if 'HEADER' authentication type is enabled - boolean headerEnabled = - Optional.ofNullable( - providerConfig - .supportedAuthentications() - .get( - EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType - .HEADER)) - .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig::enabled) - .orElse(false); - - // If neither 'NONE' nor 'HEADER' authentication type is enabled, throw an exception - if (!noneEnabled && !headerEnabled) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Service provider '%s' does not support either 'NONE' or 'HEADER' authentication types.", - userConfig.provider()); - } - } else { - // User has provided authentication details. Validate each key against the provider's accepted - // list. - for (Map.Entry userAuth : userConfig.authentication().entrySet()) { - // Check if the key is accepted by the provider - if (!acceptedKeys.contains(userAuth.getKey())) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Service provider '%s' does not support authentication key '%s'", - userConfig.provider(), userAuth.getKey()); - } - - // Validate the credential name from secret service - // already append the .providerKey to the value in CreateCollectionCommand - if (operationsConfig.enableEmbeddingGateway()) { - validateCredentials.validate(userConfig.provider(), userAuth.getValue()); - } - } - } - } - - /** - * Validates the parameters provided by the user against the expected parameters from both the - * provider and the model configurations. This method ensures that only configured parameters are - * provided, all required parameters are included, and no unexpected parameters are passed. - * - * @param userConfig The vector search configuration provided by the user. - * @param providerConfig The configuration of the embedding provider which includes model and - * provider-level parameters. - * @throws JsonApiException if any unconfigured parameters are provided, required parameters are - * missing, or if an error occurs due to no parameters being configured but some are provided - * by the user. - */ - private void validateUserParameters( - CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - // 0. Combine provider level and model level parameters - List allParameters = - new ArrayList<>(); - // Add all provider level parameters - allParameters.addAll(providerConfig.parameters()); - // Get all the parameters except "vectorDimension" for the model -- model has been validated in - // the previous step, huggingfaceDedicated uses endpoint-defined-model - List modelParameters = - providerConfig.models().stream() - .filter(m -> m.name().equals(userConfig.modelName())) - .findFirst() - .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig::parameters) - .map( - params -> - params.stream() - .filter( - param -> - !param - .name() - .equals( - "vectorDimension")) // Exclude 'vectorDimension' parameter - .collect(Collectors.toList())) - .get(); - // Add all model level parameters - allParameters.addAll(modelParameters); - // 1. Error if the user provided un-configured parameters - // Two level parameters have unique names, should be fine here - Set expectedParamNames = - allParameters.stream() - .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig::name) - .collect(Collectors.toSet()); - Map userParameters = - (userConfig.parameters() != null) ? userConfig.parameters() : Collections.emptyMap(); - // Check for unconfigured parameters provided by the user - userParameters - .keySet() - .forEach( - userParamName -> { - if (!expectedParamNames.contains(userParamName)) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Unexpected parameter '%s' for the provider '%s' provided", - userParamName, userConfig.provider()); - } - }); - - // 2. Error if the user doesn't provide required parameters - // Check for missing required parameters and collect them for type validation - List parametersToValidate = - new ArrayList<>(); - allParameters.forEach( - expectedParamConfig -> { - if (expectedParamConfig.required() - && !userParameters.containsKey(expectedParamConfig.name())) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Required parameter '%s' for the provider '%s' missing", - expectedParamConfig.name(), userConfig.provider()); - } - if (userParameters.containsKey(expectedParamConfig.name())) { - parametersToValidate.add(expectedParamConfig); - } - }); - - // 3. Validate parameter types if no errors occurred in previous steps - parametersToValidate.forEach( - expectedParamConfig -> - validateParameterType( - expectedParamConfig, userParameters.get(expectedParamConfig.name()))); - } - - /** - * Validates the type of parameter provided by the user against the expected type defined in the - * provider's configuration. This method checks if the type of the user-provided parameter matches - * the expected type, throwing an exception if there is a mismatch. - * - * @param expectedParamConfig The expected configuration for the parameter which includes its - * expected type. - * @param userParamValue The value of the parameter provided by the user. - * @throws JsonApiException if the type of the parameter provided by the user does not match the - * expected type. - */ - private void validateParameterType( - EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig expectedParamConfig, - Object userParamValue) { - - EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType expectedParamType = - expectedParamConfig.type(); - boolean typeMismatch = - expectedParamType == EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType.STRING - && !(userParamValue instanceof String) - || expectedParamType - == EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType.NUMBER - && !(userParamValue instanceof Number) - || expectedParamType - == EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType.BOOLEAN - && !(userParamValue instanceof Boolean); - - if (typeMismatch) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "The provided parameter '%s' type is incorrect. Expected: '%s'", - expectedParamConfig.name(), expectedParamType); - } - } - - /** - * Validates the model name and vector dimension provided in the user configuration against the - * specified embedding provider configuration. - * - * @param userConfig the user-specified vectorization configuration - * @param providerConfig the configuration of the embedding provider - * @param userVectorDimension the vector dimension provided by the user, or null if not provided - * @return the validated vector dimension to be used for the model - * @throws JsonApiException if the model name is not found, or if the dimension is invalid - */ - private Integer validateModelAndDimension( - CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, - Integer userVectorDimension) { - // Find the model configuration by matching the model name - // 1. huggingfaceDedicated does not require model, but requires dimension - if (userConfig.provider().equals(ProviderConstants.HUGGINGFACE_DEDICATED)) { - if (userVectorDimension == null) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'dimension' is needed for provider %s", ProviderConstants.HUGGINGFACE_DEDICATED); - } - } - // 2. other providers do require model - if (userConfig.modelName() == null) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'modelName' is needed for provider %s", userConfig.provider()); - } - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model = - providerConfig.models().stream() - .filter(m -> m.name().equals(userConfig.modelName())) - .findFirst() - .orElseThrow( - () -> - ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Model name '%s' for provider '%s' is not supported", - userConfig.modelName(), userConfig.provider())); - - // Handle models with a fixed vector dimension - if (model.vectorDimension().isPresent() && model.vectorDimension().get() != 0) { - Integer configVectorDimension = model.vectorDimension().get(); - if (userVectorDimension == null) { - return configVectorDimension; // Use model's dimension if user hasn't specified any - } else if (!configVectorDimension.equals(userVectorDimension)) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "The provided dimension value '%s' doesn't match the model's supported dimension value '%s'", - userVectorDimension, configVectorDimension); - } - return configVectorDimension; - } - - // Handle models with a range of acceptable dimensions - return model.parameters().stream() - .filter(param -> param.name().equals("vectorDimension")) - .findFirst() - .map(param -> validateRangeDimension(param, userVectorDimension)) - .orElse(userVectorDimension); // should not go here - } - - /** - * Validates the user-provided vector dimension against the dimension parameter's validation - * constraints. - * - * @param param the parameter configuration containing validation constraints - * @param userVectorDimension the vector dimension provided by the user - * @return the appropriate vector dimension based on parameter configuration - * @throws JsonApiException if the user-provided dimension is not valid - */ - private Integer validateRangeDimension( - EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig param, - Integer userVectorDimension) { - // Use the default value if the user has not provided a dimension - if (userVectorDimension == null) { - return Integer.valueOf(param.defaultValue().get()); - } - - // Extract validation type and values for comparison - Map.Entry> - entry = param.validation().entrySet().iterator().next(); - EmbeddingProvidersConfig.EmbeddingProviderConfig.ValidationType validationType = entry.getKey(); - List validationValues = entry.getValue(); - - // Perform validation based on the validation type - switch (validationType) { - case NUMERIC_RANGE -> { - if (userVectorDimension < validationValues.get(0) - || userVectorDimension > validationValues.get(1)) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "The provided dimension value (%d) is not within the supported numeric range [%d, %d]", - userVectorDimension, validationValues.get(0), validationValues.get(1)); - } - } - case OPTIONS -> { - if (!validationValues.contains(userVectorDimension)) { - String validatedValuesStr = - String.join( - ", ", validationValues.stream().map(Object::toString).toArray(String[]::new)); - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "The provided dimension value '%s' is not within the supported options [%s]", - userVectorDimension, validatedValuesStr); - } - } - } - return userVectorDimension; - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java index 3b61e0a7e..7f25c46ed 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTableCommandResolver.java @@ -1,8 +1,12 @@ package io.stargate.sgv2.jsonapi.service.resolver; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateTableCommand; +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.api.model.command.table.definition.PrimaryKey; +import io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype.ComplexTypes; import io.stargate.sgv2.jsonapi.config.DebugModeConfig; import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.SchemaException; @@ -13,13 +17,18 @@ import io.stargate.sgv2.jsonapi.service.operation.tables.KeyspaceDriverExceptionHandler; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiDataType; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @ApplicationScoped public class CreateTableCommandResolver implements CommandResolver { + @Inject ObjectMapper objectMapper; + @Inject VectorizeConfigValidator validateVectorize; + @Override public Operation resolveKeyspaceCommand( CommandContext ctx, CreateTableCommand command) { @@ -30,6 +39,22 @@ public Operation resolveKeyspaceCommand( .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getApiDataType())); List partitionKeys = Arrays.stream(command.definition().primaryKey().keys()).toList(); + Map vectorizeConfigMap = + command.definition().columns().entrySet().stream() + .filter( + e -> + e.getValue() instanceof ComplexTypes.VectorType vt + && vt.getVectorConfig() != null) + .collect( + Collectors.toMap( + Map.Entry::getKey, + e -> { + ComplexTypes.VectorType vectorType = ((ComplexTypes.VectorType) e.getValue()); + final VectorizeConfig vectorConfig = vectorType.getVectorConfig(); + validateVectorize.validateService(vectorConfig, vectorType.getDimension()); + return vectorConfig; + })); + if (partitionKeys.isEmpty()) { throw SchemaException.Code.MISSING_PRIMARY_KEYS.get(); } @@ -57,7 +82,17 @@ public Operation resolveKeyspaceCommand( }); // set to empty will be used when vectorize is supported - String comment = ""; + Map customProperties = new HashMap<>(); + try { + customProperties.put("com.datastax.data-api.schema-type", "table"); + // Versioning for schema json. This needs can be adapted in future as needed + customProperties.put("com.datastax.data-api.schema-def-version", "1"); + String vectorizeConfigToStore = objectMapper.writeValueAsString(vectorizeConfigMap); + customProperties.put("com.datastax.data-api.vectorize-config", vectorizeConfigToStore); + } catch (JsonProcessingException e) { + // this should never happen + throw new RuntimeException(e); + } var attempt = new CreateTableAttemptBuilder(0, ctx.schemaObject()) @@ -69,7 +104,7 @@ public Operation resolveKeyspaceCommand( .partitionKeys(partitionKeys) .clusteringKeys(clusteringKeys) .ifNotExists(ifNotExists) - .comment(comment) + .customProperties(customProperties) .build(); var attempts = new OperationAttemptContainer<>(List.of(attempt)); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java new file mode 100644 index 000000000..4e0a38e60 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java @@ -0,0 +1,412 @@ +package io.stargate.sgv2.jsonapi.service.resolver; + +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; +import io.stargate.sgv2.jsonapi.config.OperationsConfig; +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Class that has validation for vectorize configuration. It is used by both collection and tables + * api. + */ +@ApplicationScoped +public class VectorizeConfigValidator { + private final OperationsConfig operationsConfig; + private final EmbeddingProvidersConfig embeddingProvidersConfig; + private final ValidateCredentials validateCredentials; + + @Inject + public VectorizeConfigValidator( + OperationsConfig operationsConfig, + EmbeddingProvidersConfig embeddingProvidersConfig, + ValidateCredentials validateCredentials) { + this.operationsConfig = operationsConfig; + this.embeddingProvidersConfig = embeddingProvidersConfig; + this.validateCredentials = validateCredentials; + } + + /** + * Validates the user-provided service configuration against internal configurations. It checks + * for the existence and enabled status of the service provider, the necessity of secret names for + * certain authentication types, the validity of provided parameters against expected types, and + * the appropriateness of model dimensions. It ensures that all required and type-specific + * conditions are met for the service to be considered valid. + * + * @param userConfig The user input vectorize service configuration. + * @param userVectorDimension The dimension specified by the user, may be null. + * @return The dimension to be used for the vector, should be from the internal configuration. It + * will be used for auto populate the vector dimension + * @throws JsonApiException If the service configuration is invalid or unsupported. + */ + public Integer validateService(VectorizeConfig userConfig, Integer userVectorDimension) { + // Only for internal tests + if (userConfig.provider().equals(ProviderConstants.CUSTOM)) { + return userVectorDimension; + } + // Check if the service provider exists and is enabled + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig = + getAndValidateProviderConfig(userConfig); + + // Check secret name for shared secret authentication, if applicable + validateAuthentication(userConfig, providerConfig); + + // Validate the model and its vector dimension: + // huggingFaceDedicated: must have vectorDimension specified + // other providers: must have model specified, and default dimension when dimension not + // specified + Integer vectorDimension = + validateModelAndDimension(userConfig, providerConfig, userVectorDimension); + + // Validate user-provided parameters against internal expectations + validateUserParameters(userConfig, providerConfig); + + return vectorDimension; + } + + /** + * Retrieves and validates the provider configuration for vector search based on user input. This + * method ensures that the specified service provider is configured and enabled in the system. + * + * @param userConfig The configuration provided by the user specifying the vector search provider. + * @return The configuration for the embedding provider, if valid. + * @throws JsonApiException If the provider is not supported or not enabled. + */ + private EmbeddingProvidersConfig.EmbeddingProviderConfig getAndValidateProviderConfig( + VectorizeConfig userConfig) { + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig = + embeddingProvidersConfig.providers().get(userConfig.provider()); + if (providerConfig == null || !providerConfig.enabled()) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Service provider '%s' is not supported", userConfig.provider()); + } + return providerConfig; + } + + /** + * Validates user authentication for creating a collection using the specified configurations. + * + *
      + *
    1. Validate that all keys (member names) in the authentication stanza (e.g. providerKey) are + * listed in the configuration for the provider as accepted keys. + *
    2. For each key-value member of the authentication stanza: + *
        + *
      1. If the value does not contain the period character "." it assumes the value is the + * name of the credential without specifying the key. + *
          + *
        1. The credential name is appended with .<key> and the secret service + * called to validate that a credential with that name exists and it has the + * named key. + *
        + *
      2. If the value does contain a period character "." it assumes the first part is the + * name of the credential and the second the name of the key within it. + *
          + *
        1. The secret service called to validate that a credential with that name exists + * and it has the named key. + *
        + *
      + *
    + * + * @param userConfig The vectorize configuration provided by the user. + * @param providerConfig The embedding provider configuration. + * @throws JsonApiException If the user authentication is invalid. + */ + private void validateAuthentication( + VectorizeConfig userConfig, EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + // Get all the accepted keys in SHARED_SECRET + Set acceptedKeys = + providerConfig.supportedAuthentications().entrySet().stream() + .filter( + config -> + config + .getKey() + .equals( + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType + .SHARED_SECRET)) + .filter(config -> config.getValue().enabled() && config.getValue().tokens() != null) + .flatMap(config -> config.getValue().tokens().stream()) + .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.TokenConfig::accepted) + .collect(Collectors.toSet()); + + // If the user hasn't provided authentication details, verify that either the 'NONE' or 'HEADER' + // authentication type is enabled. + if (userConfig.authentication() == null || userConfig.authentication().isEmpty()) { + // Check if 'NONE' authentication type is enabled + boolean noneEnabled = + Optional.ofNullable( + providerConfig + .supportedAuthentications() + .get( + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE)) + .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig::enabled) + .orElse(false); + + // Check if 'HEADER' authentication type is enabled + boolean headerEnabled = + Optional.ofNullable( + providerConfig + .supportedAuthentications() + .get( + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType + .HEADER)) + .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig::enabled) + .orElse(false); + + // If neither 'NONE' nor 'HEADER' authentication type is enabled, throw an exception + if (!noneEnabled && !headerEnabled) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Service provider '%s' does not support either 'NONE' or 'HEADER' authentication types.", + userConfig.provider()); + } + } else { + // User has provided authentication details. Validate each key against the provider's accepted + // list. + for (Map.Entry userAuth : userConfig.authentication().entrySet()) { + // Check if the key is accepted by the provider + if (!acceptedKeys.contains(userAuth.getKey())) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Service provider '%s' does not support authentication key '%s'", + userConfig.provider(), userAuth.getKey()); + } + + // Validate the credential name from secret service + // already append the .providerKey to the value in CreateCollectionCommand + if (operationsConfig.enableEmbeddingGateway()) { + validateCredentials.validate(userConfig.provider(), userAuth.getValue()); + } + } + } + } + + /** + * Validates the parameters provided by the user against the expected parameters from both the + * provider and the model configurations. This method ensures that only configured parameters are + * provided, all required parameters are included, and no unexpected parameters are passed. + * + * @param userConfig The vector search configuration provided by the user. + * @param providerConfig The configuration of the embedding provider which includes model and + * provider-level parameters. + * @throws JsonApiException if any unconfigured parameters are provided, required parameters are + * missing, or if an error occurs due to no parameters being configured but some are provided + * by the user. + */ + private void validateUserParameters( + VectorizeConfig userConfig, EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + // 0. Combine provider level and model level parameters + List allParameters = + new ArrayList<>(); + // Add all provider level parameters + allParameters.addAll(providerConfig.parameters()); + // Get all the parameters except "vectorDimension" for the model -- model has been validated in + // the previous step, huggingfaceDedicated uses endpoint-defined-model + List modelParameters = + providerConfig.models().stream() + .filter(m -> m.name().equals(userConfig.modelName())) + .findFirst() + .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig::parameters) + .map( + params -> + params.stream() + .filter( + param -> + !param + .name() + .equals( + "vectorDimension")) // Exclude 'vectorDimension' parameter + .collect(Collectors.toList())) + .get(); + // Add all model level parameters + allParameters.addAll(modelParameters); + // 1. Error if the user provided un-configured parameters + // Two level parameters have unique names, should be fine here + Set expectedParamNames = + allParameters.stream() + .map(EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig::name) + .collect(Collectors.toSet()); + Map userParameters = + (userConfig.parameters() != null) ? userConfig.parameters() : Collections.emptyMap(); + // Check for unconfigured parameters provided by the user + userParameters + .keySet() + .forEach( + userParamName -> { + if (!expectedParamNames.contains(userParamName)) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Unexpected parameter '%s' for the provider '%s' provided", + userParamName, userConfig.provider()); + } + }); + + // 2. Error if the user doesn't provide required parameters + // Check for missing required parameters and collect them for type validation + List parametersToValidate = + new ArrayList<>(); + allParameters.forEach( + expectedParamConfig -> { + if (expectedParamConfig.required() + && !userParameters.containsKey(expectedParamConfig.name())) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Required parameter '%s' for the provider '%s' missing", + expectedParamConfig.name(), userConfig.provider()); + } + if (userParameters.containsKey(expectedParamConfig.name())) { + parametersToValidate.add(expectedParamConfig); + } + }); + + // 3. Validate parameter types if no errors occurred in previous steps + parametersToValidate.forEach( + expectedParamConfig -> + validateParameterType( + expectedParamConfig, userParameters.get(expectedParamConfig.name()))); + } + + /** + * Validates the type of parameter provided by the user against the expected type defined in the + * provider's configuration. This method checks if the type of the user-provided parameter matches + * the expected type, throwing an exception if there is a mismatch. + * + * @param expectedParamConfig The expected configuration for the parameter which includes its + * expected type. + * @param userParamValue The value of the parameter provided by the user. + * @throws JsonApiException if the type of the parameter provided by the user does not match the + * expected type. + */ + private void validateParameterType( + EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig expectedParamConfig, + Object userParamValue) { + + EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType expectedParamType = + expectedParamConfig.type(); + boolean typeMismatch = + expectedParamType == EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType.STRING + && !(userParamValue instanceof String) + || expectedParamType + == EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType.NUMBER + && !(userParamValue instanceof Number) + || expectedParamType + == EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterType.BOOLEAN + && !(userParamValue instanceof Boolean); + + if (typeMismatch) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "The provided parameter '%s' type is incorrect. Expected: '%s'", + expectedParamConfig.name(), expectedParamType); + } + } + + /** + * Validates the model name and vector dimension provided in the user configuration against the + * specified embedding provider configuration. + * + * @param userConfig the user-specified vectorization configuration + * @param providerConfig the configuration of the embedding provider + * @param userVectorDimension the vector dimension provided by the user, or null if not provided + * @return the validated vector dimension to be used for the model + * @throws JsonApiException if the model name is not found, or if the dimension is invalid + */ + private Integer validateModelAndDimension( + VectorizeConfig userConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + Integer userVectorDimension) { + // Find the model configuration by matching the model name + // 1. huggingfaceDedicated does not require model, but requires dimension + if (userConfig.provider().equals(ProviderConstants.HUGGINGFACE_DEDICATED)) { + if (userVectorDimension == null) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "'dimension' is needed for provider %s", ProviderConstants.HUGGINGFACE_DEDICATED); + } + } + // 2. other providers do require model + if (userConfig.modelName() == null) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "'modelName' is needed for provider %s", userConfig.provider()); + } + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model = + providerConfig.models().stream() + .filter(m -> m.name().equals(userConfig.modelName())) + .findFirst() + .orElseThrow( + () -> + ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Model name '%s' for provider '%s' is not supported", + userConfig.modelName(), userConfig.provider())); + + // Handle models with a fixed vector dimension + if (model.vectorDimension().isPresent() && model.vectorDimension().get() != 0) { + Integer configVectorDimension = model.vectorDimension().get(); + if (userVectorDimension == null) { + return configVectorDimension; // Use model's dimension if user hasn't specified any + } else if (!configVectorDimension.equals(userVectorDimension)) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "The provided dimension value '%s' doesn't match the model's supported dimension value '%s'", + userVectorDimension, configVectorDimension); + } + return configVectorDimension; + } + + // Handle models with a range of acceptable dimensions + return model.parameters().stream() + .filter(param -> param.name().equals("vectorDimension")) + .findFirst() + .map(param -> validateRangeDimension(param, userVectorDimension)) + .orElse(userVectorDimension); // should not go here + } + + /** + * Validates the user-provided vector dimension against the dimension parameter's validation + * constraints. + * + * @param param the parameter configuration containing validation constraints + * @param userVectorDimension the vector dimension provided by the user + * @return the appropriate vector dimension based on parameter configuration + * @throws JsonApiException if the user-provided dimension is not valid + */ + private Integer validateRangeDimension( + EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig param, + Integer userVectorDimension) { + // Use the default value if the user has not provided a dimension + if (userVectorDimension == null) { + return Integer.valueOf(param.defaultValue().get()); + } + + // Extract validation type and values for comparison + Map.Entry> + entry = param.validation().entrySet().iterator().next(); + EmbeddingProvidersConfig.EmbeddingProviderConfig.ValidationType validationType = entry.getKey(); + List validationValues = entry.getValue(); + + // Perform validation based on the validation type + switch (validationType) { + case NUMERIC_RANGE -> { + if (userVectorDimension < validationValues.get(0) + || userVectorDimension > validationValues.get(1)) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "The provided dimension value (%d) is not within the supported numeric range [%d, %d]", + userVectorDimension, validationValues.get(0), validationValues.get(1)); + } + } + case OPTIONS -> { + if (!validationValues.contains(userVectorDimension)) { + String validatedValuesStr = + String.join( + ", ", validationValues.stream().map(Object::toString).toArray(String[]::new)); + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "The provided dimension value '%s' is not within the supported options [%s]", + userVectorDimension, validatedValuesStr); + } + } + } + return userVectorDimension; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java index 56d30e4dc..3c054d51c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateCollectionCommand; +import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.constants.TableCommentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; @@ -280,14 +281,14 @@ public static CreateCollectionCommand collectionSettingToCreateCollectionCommand CreateCollectionCommand.Options.IndexingConfig indexingConfig = null; // populate the vectorSearchConfig if (collectionSetting.vectorConfig().vectorEnabled()) { - CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig vectorizeConfig = null; + VectorizeConfig vectorizeConfig = null; if (collectionSetting.vectorConfig().vectorizeConfig() != null) { Map authentication = collectionSetting.vectorConfig().vectorizeConfig().authentication(); Map parameters = collectionSetting.vectorConfig().vectorizeConfig().parameters(); vectorizeConfig = - new CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig( + new VectorizeConfig( collectionSetting.vectorConfig().vectorizeConfig().provider(), collectionSetting.vectorConfig().vectorizeConfig().modelName(), authentication == null ? null : Map.copyOf(authentication), diff --git a/src/main/resources/errors.yaml b/src/main/resources/errors.yaml index 667afe0da..1c220c847 100644 --- a/src/main/resources/errors.yaml +++ b/src/main/resources/errors.yaml @@ -281,10 +281,15 @@ request-errors: body: |- Vector column data type definition provided in the request is incorrect. Vector type accepts `dimension`. `dimension` is an integer value. + `service` definition is optional for vector type. It's used if the embedding needs to be done using data-api. Example map type definition: "column_name": { "type": "vector", - "dimension": 1536 + "dimension": 1024, + "service": { + "provider": "nvidia", + "modelName": "NV-Embed-QA" + } } # ================================================================================================================ diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java index e200a8220..83d21cf76 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java @@ -3,6 +3,7 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; import io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype.ColumnType; +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.ArrayList; @@ -46,50 +47,50 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "allTypesTable", - "definition": { - "columns": { - "ascii_type": "ascii", - "bigint_type": "bigint", - "blob_type": "blob", - "boolean_type": "boolean", - "date_type": "date", - "decimal_type": "decimal", - "double_type": "double", - "duration_type": "duration", - "float_type": "float", - "inet_type": "inet", - "int_type": "int", - "smallint_type": "smallint", - "text_type": "text", - "time_type": "time", - "timestamp_type": "timestamp", - "tinyint_type": "tinyint", - "uuid_type": "uuid", - "varint_type": "varint", - "map_type": { - "type": "map", - "keyType": "text", - "valueType": "int" - }, - "list_type": { - "type": "list", - "valueType": "text" - }, - "set_type": { - "type": "set", - "valueType": "text" - }, - "vector_type": { - "type": "vector", - "dimension": 5 - } - }, - "primaryKey": "text_type" - } - } - """, + { + "name": "allTypesTable", + "definition": { + "columns": { + "ascii_type": "ascii", + "bigint_type": "bigint", + "blob_type": "blob", + "boolean_type": "boolean", + "date_type": "date", + "decimal_type": "decimal", + "double_type": "double", + "duration_type": "duration", + "float_type": "float", + "inet_type": "inet", + "int_type": "int", + "smallint_type": "smallint", + "text_type": "text", + "time_type": "time", + "timestamp_type": "timestamp", + "tinyint_type": "tinyint", + "uuid_type": "uuid", + "varint_type": "varint", + "map_type": { + "type": "map", + "keyType": "text", + "valueType": "int" + }, + "list_type": { + "type": "list", + "valueType": "text" + }, + "set_type": { + "type": "set", + "valueType": "text" + }, + "vector_type": { + "type": "vector", + "dimension": 5 + } + }, + "primaryKey": "text_type" + } + } + """, "allTypesTable", false, null, @@ -99,24 +100,24 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "primaryKeyAsStringTable", - "definition": { - "columns": { - "id": { - "type": "text" - }, - "age": { - "type": "int" - }, - "name": { - "type": "text" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "primaryKeyAsStringTable", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "name": { + "type": "text" + } + }, + "primaryKey": "id" + } + } + """, "primaryKeyAsStringTable", false, null, @@ -126,24 +127,24 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "primaryKeyAsStringTable", - "definition": { - "columns": { - "id": { - "type": "text" - }, - "age": { - "type": "int" - }, - "name": { - "type": "text" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "primaryKeyAsStringTable", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "name": { + "type": "text" + } + }, + "primaryKey": "id" + } + } + """, "primaryKeyAsStringTable", false, null, @@ -154,21 +155,21 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "primaryKeyWithQuotable", - "definition": { - "primaryKey": "_id", - "columns": { - "_id": { - "type": "text" - }, - "name": { - "type": "text" - } - } - } - } - """, + { + "name": "primaryKeyWithQuotable", + "definition": { + "primaryKey": "_id", + "columns": { + "_id": { + "type": "text" + }, + "name": { + "type": "text" + } + } + } + } + """, "primaryKeyWithQuotable", false, null, @@ -179,18 +180,18 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "columnTypeusingShortHandTable", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text" - }, - "primaryKey": "id" - } - } - """, + { + "name": "columnTypeusingShortHandTable", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text" + }, + "primaryKey": "id" + } + } + """, "columnTypeusingShortHandTable", false, null, @@ -201,31 +202,31 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "primaryKeyAsJsonObjectTable", - "definition": { - "columns": { - "id": { - "type": "text" - }, - "age": { - "type": "int" - }, - "name": { - "type": "text" - } - }, - "primaryKey": { - "partitionBy": [ - "id" - ], - "partitionSort" : { - "name" : 1, "age" : -1 - } - } - } - } - """, + { + "name": "primaryKeyAsJsonObjectTable", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "name": { + "type": "text" + } + }, + "primaryKey": { + "partitionBy": [ + "id" + ], + "partitionSort" : { + "name" : 1, "age" : -1 + } + } + } + } + """, "primaryKeyAsJsonObjectTable", false, null, @@ -238,24 +239,24 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidPrimaryKeyTable", - "definition": { - "columns": { - "id": { - "type": "text" - }, - "age": { - "type": "int" - }, - "name": { - "type": "text" - } - }, - "primaryKey": "error_column" - } - } - """, + { + "name": "invalidPrimaryKeyTable", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "name": { + "type": "text" + } + }, + "primaryKey": "error_column" + } + } + """, "invalidPrimaryKeyTable", true, missingDefinition.code, @@ -266,25 +267,25 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidPartitionByTable", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text" - }, - "primaryKey": { - "partitionBy": [ - "error_column" - ], - "partitionSort" : { - "name" : 1, "age" : -1 - } - } - } - } - """, + { + "name": "invalidPartitionByTable", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text" + }, + "primaryKey": { + "partitionBy": [ + "error_column" + ], + "partitionSort" : { + "name" : 1, "age" : -1 + } + } + } + } + """, "invalidPartitionByTable", true, missingDefinition.code, @@ -295,25 +296,25 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidPartitionSortTable", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text" - }, - "primaryKey": { - "partitionBy": [ - "id" - ], - "partitionSort" : { - "error_column" : 1, "age" : -1 - } - } - } - } - """, + { + "name": "invalidPartitionSortTable", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text" + }, + "primaryKey": { + "partitionBy": [ + "id" + ], + "partitionSort" : { + "error_column" : 1, "age" : -1 + } + } + } + } + """, "invalidPartitionSortTable", true, missingDefinition.code, @@ -325,25 +326,25 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidPartitionSortOrderingValueTable", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text" - }, - "primaryKey": { - "partitionBy": [ - "id" - ], - "partitionSort" : { - "id" : 1, "age" : 0 - } - } - } - } - """, + { + "name": "invalidPartitionSortOrderingValueTable", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text" + }, + "primaryKey": { + "partitionBy": [ + "id" + ], + "partitionSort" : { + "id" : 1, "age" : 0 + } + } + } + } + """, "invalidPartitionSortOrderingValueTable", true, se.code, @@ -354,25 +355,25 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidPartitionSortOrderingValueTypeTable", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text" - }, - "primaryKey": { - "partitionBy": [ - "id" - ], - "partitionSort" : { - "id" : 1, "age" : "invalid" - } - } - } - } - """, + { + "name": "invalidPartitionSortOrderingValueTypeTable", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text" + }, + "primaryKey": { + "partitionBy": [ + "id" + ], + "partitionSort" : { + "id" : 1, "age" : "invalid" + } + } + } + } + """, "invalidPartitionSortOrderingValueTypeTable", true, se.code, @@ -391,26 +392,26 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidColumnTypeTable", - "definition": { - "columns": { - "id": "invalid_type", - "age": "int", - "name": "text" - }, - "primaryKey": { - "partitionBy": [ - "id" - ], - "partitionSort": { - "id": 1, - "age": -1 - } - } - } - } - """, + { + "name": "invalidColumnTypeTable", + "definition": { + "columns": { + "id": "invalid_type", + "age": "int", + "name": "text" + }, + "primaryKey": { + "partitionBy": [ + "id" + ], + "partitionSort": { + "id": 1, + "age": -1 + } + } + } + } + """, "invalidColumnTypeTable", true, invalidType.code, @@ -422,26 +423,26 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidColumnTypeTable", - "definition": { - "columns": { - "id": null, - "age": "int", - "name": "text" - }, - "primaryKey": { - "partitionBy": [ - "id" - ], - "partitionSort": { - "id": 1, - "age": -1 - } - } - } - } - """, + { + "name": "invalidColumnTypeTable", + "definition": { + "columns": { + "id": null, + "age": "int", + "name": "text" + }, + "primaryKey": { + "partitionBy": [ + "id" + ], + "partitionSort": { + "id": 1, + "age": -1 + } + } + } + } + """, "invalidColumnTypeTable", true, columnTypeInvalid.code, @@ -453,22 +454,22 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidMapType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "map_type": { - "type": "map", - "keyType": "text" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidMapType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "map_type": { + "type": "map", + "keyType": "text" + } + }, + "primaryKey": "id" + } + } + """, "invalidMapType value type not provided", true, invalidMapType.code, @@ -478,22 +479,22 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidMapType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "map_type": { - "type": "map", - "valueType": "text" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidMapType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "map_type": { + "type": "map", + "valueType": "text" + } + }, + "primaryKey": "id" + } + } + """, "invalidMapType key type not provided", true, invalidMapType.code, @@ -503,23 +504,23 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidMapType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "map_type": { - "type": "map", - "valueType": "list", - "keyType": "text" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidMapType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "map_type": { + "type": "map", + "valueType": "list", + "keyType": "text" + } + }, + "primaryKey": "id" + } + } + """, "invalidMapType not primitive type provided", true, invalidMapType.code, @@ -529,23 +530,23 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidMapType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "map_type": { - "type": "map", - "valueType": "text", - "keyType": "list" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidMapType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "map_type": { + "type": "map", + "valueType": "text", + "keyType": "list" + } + }, + "primaryKey": "id" + } + } + """, "invalidMapType not primitive type provided", true, invalidMapType.code, @@ -557,21 +558,21 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidListType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "list_type": { - "type": "list" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidListType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "list_type": { + "type": "list" + } + }, + "primaryKey": "id" + } + } + """, "invalidListType value type not provided", true, invalidListType.code, @@ -581,22 +582,22 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidListType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "list_type": { - "type": "list", - "valueType": "list" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidListType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "list_type": { + "type": "list", + "valueType": "list" + } + }, + "primaryKey": "id" + } + } + """, "invalidListType not primitive type provided", true, invalidListType.code, @@ -608,21 +609,21 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidSetType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "set_type": { - "type": "set" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidSetType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "set_type": { + "type": "set" + } + }, + "primaryKey": "id" + } + } + """, "invalidSetType value type not provided", true, invalidSetType.code, @@ -632,22 +633,22 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidSetType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "set_type": { - "type": "set", - "valueType": "list" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidSetType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "set_type": { + "type": "set", + "valueType": "list" + } + }, + "primaryKey": "id" + } + } + """, "invalidSetType not primitive type provided", true, invalidSetType.code, @@ -660,22 +661,22 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidVectorType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "vector_type": { - "type": "vector", - "dimension": -5 - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidVectorType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "vector_type": { + "type": "vector", + "dimension": -5 + } + }, + "primaryKey": "id" + } + } + """, "invalidVectorType value type not provided", true, invalidVectorType.code, @@ -685,27 +686,157 @@ private static Stream allTableData() { Arguments.of( new CreateTableTestData( """ - { - "name": "invalidVectorType", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "vector_type": { - "type": "vector", - "dimension": "aaa" - } - }, - "primaryKey": "id" - } - } - """, + { + "name": "invalidVectorType", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "vector_type": { + "type": "vector", + "dimension": "aaa" + } + }, + "primaryKey": "id" + } + } + """, "invalidVectorType not primitive type provided", true, invalidVectorType.code, invalidVectorType.body))); + // vector type with vectorize + testCases.add( + Arguments.of( + new CreateTableTestData( + """ + { + "name": "vectorizeConfigTest", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "content": { + "type": "vector", + "dimension": 1024, + "service": { + "provider": "nvidia", + "modelName": "NV-Embed-QA" + } + } + }, + "primaryKey": "id" + } + } + """, + "primaryKeyAsStringTable", + false, + null, + null))); + + // vector type with invalid vectorixe config + testCases.add( + Arguments.of( + new CreateTableTestData( + """ + { + "name": "invalidVectorizeServiceNameConfig", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "content": { + "type": "vector", + "dimension": 1024, + "service": { + "provider": "invalid_service", + "modelName": "NV-Embed-QA" + } + } + }, + "primaryKey": "id" + } + } + """, + "invalidVectorizeServiceNameConfig", + true, + ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.name(), + "The provided options are invalid: Service provider 'invalid_service' is not supported"))); + + // vector type with invalid model name config + testCases.add( + Arguments.of( + new CreateTableTestData( + """ + { + "name": "invalidVectorizeModelNameConfig", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "content": { + "type": "vector", + "dimension": 1024, + "service": { + "provider": "mistral", + "modelName": "mistral-embed-invalid" + } + } + }, + "primaryKey": "id" + } + } + """, + "invalidVectorizeModelNameConfig", + true, + ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.name(), + "The provided options are invalid: Model name 'mistral-embed-invalid' for provider 'mistral' is not supported"))); + // vector type with dimension mismatch + testCases.add( + Arguments.of( + new CreateTableTestData( + """ + { + "name": "invalidVectorizeModelNameConfig", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "age": { + "type": "int" + }, + "content": { + "type": "vector", + "dimension": 1536, + "service": { + "provider": "mistral", + "modelName": "mistral-embed" + } + } + }, + "primaryKey": "id" + } + } + """, + "invalidVectorizeModelNameConfig", + true, + ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.name(), + "The provided options are invalid: The provided dimension value '1536' doesn't match the model's supported dimension value '1024'"))); return testCases.stream(); } }