Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http stream response via http 1.1 chunked encoding #2233

Merged
merged 13 commits into from
Apr 21, 2023
33 changes: 30 additions & 3 deletions docs/inference_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ To get predictions from a specific version of each loaded model, make a REST cal

* POST /predictions/{model_name}/{version}

## curl Example
### curl Example

```bash
curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg
Expand All @@ -95,6 +95,34 @@ The result is JSON that tells you that the image is most likely a tabby cat. The
"probability": 0.42514491081237793
}
```
* Streaming response via HTTP 1.1 chunked encoding
TorchServe the inference API support streaming response to allow a sequence of inference responses to be sent over HTTP 1.1 chunked encoding. This new feature is only recommended for use case when the inference latency of the full response is high and the inference intermediate results are sent to client. An example could be LLMs for generative applications, where generating "n" number of tokens can have high latency, in this case user can receive each generated token once ready until the full response completes. To achieve streaming response, backend handler calls "send_intermediate_predict_response" to send one intermediate result to frontend, and return the last result as the existing style. For example,
```
from ts.protocol.otf_message_handler import send_intermediate_predict_response
def handle(data, context):
if type(data) is list:
for i in range (3):
send_intermediate_predict_response(["intermediate_response"], context.request_ids, "Intermediate Prediction success", 200, context)
return ["hello world "]
```
Client side receives the chunked data.
```
def test_echo_stream_inference():
test_utils.start_torchserve(no_config_snapshots=True, gen_mar=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just promote these test_utils functions to the core library they're very useful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which core library? who is the user of the core library?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the ts namespace

test_utils.register_model('echo_stream',
'https://torchserve.pytorch.org/mar_files/echo_stream.mar')

response = requests.post(TF_INFERENCE_API + '/predictions/echo_stream', data="foo", stream=True)
assert response.headers['Transfer-Encoding'] == 'chunked'

prediction = []
for chunk in (response.iter_content(chunk_size=None)):
if chunk:
prediction.append(chunk.decode("utf-8"))

assert str(" ".join(prediction)) == "hello hello hello hello world "
test_utils.unregister_model('echo_stream')
```
## Explanations API

Torchserve makes use of Captum's functionality to return the explanations of the models that is served.
Expand Down Expand Up @@ -181,10 +209,9 @@ The result is a json that gives you the explanations for the input json
0.007599905146155397,
,
,
,
,
]
]
]
]
}

Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ public void response(
@Override
public void sendError(int status, String error) {
Status responseStatus = GRPCUtils.getGRPCStatusCode(status);
if (this.getCmd() == WorkerCommands.PREDICT) {
if (this.getCmd() == WorkerCommands.PREDICT
|| this.getCmd() == WorkerCommands.STREAMPREDICT) {
predictionResponseObserver.onError(
responseStatus
.withDescription(error)
Expand Down
77 changes: 56 additions & 21 deletions frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package org.pytorch.serve.job;

import static org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.CharsetUtil;
import java.util.ArrayList;
import java.util.Map;
Expand Down Expand Up @@ -38,6 +45,11 @@ public class RestJob extends Job {

private ChannelHandlerContext ctx;
private CompletableFuture<byte[]> responsePromise;
/**
* numStreams is used to track 4 cases -1: stream end 0: non-stream response (default use case)
* 1: the first stream response [2, max_integer]: the 2nd and more stream response
*/
private int numStreams;

public RestJob(
ChannelHandlerContext ctx,
Expand All @@ -47,6 +59,7 @@ public RestJob(
RequestInput input) {
super(modelName, version, cmd, input);
this.ctx = ctx;
this.numStreams = 0;
}

@Override
Expand Down Expand Up @@ -117,7 +130,14 @@ private void responseInference(
(statusPhrase == null)
? HttpResponseStatus.valueOf(statusCode)
: new HttpResponseStatus(statusCode, statusPhrase);
FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false);
HttpResponse resp;

if (responseHeaders != null && responseHeaders.containsKey(TS_STREAM_NEXT)) {
resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status, false);
numStreams = responseHeaders.get(TS_STREAM_NEXT).equals("true") ? numStreams + 1 : -1;
} else {
resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false);
}

if (contentType != null && contentType.length() > 0) {
resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType);
Expand All @@ -127,7 +147,9 @@ private void responseInference(
resp.headers().set(e.getKey(), e.getValue());
}
}
resp.content().writeBytes(body);
if (resp instanceof DefaultFullHttpResponse) {
((DefaultFullHttpResponse) resp).content().writeBytes(body);
}

/*
* We can load the models based on the configuration file.Since this Job is
Expand All @@ -136,29 +158,42 @@ private void responseInference(
* by external clients.
*/
if (ctx != null) {
MetricAggregator.handleInferenceMetric(
getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime);
NettyUtils.sendHttpResponse(ctx, resp, true);
if (numStreams == 0) { // non-stream response
MetricAggregator.handleInferenceMetric(
getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime);
NettyUtils.sendHttpResponse(ctx, resp, true);
} else if (numStreams == -1) { // the last response in a stream
MetricAggregator.handleInferenceMetric(
getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime);
ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body)));
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);
Comment on lines +168 to +169

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor/Curious: Instead of writing twice, can you just write once and tell netty that this is the last chunk. Also assuming you cant do that, do you need to writeAndFlush here or can you just write knowing that you are immediately writeAndFlush'ing the LastHttpConent chunk after?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it needs call writeAndFlush twice, kind of similar as grpc (onComplete).

} else if (numStreams == 1) { // the first response in a stream
NettyUtils.sendHttpResponse(ctx, resp, true);
} else if (numStreams > 1) { // the 2nd+ response in a stream
ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body)));
}
} else if (responsePromise != null) {
responsePromise.complete(body);
}

