Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1066: Implement VoyageAI embedding client #1068

Merged
merged 8 commits into from
May 8, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,20 @@ public static ServiceConfig custom(Optional<Class<?>> implementationClass) {
}
}

record RequestProperties(int maxRetries, int retryDelayInMillis, int timeoutInMillis) {
record RequestProperties(
int maxRetries,
int retryDelayInMillis,
int timeoutInMillis,
Optional<String> requestTypeQuery,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to map 2 recently added request-properties to pass to embedding providers.

Optional<String> requestTypeIndex) {
public static RequestProperties of(
int maxRetries, int retryDelayInMillis, int timeoutInMillis) {
return new RequestProperties(maxRetries, retryDelayInMillis, timeoutInMillis);
int maxRetries,
int retryDelayInMillis,
int timeoutInMillis,
Optional<String> requestTypeQuery,
Optional<String> requestTypeIndex) {
return new RequestProperties(
maxRetries, retryDelayInMillis, timeoutInMillis, requestTypeQuery, requestTypeIndex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration(
|| !config.providers().get(serviceName).enabled()) {
throw ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException(serviceName);
}
EmbeddingProvidersConfig.EmbeddingProviderConfig.RequestProperties properties =
config.providers().get(serviceName).properties();
return ServiceConfig.provider(
serviceName,
serviceName,
config.providers().get(serviceName).url().toString(),
RequestProperties.of(
config.providers().get(serviceName).properties().maxRetries(),
config.providers().get(serviceName).properties().retryDelayMillis(),
config.providers().get(serviceName).properties().requestTimeoutMillis()));
properties.maxRetries(),
properties.retryDelayMillis(),
properties.requestTimeoutMillis(),
properties.taskTypeRead(),
properties.taskTypeStore()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public final class ProviderConstants {
public static final String VERTEXAI = "vertexai";
public static final String COHERE = "cohere";
public static final String NVIDIA = "nvidia";
public static final String VOYAGE_AI = "voyageAI";
public static final String CUSTOM = "custom";

// Private constructor to prevent instantiation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
import jakarta.inject.Inject;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;

@ApplicationScoped
public class EmbeddingProviderFactory {

private static Logger logger = org.slf4j.LoggerFactory.getLogger(EmbeddingProviderFactory.class);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was not being used, removed.

@Inject Instance<EmbeddingProviderConfigStore> embeddingProviderConfigStore;

@Inject OperationsConfig config;
Expand All @@ -41,6 +38,7 @@ EmbeddingProvider create(
Map.entry(ProviderConstants.AZURE_OPENAI, AzureOpenAIEmbeddingClient::new),
Map.entry(ProviderConstants.HUGGINGFACE, HuggingFaceEmbeddingClient::new),
Map.entry(ProviderConstants.VERTEXAI, VertexAIEmbeddingClient::new),
Map.entry(ProviderConstants.VOYAGE_AI, VoyageAIEmbeddingClient::new),
Map.entry(ProviderConstants.COHERE, CohereEmbeddingClient::new),
Map.entry(ProviderConstants.NVIDIA, NvidiaEmbeddingClient::new));

Expand All @@ -49,18 +47,18 @@ public EmbeddingProvider getConfiguration(
String serviceName,
String modelName,
int dimension,
Map<String, Object> vectorizeServiceParameter,
Map<String, Object> vectorizeServiceParameters,
String commandName) {
return addService(
tenant, serviceName, modelName, dimension, vectorizeServiceParameter, commandName);
tenant, serviceName, modelName, dimension, vectorizeServiceParameters, commandName);
}

private synchronized EmbeddingProvider addService(
Optional<String> tenant,
String serviceName,
String modelName,
int dimension,
Map<String, Object> vectorizeServiceParameter,
Map<String, Object> vectorizeServiceParameters,
String commandName) {
final EmbeddingProviderConfigStore.ServiceConfig configuration =
embeddingProviderConfigStore.get().getConfiguration(tenant, serviceName);
Expand All @@ -73,7 +71,7 @@ private synchronized EmbeddingProvider addService(
configuration.baseUrl(),
modelName,
embeddingService,
vectorizeServiceParameter,
vectorizeServiceParameters,
commandName);
}

Expand Down Expand Up @@ -105,6 +103,6 @@ private synchronized EmbeddingProvider addService(
configuration.baseUrl(),
modelName,
dimension,
vectorizeServiceParameter);
vectorizeServiceParameters);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package io.stargate.sgv2.jsonapi.service.embedding.operation;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Uni;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation;
import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;

public class VoyageAIEmbeddingClient implements EmbeddingProvider {
private EmbeddingProviderConfigStore.RequestProperties requestProperties;
private String modelName;
private final VoyageAIEmbeddingProvider embeddingProvider;

private final String requestTypeQuery, requestTypeIndex;

public VoyageAIEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
String baseUrl,
String modelName,
int dimension,
Map<String, Object> vectorizeServiceParameters) {
this.requestProperties = requestProperties;
this.modelName = modelName;
requestTypeQuery = requestProperties.requestTypeQuery().orElse(null);
requestTypeIndex = requestProperties.requestTypeIndex().orElse(null);

embeddingProvider =
QuarkusRestClientBuilder.newBuilder()
.baseUri(URI.create(baseUrl))
.readTimeout(requestProperties.timeoutInMillis(), TimeUnit.MILLISECONDS)
.build(VoyageAIEmbeddingProvider.class);
}

@RegisterRestClient
@RegisterProvider(EmbeddingProviderResponseValidation.class)
public interface VoyageAIEmbeddingProvider {
@POST
// no path specified, as it is already included in the baseUri
@ClientHeaderParam(name = "Content-Type", value = "application/json")
Uni<EmbeddingResponse> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}

record EmbeddingRequest(
@JsonInclude(JsonInclude.Include.NON_EMPTY) String input_type,
String[] input,
String model,
@JsonInclude(JsonInclude.Include.NON_NULL) Boolean truncation) {}

@JsonIgnoreProperties({"object"})
record EmbeddingResponse(Data[] data, String model, Usage usage) {
@JsonIgnoreProperties({"object"})
record Data(int index, float[] embedding) {}

record Usage(int total_tokens) {}
}

@Override
public Uni<List<float[]>> vectorize(
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
final String inputType =
(embeddingRequestType == EmbeddingRequestType.SEARCH) ? requestTypeQuery : requestTypeIndex;
String[] textArray = new String[texts.size()];
// !!! TODO: bind "truncation" from "autoTruncate" service parameter
EmbeddingRequest request =
new EmbeddingRequest(inputType, texts.toArray(textArray), modelName, true);
Uni<EmbeddingResponse> response =
embeddingProvider
.embed("Bearer " + apiKeyOverride.get(), request)
.onFailure(
throwable -> {
return (throwable.getCause() != null
&& throwable.getCause() instanceof JsonApiException jae
&& jae.getErrorCode() == ErrorCode.EMBEDDING_PROVIDER_TIMEOUT);
})
.retry()
.withBackOff(Duration.ofMillis(requestProperties.retryDelayInMillis()))
.atMost(requestProperties.maxRetries());
return response
.onItem()
.transform(
resp -> {
if (resp.data() == null) {
return Collections.emptyList();
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
return Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import org.junit.jupiter.api.Test;

@QuarkusTest
Expand All @@ -29,7 +28,7 @@ public class EmbeddingGatewayClientTest {
public static final String TESTING_COMMAND_NAME = "test_command";

@Test
void handleValidResponse() throws ExecutionException, InterruptedException {
void handleValidResponse() {
EmbeddingService embeddingService = mock(EmbeddingService.class);
final EmbeddingGateway.EmbeddingResponse.Builder builder =
EmbeddingGateway.EmbeddingResponse.newBuilder();
Expand All @@ -48,7 +47,8 @@ void handleValidResponse() throws ExecutionException, InterruptedException {
when(embeddingService.embed(any())).thenReturn(Uni.createFrom().item(builder.build()));
EmbeddingGatewayClient embeddingGatewayClient =
new EmbeddingGatewayClient(
EmbeddingProviderConfigStore.RequestProperties.of(5, 5, 5),
EmbeddingProviderConfigStore.RequestProperties.of(
5, 5, 5, Optional.empty(), Optional.empty()),
"openai",
1536,
Optional.of("default"),
Expand Down Expand Up @@ -76,7 +76,7 @@ void handleValidResponse() throws ExecutionException, InterruptedException {
}

@Test
void handleError() throws ExecutionException, InterruptedException {
void handleError() {
EmbeddingService embeddingService = mock(EmbeddingService.class);
final EmbeddingGateway.EmbeddingResponse.Builder builder =
EmbeddingGateway.EmbeddingResponse.newBuilder();
Expand All @@ -92,7 +92,8 @@ void handleError() throws ExecutionException, InterruptedException {
when(embeddingService.embed(any())).thenReturn(Uni.createFrom().item(builder.build()));
EmbeddingGatewayClient embeddingGatewayClient =
new EmbeddingGatewayClient(
EmbeddingProviderConfigStore.RequestProperties.of(5, 5, 5),
EmbeddingProviderConfigStore.RequestProperties.of(
5, 5, 5, Optional.empty(), Optional.empty()),
"openai",
1536,
Optional.of("default"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class NvidiaEmbeddingClientTest {
public void test429() throws Exception {
Throwable exception =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -51,7 +52,8 @@ public void test429() throws Exception {
public void test4xx() throws Exception {
Throwable exception =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -74,7 +76,8 @@ public void test4xx() throws Exception {
public void test5xx() throws Exception {
Throwable exception =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -97,7 +100,8 @@ public void test5xx() throws Exception {
public void testRetryError() throws Exception {
Throwable exception =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -118,7 +122,8 @@ public void testRetryError() throws Exception {
public void testCorrectHeaderAndBody() {
List<float[]> result =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -138,7 +143,8 @@ public void testCorrectHeaderAndBody() {
public void testIncorrectContentType() {
Throwable exception =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -163,7 +169,8 @@ public void testIncorrectContentType() {
public void testNoJsonResponse() {
Throwable exception =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand All @@ -188,7 +195,8 @@ public void testNoJsonResponse() {
public void testEmptyJsonResponse() {
List<float[]> result =
new NvidiaEmbeddingClient(
EmbeddingProviderConfigStore.RequestProperties.of(2, 100, 3000),
EmbeddingProviderConfigStore.RequestProperties.of(
2, 100, 3000, Optional.empty(), Optional.empty()),
config.providers().get("nvidia").url(),
"test",
DEFAULT_DIMENSIONS,
Expand Down