diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/OpenInferenceProtocolImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/OpenInferenceProtocolImpl.java index 4ff5d0a0cfc..79cc8468caa 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/OpenInferenceProtocolImpl.java +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/OpenInferenceProtocolImpl.java @@ -247,7 +247,7 @@ public void modelInfer(ModelInferRequest request, StreamObserver responseHeaders) { ByteString output = ByteString.copyFrom(body); WorkerCommands cmd = this.getCmd(); - Gson gson = new Gson(); - String jsonResponse = output.toStringUtf8(); - JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class); switch (cmd) { case PREDICT: case STREAMPREDICT: case STREAMPREDICT2: - // condition for OIP grpc ModelInfer Call - if (ConfigManager.getInstance().isOpenInferenceProtocol() && isResponseStructureOIP(jsonObject)) { - if (((ServerCallStreamObserver) modelInferResponseObserver) - .isCancelled()) { - logger.warn( - "grpc client call already cancelled, not able to send this response for requestId: {}", - getPayload().getRequestId()); - return; - } - ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder(); - responseBuilder.setId(jsonObject.get("id").getAsString()); - responseBuilder.setModelName(jsonObject.get("model_name").getAsString()); - responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString()); - JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray(); - - for (JsonElement element : jsonOutputs) { - InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder(); - outputBuilder.setName(element.getAsJsonObject().get("name").getAsString()); - outputBuilder.setDatatype(element.getAsJsonObject().get("datatype").getAsString()); - JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray(); - shapeArray.forEach(shapeElement -> outputBuilder.addShape(shapeElement.getAsLong())); - setOutputContents(element, outputBuilder); - responseBuilder.addOutputs(outputBuilder); - - } - modelInferResponseObserver.onNext(responseBuilder.build()); - modelInferResponseObserver.onCompleted(); - } else { - ServerCallStreamObserver responseObserver = - (ServerCallStreamObserver) predictionResponseObserver; - cancelHandler(responseObserver); - PredictionResponse reply = - PredictionResponse.newBuilder().setPrediction(output).build(); - responseObserver.onNext(reply); - if (cmd == WorkerCommands.PREDICT - || (cmd == WorkerCommands.STREAMPREDICT - && responseHeaders - .get(RequestInput.TS_STREAM_NEXT) - .equals("false"))) { - responseObserver.onCompleted(); - logQueueTime(); - } else if (cmd == WorkerCommands.STREAMPREDICT2 - && (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null - || responseHeaders - .get(RequestInput.TS_STREAM_NEXT) - .equals("false"))) { - logQueueTime(); - } + ServerCallStreamObserver responseObserver = + (ServerCallStreamObserver) predictionResponseObserver; + cancelHandler(responseObserver); + PredictionResponse reply = + PredictionResponse.newBuilder().setPrediction(output).build(); + responseObserver.onNext(reply); + if (cmd == WorkerCommands.PREDICT + || (cmd == WorkerCommands.STREAMPREDICT + && responseHeaders + .get(RequestInput.TS_STREAM_NEXT) + .equals("false"))) { + responseObserver.onCompleted(); + logQueueTime(); + } else if (cmd == WorkerCommands.STREAMPREDICT2 + && (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null + || responseHeaders + .get(RequestInput.TS_STREAM_NEXT) + .equals("false"))) { + logQueueTime(); } break; case DESCRIBE: @@ -193,6 +161,36 @@ public void response( managementResponseObserver, Status.NOT_FOUND, e); } break; + case OIPPREDICT: + Gson gson = new Gson(); + String jsonResponse = output.toStringUtf8(); + JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class); + if (((ServerCallStreamObserver) modelInferResponseObserver) + .isCancelled()) { + logger.warn( + "grpc client call already cancelled, not able to send this response for requestId: {}", + getPayload().getRequestId()); + return; + } + ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder(); + responseBuilder.setId(jsonObject.get("id").getAsString()); + responseBuilder.setModelName(jsonObject.get("model_name").getAsString()); + responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString()); + JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray(); + + for (JsonElement element : jsonOutputs) { + InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder(); + outputBuilder.setName(element.getAsJsonObject().get("name").getAsString()); + outputBuilder.setDatatype(element.getAsJsonObject().get("datatype").getAsString()); + JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray(); + shapeArray.forEach(shapeElement -> outputBuilder.addShape(shapeElement.getAsLong())); + setOutputContents(element, outputBuilder); + responseBuilder.addOutputs(outputBuilder); + + } + modelInferResponseObserver.onNext(responseBuilder.build()); + modelInferResponseObserver.onCompleted(); + break; default: break; } @@ -244,6 +242,14 @@ public void sendError(int status, String error) { "org.pytorch.serve.http.InternalServerException") .asRuntimeException()); break; + case OIPPREDICT: + modelInferResponseObserver.onError( + responseStatus + .withDescription(error) + .augmentDescription( + "org.pytorch.serve.http.InternalServerException") + .asRuntimeException()); + break; default: break; } @@ -317,14 +323,4 @@ private void setOutputContents(JsonElement element, InferOutputTensor.Builder ou } outputBuilder.setContents(inferTensorContents); // set output contents } - - private boolean isResponseStructureOIP(JsonObject jsonObject) { - if (jsonObject.has("id") && - jsonObject.has("model_name") && - jsonObject.has("model_version") && - jsonObject.has("outputs")) { - return true; - } - return false; - } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/messages/WorkerCommands.java b/frontend/server/src/main/java/org/pytorch/serve/util/messages/WorkerCommands.java index 64266c9f4e4..d4833b237e9 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/messages/WorkerCommands.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/messages/WorkerCommands.java @@ -16,7 +16,9 @@ public enum WorkerCommands { @SerializedName("streampredict") STREAMPREDICT("streampredict"), @SerializedName("streampredict2") - STREAMPREDICT2("streampredict2"); + STREAMPREDICT2("streampredict2"), + @SerializedName("oippredict") // for kserve open inference protocol + OIPPREDICT("oippredict"); private String command;