Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling for multi part credentials in createCollection #1164

Merged
merged 12 commits into from
Jun 13, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,7 @@ public enum ErrorCode {

VECTORIZE_FEATURE_NOT_AVAILABLE("Vectorize feature is not available in the environment"),
VECTORIZE_SERVICE_NOT_REGISTERED("Vectorize service name provided is not registered : "),

VECTORIZE_SERVICE_TYPE_UNSUPPORTED("Vectorize service type unsupported "),

VECTORIZE_SERVICE_TYPE_UNAVAILABLE("Vectorize service unavailable : "),
VECTORIZE_INVALID_SHARED_KEY_VALUE_FORMAT("Invalid authentication value format"),
VECTORIZE_USAGE_ERROR("Vectorize search can't be used with other sort clause"),
VECTORIZE_INVALID_AUTHENTICATION_TYPE("Invalid vectorize authentication type"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ private Integer validateService(
validateAuthentication(userConfig, providerConfig);

// Validate the model and its vector dimension:
// huggingFaceDedicated: must have vectorDimension specified
// other providers: must have model specified, and default dimension when dimension not
// huggingFaceDedicated: must have vectorDimension specified
// other providers: must have model specified, and default dimension when dimension not
// specified
Integer vectorDimension =
validateModelAndDimension(userConfig, providerConfig, userVectorDimension);
Expand Down Expand Up @@ -294,36 +294,74 @@ private EmbeddingProvidersConfig.EmbeddingProviderConfig getAndValidateProviderC
/**
* Validates user authentication for creating a collection using the specified configurations.
*
* <ol>
* <li>Validate that all keys (member names) in the authentication stanza (e.g. providerKey) are
* listed in the configuration for the provider as accepted keys.
* <li>For each key-value member of the authentication stanza:
* <ol type="a">
* <li>If the value does not contain the period character "." it assumes the value is the
* name of the credential without specifying the key.
* <ol type="i">
* <li>The credential name is appended with .&lt;key&gt; and the secret service
* called to validate that a credential with that name exists and it has the
* named key.
* </ol>
* <li>If the value does contain a period character "." it assumes the first part is the
* name of the credential and the second the name of the key within it.
* <ol type="i">
* <li>The secret service called to validate that a credential with that name exists
* and it has the named key.
* </ol>
* </ol>
* </ol>
*
* @param userConfig The vectorize configuration provided by the user.
* @param providerConfig The embedding provider configuration.
* @throws JsonApiException If the user authentication is invalid.
*/
private void validateAuthentication(
CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig,
EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) {
// Get all the accepted keys in auth
List<String> acceptedKeys =
providerConfig.supportedAuthentications().values().stream()
.filter(config -> config.enabled() && config.tokens() != null)
.flatMap(config -> config.tokens().stream())
// Get all the accepted keys in SHARED_SECRET
Set<String> acceptedKeys =
providerConfig.supportedAuthentications().entrySet().stream()
.filter(
config ->
config
.getKey()
.equals(
EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType
.SHARED_SECRET))
.filter(config -> config.getValue().enabled() && config.getValue().tokens() != null)
.flatMap(config -> config.getValue().tokens().stream())
.map(EmbeddingProvidersConfig.EmbeddingProviderConfig.TokenConfig::accepted)
.toList();
.collect(Collectors.toSet());

// If the user hasn't provided authentication details, verify that either the 'NONE' or 'HEADER'
// authentication type is enabled.
if (userConfig.authentication() == null || userConfig.authentication().isEmpty()) {
EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig noneAuthConfig =
providerConfig
.supportedAuthentications()
.get(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE);
EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig headerAuthConfig =
providerConfig
.supportedAuthentications()
.get(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.HEADER);

// Check if either 'NONE' or 'HEADER' authentication type is enabled
boolean noneEnabled = (noneAuthConfig != null && noneAuthConfig.enabled());
boolean headerEnabled = (headerAuthConfig != null && headerAuthConfig.enabled());
// Check if 'NONE' authentication type is enabled
boolean noneEnabled =
Optional.ofNullable(
providerConfig
.supportedAuthentications()
.get(
EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE))
.map(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig::enabled)
.orElse(false);

// Check if 'HEADER' authentication type is enabled
boolean headerEnabled =
Optional.ofNullable(
providerConfig
.supportedAuthentications()
.get(
EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType
.HEADER))
.map(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig::enabled)
.orElse(false);

// If neither 'NONE' nor 'HEADER' authentication type is enabled, throw an exception
if (!noneEnabled && !headerEnabled) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Service provider '%s' does not support either 'NONE' or 'HEADER' authentication types.",
Expand All @@ -333,34 +371,41 @@ private void validateAuthentication(
// User has provided authentication details. Validate each key against the provider's accepted
// list.
for (Map.Entry<String, String> userAuth : userConfig.authentication().entrySet()) {
// Check if the key is accepted by the provider
if (!acceptedKeys.contains(userAuth.getKey())) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"Service provider '%s' does not support authentication key '%s'",
userConfig.provider(), userAuth.getKey());
} else {
if (userAuth.getKey().equals("providerKey")) {
// sharedKeyValue must be in the format of "keyName.providerKey"
String sharedKeyValue = userAuth.getValue();
if (sharedKeyValue == null || sharedKeyValue.isEmpty()) {
throw ErrorCode.VECTORIZE_INVALID_SHARED_KEY_VALUE_FORMAT.toApiException(
"missing value");
}
int dotIndex = sharedKeyValue.lastIndexOf('.');
if (dotIndex <= 0) {
throw ErrorCode.VECTORIZE_INVALID_SHARED_KEY_VALUE_FORMAT.toApiException(
"providerKey value should be formatted as '[keyName].providerKey'");
}

String providerKeyString = sharedKeyValue.substring(dotIndex + 1);
if (!"providerKey".equals(providerKeyString)) {
throw ErrorCode.VECTORIZE_INVALID_SHARED_KEY_VALUE_FORMAT.toApiException(
"providerKey value should be formatted as '[keyName].providerKey'");
}
}
if (operationsConfig.enableEmbeddingGateway()) {
validateCredentials.validate(userConfig.provider(), userAuth.getValue());
}

// Get the full credential name by either appending the key(no dot) to the value or using
// the value(has dot)
String sharedKeyValue = userAuth.getValue();
String credentialName =
sharedKeyValue.lastIndexOf('.') <= 0
? sharedKeyValue + "." + userAuth.getKey()
: sharedKeyValue;

// If the value contains a period character, validate the key name
if (sharedKeyValue.lastIndexOf('.') > 0) {
String keyName = sharedKeyValue.substring(sharedKeyValue.lastIndexOf('.') + 1);
if (!keyName.equals(userAuth.getKey())) {
// If the key name does not match, throw an exception
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
String.format(
"Unexpected credential name '%s'. The format should be '%s' or '%s'.",
sharedKeyValue,
sharedKeyValue.substring(0, sharedKeyValue.lastIndexOf('.')),
sharedKeyValue.substring(0, sharedKeyValue.lastIndexOf('.'))
+ "."
+ userAuth.getKey()));
}
}

// Validate the credential name from secret service
if (operationsConfig.enableEmbeddingGateway()) {
validateCredentials.validate(userConfig.provider(), credentialName);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,6 @@ public void happyPathVectorizeSearch() throws Exception {
"service": {
"provider": "openai",
"modelName": "text-embedding-ada-002",
"authentication": {
"x-embedding-api-key": "user_key"
},
"parameters": {
"projectId": "test project"
}
Expand All @@ -491,9 +488,6 @@ public void happyPathVectorizeSearch() throws Exception {
Map<String, Object> parameterMap = new HashMap<>();
parameterMap.put("projectId", "test project");

Map<String, Object> authenticationMap = new HashMap<>();
authenticationMap.put("x-embedding-api-key", "user_key");

assertThat(result)
.isInstanceOfSatisfying(
CreateCollectionCommand.class,
Expand All @@ -509,10 +503,6 @@ public void happyPathVectorizeSearch() throws Exception {
.isEqualTo("openai");
assertThat(createCollection.options().vector().vectorizeConfig().modelName())
.isEqualTo("text-embedding-ada-002");
assertThat(createCollection.options().vector().vectorizeConfig().authentication())
.isNotNull();
assertThat(createCollection.options().vector().vectorizeConfig().authentication())
.isEqualTo(authenticationMap);
assertThat(createCollection.options().vector().vectorizeConfig().parameters())
.isNotNull();
assertThat(createCollection.options().vector().vectorizeConfig().parameters())
Expand Down
Loading