Skip to content

Commit

Permalink
Add parameters verification in createCollection (#1089)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hazel-Datastax authored May 16, 2024
1 parent 2bfe623 commit f59d9ca
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,19 @@ public static String generateComment(
return tableCommentNode.toString();
}

/**
* Validates the vector search options provided in a create collection command. It checks if
* vector search is enabled globally, and validates the specific vectorization service
* configuration provided by the user. It also ensures the specified vector dimension complies
* with config limits.
*
* @param vector The vector search configuration provided by the user in the create collection
* command.
* @return The validated and potentially modified (adding default vector dimension) vector search
* configuration.
* @throws JsonApiException If vector search is disabled globally or the user configuration is
* invalid.
*/
private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions(
CreateCollectionCommand.Options.VectorSearchConfig vector) {
if (!dataStoreConfig.vectorSearchEnabled()) {
Expand Down Expand Up @@ -247,13 +260,25 @@ private Integer validateService(
// Check secret name for shared secret authentication, if applicable
validateAuthentication(userConfig, providerConfig);

// Validate the model and its vector dimension, if userVectorDimension is null, return value
// will be the config/default value
Integer vectorDimension =
validateModelAndDimension(userConfig, providerConfig, userVectorDimension);

// Validate user-provided parameters against internal expectations
validateUserParameters(userConfig, providerConfig);

// Validate the model and its vector dimension
return validateModelAndDimension(userConfig, providerConfig, userVectorDimension);
return vectorDimension;
}

/**
* Retrieves and validates the provider configuration for vector search based on user input. This
* method ensures that the specified service provider is configured and enabled in the system.
*
* @param userConfig The configuration provided by the user specifying the vector search provider.
* @return The configuration for the embedding provider, if valid.
* @throws JsonApiException If the provider is not supported or not enabled.
*/
private EmbeddingProvidersConfig.EmbeddingProviderConfig getAndValidateProviderConfig(
CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig) {
EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig =
Expand All @@ -270,7 +295,7 @@ private EmbeddingProvidersConfig.EmbeddingProviderConfig getAndValidateProviderC
*
* @param userConfig The vectorize configuration provided by the user.
* @param providerConfig The embedding provider configuration.
* @throws ApiException If the user authentication is invalid.
* @throws JsonApiException If the user authentication is invalid.
*/
private void validateAuthentication(
CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig,
Expand Down Expand Up @@ -319,25 +344,53 @@ private void validateAuthentication(
}
}

/**
* Validates the parameters provided by the user against the expected parameters from both the
* provider and the model configurations. This method ensures that only configured parameters are
* provided, all required parameters are included, and no unexpected parameters are passed.
*
* @param userConfig The vector search configuration provided by the user.
* @param providerConfig The configuration of the embedding provider which includes model and
* provider-level parameters.
* @throws JsonApiException if any unconfigured parameters are provided, required parameters are
* missing, or if an error occurs due to no parameters being configured but some are provided
* by the user.
*/
private void validateUserParameters(
CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig,
EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) {
// 0. Combine provider level and model level parameters
List<EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig> allParameters =
new ArrayList<>();
// Add all provider level parameters
allParameters.addAll(providerConfig.parameters());
// Get all the parameters except "vectorDimension" for the model -- model has been validated in
// the previous step
List<EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig> modelParameters =
providerConfig.models().stream()
.filter(m -> m.name().equals(userConfig.modelName()))
.findFirst()
.map(EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig::parameters)
.map(
params ->
params.stream()
.filter(
param ->
!param
.name()
.equals(
"vectorDimension")) // Exclude 'vectorDimension' parameter
.collect(Collectors.toList()))
.get();
// Add all model level parameters
allParameters.addAll(modelParameters);

// 1. Error if the user provided unconfigured parameters
if (providerConfig.parameters() == null || providerConfig.parameters().isEmpty()) {
// If providerConfig.parameters() is null or empty but the user still provides parameters,
// it's an error
if (userConfig.parameters() != null && !userConfig.parameters().isEmpty()) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Parameters provided but the provider '%s' expects none", userConfig.provider());
}
// Exit early if no parameters are configured
return;
}
// Two level parameters have unique names, should be fine here
Set<String> expectedParamNames =
providerConfig.parameters().stream()
allParameters.stream()
.map(EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig::name)
.collect(Collectors.toSet());

Map<String, Object> userParameters =
(userConfig.parameters() != null) ? userConfig.parameters() : Collections.emptyMap();
// Check for unconfigured parameters provided by the user
Expand All @@ -356,20 +409,18 @@ private void validateUserParameters(
// Check for missing required parameters and collect them for type validation
List<EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig> parametersToValidate =
new ArrayList<>();
providerConfig
.parameters()
.forEach(
expectedParamConfig -> {
if (expectedParamConfig.required()
&& !userParameters.containsKey(expectedParamConfig.name())) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Required parameter '%s' for the provider '%s' missing",
expectedParamConfig.name(), userConfig.provider());
}
if (userParameters.containsKey(expectedParamConfig.name())) {
parametersToValidate.add(expectedParamConfig);
}
});
allParameters.forEach(
expectedParamConfig -> {
if (expectedParamConfig.required()
&& !userParameters.containsKey(expectedParamConfig.name())) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Required parameter '%s' for the provider '%s' missing",
expectedParamConfig.name(), userConfig.provider());
}
if (userParameters.containsKey(expectedParamConfig.name())) {
parametersToValidate.add(expectedParamConfig);
}
});

// 3. Validate parameter types if no errors occurred in previous steps
parametersToValidate.forEach(
Expand All @@ -378,6 +429,17 @@ private void validateUserParameters(
expectedParamConfig, userParameters.get(expectedParamConfig.name())));
}

/**
* Validates the type of parameter provided by the user against the expected type defined in the
* provider's configuration. This method checks if the type of the user-provided parameter matches
* the expected type, throwing an exception if there is a mismatch.
*
* @param expectedParamConfig The expected configuration for the parameter which includes its
* expected type.
* @param userParamValue The value of the parameter provided by the user.
* @throws JsonApiException if the type of the parameter provided by the user does not match the
* expected type.
*/
private void validateParameterType(
EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig expectedParamConfig,
Object userParamValue) {
Expand Down Expand Up @@ -409,9 +471,8 @@ private void validateParameterType(
* @param providerConfig the configuration of the embedding provider
* @param userVectorDimension the vector dimension provided by the user, or null if not provided
* @return the validated vector dimension to be used for the model
* @throws ApiException if the model name is not found, or if the dimension is invalid
* @throws JsonApiException if the model name is not found, or if the dimension is invalid
*/
// TODO: check model parameters provided by the user, will support in the future
private Integer validateModelAndDimension(
CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig,
EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig,
Expand Down Expand Up @@ -455,7 +516,7 @@ private Integer validateModelAndDimension(
* @param param the parameter configuration containing validation constraints
* @param userVectorDimension the vector dimension provided by the user
* @return the appropriate vector dimension based on parameter configuration
* @throws ApiException if the user-provided dimension is not valid
* @throws JsonApiException if the user-provided dimension is not valid
*/
private Integer validateRangeDimension(
EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig param,
Expand Down
Loading

0 comments on commit f59d9ca

Please sign in to comment.