logger.debug(
"Waiting time ns: {}, Backend time ns: {}",
getScheduled() - getBegin(),
System.nanoTime() - getScheduled());
String queueTime =
String.valueOf(
TimeUnit.MILLISECONDS.convert(
getScheduled() - getBegin(), TimeUnit.NANOSECONDS));
loggerTsMetrics.info(
"{}",
new Metric(
"QueueTime",
queueTime,
"ms",
ConfigManager.getInstance().getHostName(),
DIMENSION));
if (numStreams <= 0) {
logger.debug(
"Waiting time ns: {}, Backend time ns: {}",
getScheduled() - getBegin(),
System.nanoTime() - getScheduled());
String queueTime =
String.valueOf(
TimeUnit.MILLISECONDS.convert(
getScheduled() - getBegin(), TimeUnit.NANOSECONDS));
loggerTsMetrics.info(
"{}",
new Metric(
"QueueTime",
queueTime,
"ms",
ConfigManager.getInstance().getHostName(),
DIMENSION));
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
Expand Down Expand Up @@ -142,7 +143,7 @@ public static void sendError(
* @param keepAlive if keep the connection
*/
public static void sendHttpResponse(
ChannelHandlerContext ctx, FullHttpResponse resp, boolean keepAlive) {
ChannelHandlerContext ctx, HttpResponse resp, boolean keepAlive) {
// Send the response and close the connection if necessary.
Channel channel = ctx.channel();
Session session = channel.attr(SESSION_KEY).getAndSet(null);
Expand Down Expand Up @@ -189,7 +190,11 @@ public static void sendHttpResponse(
headers.set("Cache-Control", "no-cache; no-store, must-revalidate, private");
headers.set("Expires", "Thu, 01 Jan 1970 00:00:00 UTC");

HttpUtil.setContentLength(resp, resp.content().readableBytes());
if (resp instanceof FullHttpResponse) {
HttpUtil.setContentLength(resp, ((FullHttpResponse) resp).content().readableBytes());
} else {
HttpUtil.setTransferEncodingChunked(resp, true);
}
if (!keepAlive || code >= 400) {
headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
ChannelFuture f = channel.writeAndFlush(resp);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.pytorch.serve.wlm;

import static org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT;

import java.util.LinkedHashMap;
import java.util.Map;
import org.pytorch.serve.job.Job;
Expand Down Expand Up @@ -61,13 +59,18 @@ public BaseModelRequest getRequest(String threadName, WorkerState state)
return req;
}

public void sendResponse(ModelWorkerResponse message) {
/**
* @param message: a response of a batch inference requests
* @return - true: either a non-stream response or last stream response is sent - false: a
* stream response (not include the last stream) is sent
*/
public boolean sendResponse(ModelWorkerResponse message) {
lxning marked this conversation as resolved.
Show resolved Hide resolved
boolean jobDone = true;
// TODO: Handle prediction level code
if (message.getCode() == 200) {
if (jobs.isEmpty()) {
// this is from initial load.
return;
return true;
}
for (Predictions prediction : message.getPredictions()) {
String jobId = prediction.getRequestId();
Expand All @@ -77,9 +80,16 @@ public void sendResponse(ModelWorkerResponse message) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with 200 status code: " + jobId);
}
if (job.getCmd() == WorkerCommands.STREAMPREDICT
&& prediction.getHeaders().get(TS_STREAM_NEXT).equals("true")) {
jobDone = false;
if (jobDone) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do all the java side changes need unit tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a tech debt in frontend which is lack of unit tests, only includes integration test . This PR uses regression to test the e2e.

String streamNext =
prediction
.getHeaders()
.get(
org.pytorch.serve.util.messages.RequestInput
.TS_STREAM_NEXT);
if (streamNext != null && streamNext.equals("true")) {
jobDone = false;
}
}
job.response(
prediction.getResp(),
Expand All @@ -103,6 +113,7 @@ public void sendResponse(ModelWorkerResponse message) {
if (jobDone) {
jobs.clear();
}
return jobDone;
}

public void sendError(BaseModelRequest message, String error, int status) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.pytorch.serve.util.codec.ModelResponseDecoder;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.InputParameter;
import org.pytorch.serve.util.messages.ModelInferenceRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
Expand Down Expand Up @@ -187,7 +186,9 @@ public void run() {
logger.info("Flushing req.cmd {} to backend at: {}", req.getCommand(), wtStartTime);
int repeats =
(req.getCommand() == WorkerCommands.LOAD)
|| (req.getCommand() == WorkerCommands.PREDICT
|| ((req.getCommand() == WorkerCommands.PREDICT
|| req.getCommand()
== WorkerCommands.STREAMPREDICT)
&& model.getParallelLevel() > 1
&& model.getParallelType()
!= ModelConfig.ParallelType.PP)
Expand All @@ -200,46 +201,29 @@ public void run() {
boolean isStreaming =
req.getCommand() == WorkerCommands.STREAMPREDICT ? true : false;
ModelWorkerResponse reply = null;
long duration = 0;
long begin = System.currentTimeMillis();

if (!isStreaming) {
boolean jobDone = false;
long totalDuration = 0;
do {
long begin = System.currentTimeMillis();
for (int i = 0; i < repeats; i++) {
reply = replies.poll(responseTimeout, TimeUnit.SECONDS);
}

duration = System.currentTimeMillis() - begin;
logger.info("Backend response time: {}", duration);
long duration = System.currentTimeMillis() - begin;

if (reply != null) {
aggregator.sendResponse(reply);
jobDone = aggregator.sendResponse(reply);
logger.debug("sent a reply, jobdone: {}", jobDone);
} else if (req.getCommand() != WorkerCommands.DESCRIBE) {
int val = model.incrFailedInfReqs();
logger.error("Number or consecutive unsuccessful inference {}", val);
throw new WorkerInitializationException(
"Backend worker did not respond in given time");
}
} else {
ModelInferenceRequest inferReq = (ModelInferenceRequest) req;
boolean streamNext = true;
while (streamNext) {
for (int i = 0; i < repeats; i++) {
reply = replies.poll(responseTimeout, TimeUnit.SECONDS);
}
if (reply.getPredictions()
.get(0)
.getHeaders()
.get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT)
.equals("false")) {
duration = System.currentTimeMillis() - begin;
logger.info("Backend response time: {}", duration);
streamNext = false;
}
if (reply != null) {
aggregator.sendResponse(reply);
}
}
}
totalDuration += duration;
} while (!jobDone);
logger.info("Backend response time: {}", totalDuration);

switch (req.getCommand()) {
case PREDICT:
Expand Down Expand Up @@ -272,7 +256,8 @@ public void run() {
}
req = null;
String workerThreadTime =
String.valueOf(((System.currentTimeMillis() - wtStartTime) - duration));
String.valueOf(
((System.currentTimeMillis() - wtStartTime) - totalDuration));
loggerTsMetrics.info(
"{}",
new Metric(
Expand Down
Loading