Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Micro batching for embedding clients #1122

Merged
merged 13 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) {
collectionSettings.collectionName());
}
Uni<List<float[]>> vectors =
embeddingProvider.vectorize(
vectorizeTexts, embeddingApiKey, EmbeddingProvider.EmbeddingRequestType.INDEX);
embeddingProvider
.vectorize(
1,
vectorizeTexts,
embeddingApiKey,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(res -> res.embeddings());
return vectors
.onItem()
.transform(
Expand Down Expand Up @@ -169,8 +174,13 @@ public Uni<Boolean> vectorize(SortClause sortClause) {
collectionSettings.collectionName());
}
Uni<List<float[]>> vectors =
embeddingProvider.vectorize(
List.of(text), embeddingApiKey, EmbeddingProvider.EmbeddingRequestType.SEARCH);
embeddingProvider
.vectorize(
1,
List.of(text),
embeddingApiKey,
EmbeddingProvider.EmbeddingRequestType.SEARCH)
.map(res -> res.embeddings());
return vectors
.onItem()
.transform(
Expand Down Expand Up @@ -250,8 +260,13 @@ private Uni<Boolean> updateVectorize(ObjectNode node) {
node.putNull(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD);
} else {
final Uni<List<float[]>> vectors =
embeddingProvider.vectorize(
List.of(text), embeddingApiKey, EmbeddingProvider.EmbeddingRequestType.INDEX);
embeddingProvider
.vectorize(
1,
List.of(text),
embeddingApiKey,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(res -> res.embeddings());
return vectors
.onItem()
.transform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,24 @@ record RequestProperties(
int retryDelayInMillis,
int timeoutInMillis,
Optional<String> requestTypeQuery,
Optional<String> requestTypeIndex) {
Optional<String> requestTypeIndex,
// `maxBatchSize` is the maximum number of documents to be sent in a single request to be
// embedding provider
int maxBatchSize) {
public static RequestProperties of(
int maxRetries,
int retryDelayInMillis,
int timeoutInMillis,
Optional<String> requestTypeQuery,
Optional<String> requestTypeIndex) {
Optional<String> requestTypeIndex,
int maxBatchSize) {
return new RequestProperties(
maxRetries, retryDelayInMillis, timeoutInMillis, requestTypeQuery, requestTypeIndex);
maxRetries,
retryDelayInMillis,
timeoutInMillis,
requestTypeQuery,
requestTypeIndex,
maxBatchSize);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ interface RequestProperties {

@Nullable
Optional<String> taskTypeRead();

/** Maximum batch size supported by the provider. */
int maxBatchSize();
}

enum ParameterType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ public record RequestPropertiesImpl(
int requestTimeoutMillis,
Optional<String> maxInputLength,
Optional<String> taskTypeStore,
Optional<String> taskTypeRead)
Optional<String> taskTypeRead,
int maxBatchSize)
implements RequestProperties {
public RequestPropertiesImpl(
EmbeddingGateway.GetSupportedProvidersResponse.ProviderConfig.RequestProperties
Expand All @@ -84,9 +85,10 @@ public RequestPropertiesImpl(
grpcProviderConfigProperties.getMaxRetries(),
grpcProviderConfigProperties.getRetryDelayMillis(),
grpcProviderConfigProperties.getRequestTimeoutMillis(),
Optional.of(grpcProviderConfigProperties.getMaxInputLength()),
Optional.of(grpcProviderConfigProperties.getTaskTypeStore()),
Optional.of(grpcProviderConfigProperties.getTaskTypeRead()));
Optional.ofNullable(grpcProviderConfigProperties.getMaxInputLength()),
Optional.ofNullable(grpcProviderConfigProperties.getTaskTypeStore()),
Optional.ofNullable(grpcProviderConfigProperties.getTaskTypeRead()),
grpcProviderConfigProperties.getMaxBatchSize());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration(
properties.retryDelayMillis(),
properties.requestTimeoutMillis(),
properties.taskTypeRead(),
properties.taskTypeStore()));
properties.taskTypeStore(),
properties.maxBatchSize()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public EmbeddingGatewayClient(
* @return
*/
@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand Down Expand Up @@ -161,18 +162,30 @@ else if (value instanceof Boolean)
resp.getError().getErrorMessage());
}
if (resp.getEmbeddingsList() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
return resp.getEmbeddingsList().stream()
.map(
data -> {
float[] embedding = new float[data.getEmbeddingCount()];
for (int i = 0; i < data.getEmbeddingCount(); i++) {
embedding[i] = data.getEmbedding(i);
}
return embedding;
})
.toList();
final List<float[]> vectors =
resp.getEmbeddingsList().stream()
.map(
data -> {
float[] embedding = new float[data.getEmbeddingCount()];
for (int i = 0; i < data.getEmbeddingCount(); i++) {
embedding[i] = data.getEmbedding(i);
}
return embedding;
})
.toList();
return Response.of(batchId, vectors);
});
}

/**
* Return MAX_VALUE because the batching is done inside EGW
*
* @return
*/
@Override
public int maxBatchSize() {
return Integer.MAX_VALUE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io.stargate.sgv2.jsonapi.service.embedding.util.EmbeddingUtil;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
Expand Down Expand Up @@ -66,7 +65,7 @@ Uni<EmbeddingResponse> embed(
@HeaderParam("api-key") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand All @@ -83,7 +82,8 @@ private record Usage(int prompt_tokens, int total_tokens) {}
}

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -107,10 +107,17 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp.data() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
return Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
List<float[]> vectors =
Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
return Response.of(batchId, vectors);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Collections;
Expand Down Expand Up @@ -62,7 +61,7 @@ Uni<EmbeddingResponse> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand Down Expand Up @@ -90,7 +89,8 @@ public void setEmbeddings(List<float[]> embeddings) {
private static final String SEARCH_DOCUMENT = "search_document";

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -117,9 +117,14 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp.getEmbeddings() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
return resp.getEmbeddings();
return Response.of(batchId, resp.getEmbeddings());
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,33 @@ public interface EmbeddingProvider {
* @param apiKeyOverride Optional API key to be used for this request. If not provided, the
* default API key will be used.
* @param embeddingRequestType Type of request (INDEX or SEARCH)
* @return List of embeddings for the given texts
* @return VectorResponse
*/
Uni<List<float[]>> vectorize(
Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType);

/**
* returns the maximum batch size supported by the provider
*
* @return
*/
int maxBatchSize();

/**
* Record to hold the batchId and embedding vectors
*
* @param batchId - Sequence number for the batch to order the vectors.
* @param embeddings - Embedding vectors for the given text inputs.
*/
record Response(int batchId, List<float[]> embeddings) {
public static Response of(int batchId, List<float[]> embeddings) {
return new Response(batchId, embeddings);
}
}

enum EmbeddingRequestType {
/** This is used when vectorizing data in write operation for indexing */
INDEX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Collections;
Expand Down Expand Up @@ -60,7 +59,7 @@ Uni<List<float[]>> embed(
EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand All @@ -70,7 +69,8 @@ public record Options(boolean waitForModel) {}
}

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -90,9 +90,14 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
return resp;
return Response.of(batchId, resp);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.*;
Expand Down Expand Up @@ -53,7 +52,7 @@ Uni<EmbeddingResponse> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand All @@ -68,7 +67,8 @@ private record Usage(int prompt_tokens, int total_tokens) {}
}

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -94,10 +94,17 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp.data() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
return Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList();
List<float[]> vectors =
Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList();
return Response.of(batchId, vectors);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Loading