Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface dedicated provider support #1157

Merged
merged 12 commits into from
Jun 12, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.stargate.sgv2.jsonapi.api.model.command.NamespaceCommand;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants;
import jakarta.validation.Valid;
import jakarta.validation.constraints.*;
import java.util.HashSet;
Expand Down Expand Up @@ -122,8 +123,7 @@ public record VectorizeConfig(
implementation = String.class)
@JsonProperty("provider")
String provider,
@NotNull
@Schema(
@Schema(
description = "Registered Embedding service model",
type = SchemaType.STRING,
implementation = String.class)
Expand All @@ -144,7 +144,25 @@ public record VectorizeConfig(
type = SchemaType.OBJECT)
@JsonProperty("parameters")
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, Object> parameters) {}
Map<String, Object> parameters) {

public VectorizeConfig(
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {
this.provider = provider;
// HuggingfaceDedicated does not need user to specify model
// use endpoint-defined-model as placeholder
if (provider.equals(ProviderConstants.HUGGINGFACE_DEDICATED)) {
this.modelName = "endpoint-defined-model";
tatu-at-datastax marked this conversation as resolved.
Show resolved Hide resolved
} else {
this.modelName = modelName;
}
this.authentication = authentication;
this.parameters = parameters;
}
}
}

public record IndexingConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public final class ProviderConstants {
public static final String OPENAI = "openai";
public static final String AZURE_OPENAI = "azureOpenAI";
public static final String HUGGINGFACE = "huggingface";
public static final String HUGGINGFACE_DEDICATED = "huggingfaceDedicated";
public static final String VERTEXAI = "vertexai";
public static final String COHERE = "cohere";
public static final String NVIDIA = "nvidia";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ EmbeddingProvider create(
Map.entry(ProviderConstants.AZURE_OPENAI, AzureOpenAIEmbeddingClient::new),
Map.entry(ProviderConstants.COHERE, CohereEmbeddingClient::new),
Map.entry(ProviderConstants.HUGGINGFACE, HuggingFaceEmbeddingClient::new),
Map.entry(
ProviderConstants.HUGGINGFACE_DEDICATED, HuggingFaceDedicatedEmbeddingClient::new),
Map.entry(ProviderConstants.JINA_AI, JinaAIEmbeddingClient::new),
Map.entry(ProviderConstants.MISTRAL, MistralEmbeddingClient::new),
Map.entry(ProviderConstants.NVIDIA, NvidiaEmbeddingClient::new),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package io.stargate.sgv2.jsonapi.service.embedding.operation;

import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Uni;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation;
import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper;
import io.stargate.sgv2.jsonapi.service.embedding.util.EmbeddingUtil;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import java.net.URI;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;

public class HuggingFaceDedicatedEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private final HuggingFaceDedicatedEmbeddingProvider embeddingProvider;
private Map<String, Object> vectorizeServiceParameters;

public HuggingFaceDedicatedEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String modelName,
int dimension,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.vectorizeServiceParameters = vectorizeServiceParameters;
// replace placeholders: endPointName, regionName, cloudName
String dedicatedApiUrl = EmbeddingUtil.replaceParameters(baseUrl, vectorizeServiceParameters);
embeddingProvider =
QuarkusRestClientBuilder.newBuilder()
.baseUri(URI.create(dedicatedApiUrl))
.readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS)
.build(HuggingFaceDedicatedEmbeddingProvider.class);
}

@RegisterRestClient
@RegisterProvider(EmbeddingProviderResponseValidation.class)
public interface HuggingFaceDedicatedEmbeddingProvider {
@POST
@ClientHeaderParam(name = "Content-Type", value = "application/json")
Uni<EmbeddingResponse> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}

// huggingfaceDedicated, Test Embeddings Inference, openAI compatible route
// https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/openai_embed
private record EmbeddingRequest(String[] input) {}

