Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Optimized text embedding post processing performance #3459

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
Expand All @@ -34,7 +33,7 @@
import java.util.List;

/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
public class CrossEncoderServingTranslator implements NoBatchifyTranslator<Input, Output> {
public class CrossEncoderServingTranslator implements Translator<Input, Output> {

private Translator<StringPair, float[]> translator;

Expand All @@ -56,74 +55,13 @@ public void prepare(TranslatorContext ctx) throws Exception {
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
ReRankingInput in = ReRankingInput.parseInput(input);
if (in.batch != null) {
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, in.batch);
}

String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
StringPair pair = null;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
ctx.setAttachment("batch", Boolean.TRUE);
JsonArray array = element.getAsJsonArray();
int size = array.size();
List<StringPair> inputs = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
JsonObject obj = array.get(i).getAsJsonObject();
inputs.add(parseStringPair(obj));
}
return translator.batchProcessInput(ctx, inputs);
} else if (element.isJsonObject()) {
JsonObject obj = element.getAsJsonObject();
JsonElement query = obj.get("query");
if (query != null) {
String key = query.getAsString();
JsonArray texts = obj.get("texts").getAsJsonArray();
int size = texts.size();
List<StringPair> inputs = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
String value = texts.get(i).getAsString();
inputs.add(new StringPair(key, value));
}
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, inputs);
} else {
pair = parseStringPair(obj);
}
} else {
throw new TranslateException("Unexpected json type");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String text = input.getAsString("text");
String textPair = input.getAsString("text_pair");
if (text != null && textPair != null) {
pair = new StringPair(text, textPair);
}
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key != null && value != null) {
pair = new StringPair(key, value);
}
}

if (pair == null) {
throw new TranslateException("Missing key or value in input.");
}

NDList ret = translator.processInput(ctx, pair);
NDList ret = translator.processInput(ctx, in.pair);
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
NDList[] batch = {ret};
Expand All @@ -132,6 +70,27 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
return ret;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) throws Exception {
int[] mapping = new int[inputs.size()];
List<StringPair> prompts = new ArrayList<>(mapping.length);
for (int i = 0; i < mapping.length; ++i) {
ReRankingInput in = ReRankingInput.parseInput(inputs.get(i));
if (in.batch != null) {
List<StringPair> batch = in.batch;
mapping[i] = batch.size();
prompts.addAll(batch);
} else {
mapping[i] = -1;
prompts.add(in.pair);
}
}
ctx.setAttachment("mapping", mapping);
return translator.batchProcessInput(ctx, prompts);
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Expand All @@ -149,17 +108,126 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
return output;
}

private StringPair parseStringPair(JsonObject json) throws TranslateException {
JsonElement text = json.get("text");
JsonElement textPair = json.get("text_pair");
if (text != null && textPair != null) {
return new StringPair(text.getAsString(), textPair.getAsString());
/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception {
List<float[]> outputs = translator.batchProcessOutput(ctx, list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
List<Output> ret = new ArrayList<>(mapping.length);
int index = 0;
for (int size : mapping) {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (size == -1) {
// non-batching
output.add(BytesSupplier.wrapAsJson(outputs.get(index++)));
} else {
// client side batching
float[][] embeddings = new float[size][];
for (int j = 0; j < size; ++j) {
embeddings[j] = outputs.get(index++);
}
output.add(BytesSupplier.wrapAsJson(embeddings));
}
ret.add(output);
}
return ret;
}

private static final class ReRankingInput {

private StringPair pair;
private List<StringPair> batch;

ReRankingInput(StringPair pair) {
this.pair = pair;
}

ReRankingInput(List<StringPair> batch) {
this.batch = batch;
}
JsonElement key = json.get("key");
JsonElement value = json.get("value");
if (key != null && value != null) {
return new StringPair(key.getAsString(), value.getAsString());

static ReRankingInput parseInput(Input input) throws TranslateException {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
}

String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
StringPair pair = null;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
JsonArray array = element.getAsJsonArray();
int size = array.size();
List<StringPair> batch = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
JsonObject obj = array.get(i).getAsJsonObject();
batch.add(parseStringPair(obj));
}
return new ReRankingInput(batch);
} else if (element.isJsonObject()) {
JsonObject obj = element.getAsJsonObject();
JsonElement query = obj.get("query");
if (query != null) {
String key = query.getAsString();
JsonArray texts = obj.get("texts").getAsJsonArray();
int size = texts.size();
List<StringPair> batch = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
String value = texts.get(i).getAsString();
batch.add(new StringPair(key, value));
}
return new ReRankingInput(batch);
} else {
pair = parseStringPair(obj);
}
} else {
throw new TranslateException("Unexpected json type");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String text = input.getAsString("text");
String textPair = input.getAsString("text_pair");
if (text != null && textPair != null) {
pair = new StringPair(text, textPair);
}
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key != null && value != null) {
pair = new StringPair(key, value);
}
}

if (pair == null) {
throw new TranslateException("Missing key or value in input.");
}
return new ReRankingInput(pair);
}

