diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java index 624a32671a..5b825ba1d3 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java @@ -7,7 +7,6 @@ public interface EmbeddingProviderConfigStore { record ServiceConfig( String serviceName, String serviceProvider, - String apiKey, String baseUrl, // `implementationClass` is the custom class that implements the EmbeddingProvider interface Optional> implementationClass, @@ -16,21 +15,14 @@ record ServiceConfig( public static ServiceConfig provider( String serviceName, String serviceProvider, - String apiKey, String baseUrl, RequestProperties requestConfiguration) { - return new ServiceConfig( - serviceName, serviceProvider, apiKey, baseUrl, null, requestConfiguration); + return new ServiceConfig(serviceName, serviceProvider, baseUrl, null, requestConfiguration); } public static ServiceConfig custom(Optional> implementationClass) { return new ServiceConfig( - ProviderConstants.CUSTOM, - ProviderConstants.CUSTOM, - null, - null, - implementationClass, - null); + ProviderConstants.CUSTOM, ProviderConstants.CUSTOM, null, implementationClass, null); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfig.java index ab2eb79313..010eea773a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfig.java @@ -26,11 +26,52 @@ interface EmbeddingProviderConfig { @JsonProperty String url(); + /** + * A map of supported authentications. HEADER, SHARED_SECRET and NONE are the only techniques + * the DataAPI supports (i.e. the key of map can only be HEADER, SHARED_SECRET or NONE). + * + * @return + */ @JsonProperty - String apiKey(); + Map supportedAuthentications(); - @JsonProperty - List supportedAuthentication(); + enum AuthenticationType { + NONE, + HEADER, + SHARED_SECRET + } + + /** + * enabled() is a JSON boolean to flag if this technique is supported. If false the rest of the + * object has no impact. Any technique not listed is also not supported for the provider. + * + *

tokens() is a list of token mappings, that map from the name accepted by the Data API to + * how they are forwarded to the provider. The provider information is included for the code, + * and to allow users to see what we do with the values. + */ + interface AuthenticationConfig { + @JsonProperty + boolean enabled(); + + @JsonProperty + List tokens(); + } + + /** + * For the HEADER technique the `accepted` value is the name of the header the client should + * send, and `forwarded` is the name of the header the Data API will send to the provider. + * + *

For the SHARED_SECRET technique the `accepted` value is the name used in the + * authentication option with createCollection that maps to the name of a shared secret, and + * `forwarded` is the name of the header the Data API will send to the provider. + */ + interface TokenConfig { + @JsonProperty + String accepted(); + + @JsonProperty + String forwarded(); + } /** * A list of parameters for user customization. Parameters are used to construct the URL or to @@ -60,8 +101,15 @@ interface ModelConfig { @JsonProperty String name(); + /** + * vectorDimension is not null if the model supports a single dimension value. It will be null + * if the model supports different dimensions. A parameter called vectorDimension is included. + * + * @return + */ + @Nullable @JsonProperty - Integer vectorDimension(); + Optional vectorDimension(); @JsonProperty List parameters(); @@ -84,11 +132,46 @@ interface ParameterConfig { @JsonProperty Optional defaultValue(); + /** + * validation is an object that describes how the Data API will validate the parameters, and + * how the UI may want to provide data entry hints. Only one of the validation methods will be + * specified for each parameter. + * + *

`numericRange` if present is an array of two numbers that represent the inclusive value + * range for a number parameter. E.g. the dimensions for the text-embedding-3 + * + *

`options` if present is an array of valid options the user must select from, for example + * if a model supports 3 different dimensions. If options are present the only allowed values + * for the parameter are those in the options list. If not present, null, or an empty array + * any value of the correct type is accepted. + * + * @return + */ + @Nullable + @JsonProperty + Map> validation(); + @Nullable @JsonProperty Optional help(); } + enum ValidationType { + NUMERIC_RANGE("numericRange"), + OPTIONS("options"); + + private final String type; + + ValidationType(final String type) { + this.type = type; + } + + @Override + public String toString() { + return type; + } + } + /** A set of http properties used for request to the embedding providers. */ interface RequestProperties { @@ -115,12 +198,32 @@ interface RequestProperties { */ @WithDefault("10000") int requestTimeoutMillis(); + + @Nullable + Optional maxInputLength(); + + @Nullable + Optional taskTypeStore(); + + @Nullable + Optional taskTypeRead(); } enum ParameterType { - STRING, - NUMBER, - BOOLEAN + STRING("string"), + NUMBER("number"), + BOOLEAN("boolean"); + + private final String type; + + ParameterType(final String type) { + this.type = type; + } + + @Override + public String toString() { + return type; + } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java index d4e67f96c4..19f94abd07 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java @@ -29,7 +29,6 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration( return ServiceConfig.provider( serviceName, serviceName, - config.providers().get(serviceName).apiKey(), config.providers().get(serviceName).url().toString(), RequestProperties.of( config.providers().get(serviceName).properties().maxRetries(), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java index 2af2db95ee..1860b22d78 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java @@ -40,7 +40,6 @@ public class EmbeddingGatewayClient implements EmbeddingProvider { * @param dimension - Dimension of the embedding to be returned * @param tenant - Tenant id {aka database id} * @param baseUrl - base url of the embedding client - * @param apiKey - Api key for the embedding provider * @param modelName - Model name for the embedding provider * @param embeddingService - Embedding service client * @param vectorizeServiceParameter - Additional parameters for the vectorize service @@ -51,7 +50,6 @@ public EmbeddingGatewayClient( int dimension, Optional tenant, String baseUrl, - String apiKey, String modelName, EmbeddingService embeddingService, Map vectorizeServiceParameter) { @@ -59,7 +57,6 @@ public EmbeddingGatewayClient( this.provider = provider; this.dimension = dimension; this.tenant = tenant; - this.apiKey = apiKey; this.modelName = modelName; this.baseUrl = baseUrl; this.embeddingService = embeddingService; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingClient.java index 69b55684b1..669771742d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingClient.java @@ -30,7 +30,6 @@ */ public class CohereEmbeddingClient implements EmbeddingProvider { private EmbeddingProviderConfigStore.RequestProperties requestProperties; - private String apiKey; private String modelName; private String baseUrl; private final CohereEmbeddingProvider embeddingProvider; @@ -39,11 +38,9 @@ public class CohereEmbeddingClient implements EmbeddingProvider { public CohereEmbeddingClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String apiKey, String modelName, Map vectorizeServiceParameters) { this.requestProperties = requestProperties; - this.apiKey = apiKey; this.modelName = modelName; this.baseUrl = baseUrl; this.vectorizeServiceParameters = vectorizeServiceParameters; @@ -103,8 +100,7 @@ public Uni> vectorize( new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); Uni response = embeddingProvider - .embed( - "Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), request) + .embed("Bearer " + apiKeyOverride.get(), request) .onFailure( throwable -> { return (throwable.getCause() != null diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java index 2300f3acfd..715e7bf115 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java @@ -28,7 +28,6 @@ private interface ProviderConstructor { EmbeddingProvider create( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String apiKey, String modelName, Map vectorizeServiceParameter); } @@ -66,7 +65,6 @@ private synchronized EmbeddingProvider addService( dimension, tenant, configuration.baseUrl(), - configuration.apiKey(), modelName, embeddingService, vectorizeServiceParameter); @@ -98,7 +96,6 @@ private synchronized EmbeddingProvider addService( .create( configuration.requestConfiguration(), configuration.baseUrl(), - configuration.apiKey(), modelName, vectorizeServiceParameter); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingClient.java index abbdbac543..d14cdfebe2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingClient.java @@ -26,9 +26,7 @@ public class HuggingFaceEmbeddingClient implements EmbeddingProvider { private EmbeddingProviderConfigStore.RequestProperties requestProperties; - private String apiKey; private String modelName; - private String baseUrl; private final HuggingFaceEmbeddingProvider embeddingProvider; private Map vectorizeServiceParameters; @@ -36,11 +34,9 @@ public class HuggingFaceEmbeddingClient implements EmbeddingProvider { public HuggingFaceEmbeddingClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String apiKey, String modelName, Map vectorizeServiceParameters) { this.requestProperties = requestProperties; - this.apiKey = apiKey; this.modelName = modelName; this.baseUrl = baseUrl; this.vectorizeServiceParameters = vectorizeServiceParameters; @@ -79,10 +75,7 @@ public Uni> vectorize( EmbeddingRequestType embeddingRequestType) { EmbeddingRequest request = new EmbeddingRequest(texts, new EmbeddingRequest.Options(true)); return embeddingProvider - .embed( - "Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), - modelName, - request) + .embed("Bearer " + apiKeyOverride.get(), modelName, request) .onFailure( throwable -> { return (throwable.getCause() != null diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NVidiaEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NVidiaEmbeddingClient.java index a6eccef6a2..f09d4b9bc0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NVidiaEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NVidiaEmbeddingClient.java @@ -30,7 +30,6 @@ */ public class NVidiaEmbeddingClient implements EmbeddingProvider { private EmbeddingProviderConfigStore.RequestProperties requestProperties; - private String apiKey; private String modelName; private String baseUrl; private final NVidiaEmbeddingProvider embeddingProvider; @@ -40,11 +39,9 @@ public class NVidiaEmbeddingClient implements EmbeddingProvider { public NVidiaEmbeddingClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String apiKey, String modelName, Map vectorizeServiceParameters) { this.requestProperties = requestProperties; - this.apiKey = apiKey; this.modelName = modelName; this.baseUrl = baseUrl; this.vectorizeServiceParameters = vectorizeServiceParameters; @@ -95,8 +92,7 @@ public Uni> vectorize( new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); Uni response = embeddingProvider - .embed( - "Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), request) + .embed("Bearer " + apiKeyOverride.get(), request) .onFailure( throwable -> { return (throwable.getCause() != null diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClient.java index c426aedbd8..fb2cf05d8c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClient.java @@ -26,7 +26,6 @@ public class OpenAiEmbeddingClient implements EmbeddingProvider { private EmbeddingProviderConfigStore.RequestProperties requestProperties; - private String apiKey; private String modelName; private String baseUrl; private final OpenAiEmbeddingProvider embeddingProvider; @@ -35,11 +34,9 @@ public class OpenAiEmbeddingClient implements EmbeddingProvider { public OpenAiEmbeddingClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String apiKey, String modelName, Map vectorizeServiceParameters) { this.requestProperties = requestProperties; - this.apiKey = apiKey; this.modelName = modelName; this.baseUrl = baseUrl; this.vectorizeServiceParameters = vectorizeServiceParameters; @@ -82,8 +79,7 @@ public Uni> vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName); Uni response = embeddingProvider - .embed( - "Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), request) + .embed("Bearer " + apiKeyOverride.get(), request) .onFailure( throwable -> { return (throwable.getCause() != null diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java index 4911944c5f..847e60f8df 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java @@ -28,7 +28,6 @@ public class VertexAIEmbeddingClient implements EmbeddingProvider { private EmbeddingProviderConfigStore.RequestProperties requestProperties; - private String apiKey; private String modelName; private final VertexAIEmbeddingProvider embeddingProvider; @@ -39,11 +38,9 @@ public class VertexAIEmbeddingClient implements EmbeddingProvider { public VertexAIEmbeddingClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String apiKey, String modelName, Map vectorizeServiceParameters) { this.requestProperties = requestProperties; - this.apiKey = apiKey; this.modelName = modelName; this.vectorizeServiceParameters = vectorizeServiceParameters; baseUrl = baseUrl.replace(PROJECT_ID, vectorizeServiceParameters.get(PROJECT_ID).toString()); @@ -146,10 +143,7 @@ public Uni> vectorize( new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList()); Uni serviceResponse = embeddingProvider - .embed( - "Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), - modelName, - request) + .embed("Bearer " + apiKeyOverride.get(), modelName, request) .onFailure( throwable -> { return (throwable.getCause() != null diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindEmbeddingProvidersOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindEmbeddingProvidersOperation.java index 0c4d2f715f..0d424b31fc 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindEmbeddingProvidersOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/model/impl/FindEmbeddingProvidersOperation.java @@ -10,11 +10,12 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import java.util.stream.Collectors; /** - * Operation that list all available vector providers into the {@link + * Operation that list all available and enabled vector providers into the {@link * CommandStatus#EXISTING_VECTOR_PROVIDERS} command status. */ public record FindEmbeddingProvidersOperation(PropertyBasedEmbeddingProviderConfig config) @@ -60,7 +61,10 @@ public CommandResult get() { */ private record EmbeddingProviderResponse( String url, - List supportedAuthentication, + Map< + PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig.AuthenticationType, + PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig.AuthenticationConfig> + supportedAuthentication, List parameters, List models) { private static EmbeddingProviderResponse provider( @@ -69,12 +73,13 @@ private static EmbeddingProviderResponse provider( for (PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig.ModelConfig model : embeddingProviderConfig.models()) { ModelConfigResponse returnModel = - new ModelConfigResponse(model.name(), model.vectorDimension(), model.parameters()); + ModelConfigResponse.returnModelConfigResponse( + model.name(), model.vectorDimension(), model.parameters()); modelsRemoveProperties.add(returnModel); } return new EmbeddingProviderResponse( embeddingProviderConfig.url(), - embeddingProviderConfig.supportedAuthentication(), + embeddingProviderConfig.supportedAuthentications(), embeddingProviderConfig.parameters(), modelsRemoveProperties); } @@ -90,8 +95,52 @@ private static EmbeddingProviderResponse provider( * @param parameters Parameters for customizing the model. */ private record ModelConfigResponse( + String name, Optional vectorDimension, List parameters) { + private static ModelConfigResponse returnModelConfigResponse( + String name, + Optional vectorDimension, + List + parameters) { + // reconstruct each parameter for lowercase parameter type + ArrayList parametersResponse = new ArrayList<>(); + for (PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig.ParameterConfig parameter : + parameters) { + ParameterConfigResponse returnParameter = + new ParameterConfigResponse( + parameter.name(), + parameter.type().toString(), + parameter.required(), + parameter.defaultValue(), + parameter.validation(), + parameter.help()); + parametersResponse.add(returnParameter); + } + + return new ModelConfigResponse(name, vectorDimension, parametersResponse); + } + } + + /** + * This is used to reconstruct the {@code + * PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig.ParameterConfig} body for + * parameter type by not directly using the enum class (uppercase) but instead using the value + * (lowercase) in the enum class. + * + * @param name + * @param type + * @param required + * @param defaultValue + * @param validation + * @param help + */ + private record ParameterConfigResponse( String name, - Integer vectorDimension, - List - parameters) {} + String type, + boolean required, + Optional defaultValue, + Map< + PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig.ValidationType, + List> + validation, + Optional help) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolver.java index a4c7b630ff..efaa22d509 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolver.java @@ -262,7 +262,9 @@ private PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig getAndValid } // TODO: 1. remove the first if statement when fully support validateAuthentication - // 2. validate the 'secretName' in the future + // 2. Check if user authentication type is support + // 3. Check if required token is provided + // 4. Check if token is valid private void validateAuthentication( CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig, PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig providerConfig) { @@ -270,14 +272,14 @@ private void validateAuthentication( return; } // Check if user authentication type is support - userConfig.vectorizeServiceAuthentication().type().stream() - .filter(type -> !providerConfig.supportedAuthentication().contains(type)) - .findFirst() - .ifPresent( - type -> { - throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Authentication type '%s' is not supported", type); - }); + // userConfig.vectorizeServiceAuthentication().type().stream() + // .filter(type -> !providerConfig.supportedAuthentication().contains(type)) + // .findFirst() + // .ifPresent( + // type -> { + // throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + // "Authentication type '%s' is not supported", type); + // }); // Check if 'secretName' is provided if authentication type is 'SHARED_SECRET' if (userConfig.vectorizeServiceAuthentication().type().contains("SHARED_SECRET") && (userConfig.vectorizeServiceAuthentication().secretName() == null @@ -378,6 +380,7 @@ private void validateParameterType( } // TODO: check model parameters provided by the user, will support in the future + // TODO: fix code 396-408 private Integer validateModelAndDimension( CreateCollectionCommand.Options.VectorSearchConfig.VectorizeConfig userConfig, PropertyBasedEmbeddingProviderConfig.EmbeddingProviderConfig providerConfig, @@ -392,14 +395,18 @@ private Integer validateModelAndDimension( "Model name '%s' for provider '%s' is not supported", userConfig.modelName(), userConfig.provider())); - Integer configVectorDimension = model.vectorDimension(); - if (userVectorDimension == null) { - return configVectorDimension; // Use config dimension if user didn't provide one - } else if (!configVectorDimension.equals(userVectorDimension)) { - throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "The provided dimension value '%s' doesn't match the model supports dimension value '%s'", - userVectorDimension, configVectorDimension); + // TODO: is dimension required? do we still auto populate the dimension? + if (model.vectorDimension().isPresent()) { + Integer configVectorDimension = model.vectorDimension().get(); + if (userVectorDimension == null) { + return configVectorDimension; // Use config dimension if user didn't provide one + } else if (!configVectorDimension.equals(userVectorDimension)) { + throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "The provided dimension value '%s' doesn't match the model supports dimension value '%s'", + userVectorDimension, configVectorDimension); + } + return configVectorDimension; } - return configVectorDimension; + return 0; } } diff --git a/src/main/resources/embedding-providers-config.yaml b/src/main/resources/embedding-providers-config.yaml index 49ef7bfab8..c96df0c6be 100644 --- a/src/main/resources/embedding-providers-config.yaml +++ b/src/main/resources/embedding-providers-config.yaml @@ -3,86 +3,268 @@ stargate: jsonapi: embedding: providers: - # Open AI embedding service configuration openai: + #see https://platform.openai.com/docs/api-reference/embeddings/create enabled: true url: https://api.openai.com/v1/ - api-key: YOUR_API_KEY - supported-authentication: - - "HEADER" + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization models: - name: text-embedding-3-small - vector-dimension: 1536 + parameters: + - name: vectorDimension + type: number + required: true + default-value: 512 + validation: + numeric-range: [2, 1536] + help: "Vector dimension to use in the database and when calling Open AI." - name: text-embedding-3-large - vector-dimension: 3072 + parameters: + - name: vectorDimension + type: number + required: true + default-value: 1024 + validation: + numeric-range: [256, 3072] + help: "Vector dimension to use in the database and when calling Open AI." - name: text-embedding-ada-002 vector-dimension: 1536 - - # Hugging face embedding service configuration + azureOpenAI: + # see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference + enabled: true + url: https://{resourceName}.openai.azure.com/openai/deployments/{deploymentId}/embeddings?api-version={apiVersion} + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: api-key + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: api-key + parameters: + - name: "resourceName" + type: string + required: true + help: "The name of your Azure OpenAI Resource." + - name: "deploymentId" + type: string + required: true + help: "The name of your model deployment. You're required to first deploy a model before you can make calls." + - name: "apiVersion" + type: string + required: true + default-value: 2024-02-01 + help: "The API version to use for this operation. This follows the YYYY-MM-DD format." + properties: + max-input-length: 16 huggingface: + # see https://huggingface.co/blog/getting-started-with-embeddings enabled: true url: https://api-inference.huggingface.co/pipeline/feature-extraction/ - api-key: YOUR_API_KEY - supported-authentication: - - "HEADER" + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization models: - name: sentence-transformers/all-MiniLM-L6-v2 vector-dimension: 384 - - # Vertex AI embedding service configuration + # OUT OF SCOPE FOR INITIAL PREVIEW vertexai: + # see https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#get_text_embeddings_for_a_snippet_of_text enabled: true - url: "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models" - api-key: YOUR_API_KEY - supported-authentication: - - "HEADER" + url: "https://us-central1-aiplatform.googleapis.com/v1/projects/{projectId}/locations/us-central1/publishers/google/models" + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization parameters: - - name: PROJECT_ID - type: STRING + - name: projectId + type: string required: true + help: "The Google Cloud Project ID to use when calling" properties: max-retries: 3 request-timeout-millis: 1000 retry-delay-millis: 100 + task-type-store: RETRIEVAL_DOCUMENT # see https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023 + task-type-read: QUESTION_ANSWERING + max-input-length: 5 models: - name: textembedding-gecko@003 vector-dimension: 768 parameters: - name: "autoTruncate" - type: BOOLEAN + type: boolean required: false default-value: true - properties: - max-tokens: 3072 - - - - # Cohere embedding service configuration + help: "If set to false, text that exceeds the token limit causes the request to fail. The default value is true." + # OUT OF SCOPE FOR INITIAL PREVIEW cohere: enabled: true url: https://api.cohere.ai/v1/ - api-key: YOUR_API_KEY - supported-authentication: - - "HEADER" + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization models: - name: embed-english-v3.0 vector-dimension: 1024 - name: embed-english-v2.0 vector-dimension: 4096 - - # NVidia embedding service configuration nvidia: enabled: true url: https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/091a03bb-7364-4087-8090-bd71e9277520 - api-key: YOUR_API_KEY - supported-authentication: - - "HEADER" + supported-authentications: + NONE: + enabled: true + HEADER: + enabled: false + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization models: - - name: NVIDIA Retrieval QA Embedding Model-1.0 - vector-dimension: 1024 - properties: - max-tokens: 512 - name: NV-Embed-QA vector-dimension: 1024 properties: max-tokens: 512 + jinaAI: + #see https://api.jina.ai/redoc#tag/embeddings + enabled: true + url: https://api.jina.ai/v1/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + models: + - name: jina-embedding-t-en-v1 + vector-dimension: 312 + - name: jina-embedding-s-en-v1 + vector-dimension: 512 + - name: jina-embedding-b-en-v1 + vector-dimension: 768 + - name: jina-embedding-l-en-v1 + vector-dimension: 1024 + voyageAI: + # see https://docs.voyageai.com/reference/embeddings-api + # see https://docs.voyageai.com/docs/embeddings + enabled: true + url: https://api.voyageai.com/v1/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + - name: "autoTruncate" + type: BOOLEAN + required: false + default-value: true + help: "Whether to truncate the input texts to fit within the context length. Defaults to true." + properties: + max-input-length: 128 + task-type-store: document + task-type-read: query + models: + - name: voyage-large-2-instruct + vector-dimension: 1024 + - name: voyage-law-2 + vector-dimension: 1024 + - name: voyage-code-2 + vector-dimension: 1536 + - name: voyage-large-2 + vector-dimension: 1536 + - name: voyage-2 + vector-dimension: 1024 + mistral: + # see https://docs.mistral.ai/api/#operation/createEmbedding + enabled: true + url: https://api.mistral.ai/v1/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-provider-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + properties: + models: + - name: mistral-embed + vector-dimension: 1024 + # upstage: + # see https://developers.upstage.ai/docs/apis/embeddings + # NOTE: they have a model for storing and a diff one for reading, this is different to everyone else + # holding on the config / implementation until we confirm if this can change + diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java index 58793d7e2d..0ecc9fc368 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java @@ -584,7 +584,7 @@ public void happyCreateCollectionWithEmbeddingService() { "secretName": "test" }, "parameters": { - "PROJECT_ID": "test" + "projectId": "test" } } } @@ -638,7 +638,7 @@ public void happyCreateCollectionWithEmbeddingServiceAutoPopulateDimension() { "secretName": "test" }, "parameters": { - "PROJECT_ID": "test" + "projectId": "test" } } } @@ -665,7 +665,7 @@ public void happyCreateCollectionWithEmbeddingServiceAutoPopulateDimension() { "secretName": "test" }, "parameters": { - "PROJECT_ID": "test" + "projectId": "test" } } } @@ -802,101 +802,6 @@ public void failCreateCollectionWithEmbeddingServiceProviderNotSupport() { .body("errors[0].exceptionClass", is("JsonApiException")); } - @Test - public void failCreateCollectionWithEmbeddingServiceAuthenticationTypeUnsupported() { - // create a collection with authentication type not support - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( - """ - { - "createCollection": { - "name": "collection_with_vector_service", - "options": { - "vector": { - "metric": "cosine", - "dimension": 768, - "service": { - "provider": "openai", - "modelName": "text-embedding-3-small", - "authentication": { - "type": [ - "HEADER","SHARED_SECRET" - ], - "secretName": "test" - }, - "parameters": { - "PROJECT_ID": "test" - } - } - } - } - } - } - """) - .when() - .post(NamespaceResource.BASE_PATH, namespaceName) - .then() - .statusCode(200) - .body("status", is(nullValue())) - .body("data", is(nullValue())) - .body( - "errors[0].message", - startsWith( - "The provided options are invalid: Authentication type 'SHARED_SECRET' is not supported")) - .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) - .body("errors[0].exceptionClass", is("JsonApiException")); - } - - // TODO: Enable it when support SHARED_SECRET - @Disabled - @Test - public void failCreateCollectionWithEmbeddingServiceNoSecretName() { - // create a collection with "SHARED_SECRET" authentication type but no 'secretName' - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( - """ - { - "createCollection": { - "name": "collection_with_vector_service", - "options": { - "vector": { - "metric": "cosine", - "dimension": 768, - "service": { - "provider": "vertexai", - "modelName": "text-embedding-3-small", - "authentication": { - "type": [ - "HEADER","SHARED_SECRET" - ] - }, - "parameters": { - "PROJECT_ID": "test" - } - } - } - } - } - } - """) - .when() - .post(NamespaceResource.BASE_PATH, namespaceName) - .then() - .statusCode(200) - .body("status", is(nullValue())) - .body("data", is(nullValue())) - .body( - "errors[0].message", - startsWith( - "The provided options are invalid: 'secretName' must be provided for 'SHARED_SECRET' authentication type")) - .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) - .body("errors[0].exceptionClass", is("JsonApiException")); - } - @Test public void failCreateCollectionWithEmbeddingServiceNotProvideRequiredParameters() { // create a collection without providing required parameters @@ -936,7 +841,7 @@ public void failCreateCollectionWithEmbeddingServiceNotProvideRequiredParameters .body( "errors[0].message", startsWith( - "The provided options are invalid: Required parameter 'PROJECT_ID' for the provider 'vertexai' missing")) + "The provided options are invalid: Required parameter 'projectId' for the provider 'vertexai' missing")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) .body("errors[0].exceptionClass", is("JsonApiException")); } @@ -1055,7 +960,7 @@ public void failCreateCollectionWithEmbeddingServiceWrongParameterType() { "secretName": "test" }, "parameters": { - "PROJECT_ID": 123 + "projectId": 123 } } } @@ -1072,7 +977,7 @@ public void failCreateCollectionWithEmbeddingServiceWrongParameterType() { .body( "errors[0].message", startsWith( - "The provided options are invalid: The provided parameter 'PROJECT_ID' type is incorrect. Expected: 'STRING'")) + "The provided options are invalid: The provided parameter 'projectId' type is incorrect. Expected: 'string'")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) .body("errors[0].exceptionClass", is("JsonApiException")); } @@ -1102,7 +1007,7 @@ public void failCreateCollectionWithEmbeddingServiceUnsupportedModel() { "secretName": "test" }, "parameters": { - "PROJECT_ID": "123" + "projectId": "123" } } } @@ -1149,7 +1054,7 @@ public void failCreateCollectionWithEmbeddingServiceUnmatchedVectorDimension() { "secretName": "test" }, "parameters": { - "PROJECT_ID": "123" + "projectId": "123" } } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index f425ed88df..222a67cf84 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -316,7 +316,7 @@ public void updateClauseSetOnInsertValues() throws Exception { UpdateClause updateClause = command.updateClause(); DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), Optional.empty(), collectionSettings); + testService, objectMapper.getNodeFactory(), Optional.of("test"), collectionSettings); try { dataVectorizer.vectorizeUpdateClause(updateClause).subscribe().asCompletionStage().get(); } catch (Exception e) { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java index b2285e489b..c6e28ee820 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java @@ -50,7 +50,6 @@ void handleValidResponse() throws ExecutionException, InterruptedException { 1536, Optional.of("default"), "https://api.openai.com/v1/", - "api-key", "text-embedding-3-small", embeddingService, Map.of()); @@ -94,7 +93,6 @@ void handleError() throws ExecutionException, InterruptedException { 1536, Optional.of("default"), "https://api.openai.com/v1/", - "api-key", "text-embedding-3-small", embeddingService, Map.of()); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java index 45e8f28352..cd831f0436 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java @@ -30,10 +30,9 @@ public void test429() throws Exception { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( - List.of("429"), Optional.empty(), EmbeddingProvider.EmbeddingRequestType.INDEX) + List.of("429"), Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) .awaitFailure() @@ -53,10 +52,9 @@ public void test4xx() throws Exception { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( - List.of("400"), Optional.empty(), EmbeddingProvider.EmbeddingRequestType.INDEX) + List.of("400"), Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) .awaitFailure() @@ -76,10 +74,9 @@ public void test5xx() throws Exception { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( - List.of("503"), Optional.empty(), EmbeddingProvider.EmbeddingRequestType.INDEX) + List.of("503"), Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) .awaitFailure() @@ -99,10 +96,9 @@ public void testRetryError() throws Exception { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( - List.of("408"), Optional.empty(), EmbeddingProvider.EmbeddingRequestType.INDEX) + List.of("408"), Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) .awaitFailure() @@ -120,11 +116,10 @@ public void testCorrectHeaderAndBody() { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( List.of("application/json"), - Optional.empty(), + Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) @@ -140,11 +135,10 @@ public void testIncorrectContentType() { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( List.of("application/xml"), - Optional.empty(), + Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) @@ -165,11 +159,10 @@ public void testNoJsonResponse() { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( List.of("no json body"), - Optional.empty(), + Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) @@ -190,11 +183,10 @@ public void testEmptyJsonResponse() { EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000), config.providers().get("nvidia").url(), "test", - "test", null) .vectorize( List.of("empty json body"), - Optional.empty(), + Optional.of("test"), EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/PropertyBasedOverrideProfile.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/PropertyBasedOverrideProfile.java index 38276fe234..f31d49c042 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/PropertyBasedOverrideProfile.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/PropertyBasedOverrideProfile.java @@ -14,10 +14,8 @@ public boolean disableGlobalTestResources() { public Map getConfigOverrides() { return ImmutableMap.builder() .put("stargate.jsonapi.embedding.providers.openai.enabled", "true") - .put("stargate.jsonapi.embedding.providers.openai.api-key", "openai-api-key") .put("stargate.jsonapi.embedding.providers.openai.url", "https://api.openai.com/v1/") .put("stargate.jsonapi.embedding.providers.huggingface.enabled", "true") - .put("stargate.jsonapi.embedding.providers.huggingface.api-key", "hf-api-key") .put( "stargate.jsonapi.embedding.providers.huggingface.url", "https://api-inference.huggingface.co") diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolverTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolverTest.java index cd40b95b19..eebbf73046 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolverTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/resolver/model/impl/CreateCollectionCommandResolverTest.java @@ -108,7 +108,7 @@ public void happyPathVectorizeSearch() throws Exception { ] }, "parameters": { - "PROJECT_ID": "test project" + "projectId": "test project" } } } @@ -131,7 +131,7 @@ public void happyPathVectorizeSearch() throws Exception { assertThat(op.vectorFunction()).isEqualTo("cosine"); assertThat(op.comment()) .isEqualTo( - "{\"collection\":{\"name\":\"my_collection\",\"schema_version\":%s,\"options\":{\"vector\":{\"dimension\":768,\"metric\":\"cosine\",\"service\":{\"provider\":\"vertexai\",\"modelName\":\"textembedding-gecko@003\",\"authentication\":{\"type\":[\"HEADER\"]},\"parameters\":{\"PROJECT_ID\":\"test project\"}}},\"defaultId\":{\"type\":\"\"}}}}", + "{\"collection\":{\"name\":\"my_collection\",\"schema_version\":%s,\"options\":{\"vector\":{\"dimension\":768,\"metric\":\"cosine\",\"service\":{\"provider\":\"vertexai\",\"modelName\":\"textembedding-gecko@003\",\"authentication\":{\"type\":[\"HEADER\"]},\"parameters\":{\"projectId\":\"test project\"}}},\"defaultId\":{\"type\":\"\"}}}}", TableCommentConstants.SCHEMA_VERSION_VALUE); }); }