Skip to content

Commit

Permalink
Merge pull request #1421 from pytorch/issue_1406
Browse files Browse the repository at this point in the history
extend describeAPI to support customized metadata
  • Loading branch information
lxning authored Feb 24, 2022
2 parents fa8cffa + a61a3fb commit 162af4f
Show file tree
Hide file tree
Showing 17 changed files with 444 additions and 65 deletions.
121 changes: 121 additions & 0 deletions docs/management_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,127 @@ curl http://localhost:8081/models/noop/all
]
```

`GET /models/{model_name}/{model_version}?customized=true`
or
`GET /models/{model_name}?customized=true`

Use the Describe Model API to get detail runtime status and customized metadata of a version of a model:
* Implement function describe_handle. Eg.
```
def describe_handle(self):
"""Customized describe handler
Returns:
dict : A dictionary response.
"""
output_describe = None
logger.info("Collect customized metadata")
return output_describe
```

* Implement function _is_describe if handler is not inherited from BaseHandler. And then, call _is_describe and describe_handle in handle.
```
def _is_describe(self):
if self.context and self.context.get_request_header(0, "describe"):
if self.context.get_request_header(0, "describe") == "True":
return True
return False
def handle(self, data, context):
if self._is_describe():
output = [self.describe_handle()]
else:
data_preprocess = self.preprocess(data)
if not self._is_explain():
output = self.inference(data_preprocess)
output = self.postprocess(output)
else:
output = self.explain_handle(data_preprocess, data)
return output
```

* Call function _is_describe and describe_handle in handle. Eg.
```
def handle(self, data, context):
"""Entry point for default handler. It takes the data from the input request and returns
the predicted outcome for the input.
Args:
data (list): The input data that needs to be made a prediction request on.
context (Context): It is a JSON Object containing information pertaining to
the model artefacts parameters.
Returns:
list : Returns a list of dictionary with the predicted response.
"""
# It can be used for pre or post processing if needed as additional request
# information is available in context
start_time = time.time()
self.context = context
metrics = self.context.metrics
is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None)
if is_profiler_enabled:
output, _ = self._infer_with_profiler(data=data)
else:
if self._is_describe():
output = [self.describe_handle()]
else:
data_preprocess = self.preprocess(data)
if not self._is_explain():
output = self.inference(data_preprocess)
output = self.postprocess(output)
else:
output = self.explain_handle(data_preprocess, data)
stop_time = time.time()
metrics.add_time('HandlerTime', round(
(stop_time - start_time) * 1000, 2), None, 'ms')
return output
```
* Here is an example. "customizedMetadata" shows the metadata from user's model. These metadata can be decoded into a dictionary.
```bash
curl http://localhost:8081/models/noop-customized/1.0?customized=true
[
{
"modelName": "noop-customized",
"modelVersion": "1.0",
"modelUrl": "noop-customized.mar",
"runtime": "python",
"minWorkers": 1,
"maxWorkers": 1,
"batchSize": 1,
"maxBatchDelay": 100,
"loadedAtStartup": false,
"workers": [
{
"id": "9010",
"startTime": "2022-02-08T11:03:20.974Z",
"status": "READY",
"memoryUsage": 0,
"pid": 98972,
"gpu": false,
"gpuUsage": "N/A"
}
],
"customizedMetadata": "{\n \"data1\": \"1\",\n \"data2\": \"2\"\n}"
}
]
```
* Decode customizedMetadata on client side. For example:
```
import requests
import json
response = requests.get('http://localhost:8081/models/noop-customized/?customized=true').json()
customizedMetadata = response[0]['customizedMetadata']
print(customizedMetadata)
```

## Unregister a model

This API follows the [ManagementAPIsService.UnregisterModel](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/management.proto) gRPC API. It returns the status of a model in the ModelServer.
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelException;
Expand All @@ -18,27 +19,48 @@
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.http.StatusResponse;
import org.pytorch.serve.job.GRPCJob;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.GRPCUtils;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.wlm.ModelManager;

public class ManagementImpl extends ManagementAPIsServiceImplBase {

@Override
public void describeModel(
DescribeModelRequest request, StreamObserver<ManagementResponse> responseObserver) {

String requestId = UUID.randomUUID().toString();
RequestInput input = new RequestInput(requestId);
String modelName = request.getModelName();
String modelVersion = request.getModelVersion();

String resp;
try {
resp =
JsonUtils.GSON_PRETTY.toJson(
ApiUtils.getModelDescription(modelName, modelVersion));
sendResponse(responseObserver, resp);
} catch (ModelNotFoundException | ModelVersionNotFoundException e) {
sendErrorResponse(responseObserver, Status.NOT_FOUND, e);
boolean customized = request.getCustomized();

if ("all".equals(modelVersion) || !customized) {
String resp;
try {
resp =
JsonUtils.GSON_PRETTY.toJson(
ApiUtils.getModelDescription(modelName, modelVersion));
sendResponse(responseObserver, resp);
} catch (ModelNotFoundException | ModelVersionNotFoundException e) {
sendErrorResponse(responseObserver, Status.NOT_FOUND, e);
}
} else {
input.updateHeaders("describe", "True");
Job job = new GRPCJob(responseObserver, modelName, modelVersion, input);

try {
if (!ModelManager.getInstance().addJob(job)) {
String responseMessage = ApiUtils.getDescribeErrorResponseMessage(modelName);
InternalServerException e = new InternalServerException(responseMessage);
sendException(responseObserver, e, "InternalServerException.()");
}
} catch (ModelNotFoundException | ModelVersionNotFoundException e) {
sendErrorResponse(responseObserver, Status.INTERNAL, e);
}
}
}

Expand Down Expand Up @@ -161,7 +183,7 @@ private void sendErrorResponse(
.asRuntimeException());
}

private void sendErrorResponse(
public static void sendErrorResponse(
StreamObserver<ManagementResponse> responseObserver, Status status, Exception e) {
responseObserver.onError(
status.withDescription(e.getMessage())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
import org.pytorch.serve.http.MethodNotAllowedException;
import org.pytorch.serve.http.RequestTimeoutException;
import org.pytorch.serve.http.ResourceNotFoundException;
import org.pytorch.serve.http.ServiceUnavailableException;
import org.pytorch.serve.http.StatusResponse;
import org.pytorch.serve.http.messages.DescribeModelResponse;
import org.pytorch.serve.http.messages.KFV1ModelReadyResponse;
import org.pytorch.serve.http.messages.ListModelsResponse;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.job.RestJob;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerThread;
Expand Down Expand Up @@ -79,7 +83,7 @@ public void handleRequest(
modelVersion = segments[3];
}
if (HttpMethod.GET.equals(method)) {
handleDescribeModel(ctx, segments[2], modelVersion);
handleDescribeModel(ctx, req, segments[2], modelVersion, decoder);
} else if (HttpMethod.PUT.equals(method)) {
if (segments.length == 5 && "set-default".equals(segments[4])) {
setDefaultModelVersion(ctx, segments[2], segments[3]);
Expand Down Expand Up @@ -127,12 +131,31 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco
}

private void handleDescribeModel(
ChannelHandlerContext ctx, String modelName, String modelVersion)
ChannelHandlerContext ctx,
FullHttpRequest req,
String modelName,
String modelVersion,
QueryStringDecoder decoder)
throws ModelNotFoundException, ModelVersionNotFoundException {

ArrayList<DescribeModelResponse> resp =
ApiUtils.getModelDescription(modelName, modelVersion);
NettyUtils.sendJsonResponse(ctx, resp);
boolean customizedMetadata =
Boolean.parseBoolean(NettyUtils.getParameter(decoder, "customized", "false"));
if ("all".equals(modelVersion) || !customizedMetadata) {
ArrayList<DescribeModelResponse> resp =
ApiUtils.getModelDescription(modelName, modelVersion);
NettyUtils.sendJsonResponse(ctx, resp);
} else {
String requestId = NettyUtils.getRequestId(ctx.channel());
RequestInput input = new RequestInput(requestId);
for (Map.Entry<String, String> entry : req.headers().entries()) {
input.updateHeaders(entry.getKey(), entry.getValue());
}
input.updateHeaders("describe", "True");
RestJob job = new RestJob(ctx, modelName, modelVersion, WorkerCommands.DESCRIBE, input);
if (!ModelManager.getInstance().addJob(job)) {
String responseMessage = ApiUtils.getDescribeErrorResponseMessage(modelName);
throw new ServiceUnavailableException(responseMessage);
}
}
}

private void handleKF1ModelReady(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.pytorch.serve.http.messages;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
Expand All @@ -20,6 +21,7 @@ public class DescribeModelResponse {

private List<Worker> workers;
private Metrics metrics;
private String customizedMetadata;

public DescribeModelResponse() {
workers = new ArrayList<>();
Expand Down Expand Up @@ -148,6 +150,14 @@ public void setMetrics(Metrics metrics) {
this.metrics = metrics;
}

public void setCustomizedMetadata(byte[] customizedMetadata) {
this.customizedMetadata = new String(customizedMetadata, Charset.forName("UTF-8"));
}

public String getCustomizedMetadata() {
return customizedMetadata;
}

public static final class Worker {

private String id;
Expand Down
Loading

0 comments on commit 162af4f

Please sign in to comment.