private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) {
private record Data(String object, int index, float[] embedding) {}

private record Usage(int prompt_tokens, int total_tokens) {}
}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray));
Uni<EmbeddingResponse> response =
embeddingProvider
.embed("Bearer " + apiKeyOverride.get(), request)
.onFailure(
throwable -> {
return ((throwable.getCause() != null
&& throwable.getCause() instanceof JsonApiException jae
&& jae.getErrorCode() == ErrorCode.EMBEDDING_PROVIDER_TIMEOUT)
|| throwable instanceof TimeoutException);
})
.retry()
.withBackOff(
Duration.ofMillis(requestProperties.initialBackOffMillis()),
Duration.ofMillis(requestProperties.maxBackOffMillis()))
.withJitter(requestProperties.jitter())
.atMost(requestProperties.atMostRetries());
return response
.onItem()
.transform(
resp -> {
if (resp.data() == null) {
return Response.of(batchId, Collections.emptyList());
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
List<float[]> vectors =
Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
return Response.of(batchId, vectors);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,10 @@ private Integer validateService(
// Check secret name for shared secret authentication, if applicable
validateAuthentication(userConfig, providerConfig);

// Validate the model and its vector dimension, if userVectorDimension is null, return value
// will be the config/default value
// Validate the model and its vector dimension:
// huggingFaceDedicated: must have vectorDimension specified
// other providers: must have model specified, and default dimension when dimension not
// specified
Integer vectorDimension =
validateModelAndDimension(userConfig, providerConfig, userVectorDimension);

Expand Down Expand Up @@ -384,7 +386,7 @@ private void validateUserParameters(
// Add all provider level parameters
allParameters.addAll(providerConfig.parameters());
// Get all the parameters except "vectorDimension" for the model -- model has been validated in
// the previous step
// the previous step, huggingfaceDedicated uses endpoint-defined-model
List<EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig> modelParameters =
providerConfig.models().stream()
.filter(m -> m.name().equals(userConfig.modelName()))
Expand All @@ -403,8 +405,7 @@ private void validateUserParameters(
.get();
// Add all model level parameters
allParameters.addAll(modelParameters);

// 1. Error if the user provided unconfigured parameters
// 1. Error if the user provided un-configured parameters
// Two level parameters have unique names, should be fine here
Set<String> expectedParamNames =
allParameters.stream()
Expand Down Expand Up @@ -497,6 +498,18 @@ private Integer validateModelAndDimension(
EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig,
Integer userVectorDimension) {
// Find the model configuration by matching the model name
// 1. huggingfaceDedicated does not require model, but requires dimension
Yuqi-Du marked this conversation as resolved.
Show resolved Hide resolved
if (userConfig.provider().equals(ProviderConstants.HUGGINGFACE_DEDICATED)) {
if (userVectorDimension == null) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
Yuqi-Du marked this conversation as resolved.
Show resolved Hide resolved
"dimension is needed for huggingfaceDedicated provider");
}
}
// 2. other providers do require model
if (userConfig.modelName() == null) {
throw ErrorCode.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"model can not be null for provider '%s'", userConfig.provider());
}
EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model =
providerConfig.models().stream()
.filter(m -> m.name().equals(userConfig.modelName()))
Expand Down
45 changes: 43 additions & 2 deletions src/main/resources/embedding-providers-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ stargate:
vector-dimension: 1536
huggingface:
# see https://huggingface.co/blog/getting-started-with-embeddings
display-name: Hugging Face
display-name: Hugging Face - Serverless
enabled: true
url: https://api-inference.huggingface.co/pipeline/feature-extraction/
supported-authentications:
Expand Down Expand Up @@ -137,7 +137,48 @@ stargate:
vector-dimension: 768
- name: BAAI/bge-large-en-v1.5
vector-dimension: 1024

huggingfaceDedicated:
# see https://huggingface.co/docs/inference-endpoints/en/supported_tasks#sentence-embeddings
display-name: Hugging Face - Dedicated
enabled: true
url: https://{endpointName}.{regionName}.{cloudName}.endpoints.huggingface.cloud/embeddings
supported-authentications:
NONE:
enabled: false
HEADER:
enabled: true
tokens:
- accepted: x-embedding-api-key
forwarded: Authorization
SHARED_SECRET:
enabled: true
tokens:
- accepted: providerKey
forwarded: Authorization
properties:
max-batch-size: 32
models:
Yuqi-Du marked this conversation as resolved.
Show resolved Hide resolved
- name: endpoint-defined-model
parameters:
- name: vectorDimension
type: number
required: true
validation:
numeric-range: [2, 3072]
help: "Vector dimension to use in the database, should be the same as the model used by the endpoint."
parameters:
- name: "endpointName"
type: string
required: true
help: "The name of your Hugging Face dedicated endpoint, the first part of the Endpoint URL."
- name: "regionName"
type: string
required: true
help: "The region your Hugging Face dedicated endpoint is deployed to, the second part of the Endpoint URL."
- name: "cloudName"
type: string
required: true
help: "The cloud your Hugging Face dedicated endpoint is deployed to, the third part of the Endpoint URL."
# OUT OF SCOPE FOR INITIAL PREVIEW
vertexai:
# see https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#get_text_embeddings_for_a_snippet_of_text
Expand Down