-
Notifications
You must be signed in to change notification settings - Fork 871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
http stream response via http 1.1 chunked encoding #2233
Changes from all commits
2c9e75f
2f3e46b
916aac6
79b9160
e365266
92c4892
cf126b4
6e4025b
a1d695b
6b9706e
2b86563
fa037be
48867c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
package org.pytorch.serve.job; | ||
|
||
import static org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT; | ||
|
||
import io.netty.buffer.ByteBuf; | ||
import io.netty.buffer.Unpooled; | ||
import io.netty.channel.ChannelHandlerContext; | ||
import io.netty.handler.codec.http.DefaultFullHttpResponse; | ||
import io.netty.handler.codec.http.DefaultHttpContent; | ||
import io.netty.handler.codec.http.DefaultHttpResponse; | ||
import io.netty.handler.codec.http.FullHttpResponse; | ||
import io.netty.handler.codec.http.HttpHeaderNames; | ||
import io.netty.handler.codec.http.HttpHeaderValues; | ||
import io.netty.handler.codec.http.HttpResponse; | ||
import io.netty.handler.codec.http.HttpResponseStatus; | ||
import io.netty.handler.codec.http.HttpVersion; | ||
import io.netty.handler.codec.http.LastHttpContent; | ||
import io.netty.util.CharsetUtil; | ||
import java.util.ArrayList; | ||
import java.util.Map; | ||
|
@@ -38,6 +45,11 @@ public class RestJob extends Job { | |
|
||
private ChannelHandlerContext ctx; | ||
private CompletableFuture<byte[]> responsePromise; | ||
/** | ||
* numStreams is used to track 4 cases -1: stream end 0: non-stream response (default use case) | ||
* 1: the first stream response [2, max_integer]: the 2nd and more stream response | ||
*/ | ||
private int numStreams; | ||
|
||
public RestJob( | ||
ChannelHandlerContext ctx, | ||
|
@@ -47,6 +59,7 @@ public RestJob( | |
RequestInput input) { | ||
super(modelName, version, cmd, input); | ||
this.ctx = ctx; | ||
this.numStreams = 0; | ||
} | ||
|
||
@Override | ||
|
@@ -117,7 +130,14 @@ private void responseInference( | |
(statusPhrase == null) | ||
? HttpResponseStatus.valueOf(statusCode) | ||
: new HttpResponseStatus(statusCode, statusPhrase); | ||
FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); | ||
HttpResponse resp; | ||
|
||
if (responseHeaders != null && responseHeaders.containsKey(TS_STREAM_NEXT)) { | ||
resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status, false); | ||
numStreams = responseHeaders.get(TS_STREAM_NEXT).equals("true") ? numStreams + 1 : -1; | ||
} else { | ||
resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); | ||
} | ||
|
||
if (contentType != null && contentType.length() > 0) { | ||
resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); | ||
|
@@ -127,7 +147,9 @@ private void responseInference( | |
resp.headers().set(e.getKey(), e.getValue()); | ||
} | ||
} | ||
resp.content().writeBytes(body); | ||
if (resp instanceof DefaultFullHttpResponse) { | ||
((DefaultFullHttpResponse) resp).content().writeBytes(body); | ||
} | ||
|
||
/* | ||
* We can load the models based on the configuration file.Since this Job is | ||
|
@@ -136,29 +158,42 @@ private void responseInference( | |
* by external clients. | ||
*/ | ||
if (ctx != null) { | ||
MetricAggregator.handleInferenceMetric( | ||
getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime); | ||
NettyUtils.sendHttpResponse(ctx, resp, true); | ||
if (numStreams == 0) { // non-stream response | ||
MetricAggregator.handleInferenceMetric( | ||
getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime); | ||
NettyUtils.sendHttpResponse(ctx, resp, true); | ||
} else if (numStreams == -1) { // the last response in a stream | ||
MetricAggregator.handleInferenceMetric( | ||
getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime); | ||
ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body))); | ||
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT); | ||
Comment on lines
+168
to
+169
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor/Curious: Instead of writing twice, can you just write once and tell netty that this is the last chunk. Also assuming you cant do that, do you need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it needs call writeAndFlush twice, kind of similar as grpc (onComplete). |
||
} else if (numStreams == 1) { // the first response in a stream | ||
NettyUtils.sendHttpResponse(ctx, resp, true); | ||
} else if (numStreams > 1) { // the 2nd+ response in a stream | ||
ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body))); | ||
} | ||
} else if (responsePromise != null) { | ||
responsePromise.complete(body); | ||
} | ||
|
||
logger.debug( | ||
"Waiting time ns: {}, Backend time ns: {}", | ||
getScheduled() - getBegin(), | ||
System.nanoTime() - getScheduled()); | ||
String queueTime = | ||
String.valueOf( | ||
TimeUnit.MILLISECONDS.convert( | ||
getScheduled() - getBegin(), TimeUnit.NANOSECONDS)); | ||
loggerTsMetrics.info( | ||
"{}", | ||
new Metric( | ||
"QueueTime", | ||
queueTime, | ||
"ms", | ||
ConfigManager.getInstance().getHostName(), | ||
DIMENSION)); | ||
if (numStreams <= 0) { | ||
logger.debug( | ||
"Waiting time ns: {}, Backend time ns: {}", | ||
getScheduled() - getBegin(), | ||
System.nanoTime() - getScheduled()); | ||
String queueTime = | ||
String.valueOf( | ||
TimeUnit.MILLISECONDS.convert( | ||
getScheduled() - getBegin(), TimeUnit.NANOSECONDS)); | ||
loggerTsMetrics.info( | ||
"{}", | ||
new Metric( | ||
"QueueTime", | ||
queueTime, | ||
"ms", | ||
ConfigManager.getInstance().getHostName(), | ||
DIMENSION)); | ||
} | ||
} | ||
|
||
@Override | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
package org.pytorch.serve.wlm; | ||
|
||
import static org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT; | ||
|
||
import java.util.LinkedHashMap; | ||
import java.util.Map; | ||
import org.pytorch.serve.job.Job; | ||
|
@@ -61,13 +59,18 @@ public BaseModelRequest getRequest(String threadName, WorkerState state) | |
return req; | ||
} | ||
|
||
public void sendResponse(ModelWorkerResponse message) { | ||
/** | ||
* @param message: a response of a batch inference requests | ||
* @return - true: either a non-stream response or last stream response is sent - false: a | ||
* stream response (not include the last stream) is sent | ||
*/ | ||
public boolean sendResponse(ModelWorkerResponse message) { | ||
lxning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
boolean jobDone = true; | ||
// TODO: Handle prediction level code | ||
if (message.getCode() == 200) { | ||
if (jobs.isEmpty()) { | ||
// this is from initial load. | ||
return; | ||
return true; | ||
} | ||
for (Predictions prediction : message.getPredictions()) { | ||
String jobId = prediction.getRequestId(); | ||
|
@@ -77,9 +80,16 @@ public void sendResponse(ModelWorkerResponse message) { | |
throw new IllegalStateException( | ||
"Unexpected job in sendResponse() with 200 status code: " + jobId); | ||
} | ||
if (job.getCmd() == WorkerCommands.STREAMPREDICT | ||
&& prediction.getHeaders().get(TS_STREAM_NEXT).equals("true")) { | ||
jobDone = false; | ||
if (jobDone) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do all the java side changes need unit tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a tech debt in frontend which is lack of unit tests, only includes integration test . This PR uses regression to test the e2e. |
||
String streamNext = | ||
prediction | ||
.getHeaders() | ||
.get( | ||
org.pytorch.serve.util.messages.RequestInput | ||
.TS_STREAM_NEXT); | ||
if (streamNext != null && streamNext.equals("true")) { | ||
jobDone = false; | ||
} | ||
} | ||
job.response( | ||
prediction.getResp(), | ||
|
@@ -103,6 +113,7 @@ public void sendResponse(ModelWorkerResponse message) { | |
if (jobDone) { | ||
jobs.clear(); | ||
} | ||
return jobDone; | ||
} | ||
|
||
public void sendError(BaseModelRequest message, String error, int status) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just promote these test_utils functions to the core library they're very useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which core library? who is the user of the core library?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the
ts
namespace