private static StringPair parseStringPair(JsonObject json) throws TranslateException {
JsonElement text = json.get("text");
JsonElement textPair = json.get("text_pair");
if (text != null && textPair != null) {
return new StringPair(text.getAsString(), textPair.getAsString());
}
JsonElement key = json.get("key");
JsonElement value = json.get("value");
if (key != null && value != null) {
return new StringPair(key.getAsString(), value.getAsString());
}
throw new TranslateException("Missing text or text_pair in json.");
}
throw new TranslateException("Missing text or text_pair in json.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception

/** {@inheritDoc} */
@Override
@SuppressWarnings({"PMD.SignatureDeclareThrowsException", "unchecked"})
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception {
List<float[]> outputs = translator.batchProcessOutput(ctx, list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
Expand Down
2 changes: 1 addition & 1 deletion extensions/tokenizers/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ tasks {
downloadPath gzipInto file
}

if ("text_embedding" != task)
if (task !in arrayOf("text_embedding", "text_classification"))
continue

file = prefix / task / "ai.djl.huggingface.rust.json"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -88,14 +89,29 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
/** {@inheritDoc} */
@Override
public List<float[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
NDList[] batches = batchifier.unbatchify(list);
List<float[]> ret = new ArrayList<>(batches.length);
for (NDList batch : batches) {
NDArray result = batch.get(0);
if (sigmoid) {
if (sigmoid) {
NDList[] batches = batchifier.unbatchify(list);
List<float[]> ret = new ArrayList<>(batches.length);
for (NDList batch : batches) {
NDArray result = batch.get(0);
result = result.getNDArrayInternal().sigmoid();
ret.add(result.toFloatArray());
}
ret.add(result.toFloatArray());
return ret;
}
NDArray array = list.get(0);
int batchSize = Math.toIntExact(array.size(0));
float[] buf = list.get(0).toFloatArray();
if (batchSize == 1) {
return Collections.singletonList(buf);
}

int length = buf.length / batchSize;
List<float[]> ret = new ArrayList<>(batchSize);
for (int i = 0; i < batchSize; ++i) {
float[] f = new float[length];
System.arraycopy(buf, i * length, f, 0, length);
ret.add(f);
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -155,14 +156,20 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
/** {@inheritDoc} */
@Override
public List<float[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
int batchSize = Math.toIntExact(list.head().size(0));
NDArray attentionMask = (NDArray) ctx.getAttachment("attentionMask");
NDArray output = processEmbedding(list, attentionMask);
int batchSize = Math.toIntExact(output.size(0));
float[] buf = output.toFloatArray();
if (batchSize == 1) {
return Collections.singletonList(buf);
}

int length = buf.length / batchSize;
List<float[]> ret = new ArrayList<>(batchSize);
NDList splitList = output.split(batchSize);
for (int i = 0; i < batchSize; i++) {
NDArray array = splitList.get(i);
ret.add(array.toFloatArray());
for (int i = 0; i < batchSize; ++i) {
float[] f = new float[length];
System.arraycopy(buf, i * length, f, 0, length);
ret.add(f);
}
return ret;
}
Expand Down
Loading
Loading