Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Embedding Providers Config #1048

Merged
merged 15 commits into from
May 2, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ public interface EmbeddingProviderConfigStore {
record ServiceConfig(
String serviceName,
String serviceProvider,
String apiKey,
String baseUrl,
// `implementationClass` is the custom class that implements the EmbeddingProvider interface
Optional<Class<?>> implementationClass,
Expand All @@ -16,21 +15,14 @@ record ServiceConfig(
public static ServiceConfig provider(
String serviceName,
String serviceProvider,
String apiKey,
String baseUrl,
RequestProperties requestConfiguration) {
return new ServiceConfig(
serviceName, serviceProvider, apiKey, baseUrl, null, requestConfiguration);
return new ServiceConfig(serviceName, serviceProvider, baseUrl, null, requestConfiguration);
}

public static ServiceConfig custom(Optional<Class<?>> implementationClass) {
return new ServiceConfig(
ProviderConstants.CUSTOM,
ProviderConstants.CUSTOM,
null,
null,
implementationClass,
null);
ProviderConstants.CUSTOM, ProviderConstants.CUSTOM, null, implementationClass, null);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,52 @@ interface EmbeddingProviderConfig {
@JsonProperty
String url();

/**
* A map of supported authentications. HEADER, SHARED_SECRET and NONE are the only techniques
* the DataAPI supports (i.e. the key of map can only be HEADER, SHARED_SECRET or NONE).
*
* @return
*/
@JsonProperty
String apiKey();
Map<AuthenticationType, AuthenticationConfig> supportedAuthentications();

@JsonProperty
List<String> supportedAuthentication();
enum AuthenticationType {
NONE,
HEADER,
SHARED_SECRET
}

/**
* enabled() is a JSON boolean to flag if this technique is supported. If false the rest of the
* object has no impact. Any technique not listed is also not supported for the provider.
*
* <p>tokens() is a list of token mappings, that map from the name accepted by the Data API to
* how they are forwarded to the provider. The provider information is included for the code,
* and to allow users to see what we do with the values.
*/
interface AuthenticationConfig {
@JsonProperty
boolean enabled();

@JsonProperty
List<TokenConfig> tokens();
}

/**
* For the HEADER technique the `accepted` value is the name of the header the client should
tatu-at-datastax marked this conversation as resolved.
Show resolved Hide resolved
* send, and `forwarded` is the name of the header the Data API will send to the provider.
*
* <p>For the SHARED_SECRET technique the `accepted` value is the name used in the
* authentication option with createCollection that maps to the name of a shared secret, and
* `forwarded` is the name of the header the Data API will send to the provider.
*/
interface TokenConfig {
@JsonProperty
String accepted();

@JsonProperty
String forwarded();
}

/**
* A list of parameters for user customization. Parameters are used to construct the URL or to
Expand Down Expand Up @@ -60,8 +101,15 @@ interface ModelConfig {
@JsonProperty
String name();

/**
* vectorDimension is not null if the model supports a single dimension value. It will be null
* if the model supports different dimensions. A parameter called vectorDimension is included.
*
* @return
*/
@Nullable
@JsonProperty
Integer vectorDimension();
Optional<Integer> vectorDimension();

@JsonProperty
List<ParameterConfig> parameters();
Expand All @@ -84,11 +132,46 @@ interface ParameterConfig {
@JsonProperty
Optional<String> defaultValue();

/**
* validation is an object that describes how the Data API will validate the parameters, and
* how the UI may want to provide data entry hints. Only one of the validation methods will be
* specified for each parameter.
*
* <p>`numericRange` if present is an array of two numbers that represent the inclusive value
* range for a number parameter. E.g. the dimensions for the text-embedding-3
*
* <p>`options` if present is an array of valid options the user must select from, for example
* if a model supports 3 different dimensions. If options are present the only allowed values
* for the parameter are those in the options list. If not present, null, or an empty array
* any value of the correct type is accepted.
*
* @return
*/
@Nullable
@JsonProperty
Map<ValidationType, List<Integer>> validation();

@Nullable
@JsonProperty
Optional<String> help();
}

enum ValidationType {
NUMERIC_RANGE("numericRange"),
OPTIONS("options");

private final String type;

ValidationType(final String type) {
this.type = type;
}

@Override
public String toString() {
return type;
}
}

/** A set of http properties used for request to the embedding providers. */
interface RequestProperties {

Expand All @@ -115,12 +198,32 @@ interface RequestProperties {
*/
@WithDefault("10000")
int requestTimeoutMillis();

@Nullable
Optional<String> maxInputLength();

@Nullable
Optional<String> taskTypeStore();

@Nullable
Optional<String> taskTypeRead();
}

enum ParameterType {
STRING,
NUMBER,
BOOLEAN
STRING("string"),
NUMBER("number"),
BOOLEAN("boolean");

private final String type;

ParameterType(final String type) {
this.type = type;
}

@Override
public String toString() {
return type;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration(
return ServiceConfig.provider(
serviceName,
serviceName,
config.providers().get(serviceName).apiKey(),
config.providers().get(serviceName).url().toString(),
RequestProperties.of(
config.providers().get(serviceName).properties().maxRetries(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public class EmbeddingGatewayClient implements EmbeddingProvider {
* @param dimension - Dimension of the embedding to be returned
* @param tenant - Tenant id {aka database id}
* @param baseUrl - base url of the embedding client
* @param apiKey - Api key for the embedding provider
* @param modelName - Model name for the embedding provider
* @param embeddingService - Embedding service client
* @param vectorizeServiceParameter - Additional parameters for the vectorize service
Expand All @@ -51,15 +50,13 @@ public EmbeddingGatewayClient(
int dimension,
Optional<String> tenant,
String baseUrl,
String apiKey,
String modelName,
EmbeddingService embeddingService,
Map<String, Object> vectorizeServiceParameter) {
this.requestProperties = requestProperties;
this.provider = provider;
this.dimension = dimension;
this.tenant = tenant;
this.apiKey = apiKey;
this.modelName = modelName;
this.baseUrl = baseUrl;
this.embeddingService = embeddingService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
*/
public class CohereEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private String apiKey;
private String modelName;
private String baseUrl;
private final CohereEmbeddingProvider embeddingProvider;
Expand All @@ -39,11 +38,9 @@ public class CohereEmbeddingClient implements EmbeddingProvider {
public CohereEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String apiKey,
String modelName,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.apiKey = apiKey;
this.modelName = modelName;
this.baseUrl = baseUrl;
this.vectorizeServiceParameters = vectorizeServiceParameters;
Expand Down Expand Up @@ -103,8 +100,7 @@ public Uni<List<float[]>> vectorize(
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);
Uni<EmbeddingResponse> response =
embeddingProvider
.embed(
"Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), request)
.embed("Bearer " + apiKeyOverride.get(), request)
.onFailure(
throwable -> {
return (throwable.getCause() != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ private interface ProviderConstructor {
EmbeddingProvider create(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String apiKey,
String modelName,
Map<String, Object> vectorizeServiceParameter);
}
Expand Down Expand Up @@ -66,7 +65,6 @@ private synchronized EmbeddingProvider addService(
dimension,
tenant,
configuration.baseUrl(),
configuration.apiKey(),
modelName,
embeddingService,
vectorizeServiceParameter);
Expand Down Expand Up @@ -98,7 +96,6 @@ private synchronized EmbeddingProvider addService(
.create(
configuration.requestConfiguration(),
configuration.baseUrl(),
configuration.apiKey(),
modelName,
vectorizeServiceParameter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,17 @@

public class HuggingFaceEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private String apiKey;
private String modelName;

private String baseUrl;
private final HuggingFaceEmbeddingProvider embeddingProvider;
private Map<String, Object> vectorizeServiceParameters;

public HuggingFaceEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String apiKey,
String modelName,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.apiKey = apiKey;
this.modelName = modelName;
this.baseUrl = baseUrl;
this.vectorizeServiceParameters = vectorizeServiceParameters;
Expand Down Expand Up @@ -79,10 +75,7 @@ public Uni<List<float[]>> vectorize(
EmbeddingRequestType embeddingRequestType) {
EmbeddingRequest request = new EmbeddingRequest(texts, new EmbeddingRequest.Options(true));
return embeddingProvider
.embed(
"Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey),
modelName,
request)
.embed("Bearer " + apiKeyOverride.get(), modelName, request)
.onFailure(
throwable -> {
return (throwable.getCause() != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
*/
public class NVidiaEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private String apiKey;
private String modelName;
private String baseUrl;
private final NVidiaEmbeddingProvider embeddingProvider;
Expand All @@ -40,11 +39,9 @@ public class NVidiaEmbeddingClient implements EmbeddingProvider {
public NVidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String apiKey,
String modelName,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.apiKey = apiKey;
this.modelName = modelName;
this.baseUrl = baseUrl;
this.vectorizeServiceParameters = vectorizeServiceParameters;
Expand Down Expand Up @@ -95,8 +92,7 @@ public Uni<List<float[]>> vectorize(
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);
Uni<EmbeddingResponse> response =
embeddingProvider
.embed(
"Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), request)
.embed("Bearer " + apiKeyOverride.get(), request)
.onFailure(
throwable -> {
return (throwable.getCause() != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

public class OpenAiEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private String apiKey;
private String modelName;
private String baseUrl;
private final OpenAiEmbeddingProvider embeddingProvider;
Expand All @@ -35,11 +34,9 @@ public class OpenAiEmbeddingClient implements EmbeddingProvider {
public OpenAiEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String apiKey,
String modelName,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.apiKey = apiKey;
this.modelName = modelName;
this.baseUrl = baseUrl;
this.vectorizeServiceParameters = vectorizeServiceParameters;
Expand Down Expand Up @@ -82,8 +79,7 @@ public Uni<List<float[]>> vectorize(
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName);
Uni<EmbeddingResponse> response =
embeddingProvider
.embed(
"Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey), request)
.embed("Bearer " + apiKeyOverride.get(), request)
.onFailure(
throwable -> {
return (throwable.getCause() != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

public class VertexAIEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private String apiKey;
private String modelName;
private final VertexAIEmbeddingProvider embeddingProvider;

Expand All @@ -39,11 +38,9 @@ public class VertexAIEmbeddingClient implements EmbeddingProvider {
public VertexAIEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String apiKey,
String modelName,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.apiKey = apiKey;
this.modelName = modelName;
this.vectorizeServiceParameters = vectorizeServiceParameters;
baseUrl = baseUrl.replace(PROJECT_ID, vectorizeServiceParameters.get(PROJECT_ID).toString());
Expand Down Expand Up @@ -146,10 +143,7 @@ public Uni<List<float[]>> vectorize(
new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList());
Uni<EmbeddingResponse> serviceResponse =
embeddingProvider
.embed(
"Bearer " + (apiKeyOverride.isPresent() ? apiKeyOverride.get() : apiKey),
modelName,
request)
.embed("Bearer " + apiKeyOverride.get(), modelName, request)
.onFailure(
throwable -> {
return (throwable.getCause() != null
Expand Down
Loading