-
Notifications
You must be signed in to change notification settings - Fork 864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add KServe gRPC v2 support #2176
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
3416ccf
feat: add KServe gRPC v2 support
88c2117
feat: add utils to convert kserve pb to ts pb
ae00b78
add ts pb to kserve pb conversion method
0c0fd22
Add pb python file generation step at docker build
456d476
fix: readme doc
ef09c5f
update readme
f8a800f
fix lint errors
73a9f53
fix kserve_v2 service envelop and test data
a838ad7
re-test
f39a367
Merge branch 'master' into feat/kserve-grpc
chauhang 97883eb
Merge branch 'master' into feat/kserve-grpc
chauhang fef6692
re-test
2feb461
Merge branch 'master' into feat/kserve-grpc
chauhang 912c528
Merge branch 'master' into feat/kserve-grpc
agunapal 9d7c786
Merge branch 'master' of https://github.com/pytorch/serve into feat/k…
374f958
Merge branch 'master' into feat/kserve-grpc
msaroufim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 2 additions & 1 deletion
3
kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_bytes.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_bytes_grpc.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"model_name": "mnist", | ||
"inputs": [{ | ||
"name": "312a4eb0-0ca7-4803-a101-a6d2c18486fe", | ||
"shape": [-1], | ||
"datatype": "BYTES", | ||
"contents": { | ||
"bytes_contents": ["iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII="] | ||
} | ||
}] | ||
} |
12 changes: 12 additions & 0 deletions
12
kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_tensor_grpc.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298", | ||
"model_name": "mnist", | ||
"inputs": [{ | ||
"name": "input-0", | ||
"shape": [1, 28, 28], | ||
"datatype": "FP32", | ||
"contents": { | ||
"fp32_contents| ||
} | ||
}] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,20 +2,37 @@ | |
return a KServe side response """ | ||
import logging | ||
import pathlib | ||
from enum import Enum | ||
from typing import Dict, Union | ||
|
||
import grpc | ||
import inference_pb2_grpc | ||
import kserve | ||
from gprc_utils import from_ts_grpc, to_ts_grpc | ||
from inference_pb2 import PredictionResponse | ||
from kserve.errors import ModelMissingError | ||
from kserve.model import Model as Model | ||
from kserve.protocol.grpc.grpc_predict_v2_pb2 import ( | ||
ModelInferRequest, | ||
ModelInferResponse, | ||
) | ||
from kserve.protocol.infer_type import InferRequest, InferResponse | ||
from kserve.storage import Storage | ||
|
||
logging.basicConfig(level=kserve.constants.KSERVE_LOGLEVEL) | ||
|
||
PREDICTOR_URL_FORMAT = PREDICTOR_V2_URL_FORMAT = "http://{0}/predictions/{1}" | ||
EXPLAINER_URL_FORMAT = EXPLAINER_V2_URL_FORMAT = "http://{0}/explanations/{1}" | ||
EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}" | ||
REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}" | ||
UNREGISTER_URL_FORMAT = "{0}/models/{1}" | ||
|
||
|
||
class PredictorProtocol(Enum): | ||
REST_V1 = "v1" | ||
REST_V2 = "v2" | ||
GRPC_V2 = "grpc-v2" | ||
|
||
|
||
class TorchserveModel(Model): | ||
"""The torchserve side inference and explain end-points requests are handled to | ||
return a KServe side response | ||
|
@@ -25,7 +42,15 @@ class TorchserveModel(Model): | |
side predict and explain http requests. | ||
""" | ||
|
||
def __init__(self, name, inference_address, management_address, model_dir): | ||
def __init__( | ||
self, | ||
name, | ||
inference_address, | ||
management_address, | ||
grpc_inference_address, | ||
protocol, | ||
model_dir, | ||
): | ||
"""The Model Name, Inference Address, Management Address and the model directory | ||
are specified. | ||
|
||
|
@@ -45,10 +70,74 @@ def __init__(self, name, inference_address, management_address, model_dir): | |
self.inference_address = inference_address | ||
self.management_address = management_address | ||
self.model_dir = model_dir | ||
self.protocol = protocol | ||
|
||
if self.protocol == PredictorProtocol.GRPC_V2.value: | ||
self.predictor_host = grpc_inference_address | ||
|
||
logging.info("Predict URL set to %s", self.predictor_host) | ||
self.explainer_host = self.predictor_host | ||
logging.info("Explain URL set to %s", self.explainer_host) | ||
logging.info("Protocol version is %s", self.protocol) | ||
|
||
def grpc_client(self): | ||
if self._grpc_client_stub is None: | ||
self.channel = grpc.aio.insecure_channel(self.predictor_host) | ||
self.grpc_client_stub = inference_pb2_grpc.InferenceAPIsServiceStub( | ||
self.channel | ||
) | ||
return self.grpc_client_stub | ||
|
||
async def _grpc_predict( | ||
self, | ||
payload: Union[ModelInferRequest, InferRequest], | ||
headers: Dict[str, str] = None, | ||
) -> ModelInferResponse: | ||
"""Overrides the `_grpc_predict` method in Model class. The predict method calls | ||
the `_grpc_predict` method if the self.protocol is "grpc_v2" | ||
|
||
Args: | ||
request (Dict|InferRequest|ModelInferRequest): The response passed from ``predict`` handler. | ||
|
||
Returns: | ||
Dict: Torchserve grpc response. | ||
""" | ||
payload = to_ts_grpc(payload) | ||
grpc_stub = self.grpc_client() | ||
async_result = await grpc_stub.Predictions(payload) | ||
return async_result | ||
|
||
def postprocess( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see this function is called. Could you add some comments (eg. which function call this one)? |
||
self, | ||
response: Union[Dict, InferResponse, ModelInferResponse, PredictionResponse], | ||
headers: Dict[str, str] = None, | ||
) -> Union[Dict, ModelInferResponse]: | ||
"""This method converts the v2 infer response types to gRPC or REST. | ||
For gRPC request it converts InferResponse to gRPC message or directly returns ModelInferResponse from | ||
predictor call or converts TS PredictionResponse to ModelInferResponse. | ||
For REST request it converts ModelInferResponse to Dict or directly returns from predictor call. | ||
|
||
Args: | ||
response (Dict|InferResponse|ModelInferResponse|PredictionResponse): The response passed from ``predict`` handler. | ||
headers (Dict): Request headers. | ||
|
||
Returns: | ||
Dict: post-processed response. | ||
""" | ||
if headers: | ||
if "grpc" in headers.get("user-agent", ""): | ||
if isinstance(response, ModelInferResponse): | ||
return response | ||
elif isinstance(response, InferResponse): | ||
return response.to_grpc() | ||
elif isinstance(response, PredictionResponse): | ||
return from_ts_grpc(response) | ||
if "application/json" in headers.get("content-type", ""): | ||
# If the original request is REST, convert the gRPC predict response to dict | ||
if isinstance(response, ModelInferResponse): | ||
return InferResponse.from_grpc(response).to_rest() | ||
elif isinstance(response, InferResponse): | ||
return response.to_rest() | ||
return response | ||
|
||
def load(self) -> bool: | ||
"""This method validates model availabilty in the model directory | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this function is called. Could you add some comments (eg. which function call this one)?