Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up for PR #1251: Remove Optional and centralize validation #1259

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package io.stargate.sgv2.jsonapi.api.request;

import static io.stargate.sgv2.jsonapi.config.constants.HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME;
import static io.stargate.sgv2.jsonapi.exception.ErrorCode.EMBEDDING_PROVIDER_API_KEY_MISSING;

import io.stargate.sgv2.jsonapi.api.request.tenant.DataApiTenantResolver;
import io.stargate.sgv2.jsonapi.api.request.token.DataApiTokenResolver;
import io.vertx.ext.web.RoutingContext;
Expand Down Expand Up @@ -54,4 +57,14 @@ public Optional<String> getCassandraToken() {
public Optional<String> getEmbeddingApiKey() {
return this.embeddingApiKey;
}

public String getAndValidateEmbeddingApiKey() {
Optional<String> apiKey = getEmbeddingApiKey();
if (apiKey.isEmpty()) {
throw EMBEDDING_PROVIDER_API_KEY_MISSING.toApiException(
"header value `%s` is missing in the request.",
EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME);
}
return apiKey.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Utility class to execute embedding serive to get vector embeddings for the text fields in the
Expand All @@ -30,7 +29,7 @@
public class DataVectorizer {
private final EmbeddingProvider embeddingProvider;
private final JsonNodeFactory nodeFactory;
private final Optional<String> embeddingApiKey;
private final String embeddingApiKey;
private final CollectionSettings collectionSettings;

/**
Expand All @@ -45,7 +44,7 @@ public class DataVectorizer {
public DataVectorizer(
EmbeddingProvider embeddingProvider,
JsonNodeFactory nodeFactory,
Optional<String> embeddingApiKey,
String embeddingApiKey,
CollectionSettings collectionSettings) {
this.embeddingProvider = embeddingProvider;
this.nodeFactory = nodeFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public Uni<Command> vectorize(
new DataVectorizer(
embeddingProvider,
objectMapper.getNodeFactory(),
dataApiRequestInfo.getEmbeddingApiKey(),
dataApiRequestInfo.getAndValidateEmbeddingApiKey(),
commandContext.collectionSettings());
return vectorizeSortClause(dataVectorizer, commandContext, command)
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ public EmbeddingGatewayClient(
*/
@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
Map<String, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue>
grpcVectorizeServiceParameter = new HashMap<>();
if (vectorizeServiceParameter != null) {
Expand Down Expand Up @@ -136,8 +133,8 @@ else if (value instanceof Boolean)
.setProviderName(provider)
.setTenantId(tenant.orElse(DEFAULT_TENANT_ID));
builder.putAuthTokens(DATA_API_KEY, authToken.orElse(""));
if (apiKey.isPresent()) {
builder.putAuthTokens(API_KEY, apiKey.get());
if (null != apiKey) {
builder.putAuthTokens(API_KEY, apiKey);
}
if (authentication != null) {
builder.putAllAuthTokens(authentication);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
Expand Down Expand Up @@ -111,17 +110,13 @@ private record Usage(int prompt_tokens, int total_tokens) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension);

// NOTE: NO "Bearer " prefix with API key for Azure OpenAI
Uni<EmbeddingResponse> response =
applyRetry(openAIEmbeddingProviderClient.embed(apiKey.get(), request));
applyRetry(openAIEmbeddingProviderClient.embed(apiKey, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
Expand Down Expand Up @@ -123,20 +122,15 @@ public void setEmbeddings(List<float[]> embeddings) {

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
String[] textArray = new String[texts.size()];
String input_type =
embeddingRequestType == EmbeddingRequestType.INDEX ? SEARCH_DOCUMENT : SEARCH_QUERY;
EmbeddingRequest request =
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);

Uni<EmbeddingResponse> response =
applyRetry(cohereEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));
applyRetry(cohereEmbeddingProviderClient.embed("Bearer " + apiKey, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package io.stargate.sgv2.jsonapi.service.embedding.operation;

import static io.stargate.sgv2.jsonapi.config.constants.HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME;
import static io.stargate.sgv2.jsonapi.exception.ErrorCode.EMBEDDING_PROVIDER_API_KEY_MISSING;

import io.smallrye.mutiny.Uni;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -79,10 +75,7 @@ protected <T> Uni<T> applyRetry(Uni<T> uni) {
* @return VectorResponse
*/
public abstract Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType);
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType);

/**
* returns the maximum batch size supported by the provider
Expand Down Expand Up @@ -137,15 +130,6 @@ protected String replaceParameters(String template, Map<String, Object> paramete
return baseUrl.toString();
}

/** Helper method to check if the API key is present in the header */
protected void checkEmbeddingApiKeyHeader(String providerId, Optional<String> apiKey) {
if (apiKey.isEmpty()) {
throw EMBEDDING_PROVIDER_API_KEY_MISSING.toApiException(
"header value `%s` is missing for embedding provider: %s",
EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, providerId);
}
}

/**
* Record to hold the batchId and embedding vectors
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,12 @@ private record Usage(int prompt_tokens, int total_tokens) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray));

Uni<EmbeddingResponse> response =
applyRetry(
huggingFaceDedicatedEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));
applyRetry(huggingFaceDedicatedEmbeddingProviderClient.embed("Bearer " + apiKey, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
Expand Down Expand Up @@ -91,14 +90,11 @@ public record Options(boolean waitForModel) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
EmbeddingRequest request = new EmbeddingRequest(texts, new EmbeddingRequest.Options(true));

return applyRetry(
huggingFaceEmbeddingProviderClient.embed("Bearer " + apiKey.get(), modelName, request))
huggingFaceEmbeddingProviderClient.embed("Bearer " + apiKey, modelName, request))
.onItem()
.transform(
resp -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,11 @@ private record Usage(int prompt_tokens, int total_tokens) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
EmbeddingRequest request = new EmbeddingRequest(texts, modelName);

Uni<EmbeddingResponse> response =
applyRetry(jinaAIEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));
applyRetry(jinaAIEmbeddingProviderClient.embed("Bearer " + apiKey, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.commons.lang3.tuple.Pair;

/**
Expand Down Expand Up @@ -50,10 +49,7 @@ public MeteredEmbeddingProvider(
*/
@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
// String bytes metrics for vectorize
DistributionSummary ds =
DistributionSummary.builder(jsonApiMetricsConfig.vectorizeInputBytesMetrics())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,11 @@ private record Usage(int prompt_tokens, int total_tokens, int completion_tokens)

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
EmbeddingRequest request = new EmbeddingRequest(texts, modelName, "float");

Uni<EmbeddingResponse> response =
applyRetry(mistralEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));
applyRetry(mistralEmbeddingProviderClient.embed("Bearer " + apiKey, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
Expand Down Expand Up @@ -103,20 +102,15 @@ private record Usage(int prompt_tokens, int total_tokens) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);

int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
String[] textArray = new String[texts.size()];
String input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? PASSAGE : QUERY;

EmbeddingRequest request =
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);

Uni<EmbeddingResponse> response =
applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));
applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer " + apiKey, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
Expand Down Expand Up @@ -112,11 +111,7 @@ private record Usage(int prompt_tokens, int total_tokens) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension);
String organizationId = (String) vectorizeServiceParameters.get("organizationId");
Expand All @@ -125,7 +120,7 @@ public Uni<Response> vectorize(
Uni<EmbeddingResponse> response =
applyRetry(
openAIEmbeddingProviderClient.embed(
"Bearer " + apiKey.get(), organizationId, projectId, request));
"Bearer " + apiKey, organizationId, projectId, request));

return response
.onItem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
Expand Down Expand Up @@ -118,11 +117,7 @@ record Usage(int prompt_tokens, int total_tokens) {}

@Override
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKey,
EmbeddingRequestType embeddingRequestType) {
checkEmbeddingApiKeyHeader(providerId, apiKey);
int batchId, List<String> texts, String apiKey, EmbeddingRequestType embeddingRequestType) {
// Oddity: Implementation does not support batching, so we only accept "batches"
// of 1 String, fail for others
if (texts.size() != 1) {
Expand All @@ -140,7 +135,7 @@ public Uni<Response> vectorize(
EmbeddingRequest request = new EmbeddingRequest(texts.get(0), modelName);

Uni<EmbeddingResponse> response =
applyRetry(upstageAIEmbeddingProviderClient.embed("Bearer " + apiKey.get(), request));
applyRetry(upstageAIEmbeddingProviderClient.embed("Bearer " + apiKey, request));

return response
.onItem()
Expand Down
Loading