Skip to content

Commit

Permalink
Small fix: Improve error message when not providing provider key thro…
Browse files Browse the repository at this point in the history
…ugh `x-embedding-api-key` (#1251)
  • Loading branch information
Hazel-Datastax authored Jul 10, 2024
1 parent dd652c3 commit 197bfd2
Show file tree
Hide file tree
Showing 16 changed files with 62 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public enum ErrorCode {
EMBEDDING_PROVIDER_RATE_LIMITED("The Embedding Provider rate limited the request"),
EMBEDDING_PROVIDER_TIMEOUT("The Embedding Provider timed out"),
EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE("The Embedding Provider returned an unexpected response"),
EMBEDDING_PROVIDER_API_KEY_MISSING("The Embedding Provider API key is missing"),

FILTER_MULTIPLE_ID_FILTER(
"Cannot have more than one _id equals filter clause: use $in operator instead"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ public class EmbeddingGatewayClient extends EmbeddingProvider {

private Optional<String> tenant;
private Optional<String> authToken;

private String apiKey;
private String modelName;
private String baseUrl;
private EmbeddingService embeddingService;
Expand Down Expand Up @@ -77,15 +75,15 @@ public EmbeddingGatewayClient(
* Vectorize the given list of texts
*
* @param texts List of texts to be vectorized
* @param apiKeyOverride API key sent as header
* @param apiKey API key sent as header
* @param embeddingRequestType Type of request (INDEX or SEARCH)
* @return
*/
@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
Map<String, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue>
grpcVectorizeServiceParameter = new HashMap<>();
Expand Down Expand Up @@ -136,8 +134,8 @@ else if (value instanceof Boolean)
.setProviderName(provider)
.setTenantId(tenant.orElse(DEFAULT_TENANT_ID));
builder.putAuthTokens(DATA_API_KEY, authToken.orElse(""));
if (apiKeyOverride.isPresent()) {
builder.putAuthTokens(API_KEY, apiKeyOverride.orElse(apiKey));
if (apiKey.isPresent()) {
builder.putAuthTokens(API_KEY, apiKey.get());
}
if (authentication != null) {
builder.putAllAuthTokens(authentication);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ private record Usage(int prompt_tokens, int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension);

// NOTE: NO "Bearer " prefix with API key for Azure OpenAI
Uni<EmbeddingResponse> response =
applyRetry(openAIEmbeddingProviderClient.embed(apiKeyOverride.get(), request));
applyRetry(openAIEmbeddingProviderClient.embed(apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,18 @@ public void setEmbeddings(List<float[]> embeddings) {
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

String[] textArray = new String[texts.size()];
String input_type =
embeddingRequestType == EmbeddingRequestType.INDEX ? SEARCH_DOCUMENT : SEARCH_QUERY;
EmbeddingRequest request =
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);

Uni<EmbeddingResponse> response =
applyRetry(cohereEmbeddingProviderClient.embed("Bearer " + apiKeyOverride.get(), request));
applyRetry(cohereEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package io.stargate.sgv2.jsonapi.service.embedding.operation;

import static io.stargate.sgv2.jsonapi.config.constants.HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME;
import static io.stargate.sgv2.jsonapi.exception.ErrorCode.EMBEDDING_PROVIDER_API_KEY_MISSING;

import io.smallrye.mutiny.Uni;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
Expand Down Expand Up @@ -70,15 +73,15 @@ protected <T> Uni<T> applyRetry(Uni<T> uni) {
* Vectorizes the given list of texts and returns the embeddings.
*
* @param texts List of texts to be vectorized
* @param apiKeyOverride Optional API key to be used for this request. If not provided, the
* default API key will be used.
* @param apiKey Optional API key to be used for this request. If not provided, the default API
* key will be used.
* @param embeddingRequestType Type of request (INDEX or SEARCH)
* @return VectorResponse
*/
public abstract Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType);

/**
Expand Down Expand Up @@ -134,6 +137,15 @@ protected String replaceParameters(String template, Map<String, Object> paramete
return baseUrl.toString();
}

/** Helper method to check if the API key is present in the header */
protected void checkEmbeddingApiKeyHeader(String providerId, Optional<String> apiKey) {
if (apiKey.isEmpty()) {
throw EMBEDDING_PROVIDER_API_KEY_MISSING.toApiException(
"header value `%s` is missing for embedding provider: %s",
EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, providerId);
}
}

/**
* Record to hold the batchId and embedding vectors
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,16 @@ private record Usage(int prompt_tokens, int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray));

Uni<EmbeddingResponse> response =
applyRetry(
huggingFaceDedicatedEmbeddingProviderClient.embed(
"Bearer " + apiKeyOverride.get(), request));
huggingFaceDedicatedEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,12 @@ public record Options(boolean waitForModel) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
EmbeddingRequest request = new EmbeddingRequest(texts, new EmbeddingRequest.Options(true));

return applyRetry(
huggingFaceEmbeddingProviderClient.embed(
"Bearer " + apiKeyOverride.get(), modelName, request))
huggingFaceEmbeddingProviderClient.embed("Bearer " + apiKey.get(), modelName, request))
.onItem()
.transform(
resp -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ private record Usage(int prompt_tokens, int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

EmbeddingRequest request = new EmbeddingRequest(texts, modelName);

Uni<EmbeddingResponse> response =
applyRetry(jinaAIEmbeddingProviderClient.embed("Bearer " + apiKeyOverride.get(), request));
applyRetry(jinaAIEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ public MeteredEmbeddingProvider(
* call and the size of the input texts.
*
* @param texts the list of texts to vectorize.
* @param apiKeyOverride optional API key to override any default authentication mechanism.
* @param apiKey optional API key to override any default authentication mechanism.
* @param embeddingRequestType the type of embedding request, influencing how texts are processed.
* @return a {@link Uni} that will provide the list of vectorized texts, as arrays of floats.
*/
@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
// String bytes metrics for vectorize
DistributionSummary ds =
Expand All @@ -79,7 +79,7 @@ public Uni<Response> vectorize(
batch -> {
// call vectorize by the batch id
return embeddingProvider.vectorize(
batch.getLeft(), batch.getRight(), apiKeyOverride, embeddingRequestType);
batch.getLeft(), batch.getRight(), apiKey, embeddingRequestType);
})
.merge()
.collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ private record Usage(int prompt_tokens, int total_tokens, int completion_tokens)
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

EmbeddingRequest request = new EmbeddingRequest(texts, modelName, "float");

Uni<EmbeddingResponse> response =
applyRetry(mistralEmbeddingProviderClient.embed("Bearer " + apiKeyOverride.get(), request));
applyRetry(mistralEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,18 @@ private record Usage(int prompt_tokens, int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

String[] textArray = new String[texts.size()];
String input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? PASSAGE : QUERY;

EmbeddingRequest request =
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);

Uni<EmbeddingResponse> response =
applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer " + apiKeyOverride.get(), request));
applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ private record Usage(int prompt_tokens, int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension);
String organizationId = (String) vectorizeServiceParameters.get("organizationId");
Expand All @@ -124,7 +125,7 @@ public Uni<Response> vectorize(
Uni<EmbeddingResponse> response =
applyRetry(
openAIEmbeddingProviderClient.embed(
"Bearer " + apiKeyOverride.get(), organizationId, projectId, request));
"Bearer " + apiKey.get(), organizationId, projectId, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ record Usage(int prompt_tokens, int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
// Oddity: Implementation does not support batching, so we only accept "batches"
// of 1 String, fail for others
if (texts.size() != 1) {
Expand All @@ -139,8 +140,7 @@ public Uni<Response> vectorize(
EmbeddingRequest request = new EmbeddingRequest(texts.get(0), modelName);

Uni<EmbeddingResponse> response =
applyRetry(
upstageAIEmbeddingProviderClient.embed("Bearer " + apiKeyOverride.get(), request));
applyRetry(upstageAIEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ public void setStatistics(Object statistics) {
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
EmbeddingRequest request =
new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList());

Uni<EmbeddingResponse> serviceResponse =
applyRetry(
vertexAIEmbeddingProviderClient.embed(
"Bearer " + apiKeyOverride.get(), modelName, request));
vertexAIEmbeddingProviderClient.embed("Bearer " + apiKey.get(), modelName, request));

return serviceResponse
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,17 @@ record Usage(int total_tokens) {}
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
final String inputType =
(embeddingRequestType == EmbeddingRequestType.SEARCH) ? requestTypeQuery : requestTypeIndex;
String[] textArray = new String[texts.size()];
EmbeddingRequest request =
new EmbeddingRequest(inputType, texts.toArray(textArray), modelName, autoTruncate);

Uni<EmbeddingResponse> response =
applyRetry(
voyageAIEmbeddingProviderClient.embed("Bearer " + apiKeyOverride.get(), request));
applyRetry(voyageAIEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ public CustomITEmbeddingProvider(int dimension) {
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
List<float[]> response = new ArrayList<>(texts.size());
if (texts.size() == 0) return Uni.createFrom().item(Response.of(batchId, response));
if (!apiKeyOverride.isPresent() || !apiKeyOverride.get().equals(TEST_API_KEY))
if (!apiKey.isPresent() || !apiKey.get().equals(TEST_API_KEY))
return Uni.createFrom().failure(new RuntimeException("Invalid API Key"));
for (String text : texts) {
if (dimension == 5) {
Expand Down

0 comments on commit 197bfd2

Please sign in to comment.