Skip to content

Commit

Permalink
Micro batching for embedding clients (#1122)
Browse files Browse the repository at this point in the history
Co-authored-by: Tatu Saloranta <tatu.saloranta@datastax.com>
  • Loading branch information
maheshrajamani and tatu-at-datastax authored May 29, 2024
1 parent 73b8528 commit 2b1d285
Show file tree
Hide file tree
Showing 26 changed files with 349 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) {
collectionSettings.collectionName());
}
Uni<List<float[]>> vectors =
embeddingProvider.vectorize(
vectorizeTexts, embeddingApiKey, EmbeddingProvider.EmbeddingRequestType.INDEX);
embeddingProvider
.vectorize(
1,
vectorizeTexts,
embeddingApiKey,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(res -> res.embeddings());
return vectors
.onItem()
.transform(
Expand Down Expand Up @@ -169,8 +174,13 @@ public Uni<Boolean> vectorize(SortClause sortClause) {
collectionSettings.collectionName());
}
Uni<List<float[]>> vectors =
embeddingProvider.vectorize(
List.of(text), embeddingApiKey, EmbeddingProvider.EmbeddingRequestType.SEARCH);
embeddingProvider
.vectorize(
1,
List.of(text),
embeddingApiKey,
EmbeddingProvider.EmbeddingRequestType.SEARCH)
.map(res -> res.embeddings());
return vectors
.onItem()
.transform(
Expand Down Expand Up @@ -250,8 +260,13 @@ private Uni<Boolean> updateVectorize(ObjectNode node) {
node.putNull(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD);
} else {
final Uni<List<float[]>> vectors =
embeddingProvider.vectorize(
List.of(text), embeddingApiKey, EmbeddingProvider.EmbeddingRequestType.INDEX);
embeddingProvider
.vectorize(
1,
List.of(text),
embeddingApiKey,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(res -> res.embeddings());
return vectors
.onItem()
.transform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,24 @@ record RequestProperties(
int retryDelayInMillis,
int timeoutInMillis,
Optional<String> requestTypeQuery,
Optional<String> requestTypeIndex) {
Optional<String> requestTypeIndex,
// `maxBatchSize` is the maximum number of documents to be sent in a single request to be
// embedding provider
int maxBatchSize) {
public static RequestProperties of(
int maxRetries,
int retryDelayInMillis,
int timeoutInMillis,
Optional<String> requestTypeQuery,
Optional<String> requestTypeIndex) {
Optional<String> requestTypeIndex,
int maxBatchSize) {
return new RequestProperties(
maxRetries, retryDelayInMillis, timeoutInMillis, requestTypeQuery, requestTypeIndex);
maxRetries,
retryDelayInMillis,
timeoutInMillis,
requestTypeQuery,
requestTypeIndex,
maxBatchSize);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ interface RequestProperties {

@Nullable
Optional<String> taskTypeRead();

/** Maximum batch size supported by the provider. */
int maxBatchSize();
}

enum ParameterType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ public record RequestPropertiesImpl(
int requestTimeoutMillis,
Optional<String> maxInputLength,
Optional<String> taskTypeStore,
Optional<String> taskTypeRead)
Optional<String> taskTypeRead,
int maxBatchSize)
implements RequestProperties {
public RequestPropertiesImpl(
EmbeddingGateway.GetSupportedProvidersResponse.ProviderConfig.RequestProperties
Expand All @@ -84,9 +85,10 @@ public RequestPropertiesImpl(
grpcProviderConfigProperties.getMaxRetries(),
grpcProviderConfigProperties.getRetryDelayMillis(),
grpcProviderConfigProperties.getRequestTimeoutMillis(),
Optional.of(grpcProviderConfigProperties.getMaxInputLength()),
Optional.of(grpcProviderConfigProperties.getTaskTypeStore()),
Optional.of(grpcProviderConfigProperties.getTaskTypeRead()));
Optional.ofNullable(grpcProviderConfigProperties.getMaxInputLength()),
Optional.ofNullable(grpcProviderConfigProperties.getTaskTypeStore()),
Optional.ofNullable(grpcProviderConfigProperties.getTaskTypeRead()),
grpcProviderConfigProperties.getMaxBatchSize());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration(
properties.retryDelayMillis(),
properties.requestTimeoutMillis(),
properties.taskTypeRead(),
properties.taskTypeStore()));
properties.taskTypeStore(),
properties.maxBatchSize()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public EmbeddingGatewayClient(
* @return
*/
@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand Down Expand Up @@ -161,18 +162,30 @@ else if (value instanceof Boolean)
resp.getError().getErrorMessage());
}
if (resp.getEmbeddingsList() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
return resp.getEmbeddingsList().stream()
.map(
data -> {
float[] embedding = new float[data.getEmbeddingCount()];
for (int i = 0; i < data.getEmbeddingCount(); i++) {
embedding[i] = data.getEmbedding(i);
}
return embedding;
})
.toList();
final List<float[]> vectors =
resp.getEmbeddingsList().stream()
.map(
data -> {
float[] embedding = new float[data.getEmbeddingCount()];
for (int i = 0; i < data.getEmbeddingCount(); i++) {
embedding[i] = data.getEmbedding(i);
}
return embedding;
})
.toList();
return Response.of(batchId, vectors);
});
}

/**
* Return MAX_VALUE because the batching is done inside EGW
*
* @return
*/
@Override
public int maxBatchSize() {
return Integer.MAX_VALUE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io.stargate.sgv2.jsonapi.service.embedding.util.EmbeddingUtil;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
Expand Down Expand Up @@ -66,7 +65,7 @@ Uni<EmbeddingResponse> embed(
@HeaderParam("api-key") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand All @@ -83,7 +82,8 @@ private record Usage(int prompt_tokens, int total_tokens) {}
}

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -107,10 +107,17 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp.data() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
return Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
List<float[]> vectors =
Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
return Response.of(batchId, vectors);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Collections;
Expand Down Expand Up @@ -62,7 +61,7 @@ Uni<EmbeddingResponse> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand Down Expand Up @@ -90,7 +89,8 @@ public void setEmbeddings(List<float[]> embeddings) {
private static final String SEARCH_DOCUMENT = "search_document";

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -117,9 +117,14 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp.getEmbeddings() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
return resp.getEmbeddings();
return Response.of(batchId, resp.getEmbeddings());
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,33 @@ public interface EmbeddingProvider {
* @param apiKeyOverride Optional API key to be used for this request. If not provided, the
* default API key will be used.
* @param embeddingRequestType Type of request (INDEX or SEARCH)
* @return List of embeddings for the given texts
* @return VectorResponse
*/
Uni<List<float[]>> vectorize(
Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType);

/**
* returns the maximum batch size supported by the provider
*
* @return
*/
int maxBatchSize();

/**
* Record to hold the batchId and embedding vectors
*
* @param batchId - Sequence number for the batch to order the vectors.
* @param embeddings - Embedding vectors for the given text inputs.
*/
record Response(int batchId, List<float[]> embeddings) {
public static Response of(int batchId, List<float[]> embeddings) {
return new Response(batchId, embeddings);
}
}

enum EmbeddingRequestType {
/** This is used when vectorizing data in write operation for indexing */
INDEX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.Collections;
Expand Down Expand Up @@ -60,7 +59,7 @@ Uni<List<float[]>> embed(
EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand All @@ -70,7 +69,8 @@ public record Options(boolean waitForModel) {}
}

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -90,9 +90,14 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
return resp;
return Response.of(batchId, resp);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.time.Duration;
import java.util.*;
Expand Down Expand Up @@ -53,7 +52,7 @@ Uni<EmbeddingResponse> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
static RuntimeException mapException(Response response) {
static RuntimeException mapException(jakarta.ws.rs.core.Response response) {
return HttpResponseErrorMessageMapper.getDefaultException(response);
}
}
Expand All @@ -68,7 +67,8 @@ private record Usage(int prompt_tokens, int total_tokens) {}
}

@Override
public Uni<List<float[]>> vectorize(
public Uni<Response> vectorize(
int batchId,
List<String> texts,
Optional<String> apiKeyOverride,
EmbeddingRequestType embeddingRequestType) {
Expand All @@ -94,10 +94,17 @@ public Uni<List<float[]>> vectorize(
.transform(
resp -> {
if (resp.data() == null) {
return Collections.emptyList();
return Response.of(batchId, Collections.emptyList());
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
return Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList();
List<float[]> vectors =
Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList();
return Response.of(batchId, vectors);
});
}

@Override
public int maxBatchSize() {
return requestProperties.maxBatchSize();
}
}
Loading

0 comments on commit 2b1d285

Please sign in to comment.